1 package com.atlassian.messagequeue.internal.sqs;
2
3 import com.amazonaws.auth.BasicAWSCredentials;
4 import com.amazonaws.services.sqs.AmazonSQSClient;
5 import com.amazonaws.services.sqs.model.InvalidMessageContentsException;
6 import com.amazonaws.services.sqs.model.MessageNotInflightException;
7 import com.amazonaws.services.sqs.model.OverLimitException;
8 import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
9 import com.amazonaws.services.sqs.model.ReceiveMessageResult;
10 import com.amazonaws.services.sqs.model.SendMessageRequest;
11 import com.atlassian.messagequeue.Message;
12 import com.atlassian.messagequeue.MessageInformationService;
13 import com.atlassian.messagequeue.MessageRunnerKey;
14 import com.atlassian.messagequeue.MessageRunnerNotRegisteredException;
15 import com.atlassian.messagequeue.MessageRunnerServiceException;
16 import com.atlassian.messagequeue.MessageSerializationException;
17 import com.atlassian.messagequeue.internal.core.DefaultMessageInformationService;
18 import com.atlassian.messagequeue.internal.core.JacksonNestedMessageSerializer;
19 import com.atlassian.messagequeue.internal.core.NestedMessage;
20 import com.atlassian.messagequeue.internal.core.NestedMessageSerializer;
21 import com.atlassian.messagequeue.internal.core.DefaultMessageRunnerRegistryService;
22 import com.atlassian.messagequeue.registry.MessageContext;
23 import com.atlassian.messagequeue.registry.MessageRunner;
24 import com.atlassian.tenant.api.TenantContext;
25 import com.atlassian.tenant.api.TenantContextProvider;
26 import com.atlassian.tenant.impl.TenantIdSetter;
27 import com.atlassian.workcontext.api.ImmutableWorkContextReference;
28 import com.atlassian.workcontext.api.WorkContextDoorway;
29 import com.google.common.util.concurrent.Uninterruptibles;
30 import org.elasticmq.rest.sqs.SQSRestServer;
31 import org.elasticmq.rest.sqs.SQSRestServerBuilder;
32 import org.junit.After;
33 import org.junit.Assume;
34 import org.junit.Before;
35 import org.junit.Test;
36 import org.junit.runner.RunWith;
37 import org.mockito.ArgumentCaptor;
38 import org.mockito.Captor;
39 import org.mockito.Mock;
40 import org.mockito.Mockito;
41 import org.mockito.runners.MockitoJUnitRunner;
42
43 import java.util.Collections;
44 import java.util.Optional;
45 import java.util.Queue;
46 import java.util.concurrent.BlockingQueue;
47 import java.util.concurrent.CountDownLatch;
48 import java.util.concurrent.ExecutorService;
49 import java.util.concurrent.Executors;
50 import java.util.concurrent.LinkedBlockingQueue;
51 import java.util.concurrent.TimeUnit;
52 import java.util.concurrent.atomic.AtomicInteger;
53 import java.util.concurrent.atomic.AtomicReference;
54 import java.util.function.Supplier;
55 import java.util.stream.IntStream;
56
57 import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME;
58 import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_ID_ATTRIBUTE_NAME;
59 import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.SENT_TIMESTAMP;
60 import static com.google.common.base.Throwables.getRootCause;
61 import static com.google.common.base.Throwables.propagate;
62 import static java.util.Objects.requireNonNull;
63 import static org.hamcrest.CoreMatchers.hasItem;
64 import static org.hamcrest.CoreMatchers.not;
65 import static org.hamcrest.MatcherAssert.assertThat;
66 import static org.hamcrest.Matchers.contains;
67 import static org.hamcrest.Matchers.containsInAnyOrder;
68 import static org.hamcrest.Matchers.hasSize;
69 import static org.hamcrest.core.Is.is;
70 import static org.hamcrest.text.IsEmptyString.isEmptyOrNullString;
71 import static org.junit.Assert.assertNull;
72 import static org.junit.Assert.assertTrue;
73 import static org.mockito.Matchers.any;
74 import static org.mockito.Matchers.anyInt;
75 import static org.mockito.Matchers.anyString;
76 import static org.mockito.Mockito.after;
77 import static org.mockito.Mockito.doCallRealMethod;
78 import static org.mockito.Mockito.doReturn;
79 import static org.mockito.Mockito.doThrow;
80 import static org.mockito.Mockito.mock;
81 import static org.mockito.Mockito.never;
82 import static org.mockito.Mockito.spy;
83 import static org.mockito.Mockito.timeout;
84 import static org.mockito.Mockito.times;
85 import static org.mockito.Mockito.verify;
86 import static org.mockito.Mockito.when;
87
88 @RunWith(MockitoJUnitRunner.class)
89 public class SQSMessageRunnerServiceTest {
90 private static final MessageRunnerKey MESSAGE_RUNNER_KEY = MessageRunnerKey.of("testMessageRunner");
91 private static final int VISIBILITY_TIMEOUT_SECONDS = 10;
92 private static final int RECEIVE_WAIT_TIME_SECONDS = 0;
93 private static final String QUEUE_NAME = "testQueue";
94 private static final int CONCURRENT_CONSUMERS = 1;
95 private static final String TENANT_ID = "tenant-123";
96 private static final int VERIFY_TIMEOUT_MILLIS = 1000;
97 private static final int SCHEDULING_BUFFER_SECONDS = 5;
98
99 private SQSMessageRunnerService messageRunnerService;
100 private BlockingQueue<String> payloads;
101 private SQSRestServer sqsServer;
102 private AmazonSQSClient spyingSqsClient;
103 private WorkContextDoorway workContextDoorway;
104
105 @Mock
106 private TenantContextProvider tenantContextProvider;
107 @Mock
108 private TenantIdSetter tenantIdSetter;
109
110 @Captor
111 private ArgumentCaptor<ReceiveMessageRequest> receiveMessageRequestArgumentCaptor;
112 @Captor
113 private ArgumentCaptor<SendMessageRequest> sendMessageRequestArgumentCaptor;
114 @Captor
115 private ArgumentCaptor<MessageContext> messageContextArgumentCaptor;
116
117 private MessageRunner messageRunner;
118 private DefaultMessageRunnerRegistryService registryService;
119 private MessageInformationService messageInformationService;
120 private String queueUrl;
121 private NestedMessageSerializer nestedMessageSerializer;
122 private TenantContext.Builder tenantContextBuilder;
123 private SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager;
124
125 @Before
126 public void setUp() throws Exception {
127 sqsServer = SQSRestServerBuilder.start();
128
129 String endpoint = "http://localhost:9324";
130
131 AmazonSQSClient sqsClient = new AmazonSQSClient(new BasicAWSCredentials("x", "x")).withEndpoint(endpoint);
132 sqsClient.createQueue(QUEUE_NAME);
133 queueUrl = sqsClient.getQueueUrl(QUEUE_NAME).getQueueUrl();
134 spyingSqsClient = spy(sqsClient);
135
136 nestedMessageSerializer = spy(new JacksonNestedMessageSerializer());
137 registryService = spy(new DefaultMessageRunnerRegistryService());
138 messageInformationService = spy(new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer));
139 sqsMessageVisibilityTimeoutManager = new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, VISIBILITY_TIMEOUT_SECONDS, SCHEDULING_BUFFER_SECONDS, 2);
140
141 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
142 CONCURRENT_CONSUMERS, registryService,
143 tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
144 payloads = new LinkedBlockingQueue<>();
145
146 messageRunner = spy(new DefaultMessageRunner(payloads));
147 registryService.registerMessageRunner(MESSAGE_RUNNER_KEY, messageRunner);
148
149 tenantContextBuilder = new TenantContext.Builder()
150 .jdbcUrl("jdbcUrl")
151 .jdbcUsername("username")
152 .jdbcPassword("password");
153 when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId(TENANT_ID).build());
154
155 workContextDoorway = new WorkContextDoorway();
156 workContextDoorway.open();
157 }
158
159 @After
160 public void tearDown() throws Exception {
161 workContextDoorway.close();
162 messageRunnerService.shutdown();
163 sqsServer.stopAndWait();
164 }
165
166 @Test
167 public void messagePayloadIsPassedToMessageRunner() throws Exception {
168 String expectedPayload = "payload";
169
170 messageRunnerService.initialiseMessageConsumers();
171 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
172
173 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
174 assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of(expectedPayload)));
175 }
176
177 @Test
178 public void messageIsAutoAcknowledgedAfterSuccessfulProcessing() throws Exception {
179 String expectedPayload = "payload";
180
181 messageRunnerService.initialiseMessageConsumers();
182 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
183
184 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
185 }
186
187 @Test
188 public void addMessageWithEmptyPayload() throws Exception {
189 messageRunnerService.initialiseMessageConsumers();
190 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, ""));
191
192 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
193 assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
194 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
195 assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of("")));
196 }
197
198 @Test
199 public void addMessageWithNullPayload() throws Exception {
200 messageRunnerService.initialiseMessageConsumers();
201 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, null));
202
203 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
204 assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
205 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
206 assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.empty()));
207 }
208
209 @Test
210 public void addMultipleMessagesSerially() throws Exception {
211 messageRunnerService.initialiseMessageConsumers();
212 int numberOfMessages = 20;
213 Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
214
215 range.get().forEach(i -> {
216 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
217 });
218
219 String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
220 verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
221 assertThat(payloads, containsInAnyOrder(expectedPayloads));
222 }
223
224 @Test
225 public void addMultipleMessagesConcurrently() throws Exception {
226 messageRunnerService.initialiseMessageConsumers();
227 int numberOfMessages = 20;
228 Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
229
230 ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
231 try {
232 range.get().forEach(i -> {
233 pool.execute(() -> {
234 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
235 });
236 });
237
238 String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
239 verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
240 assertThat(payloads, containsInAnyOrder(expectedPayloads));
241 } finally {
242 pool.shutdown();
243 pool.awaitTermination(1, TimeUnit.SECONDS);
244 }
245 }
246
247 @Test
248 public void messageNotDeletedIfRunningItThrowsAnException() throws Exception {
249 MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("failingRunner");
250 registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
251 throw new RuntimeException();
252 });
253
254 messageRunnerService.initialiseMessageConsumers();
255 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
256
257 verify(spyingSqsClient, after(1000).never()).deleteMessage(anyString(), anyString());
258 }
259
260 @Test
261 public void receiveMessageRequestPerformedUsingConfiguredVisibilityTimeoutAndSchedulingBuffer() throws Exception {
262 messageRunnerService.initialiseMessageConsumers();
263 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
264
265 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
266 assertThat(receiveMessageRequestArgumentCaptor.getValue().getVisibilityTimeout(), is(VISIBILITY_TIMEOUT_SECONDS + SCHEDULING_BUFFER_SECONDS));
267 }
268
269 @Test
270 public void sentTimestampAttributeIsRequestedToAllowMessageLatencyToBeDerived() throws Exception {
271 messageRunnerService.initialiseMessageConsumers();
272 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
273
274 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
275 assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), contains(SENT_TIMESTAMP));
276 }
277
278 @Test
279 public void consumerResumesPollingIfNoMessagesReturnedInOneInvocationOfReceiveMessage() throws Exception {
280 Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
281 ReceiveMessageResult emptyReceiveMessageResult = mock(ReceiveMessageResult.class);
282 when(emptyReceiveMessageResult.getMessages()).thenReturn(Collections.emptyList());
283 doReturn(emptyReceiveMessageResult).doCallRealMethod().when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
284
285 messageRunnerService.initialiseMessageConsumers();
286 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
287
288 verify(spyingSqsClient, after(1000).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
289 }
290
291 @Test(expected = InvalidMessageContentsException.class)
292 public void addMessageWithPayloadContainingIllegalCharacters() throws Exception {
293 AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
294 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
295 SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "test", registryService,
296 tenantIdSetter, messageInformationService, nestedMessageSerializer);
297
298 when(amazonSQSClient.sendMessage(any(SendMessageRequest.class))).thenThrow(InvalidMessageContentsException.class);
299
300 try {
301 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload with invalid characters"));
302 } catch (MessageRunnerServiceException e) {
303 propagate(getRootCause(e));
304 }
305 }
306
307 @Test(expected = IllegalStateException.class)
308 public void addMessageFailsWithExceptionIfTenantContextNotPresent() throws Exception {
309 AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
310 TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class);
311 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
312 MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
313 SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
314 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
315 when(tenantContextProvider.getTenantContext()).thenReturn(null);
316
317 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
318 }
319
320 @Test(expected = IllegalStateException.class)
321 public void addMessageFailsWithNullTenantId() throws Exception {
322 AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
323 TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
324 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
325 MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
326 SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
327 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
328
329 when(tenantContextProvider.getTenantContext().getTenantId()).thenReturn(null);
330 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
331 }
332
333 @Test(expected = IllegalStateException.class)
334 public void addMessageFailsWithEmptyTenantId() throws Exception {
335 AmazonSQSClient amazonSQSClient = mock(AmazonSQSClient.class, Mockito.RETURNS_DEEP_STUBS);
336 TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
337 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
338 MessageInformationService messageInformationService = new DefaultMessageInformationService(queueUrl, tenantContextProvider, nestedMessageSerializer);
339 SQSMessageRunnerService messageRunnerService = new SQSMessageRunnerService(amazonSQSClient, "queueUrl",
340 registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer);
341
342 when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId("").build());
343 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
344 }
345
346 @Test
347 public void consumerPoolCanRecoverFromOverLimitException() throws Exception {
348
349
350 doThrow(OverLimitException.class)
351 .doCallRealMethod()
352 .when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
353
354 messageRunnerService.initialiseMessageConsumers();
355 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
356
357
358 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(any(MessageContext.class));
359 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
360 assertThat(payloads, hasSize(1));
361 assertThat(payloads, hasItem("1"));
362 }
363
364 @Test
365 public void consumerPoolCanRecoverFromError() throws Exception {
366 final int visibilityTimeoutSeconds = 1;
367 final int schedulingBufferSeconds = 1;
368 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
369 CONCURRENT_CONSUMERS, registryService,
370 tenantIdSetter, messageInformationService, nestedMessageSerializer,
371 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, schedulingBufferSeconds, 1));
372
373
374
375 doThrow(LinkageError.class)
376 .doCallRealMethod()
377 .when(messageRunner).processMessage(any(MessageContext.class));
378
379 messageRunnerService.initialiseMessageConsumers();
380 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
381
382 verify(spyingSqsClient, timeout(2500)).deleteMessage(anyString(), anyString());
383 assertThat(payloads, hasSize(1));
384 assertThat(payloads, hasItem("1"));
385 }
386
387 @Test
388 public void ensureWorkContextAvailableForMessageProcessing() throws Exception {
389 CountDownLatch latch = new CountDownLatch(1);
390 MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("workContext");
391 AtomicReference<Exception> exception = new AtomicReference<>();
392 registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
393 try {
394 new ImmutableWorkContextReference<>(() -> "foobar").get();
395 } catch (IllegalStateException e) {
396 exception.set(e);
397 } finally {
398 latch.countDown();
399 }
400 });
401
402 messageRunnerService.initialiseMessageConsumers();
403 messageRunnerService.addMessage(Message.create(messageRunnerKey, "1"));
404
405 assertTrue("test did not complete within timeout", latch.await(1, TimeUnit.SECONDS));
406 assertNull("No work context available for message processing.", exception.get());
407 }
408
409 @Test
410 public void ensureTenantIdIsSentByProducer() throws Exception {
411 messageRunnerService.initialiseMessageConsumers();
412 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
413
414 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
415 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
416
417 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
418 assertThat(nestedMessage.getAttribute(TENANT_ID_ATTRIBUTE_NAME), is(TENANT_ID));
419 }
420
421 @Test
422 public void ensureMessageRunnerKeyIsSentByProducer() throws Exception {
423 messageRunnerService.initialiseMessageConsumers();
424 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
425
426 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
427 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
428
429 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
430 assertThat(nestedMessage.getAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME), is(MESSAGE_RUNNER_KEY.toString()));
431 }
432
433 @Test
434 public void ensurePayloadIsSentByProducer() throws Exception {
435 String payload = "payload";
436 messageRunnerService.initialiseMessageConsumers();
437 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, payload));
438
439 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
440 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
441
442 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
443 assertThat(nestedMessage.getPayload(), is(payload));
444 }
445
446 @Test
447 public void dontProcessAndDeleteMessagesThatCantBeDeserialized() throws Exception {
448 doThrow(MessageSerializationException.class).when(nestedMessageSerializer).deserialize(any(String.class));
449
450 messageRunnerService.initialiseMessageConsumers();
451 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
452
453 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
454 verify(spyingSqsClient).deleteMessage(anyString(), anyString());
455 }
456
457 @Test
458 public void deleteMessageWithoutTenantId() throws Exception {
459 doReturn(new NestedMessage()).when(nestedMessageSerializer).deserialize(anyString());
460
461 messageRunnerService.initialiseMessageConsumers();
462 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
463
464 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
465 verify(spyingSqsClient).deleteMessage(anyString(), anyString());
466 }
467
468 @Test
469 public void deleteMessageWithoutMessageRunnerKey() throws Exception {
470 NestedMessage nestedMessage = new NestedMessage().addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123");
471 doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
472
473 messageRunnerService.initialiseMessageConsumers();
474 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
475
476 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
477 verify(spyingSqsClient).deleteMessage(anyString(), anyString());
478 }
479
480 @Test
481 public void dontProcessMessageWithoutMessageRunner() throws Exception {
482 NestedMessage nestedMessage = new NestedMessage()
483 .addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123")
484 .addAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME, MESSAGE_RUNNER_KEY.toString());
485 doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
486 doCallRealMethod()
487 .doReturn(Optional.empty())
488 .when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
489
490 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
491 messageRunnerService.initialiseMessageConsumers();
492
493 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
494 verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
495 }
496
497 @Test(expected = MessageSerializationException.class)
498 public void failFastIfMessageCantBeSerialized() throws Exception {
499 doThrow(MessageSerializationException.class).when(messageInformationService).toPayload(any(Message.class));
500
501 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
502 }
503
504 @Test
505 public void ensureTenantContextEstablishedOnConsumer() throws Exception {
506 messageRunnerService.initialiseMessageConsumers();
507 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
508
509 verify(tenantIdSetter, timeout(VERIFY_TIMEOUT_MILLIS)).setTenantId(TENANT_ID);
510 }
511
512 @Test
513 public void messageNotProcessedOrDeletedIfTenantContextEstablishmentThrowsException() throws Exception {
514 Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
515 doThrow(new RuntimeException("Error creating TenantContext from a tenantId (perhaps tenant context service is currently unavailable?)")).when(tenantIdSetter).setTenantId(anyString());
516
517 messageRunnerService.initialiseMessageConsumers();
518 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
519 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
520 verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
521 verify(messageRunner, never()).processMessage(any(MessageContext.class));
522 }
523
524 @Test(expected = IllegalArgumentException.class)
525 public void concurrentConsumersCantBeNegative() throws Exception {
526 int concurrentConsumers = -1;
527
528 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
529 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
530 }
531
532 @Test(expected = IllegalArgumentException.class)
533 public void concurrentConsumersCantBeZero() throws Exception {
534 int concurrentConsumers = 0;
535
536 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, QUEUE_NAME, RECEIVE_WAIT_TIME_SECONDS,
537 concurrentConsumers, registryService, tenantIdSetter, messageInformationService, nestedMessageSerializer, sqsMessageVisibilityTimeoutManager);
538 }
539
540 @Test
541 public void supportEarlyAcknowledgementOfMessage() throws Exception {
542 AtomicInteger deliveryCount = new AtomicInteger(0);
543 int visibilityTimeoutSeconds = 1;
544 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
545 CONCURRENT_CONSUMERS, registryService,
546 tenantIdSetter, messageInformationService, nestedMessageSerializer,
547 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl));
548 MessageRunnerKey fooMessageRunnerKey = MessageRunnerKey.of("fooMessageRunnerKey");
549 registryService.registerMessageRunner(fooMessageRunnerKey, messageContext -> {
550 deliveryCount.incrementAndGet();
551 messageContext.acknowledge();
552 throw new RuntimeException("message runner error");
553 });
554
555 messageRunnerService.initialiseMessageConsumers();
556 messageRunnerService.addMessage(Message.create(fooMessageRunnerKey, "payload"));
557
558 Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
559 assertThat(deliveryCount.get(), is(1));
560 }
561
562 @Test(expected = MessageRunnerNotRegisteredException.class)
563 public void failFastWhenAddingMessageWithNoRegisteredMessageRunner() throws Exception {
564 doReturn(Optional.empty()).when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
565
566 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
567 }
568
569 @Test
570 public void messageRunnerResponsiveToCancellationViaShutdown() throws Exception {
571 final CountDownLatch enterLatch = new CountDownLatch(1);
572 final CountDownLatch exitLatch = new CountDownLatch(1);
573 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
574 final int sleepMillis = 100;
575 registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
576 enterLatch.countDown();
577 while (!messageContext.isCancellationRequested()) {
578 Uninterruptibles.sleepUninterruptibly(sleepMillis, TimeUnit.MILLISECONDS);
579 }
580 exitLatch.countDown();
581 });
582 messageRunnerService.initialiseMessageConsumers();
583 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
584 assertThat("MessageRunner did not start. Cannot assert cancellation on a MessageRunner that did not start",
585 enterLatch.await(1, TimeUnit.SECONDS), is(true));
586
587 messageRunnerService.shutdown();
588
589 assertThat("MessageRunner did not respond to cancellation", exitLatch.await(sleepMillis * 2, TimeUnit.MILLISECONDS), is(true));
590 }
591
592 @Test
593 public void cancelAutoAcknowledgement() throws Exception {
594 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("cancelAckMessageRunnerKey");
595 registryService.registerMessageRunner(messageRunnerKey, MessageContext::cancelAutoAcknowledgementOfMessage);
596
597 messageRunnerService.initialiseMessageConsumers();
598 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
599
600 verify(spyingSqsClient, after(VERIFY_TIMEOUT_MILLIS).never()).deleteMessage(anyString(), anyString());
601 }
602
603 @Test
604 public void autoExtensionOfVisibilityTimeout() throws Exception {
605 final int visibilityTimeoutSeconds = 1;
606
607
608
609 final int concurrentConsumers = 2;
610 final int longRunningTaskDurationMillis = 2500;
611
612 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
613 final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
614 Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
615 }));
616 registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
617
618 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
619 concurrentConsumers, registryService,
620 tenantIdSetter, messageInformationService, nestedMessageSerializer,
621 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
622
623 messageRunnerService.initialiseMessageConsumers();
624 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
625
626 verify(longRunningMessageRunner, after(longRunningTaskDurationMillis + 500).times(1)).processMessage(any(MessageContext.class));
627 verify(spyingSqsClient, times(2)).changeMessageVisibility(anyString(), anyString(), anyInt());
628 }
629
630
631
632
633
634
635
636
637
638 @Test
639 public void autoExtensionOfVisibilityTimeoutFailsBecauseMessageNoLongerInFlight() throws Exception {
640 final int visibilityTimeoutSeconds = 1;
641
642
643
644 final int concurrentConsumers = 2;
645 final int longRunningTaskDurationMillis = 2500;
646 final int schedulingBufferSeconds = 1;
647
648 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
649 final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
650 Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
651 }));
652 registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
653
654 doThrow(MessageNotInflightException.class).when(spyingSqsClient).changeMessageVisibility(anyString(), anyString(), anyInt());
655
656 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
657 concurrentConsumers, registryService,
658 tenantIdSetter, messageInformationService, nestedMessageSerializer,
659 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, schedulingBufferSeconds, 2));
660
661 messageRunnerService.initialiseMessageConsumers();
662 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
663
664 verify(longRunningMessageRunner, after(3000).times(2)).processMessage(any(MessageContext.class));
665 }
666
667 @Test
668 public void autoExtensionOfVisibilityTimeoutCancelledOnMessageProcessingException() throws Exception {
669 final int visibilityTimeoutSeconds = 1;
670 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("erroneousMessageRunnerKey");
671 final MessageRunner erroneousMessageRunner = mock(MessageRunner.class);
672 doThrow(RuntimeException.class).when(erroneousMessageRunner).processMessage(any(MessageContext.class));
673 registryService.registerMessageRunner(messageRunnerKey, erroneousMessageRunner);
674
675 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
676 CONCURRENT_CONSUMERS, registryService,
677 tenantIdSetter, messageInformationService, nestedMessageSerializer,
678 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
679
680 messageRunnerService.initialiseMessageConsumers();
681 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
682
683
684
685 final int waitMillis = visibilityTimeoutSeconds * 2 * 1000;
686 verify(spyingSqsClient, after(waitMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
687 }
688
689 @Test
690 public void autoExtensionOfVisibilityTimeoutCancelledOnEarlyAcknowledgement() throws Exception {
691 final int visibilityTimeoutSeconds = 1;
692 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
693 final int longRunningTaskDurationMillis = 1100;
694 final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
695
696
697 context.acknowledge();
698 Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
699 }));
700 registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
701
702 messageRunnerService = new SQSMessageRunnerService(spyingSqsClient, queueUrl, RECEIVE_WAIT_TIME_SECONDS,
703 CONCURRENT_CONSUMERS, registryService,
704 tenantIdSetter, messageInformationService, nestedMessageSerializer,
705 new SQSMessageVisibilityTimeoutManager(spyingSqsClient, queueUrl, visibilityTimeoutSeconds, 5, 2));
706
707 messageRunnerService.initialiseMessageConsumers();
708 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
709
710 verify(spyingSqsClient, after(longRunningTaskDurationMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
711 }
712
713 private static TenantContext newTenantContext(String tenantId) {
714 TenantContext tenantContext = mock(TenantContext.class);
715 when(tenantContext.getTenantId()).thenReturn(tenantId);
716
717 return tenantContext;
718 }
719
720 private static class DelegatingMessageRunner implements MessageRunner {
721 private final MessageRunner delegate;
722
723 public DelegatingMessageRunner(MessageRunner delegate) {
724 this.delegate = delegate;
725 }
726
727 @Override
728 public void processMessage(MessageContext context) {
729 delegate.processMessage(context);
730 }
731 }
732
733 private static class DefaultMessageRunner implements MessageRunner {
734 private final Queue<String> payloads;
735
736 public DefaultMessageRunner(Queue<String> payloads) {
737 this.payloads = requireNonNull(payloads);
738 }
739
740 @Override
741 public void processMessage(MessageContext messageContext) {
742 payloads.offer(messageContext.getPayload().orElse(null));
743 }
744 }
745 }