View Javadoc

1   package com.atlassian.messagequeue.internal.sqs;
2   
3   import com.amazonaws.services.sqs.AmazonSQS;
4   import com.amazonaws.services.sqs.model.InvalidMessageContentsException;
5   import com.amazonaws.services.sqs.model.MessageNotInflightException;
6   import com.amazonaws.services.sqs.model.OverLimitException;
7   import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
8   import com.amazonaws.services.sqs.model.ReceiveMessageResult;
9   import com.amazonaws.services.sqs.model.SendMessageRequest;
10  import com.atlassian.messagequeue.Message;
11  import com.atlassian.messagequeue.MessageInformationService;
12  import com.atlassian.messagequeue.MessageRunnerKey;
13  import com.atlassian.messagequeue.MessageRunnerNotRegisteredException;
14  import com.atlassian.messagequeue.MessageRunnerServiceException;
15  import com.atlassian.messagequeue.MessageSerializationException;
16  import com.atlassian.messagequeue.internal.core.NestedMessage;
17  import com.atlassian.messagequeue.internal.core.messagevalidators.MessageTenantDataIdValidator;
18  import com.atlassian.messagequeue.registry.MessageContext;
19  import com.atlassian.messagequeue.registry.MessageRunner;
20  import com.atlassian.messagequeue.registry.MessageValidator;
21  import com.atlassian.tenant.api.TenantContextProvider;
22  import com.atlassian.tenant.impl.TenantIdSetter;
23  import com.atlassian.workcontext.api.ImmutableWorkContextReference;
24  import com.google.common.util.concurrent.Uninterruptibles;
25  import org.junit.Assume;
26  import org.junit.Test;
27  import org.junit.runner.RunWith;
28  import org.mockito.Mockito;
29  import org.mockito.runners.MockitoJUnitRunner;
30  
31  import java.util.Collections;
32  import java.util.Optional;
33  import java.util.concurrent.CountDownLatch;
34  import java.util.concurrent.ExecutorService;
35  import java.util.concurrent.Executors;
36  import java.util.concurrent.TimeUnit;
37  import java.util.concurrent.atomic.AtomicInteger;
38  import java.util.concurrent.atomic.AtomicReference;
39  import java.util.function.Supplier;
40  import java.util.stream.IntStream;
41  
42  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME;
43  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_DATA_ID_ATTRIBUTE_NAME;
44  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_ID_ATTRIBUTE_NAME;
45  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.APPROXIMATE_RECEIVE_COUNT;
46  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.SENT_TIMESTAMP;
47  import static com.google.common.base.Throwables.getRootCause;
48  import static com.google.common.base.Throwables.propagate;
49  import static org.hamcrest.CoreMatchers.hasItem;
50  import static org.hamcrest.CoreMatchers.not;
51  import static org.hamcrest.MatcherAssert.assertThat;
52  import static org.hamcrest.Matchers.containsInAnyOrder;
53  import static org.hamcrest.Matchers.hasSize;
54  import static org.hamcrest.core.Is.is;
55  import static org.hamcrest.text.IsEmptyString.isEmptyOrNullString;
56  import static org.junit.Assert.assertNull;
57  import static org.junit.Assert.assertTrue;
58  import static org.mockito.Matchers.any;
59  import static org.mockito.Matchers.anyInt;
60  import static org.mockito.Matchers.anyString;
61  import static org.mockito.Mockito.after;
62  import static org.mockito.Mockito.doCallRealMethod;
63  import static org.mockito.Mockito.doReturn;
64  import static org.mockito.Mockito.doThrow;
65  import static org.mockito.Mockito.mock;
66  import static org.mockito.Mockito.never;
67  import static org.mockito.Mockito.spy;
68  import static org.mockito.Mockito.timeout;
69  import static org.mockito.Mockito.times;
70  import static org.mockito.Mockito.verify;
71  import static org.mockito.Mockito.when;
72  
73  @RunWith(MockitoJUnitRunner.class)
74  public class SQSMessageRunnerServiceTest extends AbstractSQSMessageRunnerServiceTest {
75  
76      @Test
77      public void messagePayloadIsPassedToMessageRunner() throws Exception {
78          String expectedPayload = "payload";
79  
80          messageRunnerService.initialiseMessageConsumers();
81          messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
82  
83          verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
84          assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of(expectedPayload)));
85      }
86  
87      @Test
88      public void messageIsAutoAcknowledgedAfterSuccessfulProcessing() throws Exception {
89          String expectedPayload = "payload";
90  
91          messageRunnerService.initialiseMessageConsumers();
92          messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
93  
94          verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
95      }
96  
97      @Test
98      public void addMessageWithEmptyPayload() throws Exception {
99          messageRunnerService.initialiseMessageConsumers();
100         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, ""));
101 
102         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
103         assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
104         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
105         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of("")));
106     }
107 
108     @Test
109     public void addMessageWithNullPayload() throws Exception {
110         messageRunnerService.initialiseMessageConsumers();
111         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, null));
112 
113         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
114         assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
115         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
116         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.empty()));
117     }
118 
119     @Test
120     public void addMultipleMessagesSerially() throws Exception {
121         messageRunnerService.initialiseMessageConsumers();
122         int numberOfMessages = 20;
123         Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
124 
125         range.get().forEach(i -> {
126             messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
127         });
128 
129         String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
130         verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
131         assertThat(payloads, containsInAnyOrder(expectedPayloads));
132     }
133 
134     @Test
135     public void addMultipleMessagesConcurrently() throws Exception {
136         messageRunnerService.initialiseMessageConsumers();
137         int numberOfMessages = 20;
138         Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
139 
140         ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
141         try {
142             range.get().forEach(i -> {
143                 pool.execute(() -> {
144                     messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
145                 });
146             });
147 
148             String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
149             verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
150             assertThat(payloads, containsInAnyOrder(expectedPayloads));
151         } finally {
152             pool.shutdown();
153             pool.awaitTermination(1, TimeUnit.SECONDS);
154         }
155     }
156 
157     @Test
158     public void messageNotDeletedIfRunningItThrowsAnException() throws Exception {
159         MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("failingRunner");
160         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
161             throw new RuntimeException();
162         });
163 
164         messageRunnerService.initialiseMessageConsumers();
165         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
166 
167         verify(spyingSqsClient, after(1000).never()).deleteMessage(anyString(), anyString());
168     }
169 
170     @Test
171     public void receiveMessageRequestPerformedUsingConfiguredVisibilityTimeoutAndSchedulingBuffer() throws Exception {
172         messageRunnerService.initialiseMessageConsumers();
173         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
174 
175         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
176         assertThat(receiveMessageRequestArgumentCaptor.getValue().getVisibilityTimeout(), is(25));
177     }
178 
179     @Test
180     public void sentTimestampAttributeIsRequestedToAllowMessageLatencyToBeDerived() throws Exception {
181         messageRunnerService.initialiseMessageConsumers();
182         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
183 
184         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
185         assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), hasItem(SENT_TIMESTAMP));
186     }
187 
188     @Test
189     public void approximateReceiveCountAttributeIsRequested() throws Exception {
190         messageRunnerService.initialiseMessageConsumers();
191         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
192 
193         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
194         assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), hasItem(APPROXIMATE_RECEIVE_COUNT));
195     }
196 
197     @Test
198     public void consumerResumesPollingIfNoMessagesReturnedInOneInvocationOfReceiveMessage() throws Exception {
199         Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
200         ReceiveMessageResult emptyReceiveMessageResult = mock(ReceiveMessageResult.class);
201         when(emptyReceiveMessageResult.getMessages()).thenReturn(Collections.emptyList());
202         doReturn(emptyReceiveMessageResult).doCallRealMethod().when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
203 
204         messageRunnerService.initialiseMessageConsumers();
205         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
206 
207         verify(spyingSqsClient, after(1000).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
208     }
209 
210     @Test(expected = InvalidMessageContentsException.class)
211     public void addMessageWithPayloadContainingIllegalCharacters() throws Exception {
212         AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
213         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
214         SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
215                 .withAmazonSQSClient(amazonSQSClient)
216                 .withMessageRunnerRegistryHelper(registryService)
217                 .withTenantIdSetter(tenantIdSetter)
218                 .withMessageInformationService(messageInformationService)
219                 .withNestedMessageSerializer(nestedMessageSerializer)
220                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
221                 .build();
222 
223         when(amazonSQSClient.sendMessage(any(SendMessageRequest.class))).thenThrow(InvalidMessageContentsException.class);
224 
225         try {
226             messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload with invalid characters"));
227         } catch (MessageRunnerServiceException e) {
228             propagate(getRootCause(e));
229         }
230     }
231 
232     @Test(expected = IllegalStateException.class)
233     public void addMessageFailsWithExceptionIfTenantContextNotPresent() throws Exception {
234         AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
235         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class);
236         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
237         MessageInformationService messageInformationService = new SQSMessageInformationService(
238                 getDefaultMessageRunnerKeyToProducerMapper(),
239                 tenantContextProvider,
240                 nestedMessageSerializer,
241                 tenantDataIdSupplier);
242         SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
243                 .withAmazonSQSClient(amazonSQSClient)
244                 .withMessageRunnerRegistryHelper(registryService)
245                 .withTenantIdSetter(tenantIdSetter)
246                 .withMessageInformationService(messageInformationService)
247                 .withNestedMessageSerializer(nestedMessageSerializer)
248                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
249                 .build();
250         when(tenantContextProvider.getTenantContext()).thenReturn(null);
251 
252         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
253     }
254 
255     @Test(expected = IllegalStateException.class)
256     public void addMessageFailsWithNullTenantId() throws Exception {
257         AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
258         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
259         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
260         MessageInformationService messageInformationService = new SQSMessageInformationService(
261                 getDefaultMessageRunnerKeyToProducerMapper(),
262                 tenantContextProvider,
263                 nestedMessageSerializer,
264                 tenantDataIdSupplier);
265         SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
266                 .withAmazonSQSClient(amazonSQSClient)
267                 .withMessageRunnerRegistryHelper(registryService)
268                 .withTenantIdSetter(tenantIdSetter)
269                 .withMessageInformationService(messageInformationService)
270                 .withNestedMessageSerializer(nestedMessageSerializer)
271                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
272                 .build();
273 
274         when(tenantContextProvider.getTenantContext().getTenantId()).thenReturn(null);
275         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
276     }
277 
278     @Test(expected = IllegalStateException.class)
279     public void addMessageFailsWithEmptyTenantId() throws Exception {
280         AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
281         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
282         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
283         MessageInformationService messageInformationService = new SQSMessageInformationService(
284                 getDefaultMessageRunnerKeyToProducerMapper(),
285                 tenantContextProvider,
286                 nestedMessageSerializer,
287                 tenantDataIdSupplier);
288         SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
289                 .withAmazonSQSClient(amazonSQSClient)
290                 .withMessageRunnerRegistryHelper(registryService)
291                 .withTenantIdSetter(tenantIdSetter)
292                 .withMessageInformationService(messageInformationService)
293                 .withNestedMessageSerializer(nestedMessageSerializer)
294                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
295                 .build();
296 
297         when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId("").build());
298         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
299     }
300 
301     @Test
302     public void consumerPoolCanRecoverFromOverLimitException() throws Exception {
303         // setup sqsClient such that the first invocation of sqsClient.receiveMessage(...) throws an OverLimitException
304         // setup a subsequent invocation to proceed normally
305         doThrow(OverLimitException.class)
306                 .doCallRealMethod()
307                     .when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
308 
309         messageRunnerService.initialiseMessageConsumers();
310         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
311 
312         // ensure message is eventually processed after we recover from the OverLimitException
313         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(any(MessageContext.class));
314         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString()); // wait until MessageRunner.processMessage(...) invocation has completed to reduce the flakiness of the following assertions
315         assertThat(payloads, hasSize(1));
316         assertThat(payloads, hasItem("1"));
317     }
318 
319     @Test
320     public void consumerPoolCanRecoverFromError() throws Exception {
321         final int visibilityTimeoutSeconds = 1; // use a short visibility timeout here to enable message redelivery after failure
322         final int schedulingBufferSeconds = 1;
323         messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
324                 .withAmazonSQSClient(spyingSqsClient)
325                 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
326                 .withMessageRunnerRegistryHelper(registryService)
327                 .withTenantIdSetter(tenantIdSetter)
328                 .withMessageInformationService(messageInformationService)
329                 .withNestedMessageSerializer(nestedMessageSerializer)
330                 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
331                         spyingSqsClient, 1, schedulingBufferSeconds))
332                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
333                 .build();
334 
335         // setup messageRunner such that the first invocation of messageRunner.processMessage(...) throws a LinkageError
336         // setup a subsequent invocation to proceed normally
337         doThrow(LinkageError.class)
338                 .doCallRealMethod()
339                     .when(messageRunner).processMessage(any(MessageContext.class));
340 
341         messageRunnerService.initialiseMessageConsumers();
342         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
343 
344         verify(spyingSqsClient, timeout(10000)).deleteMessage(anyString(), anyString());
345         assertThat(payloads, hasSize(1));
346         assertThat(payloads, hasItem("1"));
347     }
348 
349     @Test
350     public void ensureWorkContextAvailableForMessageProcessing() throws Exception {
351         CountDownLatch latch = new CountDownLatch(1);
352         MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("workContext");
353         AtomicReference<Exception> exception = new AtomicReference<>();
354         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
355             try {
356                 new ImmutableWorkContextReference<>(() -> "foobar").get();
357             } catch (IllegalStateException e) {
358                 exception.set(e);
359             } finally {
360                 latch.countDown();
361             }
362         });
363 
364         messageRunnerService.initialiseMessageConsumers();
365         messageRunnerService.addMessage(Message.create(messageRunnerKey, "1"));
366 
367         assertTrue("test did not complete within timeout", latch.await(1, TimeUnit.SECONDS));
368         assertNull("No work context available for message processing.", exception.get());
369     }
370 
371     @Test
372     public void ensureTenantIdIsSentByProducer() throws Exception {
373         messageRunnerService.initialiseMessageConsumers();
374         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
375 
376         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
377         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
378 
379         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
380         assertThat(nestedMessage.getAttribute(TENANT_ID_ATTRIBUTE_NAME), is(TENANT_ID));
381     }
382 
383     @Test
384     public void ensureMessageRunnerKeyIsSentByProducer() throws Exception {
385         messageRunnerService.initialiseMessageConsumers();
386         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
387 
388         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
389         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
390 
391         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
392         assertThat(nestedMessage.getAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME), is(MESSAGE_RUNNER_KEY.toString()));
393     }
394 
395     @Test
396     public void ensureTenantDataIdIsSentByProducer() throws Exception {
397         messageRunnerService.initialiseMessageConsumers();
398 
399         tenantDataId = "New Id";
400 
401         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
402 
403         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
404         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
405 
406         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
407         assertThat(nestedMessage.getAttribute(TENANT_DATA_ID_ATTRIBUTE_NAME), is(tenantDataId));
408     }
409 
410     @Test
411     public void ensurePayloadIsSentByProducer() throws Exception {
412         String payload = "payload";
413         messageRunnerService.initialiseMessageConsumers();
414         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, payload));
415 
416         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
417         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
418 
419         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
420         assertThat(nestedMessage.getPayload(), is(payload));
421     }
422 
423     @Test
424     public void dontProcessAndDeleteMessagesThatCantBeDeserialized() throws Exception {
425         doThrow(MessageSerializationException.class).when(nestedMessageSerializer).deserialize(any(String.class));
426 
427         messageRunnerService.initialiseMessageConsumers();
428         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
429 
430         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
431         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
432     }
433 
434     @Test
435     public void deleteMessageWithoutTenantId() throws Exception {
436         doReturn(new NestedMessage()).when(nestedMessageSerializer).deserialize(anyString());
437 
438         messageRunnerService.initialiseMessageConsumers();
439         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
440 
441         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
442         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
443     }
444 
445     @Test
446     public void doNotProcessInvalidMessageWithValidator() throws Exception {
447         final MessageValidator messageTenantDataIdValidator = new MessageTenantDataIdValidator(tenantDataIdSupplier);
448         messageValidatorRegistry.registerMessageValidator(MessageTenantDataIdValidator.KEY, messageTenantDataIdValidator);
449 
450         tenantDataId = "old";
451 
452         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
453 
454         tenantDataId = "new";
455 
456         messageRunnerService.initialiseMessageConsumers();
457 
458         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
459     }
460 
461     @Test
462     public void doProcessValidMessageWithValidator() throws Exception {
463         final MessageValidator messageTenantDataIdValidator = new MessageTenantDataIdValidator(tenantDataIdSupplier);
464         messageValidatorRegistry.registerMessageValidator(MessageTenantDataIdValidator.KEY, messageTenantDataIdValidator);
465 
466         tenantDataId = "old";
467 
468         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
469         messageRunnerService.initialiseMessageConsumers();
470 
471         verify(messageRunner, after(1000).times(1)).processMessage(any(MessageContext.class));
472     }
473 
474     @Test
475     public void deleteMessageWithoutMessageRunnerKey() throws Exception {
476         NestedMessage nestedMessage = new NestedMessage().addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123");
477         doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
478 
479         messageRunnerService.initialiseMessageConsumers();
480         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
481 
482         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
483         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
484     }
485 
486     @Test
487     public void dontProcessMessageWithoutMessageRunner() throws Exception {
488         NestedMessage nestedMessage = new NestedMessage()
489                 .addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123")
490                 .addAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME, MESSAGE_RUNNER_KEY.toString());
491         doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
492         doCallRealMethod()
493                 .doReturn(Optional.empty())
494                 .when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
495 
496         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
497         messageRunnerService.initialiseMessageConsumers();
498 
499         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
500         verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
501     }
502 
503     @Test(expected = MessageSerializationException.class)
504     public void failFastIfMessageCantBeSerialized() throws Exception {
505         doThrow(MessageSerializationException.class).when(messageInformationService).toPayload(any(Message.class));
506 
507         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
508     }
509 
510     @Test
511     public void ensureTenantContextEstablishedOnConsumer() throws Exception {
512         messageRunnerService.initialiseMessageConsumers();
513         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
514 
515         verify(tenantIdSetter, timeout(VERIFY_TIMEOUT_MILLIS)).setTenantId(TENANT_ID);
516     }
517 
518     @Test
519     public void messageNotProcessedOrDeletedIfTenantContextEstablishmentThrowsException() throws Exception {
520         Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
521         doThrow(new RuntimeException("Error creating TenantContext from a tenantId (perhaps tenant context service is currently unavailable?)")).when(tenantIdSetter).setTenantId(anyString());
522 
523         messageRunnerService.initialiseMessageConsumers();
524         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
525         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
526         verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
527         verify(messageRunner, never()).processMessage(any(MessageContext.class));
528     }
529 
530     @Test
531     public void supportEarlyAcknowledgementOfMessage() throws Exception {
532         AtomicInteger deliveryCount = new AtomicInteger(0);
533         int visibilityTimeoutSeconds = 1; // use a short visibility timeout here to encourage message redelivery (and therefore have a chance to prevent it)
534         messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
535                 .withAmazonSQSClient(spyingSqsClient)
536                 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
537                 .withMessageRunnerRegistryHelper(registryService)
538                 .withTenantIdSetter(tenantIdSetter)
539                 .withMessageInformationService(messageInformationService)
540                 .withNestedMessageSerializer(nestedMessageSerializer)
541                 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(spyingSqsClient))
542                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
543                 .build();
544         MessageRunnerKey fooMessageRunnerKey = MessageRunnerKey.of("fooMessageRunnerKey");
545         registryService.registerMessageRunner(fooMessageRunnerKey, messageContext -> {
546             deliveryCount.incrementAndGet();
547             messageContext.acknowledge(); // acknowledge message at start of message runner before any processing begins
548             throw new RuntimeException("message runner error");
549         });
550 
551         messageRunnerService.initialiseMessageConsumers();
552         messageRunnerService.addMessage(Message.create(fooMessageRunnerKey, "payload"));
553 
554         Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
555         assertThat(deliveryCount.get(), is(1));
556     }
557 
558     @Test(expected = MessageRunnerNotRegisteredException.class)
559     public void failFastWhenAddingMessageWithNoRegisteredMessageRunner() throws Exception {
560         doReturn(Optional.empty()).when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
561 
562         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
563     }
564 
565     @Test
566     public void messageRunnerResponsiveToCancellationViaShutdown() throws Exception {
567         final CountDownLatch enterLatch = new CountDownLatch(1);
568         final CountDownLatch exitLatch = new CountDownLatch(1);
569         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
570         final int sleepMillis = 100;
571         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
572             enterLatch.countDown();
573             while (!messageContext.isCancellationRequested()) {
574                 Uninterruptibles.sleepUninterruptibly(sleepMillis, TimeUnit.MILLISECONDS);
575             }
576             exitLatch.countDown();
577         });
578         messageRunnerService.initialiseMessageConsumers();
579         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
580         assertThat("MessageRunner did not start. Cannot assert cancellation on a MessageRunner that did not start",
581                 enterLatch.await(1, TimeUnit.SECONDS), is(true));
582 
583         messageRunnerService.shutdown();
584 
585         assertThat("MessageRunner did not respond to cancellation", exitLatch.await(sleepMillis * 2, TimeUnit.MILLISECONDS), is(true));
586     }
587 
588     @Test
589     public void cancelAutoAcknowledgement() throws Exception {
590         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("cancelAckMessageRunnerKey");
591         registryService.registerMessageRunner(messageRunnerKey, MessageContext::cancelAutoAcknowledgementOfMessage);
592 
593         messageRunnerService.initialiseMessageConsumers();
594         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
595 
596         verify(spyingSqsClient, after(VERIFY_TIMEOUT_MILLIS).never()).deleteMessage(anyString(), anyString());
597     }
598 
599     @Test
600     public void autoExtensionOfVisibilityTimeout() throws Exception {
601         final int visibilityTimeoutSeconds = 1;
602 
603         // use more than one consumer to allow the a particular message to be processed concurrently
604         // this should cause this test to fail if the visibility timeout extension code is not working properly
605         final int longRunningTaskDurationMillis = 2500;
606 
607         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
608         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
609             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
610         }));
611         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
612 
613         messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
614                 .withAmazonSQSClient(spyingSqsClient)
615                 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
616                 .withMessageRunnerRegistryHelper(registryService)
617                 .withTenantIdSetter(tenantIdSetter)
618                 .withMessageInformationService(messageInformationService)
619                 .withNestedMessageSerializer(nestedMessageSerializer)
620                 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
621                         spyingSqsClient,
622                         2,
623                         5))
624                 .withMessageValidatorRegistryHelper(messageValidatorRegistry).build();
625 
626         messageRunnerService.initialiseMessageConsumers();
627         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
628 
629         verify(longRunningMessageRunner, after(longRunningTaskDurationMillis + 500).times(1)).processMessage(any(MessageContext.class));
630         verify(spyingSqsClient, times(2)).changeMessageVisibility(anyString(), anyString(), anyInt());
631     }
632 
633     /**
634      * When the visibility timeout of a message expires before we've managed to extend it, then
635      * the ChangeMessageVisibility request will throw a MessageNotInflightException
636      * .*
637      * <p>Also, since the visibility timeout has expired, we expect another MessageRunner to start executing the message
638      *
639      * <p>We should log this incident appropriately and tune the scheduling buffer time to be big enough to make this very unlikely
640      * .*/
641     @Test
642     public void autoExtensionOfVisibilityTimeoutFailsBecauseMessageNoLongerInFlight() throws Exception {
643         final int visibilityTimeoutSeconds = 1;
644 
645         // use more than one consumer to allow the a particular message to be processed concurrently
646         // this should cause this test to fail if the visibility timeout extension code is not working properly
647         final int longRunningTaskDurationMillis = 2500;
648         final int schedulingBufferSeconds = 1;
649 
650         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
651         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
652             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
653         }));
654         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
655 
656         doThrow(MessageNotInflightException.class).when(spyingSqsClient).changeMessageVisibility(anyString(), anyString(), anyInt());
657 
658         messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
659                 .withAmazonSQSClient(spyingSqsClient)
660                 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
661                 .withMessageRunnerRegistryHelper(registryService)
662                 .withTenantIdSetter(tenantIdSetter)
663                 .withMessageInformationService(messageInformationService)
664                 .withNestedMessageSerializer(nestedMessageSerializer)
665                 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
666                         spyingSqsClient,
667                         2,
668                         schedulingBufferSeconds))
669                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
670                 .build();
671 
672         messageRunnerService.initialiseMessageConsumers();
673         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
674 
675         verify(longRunningMessageRunner, after(3000).times(2)).processMessage(any(MessageContext.class));
676     }
677 
678     @Test
679     public void autoExtensionOfVisibilityTimeoutCancelledOnMessageProcessingException() throws Exception {
680         final int visibilityTimeoutSeconds = 1;
681         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("erroneousMessageRunnerKey");
682         final MessageRunner erroneousMessageRunner = mock(MessageRunner.class);
683         doThrow(RuntimeException.class).when(erroneousMessageRunner).processMessage(any(MessageContext.class));
684         registryService.registerMessageRunner(messageRunnerKey, erroneousMessageRunner);
685 
686         messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig()).withAmazonSQSClient(spyingSqsClient).withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS).withMessageRunnerRegistryHelper(registryService).withTenantIdSetter(tenantIdSetter).withMessageInformationService(messageInformationService).withNestedMessageSerializer(nestedMessageSerializer).withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(spyingSqsClient,
687                 2)).withMessageValidatorRegistryHelper(messageValidatorRegistry).build();
688 
689         messageRunnerService.initialiseMessageConsumers();
690         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
691 
692         // the first visibility timeout extension occurs after the visibilityTimeout expires (that is 1 second in this test)
693         // wait a little more than this time to allow an extension to happen that would fail this test if this test case is not addressed in the code
694         final int waitMillis = visibilityTimeoutSeconds * 2 * 1000;
695         verify(spyingSqsClient, after(waitMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
696     }
697 
698     @Test
699     public void autoExtensionOfVisibilityTimeoutCancelledOnEarlyAcknowledgement() throws Exception {
700         final int visibilityTimeoutSeconds = 1;
701         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
702         final int longRunningTaskDurationMillis = 1100;
703         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
704             // acknowledge early in the message runner (we should cancel any future extensions of the visibility timeout
705             // after this since it does not make sense to extend the visibility timeout of a deleted message)
706             context.acknowledge();
707             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS); // need the original 10 seconds plus two 10 second extensions to complete
708         }));
709         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
710 
711         messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
712                 .withAmazonSQSClient(spyingSqsClient)
713                 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
714                 .withMessageRunnerRegistryHelper(registryService)
715                 .withTenantIdSetter(tenantIdSetter)
716                 .withMessageInformationService(messageInformationService)
717                 .withNestedMessageSerializer(nestedMessageSerializer)
718                 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
719                         spyingSqsClient,
720                         2))
721                 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
722                 .build();
723 
724         messageRunnerService.initialiseMessageConsumers();
725         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
726 
727         verify(spyingSqsClient, after(longRunningTaskDurationMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
728     }
729 }