1   package com.atlassian.util.concurrent;
2   
3   import static com.atlassian.util.concurrent.Util.pause;
4   import static org.junit.Assert.assertEquals;
5   import static org.junit.Assert.assertFalse;
6   import static org.junit.Assert.assertNotNull;
7   import static org.junit.Assert.assertNull;
8   import static org.junit.Assert.assertSame;
9   import static org.junit.Assert.assertTrue;
10  import static org.junit.Assert.fail;
11  
12  import org.junit.Test;
13  
14  import java.util.ArrayList;
15  import java.util.List;
16  import java.util.concurrent.Callable;
17  import java.util.concurrent.CountDownLatch;
18  import java.util.concurrent.ExecutorService;
19  import java.util.concurrent.Executors;
20  import java.util.concurrent.Future;
21  import java.util.concurrent.atomic.AtomicInteger;
22  import java.util.concurrent.atomic.AtomicReference;
23  
24  public class LazyReferenceTest {
25  
26      /**
27       * Used to pound the tests
28       * 
29       * @param args ignored
30       * @throws Exception
31       */
32      public static void main(final String[] args) throws Exception {
33          final LazyReferenceTest test = new LazyReferenceTest();
34          for (int i = 0; i < 10000; i++) {
35              //test.concurrentCreate();
36              //test.getInterruptibly();
37              test.getNotInterruptable();
38          }
39      }
40  
41      @Test public void concurrentCreate() throws Exception {
42          final int nThreads = 40;
43          final Object[] results = new Object[nThreads];
44          final AtomicInteger createCallCount = new AtomicInteger(0);
45          final LazyReference<Object> ref = new LazyReference<Object>() {
46              @Override protected Object create() {
47                  /*
48                   * We are trying to simulate an expensive object construction call. So we do a sleep
49                   * here. The idea is that we will get many threads to call create() at the same
50                   * time, make create "slow" and then ensure that create() method was indeed invoked
51                   * only once.
52                   */
53                  createCallCount.incrementAndGet();
54                  pause();
55                  pause();
56                  pause();
57                  pause();
58                  pause();
59                  return new Object();
60              }
61          };
62  
63          /*
64           * pool size must be large enough to accommodate all Callables running in parallel as they
65           * latch
66           */
67          final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
68          final CountDownLatch latch = new CountDownLatch(nThreads);
69  
70          final List<Callable<Object>> tasks = new ArrayList<Callable<Object>>(nThreads);
71  
72          for (int i = 0; i < nThreads; i++) {
73              final int j = i;
74              tasks.add(new Callable<Object>() {
75                  public Object call() {
76                      /*
77                       * Put in a latch to synchronize all threads and try to get them to call
78                       * ref.get() at the same time (to increase concurrency and make this test more
79                       * useful)
80                       */
81                      try {
82                          latch.countDown();
83                          latch.await();
84                      }
85                      catch (final InterruptedException e) {
86                          throw new RuntimeException(e);
87                      }
88                      results[j] = ref.get();
89                      return results[j];
90                  }
91              });
92          }
93  
94          List<Future<Object>> futures = null;
95          try {
96              futures = pool.invokeAll(tasks);
97          }
98          catch (final InterruptedException e) {
99              throw new RuntimeException(e);
100         }
101 
102         // Ensure the create() method was invoked once
103         assertEquals(1, createCallCount.get());
104 
105         /*
106          * Ensure that all the references are the same, use the futures in case of exception
107          */
108         final Object result = results[0];
109         for (final Future<Object> future : futures) {
110             assertSame(result, future.get());
111         }
112         for (int i = 0; i < results.length; i++) {
113             assertSame("got back different reference in '" + i + "' place", result, results[i]);
114         }
115         pool.shutdown();
116     }
117 
118     @Test public void exception() {
119         final Exception myException = new Exception();
120 
121         final LazyReference<Object> ref = new LazyReference<Object>() {
122             @Override protected Object create() throws Exception {
123                 throw myException;
124             }
125         };
126 
127         try {
128             ref.get();
129             fail("RuntimeException should have been thrown");
130         }
131         catch (final RuntimeException yay) {
132             assertNotNull(yay.getCause());
133             assertTrue(myException == yay.getCause());
134         }
135     }
136 
137     @Test public void getNotInterruptable() throws Exception {
138         final CountDownLatch latch = new CountDownLatch(1);
139 
140         final LazyReference<Integer> ref = new LazyReference<Integer>() {
141             @Override protected Integer create() {
142                 // do not interrupt
143                 while (true) {
144                     try {
145                         latch.await();
146                         return 10;
147                     }
148                     catch (final InterruptedException e) {}
149                 }
150             }
151         };
152 
153         final Thread client = new Thread(new Runnable() {
154             public void run() {
155                 ref.get();
156             }
157         }, this.getClass().getName());
158         client.start();
159 
160         for (int i = 0; i < 10; i++) {
161             pause();
162             if (ref.isInitialized()) {
163                 System.out.println(ref.get());
164             }
165             assertFalse(ref.isInitialized());
166             client.interrupt();
167         }
168         pause();
169         assertFalse(ref.isInitialized());
170 
171         latch.countDown();
172         pause();
173         assertTrue(ref.isInitialized());
174 
175         final int obj = ref.get();
176         assertEquals(10, obj);
177     }
178 
179     @Test public void getInterruptibly() throws Exception {
180         final class Result<T> {
181             final T result;
182             final Exception exception;
183 
184             Result(final T result) {
185                 this.result = result;
186                 this.exception = null;
187             }
188 
189             Result(final Exception exception) {
190                 this.result = null;
191                 this.exception = exception;
192             }
193         }
194         final CountDownLatch latch = new CountDownLatch(1);
195 
196         final LazyReference<Integer> ref = new LazyReference<Integer>() {
197             @Override protected Integer create() {
198                 // do not interrupt
199                 while (true) {
200                     try {
201                         latch.await();
202                         return 10;
203                     }
204                     catch (final InterruptedException e) {}
205                 }
206             }
207         };
208 
209         final AtomicReference<Result<Integer>> result1 = new AtomicReference<Result<Integer>>();
210         final Thread client1 = new Thread(new Runnable() {
211             public void run() {
212                 try {
213                     result1.compareAndSet(null, new Result<Integer>(ref.getInterruptibly()));
214                 }
215                 catch (final Exception e) {
216                     result1.compareAndSet(null, new Result<Integer>(e));
217                 }
218             }
219         }, this.getClass().getName());
220         client1.start();
221 
222         pause();
223         final AtomicReference<Result<Integer>> result2 = new AtomicReference<Result<Integer>>();
224         final Thread client2 = new Thread(new Runnable() {
225             public void run() {
226                 try {
227                     result2.compareAndSet(null, new Result<Integer>(ref.getInterruptibly()));
228                 }
229                 catch (final Exception e) {
230                     result2.compareAndSet(null, new Result<Integer>(e));
231                 }
232             }
233         }, this.getClass().getName());
234         client2.start();
235 
236         for (int i = 0; i < 10; i++) {
237             pause();
238             if (ref.isInitialized()) {
239                 System.out.println(ref.get());
240             }
241             assertFalse(ref.isInitialized());
242             client1.interrupt();
243             client2.interrupt();
244         }
245 
246         assertNull(result1.get());
247         assertNotNull(result2.get().exception);
248         assertEquals(InterruptedException.class, result2.get().exception.getClass());
249         pause();
250         assertFalse(ref.isInitialized());
251 
252         latch.countDown();
253         pause();
254         assertTrue(ref.isInitialized());
255 
256         {
257             final int result = ref.get();
258             assertEquals(10, result);
259         }
260         assertNotNull(result1.get());
261         assertNotNull(result1.get().result);
262         {
263             final int result = result1.get().result;
264             assertEquals(10, result);
265         }
266     }
267 }