View Javadoc

1   package com.atlassian.messagequeue.internal.sqs;
2   
3   import com.amazonaws.AbortedException;
4   import com.amazonaws.http.timers.client.SdkInterruptedException;
5   import com.amazonaws.services.sqs.AmazonSQS;
6   import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
7   import com.amazonaws.services.sqs.model.ReceiveMessageResult;
8   import com.atlassian.messagequeue.MessageSerializationException;
9   import com.atlassian.messagequeue.internal.core.NestedMessage;
10  import com.atlassian.messagequeue.internal.core.NestedMessageConsumer;
11  import com.atlassian.messagequeue.internal.core.NestedMessageSerializer;
12  import org.slf4j.Logger;
13  import org.slf4j.LoggerFactory;
14  import org.slf4j.MDC;
15  
16  import java.util.concurrent.Future;
17  import java.util.concurrent.atomic.AtomicBoolean;
18  
19  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.APPROXIMATE_RECEIVE_COUNT;
20  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.MDC_MESSAGE_ID;
21  import static com.atlassian.messagequeue.internal.sqs.SQSMessageRunnerService.SENT_TIMESTAMP;
22  import static java.util.Objects.requireNonNull;
23  
24  class SQSMessageConsumer implements Runnable {
25      private static final Logger log = LoggerFactory.getLogger(SQSMessageConsumer.class);
26      private static final String ALL = "ALL";
27      private static final String MDC_MESSAGE_SENT_TIMESTAMP = "amq.messageSentTimestamp";
28      private static final String MDC_MESSAGE_APPROXIMATE_RECEIVE_COUNT = "amq.messageApproximateReceiveCount";
29      private final SQSConsumerQueueConfig queueConfig;
30      private final int receiveWaitTimeSeconds;
31      private final AtomicBoolean shuttingDown;
32      private final AmazonSQS amazonSQSClient;
33      private final NestedMessageSerializer nestedMessageSerializer;
34      private final NestedMessageConsumer nestedMessageConsumer;
35      private final SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager;
36  
37      SQSMessageConsumer(SQSConsumerQueueConfig queueConfig, int receiveWaitTimeSeconds, AtomicBoolean shuttingDown, AmazonSQS amazonSQSClient,
38                         NestedMessageSerializer nestedMessageSerializer, NestedMessageConsumer nestedMessageConsumer,
39                         SQSMessageVisibilityTimeoutManager sqsMessageVisibilityTimeoutManager) {
40          this.queueConfig = requireNonNull(queueConfig);
41          this.receiveWaitTimeSeconds = requireNonNull(receiveWaitTimeSeconds);
42          this.shuttingDown = requireNonNull(shuttingDown);
43          this.amazonSQSClient = requireNonNull(amazonSQSClient);
44          this.nestedMessageSerializer = requireNonNull(nestedMessageSerializer);
45          this.nestedMessageConsumer = requireNonNull(nestedMessageConsumer);
46          this.sqsMessageVisibilityTimeoutManager = requireNonNull(sqsMessageVisibilityTimeoutManager);
47      }
48  
49      @Override
50      public void run() {
51          while (!Thread.currentThread().isInterrupted() && !shuttingDown.get()) {
52              final ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(queueConfig.getQueueUrl())
53                      .withMaxNumberOfMessages(1)
54                      .withAttributeNames(SENT_TIMESTAMP, APPROXIMATE_RECEIVE_COUNT)
55                      .withMessageAttributeNames(ALL)
56                      .withWaitTimeSeconds(receiveWaitTimeSeconds)
57                      .withVisibilityTimeout(sqsMessageVisibilityTimeoutManager.getVisibilityTimeoutSeconds(queueConfig));
58  
59              final ReceiveMessageResult receiveMessageResult;
60              try {
61                  receiveMessageResult = amazonSQSClient.receiveMessage(receiveMessageRequest);
62                  processReceiveMessageResult(receiveMessageResult);
63              } catch (Throwable t) {
64                  if (isThrowableThrownForInterrupt(t)) {
65                      break;
66                  } else {
67                      log.error("Error occurred while consuming a message from SQS", t);
68                  }
69              }
70          }
71      }
72  
73      private static boolean isThrowableThrownForInterrupt(Throwable throwable) {
74          return throwable.getCause() instanceof SdkInterruptedException
75                  || throwable instanceof AbortedException; // thrown by aws SDK InputStream classes in response to an interrupt
76      }
77  
78      private void processReceiveMessageResult(ReceiveMessageResult receiveMessageResult) {
79          if (receiveMessageResult.getMessages().isEmpty()) {
80              return;
81          }
82  
83          if (receiveMessageResult.getMessages().size() > 1) {
84              throw new AssertionError("Number of messages received greater than the max number specified in the request.");
85          }
86  
87          final com.amazonaws.services.sqs.model.Message message = receiveMessageResult.getMessages().get(0);
88  
89          final Future<?> visibilityTimeoutExtensionFuture = sqsMessageVisibilityTimeoutManager.scheduleVisibilityTimeoutExtension(queueConfig, message.getReceiptHandle());
90          MDC.put(MDC_MESSAGE_ID, message.getMessageId());
91          try {
92              if (log.isInfoEnabled()) {
93                  log.info("Consuming message from {} (messageId: {}, receiptHandle: {})", queueConfig.getQueueUrl(), message.getMessageId(), message.getReceiptHandle());
94              }
95  
96              final NestedMessage nestedMessage;
97              try {
98                  nestedMessage = nestedMessageSerializer.deserialize(message.getBody());
99              } catch (MessageSerializationException e) {
100                 log.error("Message received could not be deserialized: {}. Message will be deleted (messageID: {})", e.getMessage(), message.getMessageId());
101                 amazonSQSClient.deleteMessage(queueConfig.getQueueUrl(), message.getReceiptHandle());
102                 return;
103             }
104 
105             final String sentTimestamp = message.getAttributes().get(SENT_TIMESTAMP);
106             final String approximateReceiveCount = message.getAttributes().get(APPROXIMATE_RECEIVE_COUNT);
107             MDC.put(MDC_MESSAGE_SENT_TIMESTAMP, sentTimestamp == null ? "" : sentTimestamp);
108             MDC.put(MDC_MESSAGE_APPROXIMATE_RECEIVE_COUNT, approximateReceiveCount == null ? "" : approximateReceiveCount);
109             try {
110                 nestedMessageConsumer.consume(nestedMessage, new SQSMessageContext(message.getMessageId(),
111                         nestedMessage.getPayload(), amazonSQSClient, queueConfig.getQueueUrl(), message.getReceiptHandle(), shuttingDown, visibilityTimeoutExtensionFuture));
112             } finally {
113                 MDC.remove(MDC_MESSAGE_SENT_TIMESTAMP);
114                 MDC.remove(MDC_MESSAGE_APPROXIMATE_RECEIVE_COUNT);
115             }
116         } finally {
117             MDC.remove(MDC_MESSAGE_ID);
118 
119             if (!visibilityTimeoutExtensionFuture.isDone()) {
120                 final boolean cancelled = visibilityTimeoutExtensionFuture.cancel(true);
121                 log.info("Cancelled extension of visibility timeout. Cancellation status: {}", cancelled);
122             }
123         }
124     }
125 }