View Javadoc

1   package com.atlassian.messagequeue.internal.sqs;
2   
3   import com.amazonaws.auth.BasicAWSCredentials;
4   import com.amazonaws.services.sqs.AmazonSQSClient;
5   import com.amazonaws.services.sqs.model.InvalidMessageContentsException;
6   import com.amazonaws.services.sqs.model.MessageNotInflightException;
7   import com.amazonaws.services.sqs.model.OverLimitException;
8   import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
9   import com.amazonaws.services.sqs.model.ReceiveMessageResult;
10  import com.amazonaws.services.sqs.model.SendMessageRequest;
11  import com.atlassian.messagequeue.Message;
12  import com.atlassian.messagequeue.MessageInformationService;
13  import com.atlassian.messagequeue.MessageRunnerKey;
14  import com.atlassian.messagequeue.MessageRunnerNotRegisteredException;
15  import com.atlassian.messagequeue.MessageRunnerServiceException;
16  import com.atlassian.messagequeue.MessageSerializationException;
17  import com.atlassian.messagequeue.internal.core.DefaultMessageInformationService;
18  import com.atlassian.messagequeue.internal.core.JacksonNestedMessageSerializer;
19  import com.atlassian.messagequeue.internal.core.NestedMessage;
20  import com.atlassian.messagequeue.internal.core.NestedMessageSerializer;
21  import com.atlassian.messagequeue.internal.core.DefaultMessageRunnerRegistryService;
22  import com.atlassian.messagequeue.registry.MessageContext;
23  import com.atlassian.messagequeue.registry.MessageRunner;
24  import com.atlassian.tenant.api.TenantContext;
25  import com.atlassian.tenant.api.TenantContextProvider;
26  import com.atlassian.tenant.impl.TenantIdSetter;
27  import com.atlassian.workcontext.api.ImmutableWorkContextReference;
28  import com.atlassian.workcontext.api.WorkContextDoorway;
29  import com.google.common.util.concurrent.Uninterruptibles;
30  import org.elasticmq.rest.sqs.SQSRestServer;
31  import org.elasticmq.rest.sqs.SQSRestServerBuilder;
32  import org.junit.After;
33  import org.junit.Assume;
34  import org.junit.Before;
35  import org.junit.Test;
36  import org.junit.runner.RunWith;
37  import org.mockito.ArgumentCaptor;
38  import org.mockito.Captor;
39  import org.mockito.Mock;
40  import org.mockito.Mockito;
41  import org.mockito.runners.MockitoJUnitRunner;
42  
43  import java.util.Collections;
44  import java.util.Optional;
45  import java.util.Queue;
46  import java.util.concurrent.BlockingQueue;
47  import java.util.concurrent.CountDownLatch;
48  import java.util.concurrent.ExecutorService;
49  import java.util.concurrent.Executors;
50  import java.util.concurrent.LinkedBlockingQueue;
51  import java.util.concurrent.TimeUnit;
52  import java.util.concurrent.atomic.AtomicInteger;
53  import java.util.concurrent.atomic.AtomicReference;
54  import java.util.function.Supplier;
55  import java.util.stream.IntStream;
56  
57  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME;
58  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_ID_ATTRIBUTE_NAME;
59  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.SENT_TIMESTAMP;
60  import static com.google.common.base.Throwables.getRootCause;
61  import static com.google.common.base.Throwables.propagate;
62  import static java.util.Objects.requireNonNull;
63  import static org.hamcrest.CoreMatchers.hasItem;
64  import static org.hamcrest.CoreMatchers.not;
65  import static org.hamcrest.MatcherAssert.assertThat;
66  import static org.hamcrest.Matchers.contains;
67  import static org.hamcrest.Matchers.containsInAnyOrder;
68  import static org.hamcrest.Matchers.hasSize;
69  import static org.hamcrest.core.Is.is;
70  import static org.hamcrest.text.IsEmptyString.isEmptyOrNullString;
71  import static org.junit.Assert.assertNull;
72  import static org.junit.Assert.assertTrue;
73  import static org.mockito.Matchers.any;
74  import static org.mockito.Matchers.anyInt;
75  import static org.mockito.Matchers.anyString;
76  import static org.mockito.Mockito.after;
77  import static org.mockito.Mockito.doCallRealMethod;
78  import static org.mockito.Mockito.doReturn;
79  import static org.mockito.Mockito.doThrow;
80  import static org.mockito.Mockito.mock;
81  import static org.mockito.Mockito.never;
82  import static org.mockito.Mockito.spy;
83  import static org.mockito.Mockito.timeout;
84  import static org.mockito.Mockito.times;
85  import static org.mockito.Mockito.verify;
86  import static org.mockito.Mockito.when;
87  
88  @RunWith(MockitoJUnitRunner.class)
89  public class SQSMessageRunnerServiceTest {
90      private static final MessageRunnerKey MESSAGE_RUNNER_KEY = MessageRunnerKey.of("testMessageRunner");
91      private static final int VISIBILITY_TIMEOUT_SECONDS = 10;
92      private static final int RECEIVE_WAIT_TIME_SECONDS = 0;
93      private static final String QUEUE_NAME = "testQueue";
94      private static final int CONCURRENT_CONSUMERS = 1;
95      private static final String TENANT_ID = "tenant-123";
96      private static final int VERIFY_TIMEOUT_MILLIS = 1000;
97      private static final int SCHEDULING_BUFFER_SECONDS = 5;
98  
99      private SQSMessageRunnerService messageRunnerService;
100     private BlockingQueue<String> payloads;
101     private SQSRestServer sqsServer;
102     private AmazonSQSClient spyingSqsClient;
103     private WorkContextDoorway workContextDoorway;
104 
105     @Mock
106     private TenantContextProvider tenantContextProvider;
107     @Mock
108     private TenantIdSetter tenantIdSetter;
109 
110     @Captor
111     private ArgumentCaptor<ReceiveMessageRequest> receiveMessageRequestArgumentCaptor;
112     @Captor
113     private ArgumentCaptor<SendMessageRequest> sendMessageRequestArgumentCaptor;
114     @Captor
115     private ArgumentCaptor<MessageContext> messageContextArgumentCaptor;
116 
117     private MessageRunner messageRunner;
118     private DefaultMessageRunnerRegistryService registryService;
119     private MessageInformationService messageInformationService;
120     private String queueUrl;
121     private NestedMessageSerializer nestedMessageSerializer;
122     private TenantContext.Builder tenantContextBuilder;
123     private SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager;
124 
125     @Before
126     public void setUp() throws Exception {
127         sqsServer = SQSRestServerBuilder.start();
128 
129         String endpoint = "http://localhost:9324";
130 
131         AmazonSQSClient sqsClient = new AmazonSQSClient(new BasicAWSCredentials("x", "x")).withEndpoint(endpoint);
132         sqsClient.createQueue(QUEUE_NAME);
133         queueUrl = sqsClient.getQueueUrl(QUEUE_NAME).getQueueUrl();
134         spyingSqsClient = spy(sqsClient);
135 
136         nestedMessageSerializer = spy(new JacksonNestedMessageSerializer());
137         registryService = spy(new DefaultMessageRunnerRegistryService());
138         messageInformationService = spy(new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer));
139         sqsMessageVisibilityTimeoutManager = new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, VISIBILITY_TIMEOUT_SECONDS, SCHEDULING_BUFFER_SECONDS, 2);
140 
141         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
142                 CONCURRENT_CONSUMERS, registryService,
143                 tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
144         payloads = new LinkedBlockingQueue<>();
145 
146         messageRunner = spy(new DefaultMessageRunner(payloads));
147         registryService.registerMessageRunner(MESSAGE_RUNNER_KEY, messageRunner);
148 
149         tenantContextBuilder = new TenantContext.Builder()
150                 .jdbcUrl("jdbcUrl")
151                 .jdbcUsername("username")
152                 .jdbcPassword("password");
153         when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId(TENANT_ID).build());
154 
155         workContextDoorway = new WorkContextDoorway();
156         workContextDoorway.open();
157     }
158 
159     @After
160     public void tearDown() throws Exception {
161         workContextDoorway.close();
162         messageRunnerService.shutdown();
163         sqsServer.stopAndWait();
164     }
165 
166     @Test
167     public void messagePayloadIsPassedToMessageRunner() throws Exception {
168         String expectedPayload = "payload";
169 
170         messageRunnerService.initialiseMessageConsumers();
171         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
172 
173         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
174         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of(expectedPayload)));
175     }
176 
177     @Test
178     public void messageIsAutoAcknowledgedAfterSuccessfulProcessing() throws Exception {
179         String expectedPayload = "payload";
180 
181         messageRunnerService.initialiseMessageConsumers();
182         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
183 
184         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
185     }
186 
187     @Test
188     public void addMessageWithEmptyPayload() throws Exception {
189         messageRunnerService.initialiseMessageConsumers();
190         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, ""));
191 
192         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
193         assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
194         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
195         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of("")));
196     }
197 
198     @Test
199     public void addMessageWithNullPayload() throws Exception {
200         messageRunnerService.initialiseMessageConsumers();
201         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, null));
202 
203         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
204         assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
205         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
206         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.empty()));
207     }
208 
209     @Test
210     public void addMultipleMessagesSerially() throws Exception {
211         messageRunnerService.initialiseMessageConsumers();
212         int numberOfMessages = 20;
213         Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
214 
215         range.get().forEach(i -> {
216             messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
217         });
218 
219         String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
220         verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
221         assertThat(payloads, containsInAnyOrder(expectedPayloads));
222     }
223 
224     @Test
225     public void addMultipleMessagesConcurrently() throws Exception {
226         messageRunnerService.initialiseMessageConsumers();
227         int numberOfMessages = 20;
228         Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
229 
230         ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
231         try {
232             range.get().forEach(i -> {
233                 pool.execute(() -> {
234                     messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
235                 });
236             });
237 
238             String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
239             verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
240             assertThat(payloads, containsInAnyOrder(expectedPayloads));
241         } finally {
242             pool.shutdown();
243             pool.awaitTermination(1, TimeUnit.SECONDS);
244         }
245     }
246 
247     @Test
248     public void messageNotDeletedIfRunningItThrowsAnException() throws Exception {
249         MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("failingRunner");
250         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
251             throw new RuntimeException();
252         });
253 
254         messageRunnerService.initialiseMessageConsumers();
255         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
256 
257         verify(spyingSqsClient, after(1000).never()).deleteMessage(anyString(), anyString());
258     }
259 
260     @Test
261     public void receiveMessageRequestPerformedUsingConfiguredVisibilityTimeoutAndSchedulingBuffer() throws Exception {
262         messageRunnerService.initialiseMessageConsumers();
263         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
264 
265         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
266         assertThat(receiveMessageRequestArgumentCaptor.getValue().getVisibilityTimeout(), is(VISIBILITY_TIMEOUT_SECONDS + SCHEDULING_BUFFER_SECONDS));
267     }
268 
269     @Test
270     public void sentTimestampAttributeIsRequestedToAllowMessageLatencyToBeDerived() throws Exception {
271         messageRunnerService.initialiseMessageConsumers();
272         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
273 
274         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
275         assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), contains(SENT_TIMESTAMP));
276     }
277 
278     @Test
279     public void consumerResumesPollingIfNoMessagesReturnedInOneInvocationOfReceiveMessage() throws Exception {
280         Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
281         ReceiveMessageResult emptyReceiveMessageResult = mock(ReceiveMessageResult.class);
282         when(emptyReceiveMessageResult.getMessages()).thenReturn(Collections.emptyList());
283         doReturn(emptyReceiveMessageResult).doCallRealMethod().when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
284 
285         messageRunnerService.initialiseMessageConsumers();
286         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
287 
288         verify(spyingSqsClient, after(1000).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
289     }
290 
291     @Test(expected = InvalidMessageContentsException.class)
292     public void addMessageWithPayloadContainingIllegalCharacters() throws Exception {
293         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
294         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
295         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "test", registryService,
296                 tenantIdSetter, messageInformationService, nestedMessageSerializer);
297 
298         when(amazonSQSClient.sendMessage(any(SendMessageRequest.class))).thenThrow(InvalidMessageContentsException.class);
299 
300         try {
301             messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload with invalid characters"));
302         } catch (MessageRunnerServiceException e) {
303             propagate(getRootCause(e));
304         }
305     }
306 
307     @Test(expected = IllegalStateException.class)
308     public void addMessageFailsWithExceptionIfTenantContextNotPresent() throws Exception {
309         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
310         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class);
311         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
312         MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
313         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
314                 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
315         when(tenantContextProvider.getTenantContext()).thenReturn(null);
316 
317         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
318     }
319 
320     @Test(expected = IllegalStateException.class)
321     public void addMessageFailsWithNullTenantId() throws Exception {
322         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
323         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
324         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
325         MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
326         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
327                 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
328 
329         when(tenantContextProvider.getTenantContext().getTenantId()).thenReturn(null);
330         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
331     }
332 
333     @Test(expected = IllegalStateException.class)
334     public void addMessageFailsWithEmptyTenantId() throws Exception {
335         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
336         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
337         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
338         MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
339         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
340                 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
341 
342         when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId("").build());
343         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
344     }
345 
346     @Test
347     public void consumerPoolCanRecoverFromOverLimitException() throws Exception {
348         // setup sqsClient such that the first invocation of sqsClient.receiveMessage(...) throws an OverLimitException
349         // setup a subsequent invocation to proceed normally
350         doThrow(OverLimitException.class)
351                 .doCallRealMethod()
352                     .when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
353 
354         messageRunnerService.initialiseMessageConsumers();
355         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
356 
357         // ensure message is eventually processed after we recover from the OverLimitException
358         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(any(MessageContext.class));
359         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString()); // wait until MessageRunner.processMessage(...) invocation has completed to reduce the flakiness of the following assertions
360         assertThat(payloads, hasSize(1));
361         assertThat(payloads, hasItem("1"));
362     }
363 
364     @Test
365     public void consumerPoolCanRecoverFromError() throws Exception {
366         final int visibilityTimeoutSeconds = 1; // use a short visibility timeout here to enable message redelivery after failure
367         final int schedulingBufferSeconds = 1;
368         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
369                 CONCURRENT_CONSUMERS, registryService,
370                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
371                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, schedulingBufferSeconds, 1));
372 
373         // setup messageRunner such that the first invocation of messageRunner.processMessage(...) throws a LinkageError
374         // setup a subsequent invocation to proceed normally
375         doThrow(LinkageError.class)
376                 .doCallRealMethod()
377                     .when(messageRunner).processMessage(any(MessageContext.class));
378 
379         messageRunnerService.initialiseMessageConsumers();
380         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
381 
382         verify(spyingSqsClient, timeout(2500)).deleteMessage(anyString(), anyString());
383         assertThat(payloads, hasSize(1));
384         assertThat(payloads, hasItem("1"));
385     }
386 
387     @Test
388     public void ensureWorkContextAvailableForMessageProcessing() throws Exception {
389         CountDownLatch latch = new CountDownLatch(1);
390         MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("workContext");
391         AtomicReference<Exception> exception = new AtomicReference<>();
392         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
393             try {
394                 new ImmutableWorkContextReference<>(() -> "foobar").get();
395             } catch (IllegalStateException e) {
396                 exception.set(e);
397             } finally {
398                 latch.countDown();
399             }
400         });
401 
402         messageRunnerService.initialiseMessageConsumers();
403         messageRunnerService.addMessage(Message.create(messageRunnerKey, "1"));
404 
405         assertTrue("test did not complete within timeout", latch.await(1, TimeUnit.SECONDS));
406         assertNull("No work context available for message processing.", exception.get());
407     }
408 
409     @Test
410     public void ensureTenantIdIsSentByProducer() throws Exception {
411         messageRunnerService.initialiseMessageConsumers();
412         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
413 
414         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
415         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
416 
417         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
418         assertThat(nestedMessage.getAttribute(TENANT_ID_ATTRIBUTE_NAME), is(TENANT_ID));
419     }
420 
421     @Test
422     public void ensureMessageRunnerKeyIsSentByProducer() throws Exception {
423         messageRunnerService.initialiseMessageConsumers();
424         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
425 
426         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
427         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
428 
429         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
430         assertThat(nestedMessage.getAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME), is(MESSAGE_RUNNER_KEY.toString()));
431     }
432 
433     @Test
434     public void ensurePayloadIsSentByProducer() throws Exception {
435         String payload = "payload";
436         messageRunnerService.initialiseMessageConsumers();
437         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, payload));
438 
439         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
440         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
441 
442         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
443         assertThat(nestedMessage.getPayload(), is(payload));
444     }
445 
446     @Test
447     public void dontProcessAndDeleteMessagesThatCantBeDeserialized() throws Exception {
448         doThrow(MessageSerializationException.class).when(nestedMessageSerializer).deserialize(any(String.class));
449 
450         messageRunnerService.initialiseMessageConsumers();
451         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
452 
453         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
454         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
455     }
456 
457     @Test
458     public void deleteMessageWithoutTenantId() throws Exception {
459         doReturn(new NestedMessage()).when(nestedMessageSerializer).deserialize(anyString());
460 
461         messageRunnerService.initialiseMessageConsumers();
462         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
463 
464         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
465         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
466     }
467 
468     @Test
469     public void deleteMessageWithoutMessageRunnerKey() throws Exception {
470         NestedMessage nestedMessage = new NestedMessage().addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123");
471         doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
472 
473         messageRunnerService.initialiseMessageConsumers();
474         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
475 
476         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
477         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
478     }
479 
480     @Test
481     public void dontProcessMessageWithoutMessageRunner() throws Exception {
482         NestedMessage nestedMessage = new NestedMessage()
483                 .addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123")
484                 .addAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME, MESSAGE_RUNNER_KEY.toString());
485         doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
486         doCallRealMethod()
487                 .doReturn(Optional.empty())
488                 .when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
489 
490         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
491         messageRunnerService.initialiseMessageConsumers();
492 
493         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
494         verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
495     }
496 
497     @Test(expected = MessageSerializationException.class)
498     public void failFastIfMessageCantBeSerialized() throws Exception {
499         doThrow(MessageSerializationException.class).when(messageInformationService).toPayload(any(Message.class));
500 
501         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
502     }
503 
504     @Test
505     public void ensureTenantContextEstablishedOnConsumer() throws Exception {
506         messageRunnerService.initialiseMessageConsumers();
507         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
508 
509         verify(tenantIdSetter, timeout(VERIFY_TIMEOUT_MILLIS)).setTenantId(TENANT_ID);
510     }
511 
512     @Test
513     public void messageNotProcessedOrDeletedIfTenantContextEstablishmentThrowsException() throws Exception {
514         Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
515         doThrow(new RuntimeException("Error creating TenantContext from a tenantId (perhaps tenant context service is currently unavailable?)")).when(tenantIdSetter).setTenantId(anyString());
516 
517         messageRunnerService.initialiseMessageConsumers();
518         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
519         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
520         verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
521         verify(messageRunner, never()).processMessage(any(MessageContext.class));
522     }
523 
524     @Test(expected = IllegalArgumentException.class)
525     public void concurrentConsumersCantBeNegative() throws Exception {
526         int concurrentConsumers = -1;
527 
528         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
529                 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
530     }
531 
532     @Test(expected = IllegalArgumentException.class)
533     public void concurrentConsumersCantBeZero() throws Exception {
534         int concurrentConsumers = 0;
535 
536         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
537                 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
538     }
539 
540     @Test(expected = IllegalArgumentException.class)
541     public void concurrentConsumersHasUpperBound() throws Exception {
542         int concurrentConsumers = SQSMessageRunnerService.MAX_CONCURRENT_CONSUMERS + 1;
543 
544         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
545                 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
546     }
547 
548     @Test
549     public void supportEarlyAcknowledgementOfMessage() throws Exception {
550         AtomicInteger deliveryCount = new AtomicInteger(0);
551         int visibilityTimeoutSeconds = 1; // use a short visibility timeout here to encourage message redelivery (and therefore have a chance to prevent it)
552         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
553                 CONCURRENT_CONSUMERS, registryService,
554                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
555                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl));
556         MessageRunnerKey fooMessageRunnerKey = MessageRunnerKey.of("fooMessageRunnerKey");
557         registryService.registerMessageRunner(fooMessageRunnerKey, messageContext -> {
558             deliveryCount.incrementAndGet();
559             messageContext.acknowledge(); // acknowledge message at start of message runner before any processing begins
560             throw new RuntimeException("message runner error");
561         });
562 
563         messageRunnerService.initialiseMessageConsumers();
564         messageRunnerService.addMessage(Message.create(fooMessageRunnerKey, "payload"));
565 
566         Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
567         assertThat(deliveryCount.get(), is(1));
568     }
569 
570     @Test(expected = MessageRunnerNotRegisteredException.class)
571     public void failFastWhenAddingMessageWithNoRegisteredMessageRunner() throws Exception {
572         doReturn(Optional.empty()).when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
573 
574         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
575     }
576 
577     @Test
578     public void messageRunnerResponsiveToCancellationViaShutdown() throws Exception {
579         final CountDownLatch enterLatch = new CountDownLatch(1);
580         final CountDownLatch exitLatch = new CountDownLatch(1);
581         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
582         final int sleepMillis = 100;
583         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
584             enterLatch.countDown();
585             while (!messageContext.isCancellationRequested()) {
586                 Uninterruptibles.sleepUninterruptibly(sleepMillis, TimeUnit.MILLISECONDS);
587             }
588             exitLatch.countDown();
589         });
590         messageRunnerService.initialiseMessageConsumers();
591         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
592         assertThat("MessageRunner did not start. Cannot assert cancellation on a MessageRunner that did not start",
593                 enterLatch.await(1, TimeUnit.SECONDS), is(true));
594 
595         messageRunnerService.shutdown();
596 
597         assertThat("MessageRunner did not respond to cancellation", exitLatch.await(sleepMillis * 2, TimeUnit.MILLISECONDS), is(true));
598     }
599 
600     @Test
601     public void cancelAutoAcknowledgement() throws Exception {
602         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("cancelAckMessageRunnerKey");
603         registryService.registerMessageRunner(messageRunnerKey, MessageContext::cancelAutoAcknowledgementOfMessage);
604 
605         messageRunnerService.initialiseMessageConsumers();
606         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
607 
608         verify(spyingSqsClient, after(VERIFY_TIMEOUT_MILLIS).never()).deleteMessage(anyString(), anyString());
609     }
610 
611     @Test
612     public void autoExtensionOfVisibilityTimeout() throws Exception {
613         final int visibilityTimeoutSeconds = 1;
614 
615         // use more than one consumer to allow the a particular message to be processed concurrently
616         // this should cause this test to fail if the visibility timeout extension code is not working properly
617         final int concurrentConsumers = 2;
618         final int longRunningTaskDurationMillis = 2500;
619 
620         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
621         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
622             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
623         }));
624         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
625 
626         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
627                 concurrentConsumers, registryService,
628                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
629                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
630 
631         messageRunnerService.initialiseMessageConsumers();
632         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
633 
634         verify(longRunningMessageRunner, after(longRunningTaskDurationMillis + 500).times(1)).processMessage(any(MessageContext.class));
635         verify(spyingSqsClient, times(2)).changeMessageVisibility(anyString(), anyString(), anyInt());
636     }
637 
638     /**
639      * When the visibility timeout of a message expires before we've managed to extend it, then
640      * the ChangeMessageVisibility request will throw a MessageNotInflightException.
641      *
642      * <p>Also, since the visibility timeout has expired, we expect another MessageRunner to start executing the message
643      *
644      * <p>We should log this incident appropriately and tune the scheduling buffer time to be big enough to make this very unlikely.
645      */
646     @Test
647     public void autoExtensionOfVisibilityTimeoutFailsBecauseMessageNoLongerInFlight() throws Exception {
648         final int visibilityTimeoutSeconds = 1;
649 
650         // use more than one consumer to allow the a particular message to be processed concurrently
651         // this should cause this test to fail if the visibility timeout extension code is not working properly
652         final int concurrentConsumers = 2;
653         final int longRunningTaskDurationMillis = 2500;
654         final int schedulingBufferSeconds = 1;
655 
656         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
657         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
658             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
659         }));
660         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
661 
662         doThrow(MessageNotInflightException.class).when(spyingSqsClient).changeMessageVisibility(anyString(), anyString(), anyInt());
663 
664         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
665                 concurrentConsumers, registryService,
666                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
667                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, schedulingBufferSeconds, 2));
668 
669         messageRunnerService.initialiseMessageConsumers();
670         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
671 
672         verify(longRunningMessageRunner, after(3000).times(2)).processMessage(any(MessageContext.class));
673     }
674 
675     @Test
676     public void autoExtensionOfVisibilityTimeoutCancelledOnMessageProcessingException() throws Exception {
677         final int visibilityTimeoutSeconds = 1;
678         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("erroneousMessageRunnerKey");
679         final MessageRunner erroneousMessageRunner = mock(MessageRunner.class);
680         doThrow(RuntimeException.class).when(erroneousMessageRunner).processMessage(any(MessageContext.class));
681         registryService.registerMessageRunner(messageRunnerKey, erroneousMessageRunner);
682 
683         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
684                 CONCURRENT_CONSUMERS, registryService,
685                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
686                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
687 
688         messageRunnerService.initialiseMessageConsumers();
689         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
690 
691         // the first visibility timeout extension occurs after the visibilityTimeout expires (that is 1 second in this test)
692         // wait a little more than this time to allow an extension to happen that would fail this test if this test case is not addressed in the code
693         final int waitMillis = visibilityTimeoutSeconds * 2 * 1000;
694         verify(spyingSqsClient, after(waitMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
695     }
696 
697     @Test
698     public void autoExtensionOfVisibilityTimeoutCancelledOnEarlyAcknowledgement() throws Exception {
699         final int visibilityTimeoutSeconds = 1;
700         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
701         final int longRunningTaskDurationMillis = 1100;
702         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
703             // acknowledge early in the message runner (we should cancel any future extensions of the visibility timeout
704             // after this since it does not make sense to extend the visibility timeout of a deleted message)
705             context.acknowledge();
706             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS); // need the original 10 seconds plus two 10 second extensions to complete
707         }));
708         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
709 
710         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
711                 CONCURRENT_CONSUMERS, registryService,
712                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
713                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
714 
715         messageRunnerService.initialiseMessageConsumers();
716         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
717 
718         verify(spyingSqsClient, after(longRunningTaskDurationMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
719     }
720 
721     private static TenantContext newTenantContext(String tenantId) {
722         TenantContext tenantContext = mock(TenantContext.class);
723         when(tenantContext.getTenantId()).thenReturn(tenantId);
724 
725         return tenantContext;
726     }
727 
728     private static class DelegatingMessageRunner implements MessageRunner {
729         private final MessageRunner delegate;
730 
731         public DelegatingMessageRunner(MessageRunner delegate) {
732             this.delegate = delegate;
733         }
734 
735         @Override
736         public void processMessage(MessageContext context) {
737             delegate.processMessage(context);
738         }
739     }
740 
741     private static class DefaultMessageRunner implements MessageRunner {
742         private final Queue<String> payloads;
743 
744         public DefaultMessageRunner(Queue<String> payloads) {
745             this.payloads = requireNonNull(payloads);
746         }
747 
748         @Override
749         public void processMessage(MessageContext messageContext) {
750             payloads.offer(messageContext.getPayload().orElse(null));
751         }
752     }
753 }