View Javadoc

1   package com.atlassian.vcache.internal.redis;
2   
3   import com.atlassian.marshalling.api.MarshallingPair;
4   import com.atlassian.vcache.CasIdentifier;
5   import com.atlassian.vcache.DirectExternalCache;
6   import com.atlassian.vcache.ExternalCacheException;
7   import com.atlassian.vcache.ExternalCacheSettings;
8   import com.atlassian.vcache.IdentifiedValue;
9   import com.atlassian.vcache.PutPolicy;
10  import com.atlassian.vcache.internal.RequestContext;
11  import com.atlassian.vcache.internal.core.DefaultIdentifiedValue;
12  import com.atlassian.vcache.internal.core.ExternalCacheKeyGenerator;
13  import com.atlassian.vcache.internal.core.VCacheCoreUtils;
14  import com.atlassian.vcache.internal.core.service.AbstractExternalCache;
15  import com.atlassian.vcache.internal.core.service.VersionedExternalCacheRequestContext;
16  import com.google.common.annotations.VisibleForTesting;
17  import org.slf4j.Logger;
18  import org.slf4j.LoggerFactory;
19  import redis.clients.jedis.Jedis;
20  import redis.clients.jedis.Pipeline;
21  import redis.clients.jedis.Response;
22  
23  import java.util.Arrays;
24  import java.util.Collections;
25  import java.util.HashMap;
26  import java.util.List;
27  import java.util.Map;
28  import java.util.Optional;
29  import java.util.Set;
30  import java.util.concurrent.CompletionStage;
31  import java.util.function.Function;
32  import java.util.function.Supplier;
33  import java.util.stream.Collectors;
34  import java.util.stream.StreamSupport;
35  
36  import static com.atlassian.vcache.internal.core.VCacheCoreUtils.isEmpty;
37  import static com.atlassian.vcache.internal.core.VCacheCoreUtils.marshall;
38  import static com.atlassian.vcache.internal.core.VCacheCoreUtils.unmarshall;
39  import static java.util.Objects.requireNonNull;
40  
41  /**
42   * Implementation of the {@link DirectExternalCache} that uses Redis.
43   *
44   * @param <V> the value type
45   * @since 1.0
46   */
47  class RedisDirectExternalCache<V>
48          extends AbstractExternalCache<V>
49          implements DirectExternalCache<V> {
50      private static final Logger log = LoggerFactory.getLogger(RedisDirectExternalCache.class);
51  
52      /**
53       * Script bytes for performing a {@link #replaceIf(String, CasIdentifier, Object)} operation. Accepts:
54       * <ul>
55       * <li>Single key</li>
56       * <li>
57       * Following arguments:
58       * <ol>
59       * <li>old value to match</li>
60       * <li>time to live in seconds</li>
61       * <li>new value to replace</li>
62       * </ol>
63       * </li>
64       * </ul>
65       */
66      private static final byte[] LUA_REPLACE_IF_SCRIPT =
67              ("if redis.call(\"get\",KEYS[1]) == ARGV[1] then " +
68                      "    return redis.call(\"setex\",KEYS[1],ARGV[2],ARGV[3]) " +
69                      "else " +
70                      "    return \"FAIL\" " +
71                      "end").getBytes();
72  
73      /**
74       * Script bytes for performing a {@link #removeIf(String, CasIdentifier)} operation. Accepts:
75       * <ul>
76       * <li>Single key</li>
77       * <li>
78       * Following arguments:
79       * <ol>
80       * <li>old value to match</li>
81       * </ol>
82       * </li>
83       * </ul>
84       */
85      private static final byte[] LUA_REMOVE_IF_SCRIPT =
86              ("if redis.call(\"get\",KEYS[1]) == ARGV[1] then " +
87                      "    return redis.call(\"del\",KEYS[1]) " +
88                      "else " +
89                      "    return 0 " +
90                      "end").getBytes();
91  
92      private final Supplier<Jedis> clientSupplier;
93      private final Supplier<RequestContext> contextSupplier;
94      private final ExternalCacheKeyGenerator keyGenerator;
95      private final MarshallingPair<V> valueMarshalling;
96      private final int defaultTtl;
97  
98      RedisDirectExternalCache(
99              Supplier<Jedis> clientSupplier,
100             Supplier<RequestContext> contextSupplier,
101             ExternalCacheKeyGenerator keyGenerator,
102             String name,
103             MarshallingPair<V> valueMarshalling,
104             ExternalCacheSettings settings) {
105         super(name);
106         this.clientSupplier = requireNonNull(clientSupplier);
107         this.contextSupplier = requireNonNull(contextSupplier);
108         this.keyGenerator = requireNonNull(keyGenerator);
109         this.valueMarshalling = requireNonNull(valueMarshalling);
110         this.defaultTtl = VCacheCoreUtils.roundUpToSeconds(settings.getDefaultTtl().get());
111     }
112 
113     @Override
114     public CompletionStage<Optional<V>> get(String internalKey) {
115         return perform(() -> {
116             final String externalKey = buildExternalKey(internalKey);
117             try (Jedis client = clientSupplier.get()) {
118                 return unmarshall(client.get(externalKey.getBytes()), valueMarshalling);
119             }
120         });
121     }
122 
123     @Override
124     public CompletionStage<V> get(String internalKey, Supplier<V> supplier) {
125         return perform(() -> directGet(internalKey, supplier));
126     }
127 
128     @Override
129     public CompletionStage<Optional<IdentifiedValue<V>>> getIdentified(String internalKey) {
130         return perform(() -> {
131             final String externalKey = buildExternalKey(internalKey);
132             final byte[] existingValueBytes;
133             try (Jedis client = clientSupplier.get()) {
134                 existingValueBytes = client.get(externalKey.getBytes());
135             }
136 
137             if (existingValueBytes == null) {
138                 return Optional.empty();
139             }
140 
141             final CasIdentifier identifier = new RedisCasIdentifier(existingValueBytes);
142             final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(
143                     identifier, valueMarshalling.getUnmarshaller().unmarshallFrom(existingValueBytes));
144             return Optional.of(iv);
145         });
146     }
147 
148     @Override
149     public CompletionStage<IdentifiedValue<V>> getIdentified(String internalKey, Supplier<V> supplier) {
150         return perform(() -> {
151             final V value = directGet(internalKey, supplier);
152             final CasIdentifier identifier = new RedisCasIdentifier(marshall(value, valueMarshalling));
153             return new DefaultIdentifiedValue<>(identifier, value);
154         });
155     }
156 
157     @Override
158     public CompletionStage<Map<String, Optional<V>>> getBulk(Iterable<String> internalKeys) {
159         return perform(() -> {
160             if (isEmpty(internalKeys)) {
161                 return new HashMap<>();
162             }
163 
164             // De-duplicate the keys and calculate the externalKeys
165             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
166 
167             final Set<String> externalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
168                     .map(cacheContext::externalEntryKeyFor)
169                     .collect(Collectors.toSet());
170 
171 
172             return RedisUtils.directGetBulk(externalKeys, clientSupplier, valueMarshalling)
173                     .entrySet().stream()
174                     .collect(Collectors.toMap(
175                             e -> cacheContext.internalEntryKeyFor(e.getKey()),
176                             Map.Entry::getValue));
177         });
178     }
179 
180     @Override
181     public CompletionStage<Map<String, V>> getBulk(Function<Set<String>, Map<String, V>> factory, Iterable<String> internalKeys) {
182         return perform(() -> {
183             if (isEmpty(internalKeys)) {
184                 return new HashMap<>();
185             }
186 
187             // De-duplicate the keys and calculate the externalKeys
188             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
189 
190             final Set<String> externalKeys = Collections.unmodifiableSet(
191                     StreamSupport.stream(internalKeys.spliterator(), false)
192                             .map(cacheContext::externalEntryKeyFor)
193                             .collect(Collectors.toSet()));
194 
195             // Get the known values from the external cache.
196             final Map<String, Optional<V>> candidateValues =
197                     RedisUtils.directGetBulk(externalKeys, clientSupplier, valueMarshalling);
198 
199             // Add the known values to the grand result
200             final Map<String, V> grandResult = candidateValues.entrySet().stream()
201                     .filter(e -> e.getValue().isPresent())
202                     .collect(Collectors.toMap(
203                             e -> cacheContext.internalEntryKeyFor(e.getKey()),
204                             e -> e.getValue().get()));
205             getLogger().trace("Cache {}: getBulk(Function): {} of {} entries have values",
206                     name, grandResult.size(), externalKeys.size());
207 
208             // Calculate the missing keys
209             final List<String> missingExternalKeys = candidateValues.entrySet().stream()
210                     .filter(e -> !e.getValue().isPresent())
211                     .map(Map.Entry::getKey)
212                     .collect(Collectors.toList());
213 
214             if (!missingExternalKeys.isEmpty()) {
215                 // Okay, need to get the missing values and mapping from externalKeys to internalKeys
216                 final Set<String> missingInternalKeys = Collections.unmodifiableSet(
217                         missingExternalKeys.stream().map(cacheContext::internalEntryKeyFor).collect(Collectors.toSet()));
218                 final Map<String, V> missingValues = factory.apply(missingInternalKeys);
219 
220                 // Okay, got the missing values, now need to add them to Redis
221                 try (Jedis client = clientSupplier.get()) {
222                     final Pipeline pipeline = client.pipelined();
223                     final Map<String, Response<String>> internalKeyToResponseMap =
224                             missingValues.entrySet().stream().collect(Collectors.toMap(
225                                     Map.Entry::getKey,
226                                     e -> pipeline.setex(
227                                             cacheContext.externalEntryKeyFor(e.getKey()).getBytes(),
228                                             defaultTtl,
229                                             marshall(e.getValue(), valueMarshalling))
230                             ));
231                     pipeline.sync();
232 
233                     internalKeyToResponseMap.entrySet().stream()
234                             .filter(e -> !RedisUtils.OK.equals(e.getValue().get()))
235                             .forEach(e -> log.warn("Cache {}: Unable to set key {}", name, e.getKey()));
236                 }
237 
238                 grandResult.putAll(missingValues);
239             }
240 
241             return grandResult;
242         });
243     }
244 
245     @Override
246     public CompletionStage<Map<String, Optional<IdentifiedValue<V>>>> getBulkIdentified(Iterable<String> internalKeys) {
247         return perform(() -> {
248             if (isEmpty(internalKeys)) {
249                 return new HashMap<>();
250             }
251 
252             // De-duplicate the keys and calculate the externalKeys
253             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
254 
255             final Set<String> externalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
256                     .map(cacheContext::externalEntryKeyFor)
257                     .collect(Collectors.toSet());
258 
259             return RedisUtils.directGetBulk(externalKeys, clientSupplier, valueMarshalling)
260                     .entrySet().stream()
261                     .collect(Collectors.toMap(
262                             e -> cacheContext.internalEntryKeyFor(e.getKey()),
263                             e -> buildIdentifiedValue(e.getValue())));
264         });
265     }
266 
267     @Override
268     public CompletionStage<Boolean> put(String internalKey, V value, PutPolicy policy) {
269         return perform(() -> {
270             final String externalKey = buildExternalKey(internalKey);
271             final byte[] valueBytes = valueMarshalling.getMarshaller().marshallToBytes(requireNonNull(value));
272 
273             return RedisUtils.putOperationForPolicy(
274                     policy, clientSupplier, externalKey, defaultTtl, valueBytes);
275         });
276     }
277 
278     @Override
279     public CompletionStage<Boolean> removeIf(String internalKey, CasIdentifier casId) {
280         return perform(() -> {
281             final String externalKey = buildExternalKey(internalKey);
282 
283             try (Jedis client = clientSupplier.get()) {
284                 final List<byte[]> keys = Arrays.asList(externalKey.getBytes());
285                 final List<byte[]> args = Arrays.asList(RedisUtils.safeExtractValue(casId));
286                 final Number numDeleted = (Number) client.eval(LUA_REMOVE_IF_SCRIPT, keys, args);
287                 return numDeleted.longValue() > 0;
288             }
289         });
290     }
291 
292     @Override
293     public CompletionStage<Boolean> replaceIf(String internalKey, CasIdentifier casId, V newValue) {
294         return perform(() -> {
295             final String externalKey = buildExternalKey(internalKey);
296 
297             try (Jedis client = clientSupplier.get()) {
298                 final List<byte[]> keys = Arrays.asList(externalKey.getBytes());
299                 final List<byte[]> args = Arrays.asList(
300                         RedisUtils.safeExtractValue(casId),
301                         Integer.toString(defaultTtl).getBytes(),
302                         marshall(newValue, valueMarshalling));
303                 final byte[] resultBytes = (byte[]) client.eval(LUA_REPLACE_IF_SCRIPT, keys, args);
304                 return Arrays.equals(RedisUtils.OK.getBytes(), resultBytes);
305             }
306         });
307     }
308 
309     @Override
310     public CompletionStage<Void> remove(Iterable<String> internalKeys) {
311         // There is no bulk delete in the api, so need to remove each one async
312         return perform(() -> {
313             if (isEmpty(internalKeys)) {
314                 return null;
315             }
316 
317             // Lodge all the requests for delete
318             try (Jedis client = clientSupplier.get()) {
319                 final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
320                 final List<byte[]> externalKeysList = StreamSupport.stream(internalKeys.spliterator(), false)
321                         .map(cacheContext::externalEntryKeyFor)
322                         .map(String::getBytes)
323                         .collect(Collectors.toList());
324                 final byte[][] externalKeysAsBytes = externalKeysList.toArray(new byte[externalKeysList.size()][]);
325 
326                 final long numDeleted = client.del(externalKeysAsBytes);
327                 if (numDeleted != externalKeysAsBytes.length) {
328                     log.info("Cache {}: only able to delete {} of {} keys", name, numDeleted, externalKeysAsBytes.length);
329                 }
330             }
331 
332             return null;
333         });
334     }
335 
336     @Override
337     public CompletionStage<Void> removeAll() {
338         return perform(() -> {
339             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
340             cacheContext.updateCacheVersion(
341                     RedisUtils.incrementCacheVersion(clientSupplier, cacheContext.externalCacheVersionKey()));
342             return null;
343         });
344     }
345 
346     @VisibleForTesting
347     void refreshCacheVersion() {
348         // Refresh the cacheVersion. Useful if want to get the current state of the external cache in testing.
349         final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
350         cacheContext.updateCacheVersion(
351                 RedisUtils.obtainCacheVersion(
352                         clientSupplier,
353                         cacheContext.externalCacheVersionKey(),
354                         defaultTtl + 1));
355     }
356 
357     private String buildExternalKey(String internalKey) {
358         final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
359         return cacheContext.externalEntryKeyFor(internalKey);
360     }
361 
362     protected VersionedExternalCacheRequestContext<V> ensureCacheContext() {
363         final RequestContext requestContext = contextSupplier.get();
364 
365         return requestContext.computeIfAbsent(this, () -> {
366             // Need to build a new context, which involves getting the current cache version, or setting it if it does
367             // not exist.
368             log.trace("Cache {}: Setting up a new context", name);
369             final VersionedExternalCacheRequestContext<V> newCacheContext =
370                     new VersionedExternalCacheRequestContext<>(keyGenerator, name, requestContext::partitionIdentifier);
371             newCacheContext.updateCacheVersion(
372                     RedisUtils.obtainCacheVersion(
373                             clientSupplier,
374                             newCacheContext.externalCacheVersionKey(),
375                             defaultTtl + 1));
376             return newCacheContext;
377         });
378     }
379 
380     @Override
381     protected Logger getLogger() {
382         return log;
383     }
384 
385     @Override
386     protected ExternalCacheException mapException(Exception ex) {
387         return RedisUtils.mapException(ex);
388     }
389 
390     private Optional<IdentifiedValue<V>> buildIdentifiedValue(Optional<V> from) {
391         return from.flatMap(v -> {
392             final CasIdentifier identifier = new RedisCasIdentifier(marshall(v, valueMarshalling));
393             final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(identifier, v);
394             return Optional.of(iv);
395         });
396     }
397 
398     private V directGet(String internalKey, Supplier<V> supplier) {
399         final String externalKey = buildExternalKey(internalKey);
400         try (Jedis client = clientSupplier.get()) {
401             final Optional<V> existingValue = unmarshall(client.get(externalKey.getBytes()), valueMarshalling);
402             if (existingValue.isPresent()) {
403                 return existingValue.get();
404             }
405         }
406 
407         log.trace("Cache {}, creating candidate for key {}", name, internalKey);
408         final V candidateValue = requireNonNull(supplier.get());
409         final byte[] candidateValueBytes = valueMarshalling.getMarshaller().marshallToBytes(candidateValue);
410 
411         // Loop until either able to add the candidate value, or retrieve one that has been added by another thread
412         try (Jedis client = clientSupplier.get()) {
413             for (; ; ) {
414                 final long addOp = client.setnx(externalKey.getBytes(), candidateValueBytes);
415                 if (addOp == 1) {
416                     // I break here, rather than just return, due to battling with the compiler. Unless written
417                     // this way, the compiler will not allow the lambda structure.
418                     break;
419                 }
420 
421                 log.info("Cache {}, unable to add candidate for key {}, retrieve what was added", name, internalKey);
422                 final Optional<V> otherAddedValue = unmarshall(client.get(externalKey.getBytes()), valueMarshalling);
423                 if (otherAddedValue.isPresent()) {
424                     return otherAddedValue.get();
425                 }
426 
427                 log.info("Cache {}, unable to retrieve recently added candidate for key {}, looping", name, internalKey);
428             }
429         }
430 
431         return candidateValue;
432     }
433 }