1 package com.atlassian.vcache.internal.test;
2
3 import com.atlassian.vcache.internal.NameValidator;
4 import com.atlassian.vcache.internal.RequestContext;
5 import org.slf4j.Logger;
6 import org.slf4j.LoggerFactory;
7
8 import java.util.Map;
9 import java.util.Optional;
10 import java.util.concurrent.ConcurrentHashMap;
11 import java.util.function.Supplier;
12
13 import static com.atlassian.vcache.internal.NameValidator.requireValidPartitionIdentifier;
14 import static java.util.Objects.requireNonNull;
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 public class ThreadLocalRequestContextSupplier implements Supplier<RequestContext> {
43 private static final Logger log = LoggerFactory.getLogger(ThreadLocalRequestContextSupplier.class);
44 private final ThreadLocal<RequestContext> threadRequestContexts = new ThreadLocal<>();
45 private final Optional<Supplier<String>> lenientPartitionIdSupplier;
46
47 private ThreadLocalRequestContextSupplier(Optional<Supplier<String>> lenientPartitionIdSupplier) {
48 this.lenientPartitionIdSupplier = requireNonNull(lenientPartitionIdSupplier);
49 }
50
51
52
53
54 public static ThreadLocalRequestContextSupplier strictSupplier() {
55 return new ThreadLocalRequestContextSupplier(Optional.empty());
56 }
57
58
59
60
61
62
63 public static ThreadLocalRequestContextSupplier lenientSupplier(Supplier<String> partitionIdSupplier) {
64 log.warn("A lenient supplier has been created, TransactionalExternalCaches are now broken");
65 return new ThreadLocalRequestContextSupplier(Optional.of(partitionIdSupplier));
66 }
67
68 @Override
69 public RequestContext get() {
70 final RequestContext current = threadRequestContexts.get();
71 if (current == null) {
72 if (!lenientPartitionIdSupplier.isPresent()) {
73 log.error("Asked for request context when not initialised!");
74 throw new IllegalStateException("Thread has not been initialised.");
75 }
76
77 log.debug("Asked for request context when not initialised, returning a lenient one.");
78 return new LenientRequestContext();
79 }
80
81 return current;
82 }
83
84
85
86
87
88
89
90 public void initThread(String partitionId) {
91 final RequestContext current = threadRequestContexts.get();
92 if (current != null) {
93 log.error(
94 "Asked to initialise thread {} that is already initialised!",
95 Thread.currentThread().getName());
96 throw new IllegalStateException(
97 "Thread '" + Thread.currentThread().getName() + "' has already been initialised.");
98 }
99
100 requireValidPartitionIdentifier(partitionId);
101 log.trace("Initialise request context");
102 threadRequestContexts.set(new TestThreadLocalRequestContext(() -> partitionId));
103 }
104
105 public void clearThread() {
106 final RequestContext current = threadRequestContexts.get();
107 if (current == null) {
108 log.warn("Asked to clear a thread that is already clear!");
109 }
110
111 log.trace("Clear request context");
112 threadRequestContexts.remove();
113 }
114
115
116
117
118 private class LenientRequestContext implements RequestContext {
119 @Override
120 public String partitionIdentifier() {
121 return requireValidPartitionIdentifier(lenientPartitionIdSupplier.get().get());
122 }
123
124 @Override
125 public <T> T computeIfAbsent(Object key, Supplier<T> supplier) {
126 return supplier.get();
127 }
128
129 @Override
130 public <T> Optional<T> get(Object key) {
131 return Optional.empty();
132 }
133 }
134
135 private static class TestThreadLocalRequestContext implements RequestContext {
136
137 private final Supplier<String> partitionIdSupplier;
138 private String partitionId;
139 private final Map<Object, Object> map = new ConcurrentHashMap<>();
140
141 public TestThreadLocalRequestContext(Supplier<String> partitionIdSupplier) {
142 this.partitionIdSupplier = requireNonNull(partitionIdSupplier);
143 }
144
145 @Override
146 public String partitionIdentifier() {
147
148
149 if (partitionId == null) {
150 partitionId = partitionIdSupplier.get();
151 }
152 return partitionId;
153 }
154
155 @SuppressWarnings("unchecked")
156 @Override
157 public <T> T computeIfAbsent(Object key, Supplier<T> supplier) {
158
159 return (T) map.computeIfAbsent(requireNonNull(key), o -> requireNonNull(supplier.get()));
160 }
161
162 @SuppressWarnings("unchecked")
163 @Override
164 public <T> Optional<T> get(Object key) {
165 return Optional.ofNullable((T) map.get(key));
166 }
167 }
168 }