View Javadoc

1   package com.atlassian.messagequeue.internal.lifecycle;
2   
3   import com.amazonaws.services.sqs.AmazonSQSClient;
4   import com.amazonaws.services.sqs.model.Message;
5   import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
6   import com.amazonaws.services.sqs.model.ReceiveMessageResult;
7   import com.amazonaws.util.EC2MetadataUtils;
8   import com.atlassian.messagequeue.internal.core.DefaultThreadFactory;
9   import org.slf4j.Logger;
10  import org.slf4j.LoggerFactory;
11  import org.slf4j.MDC;
12  
13  import java.util.ArrayList;
14  import java.util.List;
15  import java.util.concurrent.ExecutorService;
16  import java.util.concurrent.Executors;
17  import java.util.concurrent.TimeUnit;
18  import java.util.function.Supplier;
19  
20  import static java.util.Objects.requireNonNull;
21  
22  /**
23   * Observes EC2 instance lifecycle notifications.
24   *
25   * <p>Allows listeners to be registered to respond to an instance terminating notification</p>.
26   *
27   * @see <a href="http://docs.aws.amazon.com/autoscaling/latest/userguide/lifecycle-hooks.html">AWS Autoscaling lifecycle hooks</a>
28   * @see <a href="https://extranet.atlassian.com/display/MICROS/How+To%3A+Lifecycle+hooks">Micros lifecycle hooks</a>
29   */
30  public class DefaultInstanceLifecycleNotificationObserver implements InstanceLifecycleNotificationObserver {
31      private static final Logger log = LoggerFactory.getLogger(DefaultInstanceLifecycleNotificationObserver.class);
32  
33      private static final int DEFAULT_RECEIVE_WAIT_TIME_SECONDS = Integer.getInteger("instance.lifecycle.observer.receive.wait.time", 20);
34      private static final int DEFAULT_VISIBILITY_TIMEOUT_SECONDS = Integer.getInteger("instance.lifecycle.observer.receive.visibility.timeout", 30);
35      private static final long AWAIT_TERMINATION_TIMEOUT_SECONDS = Long.getLong("instance.lifecycle.observer.await.termination.timeout", TimeUnit.SECONDS.convert(1, TimeUnit.HOURS));
36      private static final String MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_SUBJECT = "instanceLifecycle.notification.subject";
37      private static final String MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_MESSAGE = "instanceLifecycle.notification.message";
38      private static final String MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_LIFECYCLE_TRANSITION = "instanceLifecycle.notification.lifecycleTransition";
39  
40      private final String instanceLifecycleNotificationQueueUrl;
41      private final List<InstanceLifecycleListener> instanceTerminatingListeners;
42      private final AmazonSQSClient amazonSqsClient;
43      private final NotificationDeserializer notificationDeserializer;
44      private final Supplier<String> ec2InstanceIdSupplier; // supplies the EC2 instance ID of the currently running instance
45      private final ExecutorService messageConsumerThreadPool;
46      private final ExecutorService listenerInvokerThreadPool;
47      private final int receiveWaitTimeSeconds;
48      private final int visibilityTimeoutSeconds;
49  
50      private volatile boolean shuttingDown;
51  
52      /**
53       * Constructs an instance of InstanceLifecycleNotificationObserver.
54       *
55       * @param instanceLifecycleNotificationQueueUrl URL of SQS queue where lifecycle notifications are being sent to.
56       * @param amazonSqsClient SQS client
57       * @param notificationDeserializer notification deserializer
58       * @see <a href="https://extranet.atlassian.com/display/MICROS/How+To%3A+Lifecycle+hooks">Micros lifecycle hooks</a>
59       */
60      public DefaultInstanceLifecycleNotificationObserver(String instanceLifecycleNotificationQueueUrl,
61                                                          AmazonSQSClient amazonSqsClient,
62                                                          NotificationDeserializer notificationDeserializer) {
63          this(instanceLifecycleNotificationQueueUrl, amazonSqsClient, notificationDeserializer, EC2MetadataUtils::getInstanceId, DEFAULT_RECEIVE_WAIT_TIME_SECONDS, DEFAULT_VISIBILITY_TIMEOUT_SECONDS);
64      }
65  
66      /**
67       * Visible for tests
68       */
69      DefaultInstanceLifecycleNotificationObserver(String instanceLifecycleNotificationQueueUrl,
70                                                   AmazonSQSClient amazonSqsClient,
71                                                   NotificationDeserializer notificationDeserializer,
72                                                   Supplier<String> ec2InstanceIdSupplier, int receiveWaitTimeSeconds,
73                                                   int visibilityTimeoutSeconds) {
74          this.instanceLifecycleNotificationQueueUrl = requireNonNull(instanceLifecycleNotificationQueueUrl, "instanceLifecycleNotificationsQueueUrl");
75          this.notificationDeserializer = requireNonNull(notificationDeserializer, "notificationDeserializer");
76          this.ec2InstanceIdSupplier = requireNonNull(ec2InstanceIdSupplier, "ec2InstanceIdSupplier");
77          this.instanceTerminatingListeners = new ArrayList<>();
78          this.amazonSqsClient = requireNonNull(amazonSqsClient, "amazonSQSClient");
79          this.receiveWaitTimeSeconds = receiveWaitTimeSeconds;
80          this.visibilityTimeoutSeconds = visibilityTimeoutSeconds;
81          this.messageConsumerThreadPool = Executors.newSingleThreadExecutor(new DefaultThreadFactory("instance-lifecycle-notification-message-consumer-thread-%d", (t, throwable) -> {
82              log.warn("Instance lifecycle notification message consumer thread '{}' died due to an exception.", t.getName(), throwable);
83          }));
84          this.listenerInvokerThreadPool = Executors.newCachedThreadPool(new DefaultThreadFactory("instance-lifecycle-notification-listener-invoker-thread-%d", (t, throwable) -> {
85              log.warn("Instance lifecycle notification listener invoker thread '{}' died due to an exception.", t.getName(), throwable);
86          }));
87  
88          log.info("Constructing {} (queueUrl: {}, receiveWaitTimeSeconds: {}, visibilityTimeoutSeconds: {})",
89              this.getClass().getSimpleName(), instanceLifecycleNotificationQueueUrl, receiveWaitTimeSeconds, visibilityTimeoutSeconds);
90      }
91  
92      @Override
93      public void initialise() {
94          log.info("Initialising lifecycle notification message consumer ...");
95          this.messageConsumerThreadPool.submit(new InstanceLifecycleNotificationMessageConsumer());
96      }
97  
98      /**
99       * Register a listener to respond to instance terminating notification.
100      *
101      * @param listener a listener that will be invoked on an instance terminating notification.
102      */
103     @Override
104     public void addInstanceTerminatingListener(InstanceLifecycleListener listener) {
105         instanceTerminatingListeners.add(listener);
106     }
107 
108     private class InstanceLifecycleNotificationMessageConsumer implements Runnable {
109         @Override
110         public void run() {
111             while (!Thread.currentThread().isInterrupted() && !shuttingDown) {
112                 final ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(instanceLifecycleNotificationQueueUrl)
113                         .withMaxNumberOfMessages(1)
114                         .withWaitTimeSeconds(receiveWaitTimeSeconds)
115                         .withVisibilityTimeout(visibilityTimeoutSeconds);
116 
117                 final ReceiveMessageResult receiveMessageResult;
118                 try {
119                     receiveMessageResult = amazonSqsClient.receiveMessage(receiveMessageRequest);
120                     processReceiveMessageResult(receiveMessageResult);
121                 } catch (Throwable t) {
122                     log.error("Error occurred while consuming a message from SQS", t);
123                 }
124             }
125         }
126 
127         private void processReceiveMessageResult(ReceiveMessageResult receiveMessageResult) {
128             if (receiveMessageResult.getMessages().isEmpty()) {
129                 return;
130             }
131 
132             if (receiveMessageResult.getMessages().size() > 1) {
133                 throw new AssertionError("Number of messages received greater than the max number specified in the request.");
134             }
135 
136             final Message message = receiveMessageResult.getMessages().get(0);
137 
138             log.info("Consuming message (QueueUrl: {}, MessageId: {}, ReceiptHandle: {})",
139                     instanceLifecycleNotificationQueueUrl,
140                     message.getMessageId(),
141                     message.getReceiptHandle());
142 
143             try {
144                 notificationDeserializer.deserialize(message.getBody(), SNSNotification.class).ifPresent(snsNotification -> {
145                     MDC.put(MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_SUBJECT, nullToEmpty(snsNotification.getSubject()));
146                     MDC.put(MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_MESSAGE, nullToEmpty(snsNotification.getMessage()));
147                     try {
148                         log.info("Processing SNS notification (Subject: {})", snsNotification.getSubject());
149                     } finally {
150                         MDC.remove(MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_SUBJECT);
151                         MDC.remove(MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_MESSAGE);
152                     }
153 
154                     notificationDeserializer.deserialize(snsNotification.getMessage(), InstanceLifecycleNotification.class).ifPresent(instanceLifecycleNotification -> {
155                         MDC.put(MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_LIFECYCLE_TRANSITION, nullToEmpty(instanceLifecycleNotification.getLifecycleTransition()));
156                         try {
157                             log.info("Processing lifecycle notification (LifecycleTransition: {}, LifecycleHookName: {}, " +
158                                             "AutoScalingGroupName: {}, LifecycleActionToken: {}, EC2InstanceId: {})",
159                                     instanceLifecycleNotification.getLifecycleTransition(),
160                                     instanceLifecycleNotification.getLifecycleHookName(),
161                                     instanceLifecycleNotification.getAutoScalingGroupName(),
162                                     instanceLifecycleNotification.getLifecycleActionToken(),
163                                     instanceLifecycleNotification.getEc2InstanceId());
164                         } finally {
165                             MDC.remove(MDC_INSTANCE_LIFECYCLE_SNS_NOTIFICATION_LIFECYCLE_TRANSITION);
166                         }
167 
168                         if (ec2InstanceIdSupplier.get() != null && ec2InstanceIdSupplier.get().equals(instanceLifecycleNotification.getEc2InstanceId()) &&
169                                 nullToEmpty(snsNotification.getSubject()).startsWith("Auto Scaling") &&
170                                 "autoscaling:EC2_INSTANCE_TERMINATING".equals(instanceLifecycleNotification.getLifecycleTransition())) {
171                             log.info("Invoking instance terminating listeners ...", ec2InstanceIdSupplier.get());
172                             instanceTerminatingListeners.forEach(listener -> {
173                                 listenerInvokerThreadPool.execute(() -> {
174                                     final InstanceLifecycleContext instanceLifecycleContext =
175                                             new DefaultInstanceLifecycleContext(instanceLifecycleNotification);
176                                     listener.onInstanceLifecycleNotification(instanceLifecycleContext);
177                                 });
178                             });
179                         }
180                     });
181                 });
182             } finally {
183                 log.info("Deleting message (MessageId: {}, ReceiptHandle: {})", message.getMessageId(), message.getReceiptHandle());
184                 amazonSqsClient.deleteMessage(instanceLifecycleNotificationQueueUrl, message.getReceiptHandle());
185             }
186         }
187     }
188 
189     @Override
190     public void shutdown() {
191         log.info("Shutting down {}", this.getClass().getSimpleName());
192 
193         this.shuttingDown = true;
194         this.messageConsumerThreadPool.shutdown();
195         try {
196             if (!messageConsumerThreadPool.awaitTermination(AWAIT_TERMINATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
197                 log.warn("Pool did not terminate in {} seconds", AWAIT_TERMINATION_TIMEOUT_SECONDS);
198             }
199         } catch (InterruptedException ie) {
200             Thread.currentThread().interrupt();
201         }
202     }
203 
204     private static String nullToEmpty(String value) {
205         return value == null ? "" : value;
206     }
207 }