View Javadoc

1   package com.atlassian.vcache.internal.redis;
2   
3   import com.atlassian.marshalling.api.MarshallingPair;
4   import com.atlassian.vcache.CasIdentifier;
5   import com.atlassian.vcache.DirectExternalCache;
6   import com.atlassian.vcache.ExternalCacheException;
7   import com.atlassian.vcache.ExternalCacheSettings;
8   import com.atlassian.vcache.IdentifiedValue;
9   import com.atlassian.vcache.PutPolicy;
10  import com.atlassian.vcache.internal.RequestContext;
11  import com.atlassian.vcache.internal.core.DefaultIdentifiedValue;
12  import com.atlassian.vcache.internal.core.ExternalCacheKeyGenerator;
13  import com.atlassian.vcache.internal.core.VCacheCoreUtils;
14  import com.atlassian.vcache.internal.core.service.AbstractExternalCache;
15  import com.atlassian.vcache.internal.core.service.FactoryUtils;
16  import com.atlassian.vcache.internal.core.service.VersionedExternalCacheRequestContext;
17  import com.google.common.annotations.VisibleForTesting;
18  import org.slf4j.Logger;
19  import org.slf4j.LoggerFactory;
20  import redis.clients.jedis.Jedis;
21  import redis.clients.jedis.Pipeline;
22  import redis.clients.jedis.Response;
23  
24  import java.time.Duration;
25  import java.util.Arrays;
26  import java.util.Collections;
27  import java.util.HashMap;
28  import java.util.List;
29  import java.util.Map;
30  import java.util.Optional;
31  import java.util.Set;
32  import java.util.concurrent.CompletionStage;
33  import java.util.function.Function;
34  import java.util.function.Supplier;
35  import java.util.stream.Collectors;
36  import java.util.stream.StreamSupport;
37  
38  import static com.atlassian.vcache.internal.core.VCacheCoreUtils.isEmpty;
39  import static com.atlassian.vcache.internal.core.VCacheCoreUtils.marshall;
40  import static com.atlassian.vcache.internal.core.VCacheCoreUtils.unmarshall;
41  import static java.util.Objects.requireNonNull;
42  
43  /**
44   * Implementation of the {@link DirectExternalCache} that uses Redis.
45   *
46   * @param <V> the value type
47   * @since 1.0
48   */
49  class RedisDirectExternalCache<V>
50          extends AbstractExternalCache<V>
51          implements DirectExternalCache<V> {
52      private static final Logger log = LoggerFactory.getLogger(RedisDirectExternalCache.class);
53  
54      /**
55       * Script bytes for performing a {@link #replaceIf(String, CasIdentifier, Object)} operation. Accepts:
56       * <ul>
57       * <li>Single key</li>
58       * <li>
59       * Following arguments:
60       * <ol>
61       * <li>old value to match</li>
62       * <li>time to live in seconds</li>
63       * <li>new value to replace</li>
64       * </ol>
65       * </li>
66       * </ul>
67       */
68      private static final byte[] LUA_REPLACE_IF_SCRIPT =
69              ("if redis.call(\"get\",KEYS[1]) == ARGV[1] then " +
70                      "    return redis.call(\"setex\",KEYS[1],ARGV[2],ARGV[3]) " +
71                      "else " +
72                      "    return \"FAIL\" " +
73                      "end").getBytes();
74  
75      /**
76       * Script bytes for performing a {@link #removeIf(String, CasIdentifier)} operation. Accepts:
77       * <ul>
78       * <li>Single key</li>
79       * <li>
80       * Following arguments:
81       * <ol>
82       * <li>old value to match</li>
83       * </ol>
84       * </li>
85       * </ul>
86       */
87      private static final byte[] LUA_REMOVE_IF_SCRIPT =
88              ("if redis.call(\"get\",KEYS[1]) == ARGV[1] then " +
89                      "    return redis.call(\"del\",KEYS[1]) " +
90                      "else " +
91                      "    return 0 " +
92                      "end").getBytes();
93  
94      private final Supplier<Jedis> clientSupplier;
95      private final Supplier<RequestContext> contextSupplier;
96      private final ExternalCacheKeyGenerator keyGenerator;
97      private final MarshallingPair<V> valueMarshalling;
98      private final int defaultTtl;
99  
100     RedisDirectExternalCache(
101             Supplier<Jedis> clientSupplier,
102             Supplier<RequestContext> contextSupplier,
103             ExternalCacheKeyGenerator keyGenerator,
104             String name,
105             MarshallingPair<V> valueMarshalling,
106             ExternalCacheSettings settings,
107             Duration lockTimeout) {
108         super(name, lockTimeout);
109         this.clientSupplier = requireNonNull(clientSupplier);
110         this.contextSupplier = requireNonNull(contextSupplier);
111         this.keyGenerator = requireNonNull(keyGenerator);
112         this.valueMarshalling = requireNonNull(valueMarshalling);
113         this.defaultTtl = VCacheCoreUtils.roundUpToSeconds(settings.getDefaultTtl().get());
114     }
115 
116     @Override
117     public CompletionStage<Optional<V>> get(String internalKey) {
118         return perform(() -> {
119             final String externalKey = buildExternalKey(internalKey);
120             try (Jedis client = clientSupplier.get()) {
121                 return unmarshall(client.get(externalKey.getBytes()), valueMarshalling);
122             }
123         });
124     }
125 
126     @Override
127     public CompletionStage<V> get(String internalKey, Supplier<V> supplier) {
128         return perform(() -> directGet(internalKey, supplier));
129     }
130 
131     @Override
132     public CompletionStage<Optional<IdentifiedValue<V>>> getIdentified(String internalKey) {
133         return perform(() -> {
134             final String externalKey = buildExternalKey(internalKey);
135             final byte[] existingValueBytes;
136             try (Jedis client = clientSupplier.get()) {
137                 existingValueBytes = client.get(externalKey.getBytes());
138             }
139 
140             if (existingValueBytes == null) {
141                 return Optional.empty();
142             }
143 
144             final CasIdentifier identifier = new RedisCasIdentifier(existingValueBytes);
145             final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(
146                     identifier, valueMarshalling.getUnmarshaller().unmarshallFrom(existingValueBytes));
147             return Optional.of(iv);
148         });
149     }
150 
151     @Override
152     public CompletionStage<IdentifiedValue<V>> getIdentified(String internalKey, Supplier<V> supplier) {
153         return perform(() -> {
154             final V value = directGet(internalKey, supplier);
155             final CasIdentifier identifier = new RedisCasIdentifier(marshall(value, valueMarshalling));
156             return new DefaultIdentifiedValue<>(identifier, value);
157         });
158     }
159 
160     @Override
161     public CompletionStage<Map<String, Optional<V>>> getBulk(Iterable<String> internalKeys) {
162         return perform(() -> {
163             if (isEmpty(internalKeys)) {
164                 return new HashMap<>();
165             }
166 
167             // De-duplicate the keys and calculate the externalKeys
168             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
169 
170             final Set<String> externalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
171                     .map(cacheContext::externalEntryKeyFor)
172                     .collect(Collectors.toSet());
173 
174 
175             return RedisUtils.directGetBulk(externalKeys, clientSupplier, valueMarshalling)
176                     .entrySet().stream()
177                     .collect(Collectors.toMap(
178                             e -> cacheContext.internalEntryKeyFor(e.getKey()),
179                             Map.Entry::getValue));
180         });
181     }
182 
183     @Override
184     public CompletionStage<Map<String, V>> getBulk(Function<Set<String>, Map<String, V>> factory, Iterable<String> internalKeys) {
185         return perform(() -> {
186             if (isEmpty(internalKeys)) {
187                 return new HashMap<>();
188             }
189 
190             // De-duplicate the keys and calculate the externalKeys
191             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
192 
193             final Set<String> externalKeys = Collections.unmodifiableSet(
194                     StreamSupport.stream(internalKeys.spliterator(), false)
195                             .map(cacheContext::externalEntryKeyFor)
196                             .collect(Collectors.toSet()));
197 
198             // Get the known values from the external cache.
199             final Map<String, Optional<V>> candidateValues =
200                     RedisUtils.directGetBulk(externalKeys, clientSupplier, valueMarshalling);
201 
202             // Add the known values to the grand result
203             final Map<String, V> grandResult = candidateValues.entrySet().stream()
204                     .filter(e -> e.getValue().isPresent())
205                     .collect(Collectors.toMap(
206                             e -> cacheContext.internalEntryKeyFor(e.getKey()),
207                             e -> e.getValue().get()));
208             getLogger().trace("Cache {}: getBulk(Function): {} of {} entries have values",
209                     name, grandResult.size(), externalKeys.size());
210 
211             // Calculate the missing keys
212             final List<String> missingExternalKeys = candidateValues.entrySet().stream()
213                     .filter(e -> !e.getValue().isPresent())
214                     .map(Map.Entry::getKey)
215                     .collect(Collectors.toList());
216 
217             if (!missingExternalKeys.isEmpty()) {
218                 // Okay, need to get the missing values and mapping from externalKeys to internalKeys
219                 final Set<String> missingInternalKeys = Collections.unmodifiableSet(
220                         missingExternalKeys.stream().map(cacheContext::internalEntryKeyFor).collect(Collectors.toSet()));
221                 final Map<String, V> missingValues = factory.apply(missingInternalKeys);
222                 FactoryUtils.verifyFactoryResult(missingValues, missingInternalKeys);
223 
224                 // Okay, got the missing values, now need to add them to Redis
225                 try (Jedis client = clientSupplier.get()) {
226                     final Pipeline pipeline = client.pipelined();
227                     final Map<String, Response<String>> internalKeyToResponseMap =
228                             missingValues.entrySet().stream().collect(Collectors.toMap(
229                                     Map.Entry::getKey,
230                                     e -> pipeline.setex(
231                                             cacheContext.externalEntryKeyFor(e.getKey()).getBytes(),
232                                             defaultTtl,
233                                             marshall(e.getValue(), valueMarshalling))
234                             ));
235                     pipeline.sync();
236 
237                     internalKeyToResponseMap.entrySet().stream()
238                             .filter(e -> !RedisUtils.OK.equals(e.getValue().get()))
239                             .forEach(e -> log.warn("Cache {}: Unable to set key {}", name, e.getKey()));
240                 }
241 
242                 grandResult.putAll(missingValues);
243             }
244 
245             return grandResult;
246         });
247     }
248 
249     @Override
250     public CompletionStage<Map<String, Optional<IdentifiedValue<V>>>> getBulkIdentified(Iterable<String> internalKeys) {
251         return perform(() -> {
252             if (isEmpty(internalKeys)) {
253                 return new HashMap<>();
254             }
255 
256             // De-duplicate the keys and calculate the externalKeys
257             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
258 
259             final Set<String> externalKeys = StreamSupport.stream(internalKeys.spliterator(), false)
260                     .map(cacheContext::externalEntryKeyFor)
261                     .collect(Collectors.toSet());
262 
263             return RedisUtils.directGetBulk(externalKeys, clientSupplier, valueMarshalling)
264                     .entrySet().stream()
265                     .collect(Collectors.toMap(
266                             e -> cacheContext.internalEntryKeyFor(e.getKey()),
267                             e -> buildIdentifiedValue(e.getValue())));
268         });
269     }
270 
271     @Override
272     public CompletionStage<Boolean> put(String internalKey, V value, PutPolicy policy) {
273         return perform(() -> {
274             final String externalKey = buildExternalKey(internalKey);
275             final byte[] valueBytes = valueMarshalling.getMarshaller().marshallToBytes(requireNonNull(value));
276 
277             return RedisUtils.putOperationForPolicy(
278                     policy, clientSupplier, externalKey, defaultTtl, valueBytes);
279         });
280     }
281 
282     @Override
283     public CompletionStage<Boolean> removeIf(String internalKey, CasIdentifier casId) {
284         return perform(() -> {
285             final String externalKey = buildExternalKey(internalKey);
286 
287             try (Jedis client = clientSupplier.get()) {
288                 final List<byte[]> keys = Arrays.asList(externalKey.getBytes());
289                 final List<byte[]> args = Arrays.asList(RedisUtils.safeExtractValue(casId));
290                 final Number numDeleted = (Number) client.eval(LUA_REMOVE_IF_SCRIPT, keys, args);
291                 return numDeleted.longValue() > 0;
292             }
293         });
294     }
295 
296     @Override
297     public CompletionStage<Boolean> replaceIf(String internalKey, CasIdentifier casId, V newValue) {
298         return perform(() -> {
299             final String externalKey = buildExternalKey(internalKey);
300 
301             try (Jedis client = clientSupplier.get()) {
302                 final List<byte[]> keys = Arrays.asList(externalKey.getBytes());
303                 final List<byte[]> args = Arrays.asList(
304                         RedisUtils.safeExtractValue(casId),
305                         Integer.toString(defaultTtl).getBytes(),
306                         marshall(newValue, valueMarshalling));
307                 final byte[] resultBytes = (byte[]) client.eval(LUA_REPLACE_IF_SCRIPT, keys, args);
308                 return Arrays.equals(RedisUtils.OK.getBytes(), resultBytes);
309             }
310         });
311     }
312 
313     @Override
314     public CompletionStage<Void> remove(Iterable<String> internalKeys) {
315         // There is no bulk delete in the api, so need to remove each one async
316         return perform(() -> {
317             if (isEmpty(internalKeys)) {
318                 return null;
319             }
320 
321             // Lodge all the requests for delete
322             try (Jedis client = clientSupplier.get()) {
323                 final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
324                 final List<byte[]> externalKeysList = StreamSupport.stream(internalKeys.spliterator(), false)
325                         .map(cacheContext::externalEntryKeyFor)
326                         .map(String::getBytes)
327                         .collect(Collectors.toList());
328                 final byte[][] externalKeysAsBytes = externalKeysList.toArray(new byte[externalKeysList.size()][]);
329 
330                 final long numDeleted = client.del(externalKeysAsBytes);
331                 if (numDeleted != externalKeysAsBytes.length) {
332                     log.info("Cache {}: only able to delete {} of {} keys", name, numDeleted, externalKeysAsBytes.length);
333                 }
334             }
335 
336             return null;
337         });
338     }
339 
340     @Override
341     public CompletionStage<Void> removeAll() {
342         return perform(() -> {
343             final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
344             cacheContext.updateCacheVersion(
345                     RedisUtils.incrementCacheVersion(clientSupplier, cacheContext.externalCacheVersionKey()));
346             return null;
347         });
348     }
349 
350     @VisibleForTesting
351     void refreshCacheVersion() {
352         // Refresh the cacheVersion. Useful if want to get the current state of the external cache in testing.
353         final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
354         cacheContext.updateCacheVersion(
355                 RedisUtils.obtainCacheVersion(
356                         clientSupplier,
357                         cacheContext.externalCacheVersionKey(),
358                         defaultTtl + 1));
359     }
360 
361     private String buildExternalKey(String internalKey) {
362         final VersionedExternalCacheRequestContext<V> cacheContext = ensureCacheContext();
363         return cacheContext.externalEntryKeyFor(internalKey);
364     }
365 
366     protected VersionedExternalCacheRequestContext<V> ensureCacheContext() {
367         final RequestContext requestContext = contextSupplier.get();
368 
369         return requestContext.computeIfAbsent(this, () -> {
370             // Need to build a new context, which involves getting the current cache version, or setting it if it does
371             // not exist.
372             log.trace("Cache {}: Setting up a new context", name);
373             final VersionedExternalCacheRequestContext<V> newCacheContext = new VersionedExternalCacheRequestContext<>(
374                     keyGenerator, name, requestContext::partitionIdentifier, lockTimeout);
375             newCacheContext.updateCacheVersion(
376                     RedisUtils.obtainCacheVersion(
377                             clientSupplier,
378                             newCacheContext.externalCacheVersionKey(),
379                             defaultTtl + 1));
380             return newCacheContext;
381         });
382     }
383 
384     @Override
385     protected Logger getLogger() {
386         return log;
387     }
388 
389     @Override
390     protected ExternalCacheException mapException(Exception ex) {
391         return RedisUtils.mapException(ex);
392     }
393 
394     private Optional<IdentifiedValue<V>> buildIdentifiedValue(Optional<V> from) {
395         return from.flatMap(v -> {
396             final CasIdentifier identifier = new RedisCasIdentifier(marshall(v, valueMarshalling));
397             final IdentifiedValue<V> iv = new DefaultIdentifiedValue<>(identifier, v);
398             return Optional.of(iv);
399         });
400     }
401 
402     private V directGet(String internalKey, Supplier<V> supplier) {
403         final String externalKey = buildExternalKey(internalKey);
404         try (Jedis client = clientSupplier.get()) {
405             final Optional<V> existingValue = unmarshall(client.get(externalKey.getBytes()), valueMarshalling);
406             if (existingValue.isPresent()) {
407                 return existingValue.get();
408             }
409         }
410 
411         log.trace("Cache {}, creating candidate for key {}", name, internalKey);
412         final V candidateValue = requireNonNull(supplier.get());
413         final byte[] candidateValueBytes = valueMarshalling.getMarshaller().marshallToBytes(candidateValue);
414 
415         // Loop until either able to add the candidate value, or retrieve one that has been added by another thread
416         try (Jedis client = clientSupplier.get()) {
417             for (; ; ) {
418                 final long addOp = client.setnx(externalKey.getBytes(), candidateValueBytes);
419                 if (addOp == 1) {
420                     // I break here, rather than just return, due to battling with the compiler. Unless written
421                     // this way, the compiler will not allow the lambda structure.
422                     break;
423                 }
424 
425                 log.info("Cache {}, unable to add candidate for key {}, retrieve what was added", name, internalKey);
426                 final Optional<V> otherAddedValue = unmarshall(client.get(externalKey.getBytes()), valueMarshalling);
427                 if (otherAddedValue.isPresent()) {
428                     return otherAddedValue.get();
429                 }
430 
431                 log.info("Cache {}, unable to retrieve recently added candidate for key {}, looping", name, internalKey);
432             }
433         }
434 
435         return candidateValue;
436     }
437 }