View Javadoc

1   package com.atlassian.vcache.internal.core;
2   
3   import org.slf4j.Logger;
4   import org.slf4j.LoggerFactory;
5   
6   import java.util.ArrayList;
7   import java.util.Collection;
8   import java.util.Collections;
9   import java.util.function.BiFunction;
10  
11  import static java.util.Objects.requireNonNull;
12  
13  /**
14   * Helper class that guards against a thread recursively performing the same key. This is useful to guard
15   * against behaviour like a thread calling {@link java.util.concurrent.ConcurrentHashMap#compute(Object, BiFunction)}
16   * recursively.
17   *
18   * @param <T> the key type.
19   * @since 1.8.2
20   */
21  public class RecursionDetector<T> {
22      private static final Logger log = LoggerFactory.getLogger(RecursionDetector.class);
23  
24      /**
25       * Used to track inflight calls.
26       */
27      private final ThreadLocal<Collection<T>> INFLIGHT = new ThreadLocal<Collection<T>>() {
28          @Override
29          protected Collection<T> initialValue() {
30              // Using an ArrayList as I believe that size of the collection will be very small, and hence the
31              // performance will be faster than a Set.
32              return new ArrayList<>();
33          }
34      };
35  
36      public Guard<T> guardOn(T key) {
37          return new Guard<>(Collections.singletonList(key), INFLIGHT.get());
38      }
39  
40      public Guard<T> guardOn(Iterable<T> keys) {
41          return new Guard<>(keys, INFLIGHT.get());
42      }
43  
44      /**
45       * Helper class to detect recursive calls to the same key.
46       *
47       * @param <T> the key type.
48       */
49      public static class Guard<T> implements AutoCloseable {
50          private final Iterable<T> keys;
51          private final Collection<T> active;
52  
53          private Guard(Iterable<T> keys, Collection<T> active) {
54              this.keys = requireNonNull(keys);
55              this.active = requireNonNull(active);
56  
57              // first confirm none of the keys are being used, otherwise could leave the threadlocal in a bad state.
58              // Naive implementation would loop once, checking each key and then adding them. If an error was detected,
59              // then we would need to remove all the added keys, before throwing the exception.
60              keys.forEach(k -> {
61                  if (active.contains(k)) {
62                      log.error("Detected recursive call for key {}.", k);
63                      throw new IllegalStateException("Recursive call with key: " + k);
64                  }
65              });
66  
67              // now add all the keys
68              keys.forEach(active::add);
69          }
70  
71          @Override
72          public void close() {
73              keys.forEach(active::remove);
74          }
75      }
76  }