View Javadoc

1   package com.atlassian.sal.core.rdbms;
2   
3   import com.atlassian.sal.api.rdbms.ConnectionCallback;
4   import com.atlassian.sal.spi.HostConnectionAccessor;
5   import org.junit.Before;
6   import org.junit.Rule;
7   import org.junit.rules.ExpectedException;
8   import org.mockito.Mock;
9   import org.mockito.MockitoAnnotations;
10  import org.mockito.invocation.InvocationOnMock;
11  import org.mockito.stubbing.Answer;
12  
13  import java.sql.Connection;
14  
15  import static org.mockito.Matchers.any;
16  import static org.mockito.Mockito.when;
17  
18  public class TestDefaultTransactionalExecutorBase
19  {
20      @Rule
21      public ExpectedException expectedException = ExpectedException.none();
22  
23      protected DefaultTransactionalExecutor defaultTransactionalExecutor;
24  
25      @Mock
26      protected HostConnectionAccessor hostConnectionAccessor;
27      @Mock
28      protected Connection connection;
29      @Mock
30      protected ConnectionCallback<Object> callback;
31      @Mock
32      protected Object result;
33  
34      protected WrappedConnection wrappedConnection;
35  
36      protected Throwable callbackThrows;
37  
38      protected static class CallbackException extends RuntimeException {
39          public CallbackException(final String message)
40          {
41              super(message);
42          }
43      }
44  
45      @Before
46      public void before()
47      {
48          MockitoAnnotations.initMocks(this);
49  
50          defaultTransactionalExecutor = new DefaultTransactionalExecutor(hostConnectionAccessor, false, false);
51  
52          callbackThrows = null;
53  
54          // grab a hold of the wrapped connection each time
55          when(callback.execute(any(WrappedConnection.class))).thenAnswer(new Answer<Object>()
56          {
57              @Override
58              public Object answer(final InvocationOnMock invocation) throws Throwable
59              {
60                  wrappedConnection = (WrappedConnection) invocation.getArguments()[0];
61                  if (callbackThrows != null)
62                  {
63                      throw callbackThrows;
64                  }
65                  else
66                  {
67                      return result;
68                  }
69              }
70          });
71      }
72  }