1   package com.atlassian.util.concurrent;
2   
3   import static org.junit.Assert.assertEquals;
4   import static org.junit.Assert.assertTrue;
5   
6   import java.util.concurrent.Callable;
7   import java.util.concurrent.TimeUnit;
8   import java.util.concurrent.atomic.AtomicBoolean;
9   import java.util.concurrent.locks.Condition;
10  import java.util.concurrent.locks.Lock;
11  import java.util.concurrent.locks.ReadWriteLock;
12  import java.util.concurrent.locks.ReentrantReadWriteLock;
13  
14  import org.junit.Test;
15  
16  public class ReadWriteManagedLockTest {
17      @Test
18      public void testSupplierReturnsValue() throws Exception {
19          final AtomicBoolean called = new AtomicBoolean();
20          final TrackedReadWriteLock lock = new TrackedReadWriteLock();
21          final ManagedLock.ReadWrite managedLock = ManagedLocks.manageReadWrite(lock);
22          assertEquals("blah", managedLock.read().withLock(new Supplier<String>() {
23              public String get() {
24                  called.set(true);
25                  return "blah";
26              }
27          }));
28          assertTrue(called.get());
29          lock.read.check();
30          called.set(false);
31          assertEquals("blah", managedLock.write().withLock(new Supplier<String>() {
32              public String get() {
33                  called.set(true);
34                  return "blah";
35              }
36          }));
37          assertTrue(called.get());
38          lock.write.check();
39      }
40  
41      @Test
42      public void testCallableReturnsValue() throws Exception {
43          final AtomicBoolean called = new AtomicBoolean();
44          final TrackedReadWriteLock lock = new TrackedReadWriteLock();
45          final ManagedLock.ReadWrite managedLock = ManagedLocks.manageReadWrite(lock);
46  
47          assertEquals("blah", managedLock.read().withLock(new Callable<String>() {
48              public String call() {
49                  called.set(true);
50                  return "blah";
51              }
52          }));
53          assertTrue(called.get());
54          lock.read.check();
55          called.set(false);
56          assertEquals("blah", managedLock.write().withLock(new Callable<String>() {
57              public String call() {
58                  called.set(true);
59                  return "blah";
60              }
61          }));
62          assertTrue(called.get());
63          lock.write.check();
64      }
65  
66      @Test
67      public void testRunnableRuns() throws Exception {
68          final AtomicBoolean called = new AtomicBoolean();
69          final TrackedReadWriteLock lock = new TrackedReadWriteLock();
70          final ManagedLock.ReadWrite managedLock = ManagedLocks.manageReadWrite(lock);
71  
72          managedLock.read().withLock(new Runnable() {
73              public void run() {
74                  called.set(true);
75              }
76          });
77          assertTrue(called.get());
78          lock.read.check();
79          called.set(false);
80          managedLock.write().withLock(new Runnable() {
81              public void run() {
82                  called.set(true);
83              }
84          });
85          assertTrue(called.get());
86          lock.write.check();
87      }
88  
89      static class TrackedReadWriteLock implements ReadWriteLock {
90          private static final long serialVersionUID = 9210941568120426704L;
91  
92          private final ReadWriteLock lock = new ReentrantReadWriteLock();
93          final TrackedLock read = new TrackedLock(lock.readLock());
94          final TrackedLock write = new TrackedLock(lock.writeLock());
95  
96          public Lock readLock() {
97              return read;
98          }
99  
100         public Lock writeLock() {
101             return write;
102         }
103     }
104 
105     static class TrackedLock implements Lock {
106         private final Lock delegate;
107 
108         boolean locked;
109         boolean unlocked;
110 
111         TrackedLock(final Lock delegate) {
112             this.delegate = delegate;
113         }
114 
115         void check() {
116             assertTrue(locked);
117             assertTrue(unlocked);
118         }
119 
120         public void lock() {
121             delegate.lock();
122             locked = true;
123         }
124 
125         public void lockInterruptibly() throws InterruptedException {
126             delegate.lockInterruptibly();
127         }
128 
129         public Condition newCondition() {
130             return delegate.newCondition();
131         }
132 
133         public boolean tryLock() {
134             return delegate.tryLock();
135         }
136 
137         public boolean tryLock(final long time, final TimeUnit unit) throws InterruptedException {
138             return delegate.tryLock(time, unit);
139         }
140 
141         public void unlock() {
142             delegate.unlock();
143             unlocked = true;
144         }
145     }
146 }