View Javadoc

1   package com.atlassian.messagequeue.internal.sqs;
2   
3   import com.amazonaws.auth.BasicAWSCredentials;
4   import com.amazonaws.services.sqs.AmazonSQSClient;
5   import com.amazonaws.services.sqs.model.InvalidMessageContentsException;
6   import com.amazonaws.services.sqs.model.MessageNotInflightException;
7   import com.amazonaws.services.sqs.model.OverLimitException;
8   import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
9   import com.amazonaws.services.sqs.model.ReceiveMessageResult;
10  import com.amazonaws.services.sqs.model.SendMessageRequest;
11  import com.atlassian.messagequeue.Message;
12  import com.atlassian.messagequeue.MessageInformationService;
13  import com.atlassian.messagequeue.MessageRunnerKey;
14  import com.atlassian.messagequeue.MessageRunnerNotRegisteredException;
15  import com.atlassian.messagequeue.MessageRunnerServiceException;
16  import com.atlassian.messagequeue.MessageSerializationException;
17  import com.atlassian.messagequeue.internal.core.DefaultMessageInformationService;
18  import com.atlassian.messagequeue.internal.core.JacksonNestedMessageSerializer;
19  import com.atlassian.messagequeue.internal.core.NestedMessage;
20  import com.atlassian.messagequeue.internal.core.NestedMessageSerializer;
21  import com.atlassian.messagequeue.internal.core.DefaultMessageRunnerRegistryService;
22  import com.atlassian.messagequeue.registry.MessageContext;
23  import com.atlassian.messagequeue.registry.MessageRunner;
24  import com.atlassian.tenant.api.TenantContext;
25  import com.atlassian.tenant.api.TenantContextProvider;
26  import com.atlassian.tenant.impl.TenantIdSetter;
27  import com.atlassian.workcontext.api.ImmutableWorkContextReference;
28  import com.atlassian.workcontext.api.WorkContextDoorway;
29  import com.google.common.util.concurrent.Uninterruptibles;
30  import org.elasticmq.rest.sqs.SQSRestServer;
31  import org.elasticmq.rest.sqs.SQSRestServerBuilder;
32  import org.junit.After;
33  import org.junit.Assume;
34  import org.junit.Before;
35  import org.junit.Test;
36  import org.junit.runner.RunWith;
37  import org.mockito.ArgumentCaptor;
38  import org.mockito.Captor;
39  import org.mockito.Mock;
40  import org.mockito.Mockito;
41  import org.mockito.runners.MockitoJUnitRunner;
42  
43  import java.util.Collections;
44  import java.util.Optional;
45  import java.util.Queue;
46  import java.util.concurrent.BlockingQueue;
47  import java.util.concurrent.CountDownLatch;
48  import java.util.concurrent.ExecutorService;
49  import java.util.concurrent.Executors;
50  import java.util.concurrent.LinkedBlockingQueue;
51  import java.util.concurrent.TimeUnit;
52  import java.util.concurrent.atomic.AtomicInteger;
53  import java.util.concurrent.atomic.AtomicReference;
54  import java.util.function.Supplier;
55  import java.util.stream.IntStream;
56  
57  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME;
58  import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_ID_ATTRIBUTE_NAME;
59  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.SENT_TIMESTAMP;
60  import static com.google.common.base.Throwables.getRootCause;
61  import static com.google.common.base.Throwables.propagate;
62  import static java.util.Objects.requireNonNull;
63  import static org.hamcrest.CoreMatchers.hasItem;
64  import static org.hamcrest.CoreMatchers.not;
65  import static org.hamcrest.MatcherAssert.assertThat;
66  import static org.hamcrest.Matchers.contains;
67  import static org.hamcrest.Matchers.containsInAnyOrder;
68  import static org.hamcrest.Matchers.hasSize;
69  import static org.hamcrest.core.Is.is;
70  import static org.hamcrest.text.IsEmptyString.isEmptyOrNullString;
71  import static org.junit.Assert.assertNull;
72  import static org.junit.Assert.assertTrue;
73  import static org.mockito.Matchers.any;
74  import static org.mockito.Matchers.anyInt;
75  import static org.mockito.Matchers.anyString;
76  import static org.mockito.Mockito.after;
77  import static org.mockito.Mockito.doCallRealMethod;
78  import static org.mockito.Mockito.doReturn;
79  import static org.mockito.Mockito.doThrow;
80  import static org.mockito.Mockito.mock;
81  import static org.mockito.Mockito.never;
82  import static org.mockito.Mockito.spy;
83  import static org.mockito.Mockito.timeout;
84  import static org.mockito.Mockito.times;
85  import static org.mockito.Mockito.verify;
86  import static org.mockito.Mockito.when;
87  
88  @RunWith(MockitoJUnitRunner.class)
89  public class SQSMessageRunnerServiceTest {
90      private static final MessageRunnerKey MESSAGE_RUNNER_KEY = MessageRunnerKey.of("testMessageRunner");
91      private static final int VISIBILITY_TIMEOUT_SECONDS = 10;
92      private static final int RECEIVE_WAIT_TIME_SECONDS = 0;
93      private static final String QUEUE_NAME = "testQueue";
94      private static final int CONCURRENT_CONSUMERS = 1;
95      private static final String TENANT_ID = "tenant-123";
96      private static final int VERIFY_TIMEOUT_MILLIS = 1000;
97      private static final int SCHEDULING_BUFFER_SECONDS = 5;
98  
99      private SQSMessageRunnerService messageRunnerService;
100     private BlockingQueue<String> payloads;
101     private SQSRestServer sqsServer;
102     private AmazonSQSClient spyingSqsClient;
103     private WorkContextDoorway workContextDoorway;
104 
105     @Mock
106     private TenantContextProvider tenantContextProvider;
107     @Mock
108     private TenantIdSetter tenantIdSetter;
109 
110     @Captor
111     private ArgumentCaptor<ReceiveMessageRequest> receiveMessageRequestArgumentCaptor;
112     @Captor
113     private ArgumentCaptor<SendMessageRequest> sendMessageRequestArgumentCaptor;
114     @Captor
115     private ArgumentCaptor<MessageContext> messageContextArgumentCaptor;
116 
117     private MessageRunner messageRunner;
118     private DefaultMessageRunnerRegistryService registryService;
119     private MessageInformationService messageInformationService;
120     private String queueUrl;
121     private NestedMessageSerializer nestedMessageSerializer;
122     private TenantContext.Builder tenantContextBuilder;
123     private SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager;
124 
125     @Before
126     public void setUp() throws Exception {
127         sqsServer = SQSRestServerBuilder.start();
128 
129         String endpoint = "http://localhost:9324";
130 
131         AmazonSQSClient sqsClient = new AmazonSQSClient(new BasicAWSCredentials("x", "x")).withEndpoint(endpoint);
132         sqsClient.createQueue(QUEUE_NAME);
133         queueUrl = sqsClient.getQueueUrl(QUEUE_NAME).getQueueUrl();
134         spyingSqsClient = spy(sqsClient);
135 
136         nestedMessageSerializer = spy(new JacksonNestedMessageSerializer());
137         registryService = spy(new DefaultMessageRunnerRegistryService());
138         messageInformationService = spy(new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer));
139         sqsMessageVisibilityTimeoutManager = new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, VISIBILITY_TIMEOUT_SECONDS, SCHEDULING_BUFFER_SECONDS, 2);
140 
141         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
142                 CONCURRENT_CONSUMERS, registryService,
143                 tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
144         payloads = new LinkedBlockingQueue<>();
145 
146         messageRunner = spy(new DefaultMessageRunner(payloads));
147         registryService.registerMessageRunner(MESSAGE_RUNNER_KEY, messageRunner);
148 
149         tenantContextBuilder = new TenantContext.Builder()
150                 .jdbcUrl("jdbcUrl")
151                 .jdbcUsername("username")
152                 .jdbcPassword("password");
153         when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId(TENANT_ID).build());
154 
155         workContextDoorway = new WorkContextDoorway();
156         workContextDoorway.open();
157     }
158 
159     @After
160     public void tearDown() throws Exception {
161         workContextDoorway.close();
162         messageRunnerService.shutdown();
163         sqsServer.stopAndWait();
164     }
165 
166     @Test
167     public void messagePayloadIsPassedToMessageRunner() throws Exception {
168         String expectedPayload = "payload";
169 
170         messageRunnerService.initialiseMessageConsumers();
171         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
172 
173         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
174         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of(expectedPayload)));
175     }
176 
177     @Test
178     public void messageIsAutoAcknowledgedAfterSuccessfulProcessing() throws Exception {
179         String expectedPayload = "payload";
180 
181         messageRunnerService.initialiseMessageConsumers();
182         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
183 
184         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
185     }
186 
187     @Test
188     public void addMessageWithEmptyPayload() throws Exception {
189         messageRunnerService.initialiseMessageConsumers();
190         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, ""));
191 
192         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
193         assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
194         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
195         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of("")));
196     }
197 
198     @Test
199     public void addMessageWithNullPayload() throws Exception {
200         messageRunnerService.initialiseMessageConsumers();
201         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, null));
202 
203         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
204         assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
205         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
206         assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.empty()));
207     }
208 
209     @Test
210     public void addMultipleMessagesSerially() throws Exception {
211         messageRunnerService.initialiseMessageConsumers();
212         int numberOfMessages = 20;
213         Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
214 
215         range.get().forEach(i -> {
216             messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
217         });
218 
219         String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
220         verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
221         assertThat(payloads, containsInAnyOrder(expectedPayloads));
222     }
223 
224     @Test
225     public void addMultipleMessagesConcurrently() throws Exception {
226         messageRunnerService.initialiseMessageConsumers();
227         int numberOfMessages = 20;
228         Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
229 
230         ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
231         try {
232             range.get().forEach(i -> {
233                 pool.execute(() -> {
234                     messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
235                 });
236             });
237 
238             String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
239             verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
240             assertThat(payloads, containsInAnyOrder(expectedPayloads));
241         } finally {
242             pool.shutdown();
243             pool.awaitTermination(1, TimeUnit.SECONDS);
244         }
245     }
246 
247     @Test
248     public void messageNotDeletedIfRunningItThrowsAnException() throws Exception {
249         MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("failingRunner");
250         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
251             throw new RuntimeException();
252         });
253 
254         messageRunnerService.initialiseMessageConsumers();
255         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
256 
257         verify(spyingSqsClient, after(1000).never()).deleteMessage(anyString(), anyString());
258     }
259 
260     @Test
261     public void receiveMessageRequestPerformedUsingConfiguredVisibilityTimeoutAndSchedulingBuffer() throws Exception {
262         messageRunnerService.initialiseMessageConsumers();
263         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
264 
265         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
266         assertThat(receiveMessageRequestArgumentCaptor.getValue().getVisibilityTimeout(), is(VISIBILITY_TIMEOUT_SECONDS + SCHEDULING_BUFFER_SECONDS));
267     }
268 
269     @Test
270     public void sentTimestampAttributeIsRequestedToAllowMessageLatencyToBeDerived() throws Exception {
271         messageRunnerService.initialiseMessageConsumers();
272         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
273 
274         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
275         assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), contains(SENT_TIMESTAMP));
276     }
277 
278     @Test
279     public void consumerResumesPollingIfNoMessagesReturnedInOneInvocationOfReceiveMessage() throws Exception {
280         Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
281         ReceiveMessageResult emptyReceiveMessageResult = mock(ReceiveMessageResult.class);
282         when(emptyReceiveMessageResult.getMessages()).thenReturn(Collections.emptyList());
283         doReturn(emptyReceiveMessageResult).doCallRealMethod().when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
284 
285         messageRunnerService.initialiseMessageConsumers();
286         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
287 
288         verify(spyingSqsClient, after(1000).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
289     }
290 
291     @Test(expected = InvalidMessageContentsException.class)
292     public void addMessageWithPayloadContainingIllegalCharacters() throws Exception {
293         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
294         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
295         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "test", registryService,
296                 tenantIdSetter, messageInformationService, nestedMessageSerializer);
297 
298         when(amazonSQSClient.sendMessage(any(SendMessageRequest.class))).thenThrow(InvalidMessageContentsException.class);
299 
300         try {
301             messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload with invalid characters"));
302         } catch (MessageRunnerServiceException e) {
303             propagate(getRootCause(e));
304         }
305     }
306 
307     @Test(expected = IllegalStateException.class)
308     public void addMessageFailsWithExceptionIfTenantContextNotPresent() throws Exception {
309         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
310         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class);
311         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
312         MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
313         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
314                 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
315         when(tenantContextProvider.getTenantContext()).thenReturn(null);
316 
317         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
318     }
319 
320     @Test(expected = IllegalStateException.class)
321     public void addMessageFailsWithNullTenantId() throws Exception {
322         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
323         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
324         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
325         MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
326         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
327                 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
328 
329         when(tenantContextProvider.getTenantContext().getTenantId()).thenReturn(null);
330         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
331     }
332 
333     @Test(expected = IllegalStateException.class)
334     public void addMessageFailsWithEmptyTenantId() throws Exception {
335         AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
336         TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
337         TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
338         MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
339         SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
340                 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
341 
342         when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId("").build());
343         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
344     }
345 
346     @Test
347     public void consumerPoolCanRecoverFromOverLimitException() throws Exception {
348         // setup sqsClient such that the first invocation of sqsClient.receiveMessage(...) throws an OverLimitException
349         // setup a subsequent invocation to proceed normally
350         doThrow(OverLimitException.class)
351                 .doCallRealMethod()
352                     .when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
353 
354         messageRunnerService.initialiseMessageConsumers();
355         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
356 
357         // ensure message is eventually processed after we recover from the OverLimitException
358         verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(any(MessageContext.class));
359         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString()); // wait until MessageRunner.processMessage(...) invocation has completed to reduce the flakiness of the following assertions
360         assertThat(payloads, hasSize(1));
361         assertThat(payloads, hasItem("1"));
362     }
363 
364     @Test
365     public void consumerPoolCanRecoverFromError() throws Exception {
366         final int visibilityTimeoutSeconds = 1; // use a short visibility timeout here to enable message redelivery after failure
367         final int schedulingBufferSeconds = 1;
368         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
369                 CONCURRENT_CONSUMERS, registryService,
370                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
371                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, schedulingBufferSeconds, 1));
372 
373         // setup messageRunner such that the first invocation of messageRunner.processMessage(...) throws a LinkageError
374         // setup a subsequent invocation to proceed normally
375         doThrow(LinkageError.class)
376                 .doCallRealMethod()
377                     .when(messageRunner).processMessage(any(MessageContext.class));
378 
379         messageRunnerService.initialiseMessageConsumers();
380         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
381 
382         verify(spyingSqsClient, timeout(2500)).deleteMessage(anyString(), anyString());
383         assertThat(payloads, hasSize(1));
384         assertThat(payloads, hasItem("1"));
385     }
386 
387     @Test
388     public void ensureWorkContextAvailableForMessageProcessing() throws Exception {
389         CountDownLatch latch = new CountDownLatch(1);
390         MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("workContext");
391         AtomicReference<Exception> exception = new AtomicReference<>();
392         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
393             try {
394                 new ImmutableWorkContextReference<>(() -> "foobar").get();
395             } catch (IllegalStateException e) {
396                 exception.set(e);
397             } finally {
398                 latch.countDown();
399             }
400         });
401 
402         messageRunnerService.initialiseMessageConsumers();
403         messageRunnerService.addMessage(Message.create(messageRunnerKey, "1"));
404 
405         assertTrue("test did not complete within timeout", latch.await(1, TimeUnit.SECONDS));
406         assertNull("No work context available for message processing.", exception.get());
407     }
408 
409     @Test
410     public void ensureTenantIdIsSentByProducer() throws Exception {
411         messageRunnerService.initialiseMessageConsumers();
412         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
413 
414         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
415         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
416 
417         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
418         assertThat(nestedMessage.getAttribute(TENANT_ID_ATTRIBUTE_NAME), is(TENANT_ID));
419     }
420 
421     @Test
422     public void ensureMessageRunnerKeyIsSentByProducer() throws Exception {
423         messageRunnerService.initialiseMessageConsumers();
424         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
425 
426         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
427         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
428 
429         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
430         assertThat(nestedMessage.getAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME), is(MESSAGE_RUNNER_KEY.toString()));
431     }
432 
433     @Test
434     public void ensurePayloadIsSentByProducer() throws Exception {
435         String payload = "payload";
436         messageRunnerService.initialiseMessageConsumers();
437         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, payload));
438 
439         verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
440         String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
441 
442         NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
443         assertThat(nestedMessage.getPayload(), is(payload));
444     }
445 
446     @Test
447     public void dontProcessAndDeleteMessagesThatCantBeDeserialized() throws Exception {
448         doThrow(MessageSerializationException.class).when(nestedMessageSerializer).deserialize(any(String.class));
449 
450         messageRunnerService.initialiseMessageConsumers();
451         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
452 
453         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
454         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
455     }
456 
457     @Test
458     public void deleteMessageWithoutTenantId() throws Exception {
459         doReturn(new NestedMessage()).when(nestedMessageSerializer).deserialize(anyString());
460 
461         messageRunnerService.initialiseMessageConsumers();
462         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
463 
464         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
465         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
466     }
467 
468     @Test
469     public void deleteMessageWithoutMessageRunnerKey() throws Exception {
470         NestedMessage nestedMessage = new NestedMessage().addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123");
471         doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
472 
473         messageRunnerService.initialiseMessageConsumers();
474         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
475 
476         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
477         verify(spyingSqsClient).deleteMessage(anyString(), anyString());
478     }
479 
480     @Test
481     public void dontProcessMessageWithoutMessageRunner() throws Exception {
482         NestedMessage nestedMessage = new NestedMessage()
483                 .addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123")
484                 .addAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME, MESSAGE_RUNNER_KEY.toString());
485         doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
486         doCallRealMethod()
487                 .doReturn(Optional.empty())
488                 .when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
489 
490         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
491         messageRunnerService.initialiseMessageConsumers();
492 
493         verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
494         verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
495     }
496 
497     @Test(expected = MessageSerializationException.class)
498     public void failFastIfMessageCantBeSerialized() throws Exception {
499         doThrow(MessageSerializationException.class).when(messageInformationService).toPayload(any(Message.class));
500 
501         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
502     }
503 
504     @Test
505     public void ensureTenantContextEstablishedOnConsumer() throws Exception {
506         messageRunnerService.initialiseMessageConsumers();
507         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
508 
509         verify(tenantIdSetter, timeout(VERIFY_TIMEOUT_MILLIS)).setTenantId(TENANT_ID);
510     }
511 
512     @Test
513     public void messageNotProcessedOrDeletedIfTenantContextEstablishmentThrowsException() throws Exception {
514         Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
515         doThrow(new RuntimeException("Error creating TenantContext from a tenantId (perhaps tenant context service is currently unavailable?)")).when(tenantIdSetter).setTenantId(anyString());
516 
517         messageRunnerService.initialiseMessageConsumers();
518         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
519         verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
520         verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
521         verify(messageRunner, never()).processMessage(any(MessageContext.class));
522     }
523 
524     @Test(expected = IllegalArgumentException.class)
525     public void concurrentConsumersCantBeNegative() throws Exception {
526         int concurrentConsumers = -1;
527 
528         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
529                 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
530     }
531 
532     @Test(expected = IllegalArgumentException.class)
533     public void concurrentConsumersCantBeZero() throws Exception {
534         int concurrentConsumers = 0;
535 
536         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
537                 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
538     }
539 
540     @Test
541     public void supportEarlyAcknowledgementOfMessage() throws Exception {
542         AtomicInteger deliveryCount = new AtomicInteger(0);
543         int visibilityTimeoutSeconds = 1; // use a short visibility timeout here to encourage message redelivery (and therefore have a chance to prevent it)
544         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
545                 CONCURRENT_CONSUMERS, registryService,
546                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
547                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl));
548         MessageRunnerKey fooMessageRunnerKey = MessageRunnerKey.of("fooMessageRunnerKey");
549         registryService.registerMessageRunner(fooMessageRunnerKey, messageContext -> {
550             deliveryCount.incrementAndGet();
551             messageContext.acknowledge(); // acknowledge message at start of message runner before any processing begins
552             throw new RuntimeException("message runner error");
553         });
554 
555         messageRunnerService.initialiseMessageConsumers();
556         messageRunnerService.addMessage(Message.create(fooMessageRunnerKey, "payload"));
557 
558         Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
559         assertThat(deliveryCount.get(), is(1));
560     }
561 
562     @Test(expected = MessageRunnerNotRegisteredException.class)
563     public void failFastWhenAddingMessageWithNoRegisteredMessageRunner() throws Exception {
564         doReturn(Optional.empty()).when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
565 
566         messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
567     }
568 
569     @Test
570     public void messageRunnerResponsiveToCancellationViaShutdown() throws Exception {
571         final CountDownLatch enterLatch = new CountDownLatch(1);
572         final CountDownLatch exitLatch = new CountDownLatch(1);
573         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
574         final int sleepMillis = 100;
575         registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
576             enterLatch.countDown();
577             while (!messageContext.isCancellationRequested()) {
578                 Uninterruptibles.sleepUninterruptibly(sleepMillis, TimeUnit.MILLISECONDS);
579             }
580             exitLatch.countDown();
581         });
582         messageRunnerService.initialiseMessageConsumers();
583         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
584         assertThat("MessageRunner did not start. Cannot assert cancellation on a MessageRunner that did not start",
585                 enterLatch.await(1, TimeUnit.SECONDS), is(true));
586 
587         messageRunnerService.shutdown();
588 
589         assertThat("MessageRunner did not respond to cancellation", exitLatch.await(sleepMillis * 2, TimeUnit.MILLISECONDS), is(true));
590     }
591 
592     @Test
593     public void cancelAutoAcknowledgement() throws Exception {
594         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("cancelAckMessageRunnerKey");
595         registryService.registerMessageRunner(messageRunnerKey, MessageContext::cancelAutoAcknowledgementOfMessage);
596 
597         messageRunnerService.initialiseMessageConsumers();
598         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
599 
600         verify(spyingSqsClient, after(VERIFY_TIMEOUT_MILLIS).never()).deleteMessage(anyString(), anyString());
601     }
602 
603     @Test
604     public void autoExtensionOfVisibilityTimeout() throws Exception {
605         final int visibilityTimeoutSeconds = 1;
606 
607         // use more than one consumer to allow the a particular message to be processed concurrently
608         // this should cause this test to fail if the visibility timeout extension code is not working properly
609         final int concurrentConsumers = 2;
610         final int longRunningTaskDurationMillis = 2500;
611 
612         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
613         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
614             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
615         }));
616         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
617 
618         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
619                 concurrentConsumers, registryService,
620                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
621                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
622 
623         messageRunnerService.initialiseMessageConsumers();
624         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
625 
626         verify(longRunningMessageRunner, after(longRunningTaskDurationMillis + 500).times(1)).processMessage(any(MessageContext.class));
627         verify(spyingSqsClient, times(2)).changeMessageVisibility(anyString(), anyString(), anyInt());
628     }
629 
630     /**
631      * When the visibility timeout of a message expires before we've managed to extend it, then
632      * the ChangeMessageVisibility request will throw a MessageNotInflightException.
633      *
634      * <p>Also, since the visibility timeout has expired, we expect another MessageRunner to start executing the message
635      *
636      * <p>We should log this incident appropriately and tune the scheduling buffer time to be big enough to make this very unlikely.
637      */
638     @Test
639     public void autoExtensionOfVisibilityTimeoutFailsBecauseMessageNoLongerInFlight() throws Exception {
640         final int visibilityTimeoutSeconds = 1;
641 
642         // use more than one consumer to allow the a particular message to be processed concurrently
643         // this should cause this test to fail if the visibility timeout extension code is not working properly
644         final int concurrentConsumers = 2;
645         final int longRunningTaskDurationMillis = 2500;
646         final int schedulingBufferSeconds = 1;
647 
648         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
649         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
650             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
651         }));
652         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
653 
654         doThrow(MessageNotInflightException.class).when(spyingSqsClient).changeMessageVisibility(anyString(), anyString(), anyInt());
655 
656         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
657                 concurrentConsumers, registryService,
658                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
659                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, schedulingBufferSeconds, 2));
660 
661         messageRunnerService.initialiseMessageConsumers();
662         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
663 
664         verify(longRunningMessageRunner, after(3000).times(2)).processMessage(any(MessageContext.class));
665     }
666 
667     @Test
668     public void autoExtensionOfVisibilityTimeoutCancelledOnMessageProcessingException() throws Exception {
669         final int visibilityTimeoutSeconds = 1;
670         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("erroneousMessageRunnerKey");
671         final MessageRunner erroneousMessageRunner = mock(MessageRunner.class);
672         doThrow(RuntimeException.class).when(erroneousMessageRunner).processMessage(any(MessageContext.class));
673         registryService.registerMessageRunner(messageRunnerKey, erroneousMessageRunner);
674 
675         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
676                 CONCURRENT_CONSUMERS, registryService,
677                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
678                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
679 
680         messageRunnerService.initialiseMessageConsumers();
681         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
682 
683         // the first visibility timeout extension occurs after the visibilityTimeout expires (that is 1 second in this test)
684         // wait a little more than this time to allow an extension to happen that would fail this test if this test case is not addressed in the code
685         final int waitMillis = visibilityTimeoutSeconds * 2 * 1000;
686         verify(spyingSqsClient, after(waitMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
687     }
688 
689     @Test
690     public void autoExtensionOfVisibilityTimeoutCancelledOnEarlyAcknowledgement() throws Exception {
691         final int visibilityTimeoutSeconds = 1;
692         final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
693         final int longRunningTaskDurationMillis = 1100;
694         final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
695             // acknowledge early in the message runner (we should cancel any future extensions of the visibility timeout
696             // after this since it does not make sense to extend the visibility timeout of a deleted message)
697             context.acknowledge();
698             Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS); // need the original 10 seconds plus two 10 second extensions to complete
699         }));
700         registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
701 
702         messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
703                 CONCURRENT_CONSUMERS, registryService,
704                 tenantIdSetter, messageInformationService, nestedMessageSerializer,
705                 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
706 
707         messageRunnerService.initialiseMessageConsumers();
708         messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
709 
710         verify(spyingSqsClient, after(longRunningTaskDurationMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
711     }
712 
713     private static TenantContext newTenantContext(String tenantId) {
714         TenantContext tenantContext = mock(TenantContext.class);
715         when(tenantContext.getTenantId()).thenReturn(tenantId);
716 
717         return tenantContext;
718     }
719 
720     private static class DelegatingMessageRunner implements MessageRunner {
721         private final MessageRunner delegate;
722 
723         public DelegatingMessageRunner(MessageRunner delegate) {
724             this.delegate = delegate;
725         }
726 
727         @Override
728         public void processMessage(MessageContext context) {
729             delegate.processMessage(context);
730         }
731     }
732 
733     private static class DefaultMessageRunner implements MessageRunner {
734         private final Queue<String> payloads;
735 
736         public DefaultMessageRunner(Queue<String> payloads) {
737             this.payloads = requireNonNull(payloads);
738         }
739 
740         @Override
741         public void processMessage(MessageContext messageContext) {
742             payloads.offer(messageContext.getPayload().orElse(null));
743         }
744     }
745 }