1 package com.atlassian.messagequeue.internal.sqs;
2
3 import com.amazonaws.AbortedException;
4 import com.amazonaws.http.timers.client.SdkInterruptedException;
5 import com.amazonaws.services.sqs.AmazonSQSClient;
6 import com.amazonaws.services.sqs.model.InvalidMessageContentsException;
7 import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
8 import com.amazonaws.services.sqs.model.ReceiveMessageResult;
9 import com.amazonaws.services.sqs.model.SendMessageRequest;
10 import com.amazonaws.services.sqs.model.SendMessageResult;
11 import com.amazonaws.services.sqs.model.UnsupportedOperationException;
12 import com.atlassian.messagequeue.Message;
13 import com.atlassian.messagequeue.MessageInformationService;
14 import com.atlassian.messagequeue.MessageRunnerNotRegisteredException;
15 import com.atlassian.messagequeue.MessageRunnerService;
16 import com.atlassian.messagequeue.MessageRunnerServiceException;
17 import com.atlassian.messagequeue.MessageSerializationException;
18 import com.atlassian.messagequeue.internal.core.NestedMessage;
19 import com.atlassian.messagequeue.internal.core.NestedMessageConsumer;
20 import com.atlassian.messagequeue.internal.core.NestedMessageSerializer;
21 import com.atlassian.messagequeue.internal.core.MessageRunnerRegistryHelper;
22 import com.atlassian.tenant.impl.TenantIdSetter;
23 import org.slf4j.Logger;
24 import org.slf4j.LoggerFactory;
25 import org.slf4j.MDC;
26
27 import javax.annotation.Nonnull;
28 import javax.annotation.PostConstruct;
29 import javax.annotation.PreDestroy;
30 import java.util.concurrent.ExecutorService;
31 import java.util.concurrent.Future;
32 import java.util.concurrent.LinkedBlockingQueue;
33 import java.util.concurrent.ThreadPoolExecutor;
34 import java.util.concurrent.TimeUnit;
35 import java.util.concurrent.atomic.AtomicBoolean;
36
37 import static com.atlassian.messagequeue.internal.core.NestedMessageConsumer.MDC_MESSAGE_RUNNER_KEY;
38 import static java.util.Objects.requireNonNull;
39
40
41
42
43
44
45
46
47
48
49
50 public class SQSMessageRunnerService implements MessageRunnerService {
51 static final int MAX_CONCURRENT_CONSUMERS = 2 * Runtime.getRuntime().availableProcessors();
52 static final String SENT_TIMESTAMP = "SentTimestamp";
53 static final long AWAIT_TERMINATION_TIMEOUT_SECONDS = Long.getLong("amq.sqs.await.termination.timeout", TimeUnit.SECONDS.convert(1, TimeUnit.HOURS));
54
55 private static final Logger log = LoggerFactory.getLogger(SQSMessageRunnerService.class);
56 private static final String ALL = "ALL";
57
58 private static final int CONCURRENT_CONSUMERS = Integer.getInteger("amq.sqs.concurrent.consumers", Runtime.getRuntime().availableProcessors());
59 private static final int DEFAULT_RECEIVE_WAIT_TIME_SECONDS = Integer.getInteger("amq.sqs.receive.wait.time", 20);
60 private static final String MDC_MESSAGE_ID = "AMQ-messageID";
61 private static final String MDC_MESSAGE_SENT_TIMESTAMP = "AMQ-messageSentTimestamp";
62
63 private final AmazonSQSClient amazonSQSClient;
64 private final MessageRunnerRegistryHelper messageRunnerRegistryHelper;
65 private final AtomicBoolean shuttingDown = new AtomicBoolean(false);
66 private final int receiveWaitTimeSeconds;
67 private final String queueUrl;
68 private final int concurrentConsumers;
69 private final ExecutorService threadPool;
70 private final MessageInformationService messageInformationService;
71 private final NestedMessageSerializer nestedMessageSerializer;
72 private final NestedMessageConsumer nestedMessageConsumer;
73 private final SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager;
74
75 public SQSMessageRunnerService(AmazonSQSClient amazonSQSClient, String queueUrl,
76 MessageRunnerRegistryHelper messageRunnerRegistryHelper,
77 TenantIdSetter tenantIdSetter,
78 MessageInformationService messageInformationService, NestedMessageSerializer nestedMessageSerializer) {
79 this(amazonSQSClient, queueUrl, DEFAULT_RECEIVE_WAIT_TIME_SECONDS,
80 CONCURRENT_CONSUMERS, messageRunnerRegistryHelper, tenantIdSetter, messageInformationService, nestedMessageSerializer,
81 new SQSMessageVisibilityTimeoutManager(amazonSQSClient, queueUrl));
82 }
83
84
85
86
87 SQSMessageRunnerService(AmazonSQSClient amazonSQSClient, String queueUrl, int receiveWaitTimeSeconds,
88 int concurrentConsumers,
89 MessageRunnerRegistryHelper messageRunnerRegistryHelper,
90 TenantIdSetter tenantIdSetter,
91 MessageInformationService messageInformationService,
92 NestedMessageSerializer nestedMessageSerializer,
93 SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager) {
94 this.amazonSQSClient = requireNonNull(amazonSQSClient);
95 this.messageRunnerRegistryHelper = requireNonNull(messageRunnerRegistryHelper);
96 this.receiveWaitTimeSeconds = receiveWaitTimeSeconds;
97 this.messageInformationService = requireNonNull(messageInformationService);
98 this.nestedMessageSerializer = requireNonNull(nestedMessageSerializer);
99 this.nestedMessageConsumer = new NestedMessageConsumer(messageRunnerRegistryHelper, tenantIdSetter);
100
101 if (concurrentConsumers <= 0 || concurrentConsumers > MAX_CONCURRENT_CONSUMERS) {
102 throw new IllegalArgumentException("concurrent consumers must be > 0 and <= " + MAX_CONCURRENT_CONSUMERS + ". Received: " + concurrentConsumers);
103 }
104 this.concurrentConsumers = concurrentConsumers;
105 this.queueUrl = requireNonNull(queueUrl, "queueUrl");
106
107 final DefaultThreadFactory threadFactory = new DefaultThreadFactory("sqs-consumer-thread-%d", (t, throwable) -> {
108 log.warn("SQS consumer thread '{}' died due to an exception.", t.getName(), throwable);
109 });
110
111 this.threadPool = new ThreadPoolExecutor(concurrentConsumers, concurrentConsumers, 0L, TimeUnit.MILLISECONDS,
112 new LinkedBlockingQueue<>(),
113 threadFactory) {
114 @Override
115 protected void afterExecute(Runnable r, Throwable t) {
116 if (t != null && !this.isShutdown()) {
117 this.execute(r);
118 }
119 }
120 };
121
122 this.sqsMessageVisibilityTimeoutManager = requireNonNull(sqsMessageVisibilityTimeoutManager, "sqsMessageVisibilityTimeoutManager");
123
124 log.info("Constructing {} (concurrentConsumers: {}, queueUrl: {}, receiveWaitTimeSeconds: {})",
125 this.getClass().getSimpleName(), concurrentConsumers, queueUrl, receiveWaitTimeSeconds);
126 }
127
128 @PostConstruct
129 public void initialiseMessageConsumers() {
130 log.info("Initialising SQS messsage consumers");
131 for (int i = 0; i < concurrentConsumers; i++) {
132 threadPool.execute(new SQSMessageConsumer());
133 }
134 }
135
136 private class SQSMessageConsumer implements Runnable {
137 @Override
138 public void run() {
139 while (!Thread.currentThread().isInterrupted() && !shuttingDown.get()) {
140 final ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(queueUrl)
141 .withMaxNumberOfMessages(1)
142 .withAttributeNames(SENT_TIMESTAMP)
143 .withMessageAttributeNames(ALL)
144 .withWaitTimeSeconds(receiveWaitTimeSeconds)
145 .withVisibilityTimeout(sqsMessageVisibilityTimeoutManager.getVisibilityTimeoutSeconds());
146
147 final ReceiveMessageResult receiveMessageResult;
148 try {
149 receiveMessageResult = amazonSQSClient.receiveMessage(receiveMessageRequest);
150 processReceiveMessageResult(receiveMessageResult);
151 } catch (Throwable t) {
152 if (isThrowbableThrownForInterrupt(t)) {
153 break;
154 } else {
155 log.error("Error occurred while consuming a message from SQS", t);
156 }
157 }
158 }
159 }
160
161 private void processReceiveMessageResult(ReceiveMessageResult receiveMessageResult) {
162 if (receiveMessageResult.getMessages().isEmpty()) {
163 return;
164 }
165
166 if (receiveMessageResult.getMessages().size() > 1) {
167 throw new AssertionError("Number of messages received greater than the max number specified in the request.");
168 }
169
170 final com.amazonaws.services.sqs.model.Message message = receiveMessageResult.getMessages().get(0);
171
172 final Future<?> visibilityTimeoutExtensionFuture = sqsMessageVisibilityTimeoutManager.scheduleVisibilityTimeoutExtension(message.getReceiptHandle());
173 try (MDC.MDCCloseable mdcMessageId = MDC.putCloseable(MDC_MESSAGE_ID, message.getMessageId())) {
174 if (log.isInfoEnabled()) {
175 log.info("Consuming message from {} (messageId: {}, receiptHandle: {})", queueUrl, message.getMessageId(), message.getReceiptHandle());
176 }
177
178 final NestedMessage nestedMessage;
179 try {
180 nestedMessage = nestedMessageSerializer.deserialize(message.getBody());
181 } catch (MessageSerializationException e) {
182 log.error("Message received could not be deserialized: {}. Message will be deleted (messageID: {})", e.getMessage(), message.getMessageId());
183 amazonSQSClient.deleteMessage(queueUrl, message.getReceiptHandle());
184 return;
185 }
186
187 try (MDC.MDCCloseable mdcMessageSentTimestamp = MDC.putCloseable(MDC_MESSAGE_SENT_TIMESTAMP, message.getAttributes().get(SENT_TIMESTAMP))) {
188 nestedMessageConsumer.consume(nestedMessage, new SQSMessageContext(message.getMessageId(),
189 nestedMessage.getPayload(), amazonSQSClient, queueUrl, message.getReceiptHandle(), shuttingDown, visibilityTimeoutExtensionFuture));
190 }
191 } finally {
192 if (!visibilityTimeoutExtensionFuture.isDone()) {
193 final boolean cancelled = visibilityTimeoutExtensionFuture.cancel(true);
194 log.info("Cancelled extension of visibility timeout. Cancellation status: {}", cancelled);
195 }
196 }
197 }
198 }
199
200 private static boolean isThrowbableThrownForInterrupt(Throwable throwable) {
201 return throwable.getCause() instanceof SdkInterruptedException
202 || throwable instanceof AbortedException;
203 }
204
205 @PreDestroy
206 public void shutdown() {
207 log.info("Shutting down {}", this.getClass().getSimpleName());
208
209 shuttingDown.compareAndSet(false, true);
210
211 threadPool.shutdown();
212 try {
213 if (!threadPool.awaitTermination(AWAIT_TERMINATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
214 log.warn("Pool did not terminate in {} seconds", AWAIT_TERMINATION_TIMEOUT_SECONDS);
215 }
216 } catch (InterruptedException ie) {
217 Thread.currentThread().interrupt();
218 }
219
220 sqsMessageVisibilityTimeoutManager.shutdown();
221 }
222
223 @Override
224 public void addMessage(@Nonnull Message message) {
225 if (!messageRunnerRegistryHelper.getMessageRunner(message.getRunnerKey()).isPresent()) {
226 throw new MessageRunnerNotRegisteredException(message.getRunnerKey());
227 }
228
229 final SendMessageResult sendMessageResult;
230 try {
231 final String queueUrl = messageInformationService.getQueueUrl(message.getRunnerKey());
232
233 sendMessageResult = amazonSQSClient.sendMessage(new SendMessageRequest(queueUrl,
234 messageInformationService.toPayload(message)));
235 } catch (InvalidMessageContentsException e) {
236 throw new MessageRunnerServiceException(e);
237 } catch (UnsupportedOperationException e) {
238 throw new java.lang.UnsupportedOperationException(e);
239 }
240
241 if (log.isInfoEnabled()) {
242 try (MDC.MDCCloseable mdcCloseable1 = MDC.putCloseable(MDC_MESSAGE_ID, sendMessageResult.getMessageId());
243 MDC.MDCCloseable mdcCloseable2 = MDC.putCloseable(MDC_MESSAGE_RUNNER_KEY, message.getRunnerKey().toString())) {
244 log.info("Message produced to {} (messageId: {}, messageRunnerKey: {})",
245 queueUrl, sendMessageResult.getMessageId(), message.getRunnerKey().toString());
246 }
247 }
248 }
249 }