1 package com.atlassian.messagequeue.internal.sqs;
2
3 import com.amazonaws.services.sqs.AmazonSQS;
4 import com.amazonaws.services.sqs.model.InvalidMessageContentsException;
5 import com.amazonaws.services.sqs.model.MessageNotInflightException;
6 import com.amazonaws.services.sqs.model.OverLimitException;
7 import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
8 import com.amazonaws.services.sqs.model.ReceiveMessageResult;
9 import com.amazonaws.services.sqs.model.SendMessageRequest;
10 import com.atlassian.messagequeue.Message;
11 import com.atlassian.messagequeue.MessageInformationService;
12 import com.atlassian.messagequeue.MessageRunnerKey;
13 import com.atlassian.messagequeue.MessageRunnerNotRegisteredException;
14 import com.atlassian.messagequeue.MessageRunnerServiceException;
15 import com.atlassian.messagequeue.MessageSerializationException;
16 import com.atlassian.messagequeue.internal.core.NestedMessage;
17 import com.atlassian.messagequeue.internal.core.messagevalidators.MessageTenantDataIdValidator;
18 import com.atlassian.messagequeue.registry.MessageContext;
19 import com.atlassian.messagequeue.registry.MessageRunner;
20 import com.atlassian.messagequeue.registry.MessageValidator;
21 import com.atlassian.tenant.api.TenantContextProvider;
22 import com.atlassian.tenant.impl.TenantIdSetter;
23 import com.atlassian.workcontext.api.ImmutableWorkContextReference;
24 import com.google.common.util.concurrent.Uninterruptibles;
25 import org.junit.Assume;
26 import org.junit.Test;
27 import org.junit.runner.RunWith;
28 import org.mockito.Mockito;
29 import org.mockito.runners.MockitoJUnitRunner;
30
31 import java.util.Collections;
32 import java.util.Optional;
33 import java.util.concurrent.CountDownLatch;
34 import java.util.concurrent.ExecutorService;
35 import java.util.concurrent.Executors;
36 import java.util.concurrent.TimeUnit;
37 import java.util.concurrent.atomic.AtomicInteger;
38 import java.util.concurrent.atomic.AtomicReference;
39 import java.util.function.Supplier;
40 import java.util.stream.IntStream;
41
42 import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME;
43 import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_DATA_ID_ATTRIBUTE_NAME;
44 import static com.atlassian.messagequeue.internal.core.NestedMessageConstants.TENANT_ID_ATTRIBUTE_NAME;
45 import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.APPROXIMATE_RECEIVE_COUNT;
46 import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.SENT_TIMESTAMP;
47 import static com.google.common.base.Throwables.getRootCause;
48 import static com.google.common.base.Throwables.propagate;
49 import static org.hamcrest.CoreMatchers.hasItem;
50 import static org.hamcrest.CoreMatchers.not;
51 import static org.hamcrest.MatcherAssert.assertThat;
52 import static org.hamcrest.Matchers.containsInAnyOrder;
53 import static org.hamcrest.Matchers.hasSize;
54 import static org.hamcrest.core.Is.is;
55 import static org.hamcrest.text.IsEmptyString.isEmptyOrNullString;
56 import static org.junit.Assert.assertNull;
57 import static org.junit.Assert.assertTrue;
58 import static org.mockito.Matchers.any;
59 import static org.mockito.Matchers.anyInt;
60 import static org.mockito.Matchers.anyString;
61 import static org.mockito.Mockito.after;
62 import static org.mockito.Mockito.doCallRealMethod;
63 import static org.mockito.Mockito.doReturn;
64 import static org.mockito.Mockito.doThrow;
65 import static org.mockito.Mockito.mock;
66 import static org.mockito.Mockito.never;
67 import static org.mockito.Mockito.spy;
68 import static org.mockito.Mockito.timeout;
69 import static org.mockito.Mockito.times;
70 import static org.mockito.Mockito.verify;
71 import static org.mockito.Mockito.when;
72
73 @RunWith(MockitoJUnitRunner.class)
74 public class SQSMessageRunnerServiceTest extends AbstractSQSMessageRunnerServiceTest {
75
76 @Test
77 public void messagePayloadIsPassedToMessageRunner() throws Exception {
78 String expectedPayload = "payload";
79
80 messageRunnerService.initialiseMessageConsumers();
81 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
82
83 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
84 assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of(expectedPayload)));
85 }
86
87 @Test
88 public void messageIsAutoAcknowledgedAfterSuccessfulProcessing() throws Exception {
89 String expectedPayload = "payload";
90
91 messageRunnerService.initialiseMessageConsumers();
92 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, expectedPayload));
93
94 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
95 }
96
97 @Test
98 public void addMessageWithEmptyPayload() throws Exception {
99 messageRunnerService.initialiseMessageConsumers();
100 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, ""));
101
102 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
103 assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
104 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
105 assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.of("")));
106 }
107
108 @Test
109 public void addMessageWithNullPayload() throws Exception {
110 messageRunnerService.initialiseMessageConsumers();
111 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, null));
112
113 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
114 assertThat(sendMessageRequestArgumentCaptor.getValue().getMessageBody(), not(isEmptyOrNullString()));
115 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(messageContextArgumentCaptor.capture());
116 assertThat(messageContextArgumentCaptor.getValue().getPayload(), is(Optional.empty()));
117 }
118
119 @Test
120 public void addMultipleMessagesSerially() throws Exception {
121 messageRunnerService.initialiseMessageConsumers();
122 int numberOfMessages = 20;
123 Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
124
125 range.get().forEach(i -> {
126 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
127 });
128
129 String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
130 verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
131 assertThat(payloads, containsInAnyOrder(expectedPayloads));
132 }
133
134 @Test
135 public void addMultipleMessagesConcurrently() throws Exception {
136 messageRunnerService.initialiseMessageConsumers();
137 int numberOfMessages = 20;
138 Supplier<IntStream> range = () -> IntStream.range(0, numberOfMessages);
139
140 ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
141 try {
142 range.get().forEach(i -> {
143 pool.execute(() -> {
144 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, String.valueOf(i)));
145 });
146 });
147
148 String[] expectedPayloads = range.get().mapToObj(String::valueOf).toArray(String[]::new);
149 verify(spyingSqsClient, timeout(10000).times(numberOfMessages)).deleteMessage(anyString(), anyString());
150 assertThat(payloads, containsInAnyOrder(expectedPayloads));
151 } finally {
152 pool.shutdown();
153 pool.awaitTermination(1, TimeUnit.SECONDS);
154 }
155 }
156
157 @Test
158 public void messageNotDeletedIfRunningItThrowsAnException() throws Exception {
159 MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("failingRunner");
160 registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
161 throw new RuntimeException();
162 });
163
164 messageRunnerService.initialiseMessageConsumers();
165 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
166
167 verify(spyingSqsClient, after(1000).never()).deleteMessage(anyString(), anyString());
168 }
169
170 @Test
171 public void receiveMessageRequestPerformedUsingConfiguredVisibilityTimeoutAndSchedulingBuffer() throws Exception {
172 messageRunnerService.initialiseMessageConsumers();
173 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
174
175 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
176 assertThat(receiveMessageRequestArgumentCaptor.getValue().getVisibilityTimeout(), is(25));
177 }
178
179 @Test
180 public void sentTimestampAttributeIsRequestedToAllowMessageLatencyToBeDerived() throws Exception {
181 messageRunnerService.initialiseMessageConsumers();
182 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
183
184 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
185 assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), hasItem(SENT_TIMESTAMP));
186 }
187
188 @Test
189 public void approximateReceiveCountAttributeIsRequested() throws Exception {
190 messageRunnerService.initialiseMessageConsumers();
191 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
192
193 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeastOnce()).receiveMessage(receiveMessageRequestArgumentCaptor.capture());
194 assertThat(receiveMessageRequestArgumentCaptor.getValue().getAttributeNames(), hasItem(APPROXIMATE_RECEIVE_COUNT));
195 }
196
197 @Test
198 public void consumerResumesPollingIfNoMessagesReturnedInOneInvocationOfReceiveMessage() throws Exception {
199 Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
200 ReceiveMessageResult emptyReceiveMessageResult = mock(ReceiveMessageResult.class);
201 when(emptyReceiveMessageResult.getMessages()).thenReturn(Collections.emptyList());
202 doReturn(emptyReceiveMessageResult).doCallRealMethod().when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
203
204 messageRunnerService.initialiseMessageConsumers();
205 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
206
207 verify(spyingSqsClient, after(1000).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
208 }
209
210 @Test(expected = InvalidMessageContentsException.class)
211 public void addMessageWithPayloadContainingIllegalCharacters() throws Exception {
212 AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
213 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
214 SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
215 .withAmazonSQSClient(amazonSQSClient)
216 .withMessageRunnerRegistryHelper(registryService)
217 .withTenantIdSetter(tenantIdSetter)
218 .withMessageInformationService(messageInformationService)
219 .withNestedMessageSerializer(nestedMessageSerializer)
220 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
221 .build();
222
223 when(amazonSQSClient.sendMessage(any(SendMessageRequest.class))).thenThrow(InvalidMessageContentsException.class);
224
225 try {
226 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload with invalid characters"));
227 } catch (MessageRunnerServiceException e) {
228 propagate(getRootCause(e));
229 }
230 }
231
232 @Test(expected = IllegalStateException.class)
233 public void addMessageFailsWithExceptionIfTenantContextNotPresent() throws Exception {
234 AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
235 TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class);
236 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
237 MessageInformationService messageInformationService = new SQSMessageInformationService(
238 getDefaultMessageRunnerKeyToProducerMapper(),
239 tenantContextProvider,
240 nestedMessageSerializer,
241 tenantDataIdSupplier);
242 SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
243 .withAmazonSQSClient(amazonSQSClient)
244 .withMessageRunnerRegistryHelper(registryService)
245 .withTenantIdSetter(tenantIdSetter)
246 .withMessageInformationService(messageInformationService)
247 .withNestedMessageSerializer(nestedMessageSerializer)
248 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
249 .build();
250 when(tenantContextProvider.getTenantContext()).thenReturn(null);
251
252 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
253 }
254
255 @Test(expected = IllegalStateException.class)
256 public void addMessageFailsWithNullTenantId() throws Exception {
257 AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
258 TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
259 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
260 MessageInformationService messageInformationService = new SQSMessageInformationService(
261 getDefaultMessageRunnerKeyToProducerMapper(),
262 tenantContextProvider,
263 nestedMessageSerializer,
264 tenantDataIdSupplier);
265 SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
266 .withAmazonSQSClient(amazonSQSClient)
267 .withMessageRunnerRegistryHelper(registryService)
268 .withTenantIdSetter(tenantIdSetter)
269 .withMessageInformationService(messageInformationService)
270 .withNestedMessageSerializer(nestedMessageSerializer)
271 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
272 .build();
273
274 when(tenantContextProvider.getTenantContext().getTenantId()).thenReturn(null);
275 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
276 }
277
278 @Test(expected = IllegalStateException.class)
279 public void addMessageFailsWithEmptyTenantId() throws Exception {
280 AmazonSQS amazonSQSClient = mock(AmazonSQS.class, Mockito.RETURNS_DEEP_STUBS);
281 TenantContextProvider tenantContextProvider = mock(TenantContextProvider.class, Mockito.RETURNS_DEEP_STUBS);
282 TenantIdSetter tenantIdSetter = mock(TenantIdSetter.class);
283 MessageInformationService messageInformationService = new SQSMessageInformationService(
284 getDefaultMessageRunnerKeyToProducerMapper(),
285 tenantContextProvider,
286 nestedMessageSerializer,
287 tenantDataIdSupplier);
288 SQSMessageRunnerService messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig())
289 .withAmazonSQSClient(amazonSQSClient)
290 .withMessageRunnerRegistryHelper(registryService)
291 .withTenantIdSetter(tenantIdSetter)
292 .withMessageInformationService(messageInformationService)
293 .withNestedMessageSerializer(nestedMessageSerializer)
294 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
295 .build();
296
297 when(tenantContextProvider.getTenantContext()).thenReturn(tenantContextBuilder.tenantId("").build());
298 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
299 }
300
301 @Test
302 public void consumerPoolCanRecoverFromOverLimitException() throws Exception {
303
304
305 doThrow(OverLimitException.class)
306 .doCallRealMethod()
307 .when(spyingSqsClient).receiveMessage(any(ReceiveMessageRequest.class));
308
309 messageRunnerService.initialiseMessageConsumers();
310 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
311
312
313 verify(messageRunner, timeout(VERIFY_TIMEOUT_MILLIS)).processMessage(any(MessageContext.class));
314 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS)).deleteMessage(anyString(), anyString());
315 assertThat(payloads, hasSize(1));
316 assertThat(payloads, hasItem("1"));
317 }
318
319 @Test
320 public void consumerPoolCanRecoverFromError() throws Exception {
321 final int visibilityTimeoutSeconds = 1;
322 final int schedulingBufferSeconds = 1;
323 messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
324 .withAmazonSQSClient(spyingSqsClient)
325 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
326 .withMessageRunnerRegistryHelper(registryService)
327 .withTenantIdSetter(tenantIdSetter)
328 .withMessageInformationService(messageInformationService)
329 .withNestedMessageSerializer(nestedMessageSerializer)
330 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
331 spyingSqsClient, 1, schedulingBufferSeconds))
332 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
333 .build();
334
335
336
337 doThrow(LinkageError.class)
338 .doCallRealMethod()
339 .when(messageRunner).processMessage(any(MessageContext.class));
340
341 messageRunnerService.initialiseMessageConsumers();
342 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "1"));
343
344 verify(spyingSqsClient, timeout(10000)).deleteMessage(anyString(), anyString());
345 assertThat(payloads, hasSize(1));
346 assertThat(payloads, hasItem("1"));
347 }
348
349 @Test
350 public void ensureWorkContextAvailableForMessageProcessing() throws Exception {
351 CountDownLatch latch = new CountDownLatch(1);
352 MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("workContext");
353 AtomicReference<Exception> exception = new AtomicReference<>();
354 registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
355 try {
356 new ImmutableWorkContextReference<>(() -> "foobar").get();
357 } catch (IllegalStateException e) {
358 exception.set(e);
359 } finally {
360 latch.countDown();
361 }
362 });
363
364 messageRunnerService.initialiseMessageConsumers();
365 messageRunnerService.addMessage(Message.create(messageRunnerKey, "1"));
366
367 assertTrue("test did not complete within timeout", latch.await(1, TimeUnit.SECONDS));
368 assertNull("No work context available for message processing.", exception.get());
369 }
370
371 @Test
372 public void ensureTenantIdIsSentByProducer() throws Exception {
373 messageRunnerService.initialiseMessageConsumers();
374 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
375
376 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
377 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
378
379 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
380 assertThat(nestedMessage.getAttribute(TENANT_ID_ATTRIBUTE_NAME), is(TENANT_ID));
381 }
382
383 @Test
384 public void ensureMessageRunnerKeyIsSentByProducer() throws Exception {
385 messageRunnerService.initialiseMessageConsumers();
386 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
387
388 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
389 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
390
391 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
392 assertThat(nestedMessage.getAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME), is(MESSAGE_RUNNER_KEY.toString()));
393 }
394
395 @Test
396 public void ensureTenantDataIdIsSentByProducer() throws Exception {
397 messageRunnerService.initialiseMessageConsumers();
398
399 tenantDataId = "New Id";
400
401 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
402
403 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
404 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
405
406 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
407 assertThat(nestedMessage.getAttribute(TENANT_DATA_ID_ATTRIBUTE_NAME), is(tenantDataId));
408 }
409
410 @Test
411 public void ensurePayloadIsSentByProducer() throws Exception {
412 String payload = "payload";
413 messageRunnerService.initialiseMessageConsumers();
414 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, payload));
415
416 verify(spyingSqsClient).sendMessage(sendMessageRequestArgumentCaptor.capture());
417 String messageBody = sendMessageRequestArgumentCaptor.getValue().getMessageBody();
418
419 NestedMessage nestedMessage = nestedMessageSerializer.deserialize(messageBody);
420 assertThat(nestedMessage.getPayload(), is(payload));
421 }
422
423 @Test
424 public void dontProcessAndDeleteMessagesThatCantBeDeserialized() throws Exception {
425 doThrow(MessageSerializationException.class).when(nestedMessageSerializer).deserialize(any(String.class));
426
427 messageRunnerService.initialiseMessageConsumers();
428 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
429
430 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
431 verify(spyingSqsClient).deleteMessage(anyString(), anyString());
432 }
433
434 @Test
435 public void deleteMessageWithoutTenantId() throws Exception {
436 doReturn(new NestedMessage()).when(nestedMessageSerializer).deserialize(anyString());
437
438 messageRunnerService.initialiseMessageConsumers();
439 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
440
441 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
442 verify(spyingSqsClient).deleteMessage(anyString(), anyString());
443 }
444
445 @Test
446 public void doNotProcessInvalidMessageWithValidator() throws Exception {
447 final MessageValidator messageTenantDataIdValidator = new MessageTenantDataIdValidator(tenantDataIdSupplier);
448 messageValidatorRegistry.registerMessageValidator(MessageTenantDataIdValidator.KEY, messageTenantDataIdValidator);
449
450 tenantDataId = "old";
451
452 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
453
454 tenantDataId = "new";
455
456 messageRunnerService.initialiseMessageConsumers();
457
458 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
459 }
460
461 @Test
462 public void doProcessValidMessageWithValidator() throws Exception {
463 final MessageValidator messageTenantDataIdValidator = new MessageTenantDataIdValidator(tenantDataIdSupplier);
464 messageValidatorRegistry.registerMessageValidator(MessageTenantDataIdValidator.KEY, messageTenantDataIdValidator);
465
466 tenantDataId = "old";
467
468 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
469 messageRunnerService.initialiseMessageConsumers();
470
471 verify(messageRunner, after(1000).times(1)).processMessage(any(MessageContext.class));
472 }
473
474 @Test
475 public void deleteMessageWithoutMessageRunnerKey() throws Exception {
476 NestedMessage nestedMessage = new NestedMessage().addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123");
477 doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
478
479 messageRunnerService.initialiseMessageConsumers();
480 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
481
482 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
483 verify(spyingSqsClient).deleteMessage(anyString(), anyString());
484 }
485
486 @Test
487 public void dontProcessMessageWithoutMessageRunner() throws Exception {
488 NestedMessage nestedMessage = new NestedMessage()
489 .addAttribute(TENANT_ID_ATTRIBUTE_NAME, "tenant-123")
490 .addAttribute(MESSAGE_RUNNER_KEY_ATTRIBUTE_NAME, MESSAGE_RUNNER_KEY.toString());
491 doReturn(nestedMessage).when(nestedMessageSerializer).deserialize(anyString());
492 doCallRealMethod()
493 .doReturn(Optional.empty())
494 .when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
495
496 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
497 messageRunnerService.initialiseMessageConsumers();
498
499 verify(messageRunner, after(1000).never()).processMessage(any(MessageContext.class));
500 verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
501 }
502
503 @Test(expected = MessageSerializationException.class)
504 public void failFastIfMessageCantBeSerialized() throws Exception {
505 doThrow(MessageSerializationException.class).when(messageInformationService).toPayload(any(Message.class));
506
507 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
508 }
509
510 @Test
511 public void ensureTenantContextEstablishedOnConsumer() throws Exception {
512 messageRunnerService.initialiseMessageConsumers();
513 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
514
515 verify(tenantIdSetter, timeout(VERIFY_TIMEOUT_MILLIS)).setTenantId(TENANT_ID);
516 }
517
518 @Test
519 public void messageNotProcessedOrDeletedIfTenantContextEstablishmentThrowsException() throws Exception {
520 Assume.assumeThat(CONCURRENT_CONSUMERS, is(1));
521 doThrow(new RuntimeException("Error creating TenantContext from a tenantId (perhaps tenant context service is currently unavailable?)")).when(tenantIdSetter).setTenantId(anyString());
522
523 messageRunnerService.initialiseMessageConsumers();
524 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
525 verify(spyingSqsClient, timeout(VERIFY_TIMEOUT_MILLIS).atLeast(2)).receiveMessage(any(ReceiveMessageRequest.class));
526 verify(spyingSqsClient, never()).deleteMessage(anyString(), anyString());
527 verify(messageRunner, never()).processMessage(any(MessageContext.class));
528 }
529
530 @Test
531 public void supportEarlyAcknowledgementOfMessage() throws Exception {
532 AtomicInteger deliveryCount = new AtomicInteger(0);
533 int visibilityTimeoutSeconds = 1;
534 messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
535 .withAmazonSQSClient(spyingSqsClient)
536 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
537 .withMessageRunnerRegistryHelper(registryService)
538 .withTenantIdSetter(tenantIdSetter)
539 .withMessageInformationService(messageInformationService)
540 .withNestedMessageSerializer(nestedMessageSerializer)
541 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(spyingSqsClient))
542 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
543 .build();
544 MessageRunnerKey fooMessageRunnerKey = MessageRunnerKey.of("fooMessageRunnerKey");
545 registryService.registerMessageRunner(fooMessageRunnerKey, messageContext -> {
546 deliveryCount.incrementAndGet();
547 messageContext.acknowledge();
548 throw new RuntimeException("message runner error");
549 });
550
551 messageRunnerService.initialiseMessageConsumers();
552 messageRunnerService.addMessage(Message.create(fooMessageRunnerKey, "payload"));
553
554 Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
555 assertThat(deliveryCount.get(), is(1));
556 }
557
558 @Test(expected = MessageRunnerNotRegisteredException.class)
559 public void failFastWhenAddingMessageWithNoRegisteredMessageRunner() throws Exception {
560 doReturn(Optional.empty()).when(registryService).getMessageRunner(MESSAGE_RUNNER_KEY);
561
562 messageRunnerService.addMessage(Message.create(MESSAGE_RUNNER_KEY, "payload"));
563 }
564
565 @Test
566 public void messageRunnerResponsiveToCancellationViaShutdown() throws Exception {
567 final CountDownLatch enterLatch = new CountDownLatch(1);
568 final CountDownLatch exitLatch = new CountDownLatch(1);
569 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
570 final int sleepMillis = 100;
571 registryService.registerMessageRunner(messageRunnerKey, messageContext -> {
572 enterLatch.countDown();
573 while (!messageContext.isCancellationRequested()) {
574 Uninterruptibles.sleepUninterruptibly(sleepMillis, TimeUnit.MILLISECONDS);
575 }
576 exitLatch.countDown();
577 });
578 messageRunnerService.initialiseMessageConsumers();
579 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
580 assertThat("MessageRunner did not start. Cannot assert cancellation on a MessageRunner that did not start",
581 enterLatch.await(1, TimeUnit.SECONDS), is(true));
582
583 messageRunnerService.shutdown();
584
585 assertThat("MessageRunner did not respond to cancellation", exitLatch.await(sleepMillis * 2, TimeUnit.MILLISECONDS), is(true));
586 }
587
588 @Test
589 public void cancelAutoAcknowledgement() throws Exception {
590 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("cancelAckMessageRunnerKey");
591 registryService.registerMessageRunner(messageRunnerKey, MessageContext::cancelAutoAcknowledgementOfMessage);
592
593 messageRunnerService.initialiseMessageConsumers();
594 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
595
596 verify(spyingSqsClient, after(VERIFY_TIMEOUT_MILLIS).never()).deleteMessage(anyString(), anyString());
597 }
598
599 @Test
600 public void autoExtensionOfVisibilityTimeout() throws Exception {
601 final int visibilityTimeoutSeconds = 1;
602
603
604
605 final int longRunningTaskDurationMillis = 2500;
606
607 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
608 final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
609 Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
610 }));
611 registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
612
613 messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
614 .withAmazonSQSClient(spyingSqsClient)
615 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
616 .withMessageRunnerRegistryHelper(registryService)
617 .withTenantIdSetter(tenantIdSetter)
618 .withMessageInformationService(messageInformationService)
619 .withNestedMessageSerializer(nestedMessageSerializer)
620 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
621 spyingSqsClient,
622 2,
623 5))
624 .withMessageValidatorRegistryHelper(messageValidatorRegistry).build();
625
626 messageRunnerService.initialiseMessageConsumers();
627 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
628
629 verify(longRunningMessageRunner, after(longRunningTaskDurationMillis + 500).times(1)).processMessage(any(MessageContext.class));
630 verify(spyingSqsClient, times(2)).changeMessageVisibility(anyString(), anyString(), anyInt());
631 }
632
633
634
635
636
637
638
639
640
641 @Test
642 public void autoExtensionOfVisibilityTimeoutFailsBecauseMessageNoLongerInFlight() throws Exception {
643 final int visibilityTimeoutSeconds = 1;
644
645
646
647 final int longRunningTaskDurationMillis = 2500;
648 final int schedulingBufferSeconds = 1;
649
650 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
651 final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
652 Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
653 }));
654 registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
655
656 doThrow(MessageNotInflightException.class).when(spyingSqsClient).changeMessageVisibility(anyString(), anyString(), anyInt());
657
658 messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
659 .withAmazonSQSClient(spyingSqsClient)
660 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
661 .withMessageRunnerRegistryHelper(registryService)
662 .withTenantIdSetter(tenantIdSetter)
663 .withMessageInformationService(messageInformationService)
664 .withNestedMessageSerializer(nestedMessageSerializer)
665 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
666 spyingSqsClient,
667 2,
668 schedulingBufferSeconds))
669 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
670 .build();
671
672 messageRunnerService.initialiseMessageConsumers();
673 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
674
675 verify(longRunningMessageRunner, after(3000).times(2)).processMessage(any(MessageContext.class));
676 }
677
678 @Test
679 public void autoExtensionOfVisibilityTimeoutCancelledOnMessageProcessingException() throws Exception {
680 final int visibilityTimeoutSeconds = 1;
681 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("erroneousMessageRunnerKey");
682 final MessageRunner erroneousMessageRunner = mock(MessageRunner.class);
683 doThrow(RuntimeException.class).when(erroneousMessageRunner).processMessage(any(MessageContext.class));
684 registryService.registerMessageRunner(messageRunnerKey, erroneousMessageRunner);
685
686 messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig()).withAmazonSQSClient(spyingSqsClient).withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS).withMessageRunnerRegistryHelper(registryService).withTenantIdSetter(tenantIdSetter).withMessageInformationService(messageInformationService).withNestedMessageSerializer(nestedMessageSerializer).withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(spyingSqsClient,
687 2)).withMessageValidatorRegistryHelper(messageValidatorRegistry).build();
688
689 messageRunnerService.initialiseMessageConsumers();
690 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
691
692
693
694 final int waitMillis = visibilityTimeoutSeconds * 2 * 1000;
695 verify(spyingSqsClient, after(waitMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
696 }
697
698 @Test
699 public void autoExtensionOfVisibilityTimeoutCancelledOnEarlyAcknowledgement() throws Exception {
700 final int visibilityTimeoutSeconds = 1;
701 final MessageRunnerKey messageRunnerKey = MessageRunnerKey.of("longRunningMessageRunnerKey");
702 final int longRunningTaskDurationMillis = 1100;
703 final MessageRunner longRunningMessageRunner = spy(new DelegatingMessageRunner(context -> {
704
705
706 context.acknowledge();
707 Uninterruptibles.sleepUninterruptibly(longRunningTaskDurationMillis, TimeUnit.MILLISECONDS);
708 }));
709 registryService.registerMessageRunner(messageRunnerKey, longRunningMessageRunner);
710
711 messageRunnerService = SQSMessageRunnerService.newBuilder(getDefaultMessageRunnerKeyToProducerMapper(), getDefaultConsumerQueueConfig(visibilityTimeoutSeconds))
712 .withAmazonSQSClient(spyingSqsClient)
713 .withReceiveWaitTimeSeconds(RECEIVE_WAIT_TIME_SECONDS)
714 .withMessageRunnerRegistryHelper(registryService)
715 .withTenantIdSetter(tenantIdSetter)
716 .withMessageInformationService(messageInformationService)
717 .withNestedMessageSerializer(nestedMessageSerializer)
718 .withSqsMessageVisibilityTimeoutManager(new SQSMessageVisibilityTimeoutManager(
719 spyingSqsClient,
720 2))
721 .withMessageValidatorRegistryHelper(messageValidatorRegistry)
722 .build();
723
724 messageRunnerService.initialiseMessageConsumers();
725 messageRunnerService.addMessage(Message.create(messageRunnerKey, "payload"));
726
727 verify(spyingSqsClient, after(longRunningTaskDurationMillis).never()).changeMessageVisibility(anyString(), anyString(), anyInt());
728 }
729 }