Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Bump `com.gradle.develocity` from 3.17.4 to 3.17.5 ([#14397](https://github.com/opensearch-project/OpenSearch/pull/14397))

### Changed
- [Tiered Caching] Move query recomputation logic outside write lock ([#14187](https://github.com/opensearch-project/OpenSearch/pull/14187))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.opensearch.cache.common.tier;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cache.common.policy.TookTimePolicy;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.cache.CacheType;
Expand Down Expand Up @@ -35,9 +37,13 @@
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.ToLongBiFunction;
Expand All @@ -61,6 +67,7 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {

// Used to avoid caching stale entries in lower tiers.
private static final List<RemovalReason> SPILLOVER_REMOVAL_REASONS = List.of(RemovalReason.EVICTED, RemovalReason.CAPACITY);
private static final Logger logger = LogManager.getLogger(TieredSpilloverCache.class);

private final ICache<K, V> diskCache;
private final ICache<K, V> onHeapCache;
Expand All @@ -86,6 +93,12 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {
private final Map<ICache<K, V>, TierInfo> caches;
private final List<Predicate<V>> policies;

/**
* This map is used to handle concurrent requests for same key in computeIfAbsent() to ensure we load the value
* only once.
*/
Map<ICacheKey<K>, CompletableFuture<Tuple<ICacheKey<K>, V>>> completableFutureMap = new ConcurrentHashMap<>();

TieredSpilloverCache(Builder<K, V> builder) {
Objects.requireNonNull(builder.onHeapCacheFactory, "onHeap cache builder can't be null");
Objects.requireNonNull(builder.diskCacheFactory, "disk cache builder can't be null");
Expand Down Expand Up @@ -190,10 +203,7 @@ public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V>
// Add the value to the onHeap cache. We are calling computeIfAbsent which does another get inside.
// This is needed as there can be many requests for the same key at the same time and we only want to load
// the value once.
V value = null;
try (ReleasableLock ignore = writeLock.acquire()) {
value = onHeapCache.computeIfAbsent(key, loader);
}
V value = compute(key, loader);
// Handle stats
if (loader.isLoaded()) {
// The value was just computed and added to the cache by this thread. Register a miss for the heap cache, and the disk cache
Expand Down Expand Up @@ -222,6 +232,55 @@ public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V>
return cacheValueTuple.v1();
}

private V compute(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V> loader) throws Exception {
// Only one of the threads will succeed putting a future into map for the same key.
// Rest will fetch existing future and wait on that to complete.
CompletableFuture<Tuple<ICacheKey<K>, V>> future = completableFutureMap.putIfAbsent(key, new CompletableFuture<>());
// Handler to handle results post processing. Takes a tuple<key, value> or exception as an input and returns
// the value. Also before returning value, puts the value in cache.
BiFunction<Tuple<ICacheKey<K>, V>, Throwable, Void> handler = (pair, ex) -> {
if (pair != null) {
try (ReleasableLock ignore = writeLock.acquire()) {
onHeapCache.put(pair.v1(), pair.v2());
}
} else {
if (ex != null) {
logger.warn("Exception occurred while trying to compute the value", ex);
}
}
completableFutureMap.remove(key); // Remove key from map as not needed anymore.
return null;
};
if (future == null) {
V value = null;
future = completableFutureMap.get(key);
future.handle(handler);
try {
value = loader.load(key);
} catch (Exception ex) {
future.completeExceptionally(ex);
throw new ExecutionException(ex);
}
if (value == null) {
NullPointerException npe = new NullPointerException("Loader returned a null value");
future.completeExceptionally(npe);
throw new ExecutionException(npe);
} else {
future.complete(new Tuple<>(key, value));
}
}
V value;
try {
value = future.get().v2();
if (future.isCompletedExceptionally()) {
throw new IllegalStateException("Future completed exceptionally but no error thrown");
}
} catch (InterruptedException ex) {
throw new IllegalStateException(ex);
}
return value;
}

@Override
public void invalidate(ICacheKey<K> key) {
// We are trying to invalidate the key from all caches though it would be present in only of them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -408,6 +412,7 @@ public void testComputeIfAbsentWithEvictionsFromOnHeapCache() throws Exception {
assertEquals(onHeapCacheHit, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(cacheMiss + numOfItems1, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_DISK));
assertEquals(diskCacheHit, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_DISK));
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testComputeIfAbsentWithEvictionsFromTieredCache() throws Exception {
Expand Down Expand Up @@ -802,7 +807,7 @@ public String load(ICacheKey<String> key) {
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
phaser.arriveAndAwaitAdvance();
tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader);
assertEquals(value, tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader));
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -811,7 +816,7 @@ public String load(ICacheKey<String> key) {
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await(); // Wait for rest of tasks to be cancelled.
countDownLatch.await();
int numberOfTimesKeyLoaded = 0;
assertEquals(numberOfSameKeys, loadAwareCacheLoaderList.size());
for (int i = 0; i < loadAwareCacheLoaderList.size(); i++) {
Expand All @@ -824,6 +829,231 @@ public String load(ICacheKey<String> key) {
// We should see only one heap miss, and the rest hits
assertEquals(1, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(numberOfSameKeys - 1, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testComputIfAbsentConcurrentlyWithMultipleKeys() throws Exception {
int onHeapCacheSize = randomIntBetween(300, 500);
int diskCacheSize = randomIntBetween(600, 700);
int keyValueSize = 50;

MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
Settings settings = Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build();

TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
settings,
0
);

int iterations = 10;
int numberOfKeys = 20;
List<ICacheKey<String>> iCacheKeyList = new ArrayList<>();
for (int i = 0; i < numberOfKeys; i++) {
ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
iCacheKeyList.add(key);
}
ExecutorService executorService = Executors.newFixedThreadPool(8);
CountDownLatch countDownLatch = new CountDownLatch(iterations * numberOfKeys); // To wait for all threads to finish.

List<LoadAwareCacheLoader<ICacheKey<String>, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>();

for (int i = 0; i < iterations; i++) {
for (int j = 0; j < numberOfKeys; j++) {
int finalJ = j;
executorService.submit(() -> {
try {
LoadAwareCacheLoader<ICacheKey<String>, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() {
boolean isLoaded = false;

@Override
public boolean isLoaded() {
return isLoaded;
}

@Override
public String load(ICacheKey<String> key) {
isLoaded = true;
return iCacheKeyList.get(finalJ).key;
}
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
tieredSpilloverCache.computeIfAbsent(iCacheKeyList.get(finalJ), loadAwareCacheLoader);
} catch (Exception e) {
throw new RuntimeException(e);
}
countDownLatch.countDown();
});
}
}
countDownLatch.await();
int numberOfTimesKeyLoaded = 0;
assertEquals(iterations * numberOfKeys, loadAwareCacheLoaderList.size());
for (int i = 0; i < loadAwareCacheLoaderList.size(); i++) {
LoadAwareCacheLoader<ICacheKey<String>, String> loader = loadAwareCacheLoaderList.get(i);
if (loader.isLoaded()) {
numberOfTimesKeyLoaded++;
}
}
assertEquals(numberOfKeys, numberOfTimesKeyLoaded); // It should be loaded only once.
// We should see only one heap miss, and the rest hits
assertEquals(numberOfKeys, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals((iterations * numberOfKeys) - numberOfKeys, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
executorService.shutdownNow();
}

public void testComputeIfAbsentConcurrentlyAndThrowsException() throws Exception {
int onHeapCacheSize = randomIntBetween(100, 300);
int diskCacheSize = randomIntBetween(200, 400);
int keyValueSize = 50;

MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
Settings settings = Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build();

TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
settings,
0
);

int numberOfSameKeys = randomIntBetween(10, onHeapCacheSize - 1);
ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
String value = UUID.randomUUID().toString();
AtomicInteger exceptionCount = new AtomicInteger();

Thread[] threads = new Thread[numberOfSameKeys];
Phaser phaser = new Phaser(numberOfSameKeys + 1);
CountDownLatch countDownLatch = new CountDownLatch(numberOfSameKeys); // To wait for all threads to finish.

List<LoadAwareCacheLoader<ICacheKey<String>, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>();

for (int i = 0; i < numberOfSameKeys; i++) {
threads[i] = new Thread(() -> {
try {
LoadAwareCacheLoader<ICacheKey<String>, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() {
boolean isLoaded = false;

@Override
public boolean isLoaded() {
return isLoaded;
}

@Override
public String load(ICacheKey<String> key) {
throw new RuntimeException("Testing");
}
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
phaser.arriveAndAwaitAdvance();
tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader);
} catch (Exception e) {
exceptionCount.incrementAndGet();
assertEquals(ExecutionException.class, e.getClass());
assertEquals(RuntimeException.class, e.getCause().getClass());
assertEquals("Testing", e.getCause().getMessage());
} finally {
countDownLatch.countDown();
}
});
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await(); // Wait for rest of tasks to be cancelled.

// Verify exception count was equal to number of requests
assertEquals(numberOfSameKeys, exceptionCount.get());
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testComputeIfAbsentConcurrentlyWithLoaderReturningNull() throws Exception {
int onHeapCacheSize = randomIntBetween(100, 300);
int diskCacheSize = randomIntBetween(200, 400);
int keyValueSize = 50;

MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
Settings settings = Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build();

TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
settings,
0
);

int numberOfSameKeys = randomIntBetween(10, onHeapCacheSize - 1);
ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
String value = UUID.randomUUID().toString();
AtomicInteger exceptionCount = new AtomicInteger();

Thread[] threads = new Thread[numberOfSameKeys];
Phaser phaser = new Phaser(numberOfSameKeys + 1);
CountDownLatch countDownLatch = new CountDownLatch(numberOfSameKeys); // To wait for all threads to finish.

List<LoadAwareCacheLoader<ICacheKey<String>, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>();

for (int i = 0; i < numberOfSameKeys; i++) {
threads[i] = new Thread(() -> {
try {
LoadAwareCacheLoader<ICacheKey<String>, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() {
boolean isLoaded = false;

@Override
public boolean isLoaded() {
return isLoaded;
}

@Override
public String load(ICacheKey<String> key) {
return null;
}
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
phaser.arriveAndAwaitAdvance();
tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader);
} catch (Exception e) {
exceptionCount.incrementAndGet();
assertEquals(ExecutionException.class, e.getClass());
assertEquals(NullPointerException.class, e.getCause().getClass());
assertEquals("Loader returned a null value", e.getCause().getMessage());
} finally {
countDownLatch.countDown();
}
});
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await(); // Wait for rest of tasks to be cancelled.

// Verify exception count was equal to number of requests
assertEquals(numberOfSameKeys, exceptionCount.get());
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testConcurrencyForEvictionFlowFromOnHeapToDiskTier() throws Exception {
Expand Down