diff --git a/CHANGELOG.md b/CHANGELOG.md index 976017cae8da0..68975b7988536 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Support AutoExpand for SearchReplica ([#17741](https://github.com/opensearch-project/OpenSearch/pull/17741)) - Implement fixed interval refresh task scheduling ([#17777](https://github.com/opensearch-project/OpenSearch/pull/17777)) - Add GRPC DocumentService and Bulk endpoint ([#17727](https://github.com/opensearch-project/OpenSearch/pull/17727)) +- Support multi-threaded writes, updates and deletes in pull-based ingestion ([#17771](https://github.com/opensearch-project/OpenSearch/pull/17771)) ### Changed - Migrate BC libs to their FIPS counterparts ([#14912](https://github.com/opensearch-project/OpenSearch/pull/14912)) diff --git a/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/IngestFromKafkaIT.java b/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/IngestFromKafkaIT.java index 86d8710f4daab..990343b42bc57 100644 --- a/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/IngestFromKafkaIT.java +++ b/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/IngestFromKafkaIT.java @@ -15,7 +15,9 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.indices.pollingingest.PollingIngestStats; import org.opensearch.plugins.PluginInfo; import org.opensearch.test.OpenSearchIntegTestCase; @@ -32,7 +34,7 @@ import static org.awaitility.Awaitility.await; /** - * Integration test for Kafka ingestion. + * Integration test for Kafka ingestion */ @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0) public class IngestFromKafkaIT extends KafkaIngestionBaseIT { @@ -73,8 +75,8 @@ public void testKafkaIngestion() { } public void testKafkaIngestion_RewindByTimeStamp() { - produceData("1", "name1", "24", 1739459500000L); - produceData("2", "name2", "20", 1739459800000L); + produceData("1", "name1", "24", 1739459500000L, "index"); + produceData("2", "name2", "20", 1739459800000L, "index"); // create an index with ingestion source from kafka createIndex( @@ -135,4 +137,36 @@ public void testCloseIndex() throws Exception { ensureGreen(indexName); client().admin().indices().close(Requests.closeIndexRequest(indexName)).get(); } + + public void testUpdateAndDelete() throws Exception { + // Step 1: Produce message and wait for it to be searchable + + produceData("1", "name", "25", defaultMessageTimestamp, "index"); + createIndexWithDefaultSettings(1, 0); + ensureGreen(indexName); + waitForState(() -> { + BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "1")); + SearchResponse response = client().prepareSearch(indexName).setQuery(query).get(); + assertThat(response.getHits().getTotalHits().value(), is(1L)); + return 25 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age"); + }); + + // Step 2: Update age field from 25 to 30 and validate + + produceData("1", "name", "30", defaultMessageTimestamp, "index"); + waitForState(() -> { + BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "1")); + SearchResponse response = client().prepareSearch(indexName).setQuery(query).get(); + assertThat(response.getHits().getTotalHits().value(), is(1L)); + return 30 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age"); + }); + + // Step 3: Delete the document and validate + produceData("1", "name", "30", defaultMessageTimestamp, "delete"); + waitForState(() -> { + BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "1")); + SearchResponse response = client().prepareSearch(indexName).setQuery(query).get(); + return response.getHits().getTotalHits().value() == 0; + }); + } } diff --git a/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/KafkaIngestionBaseIT.java b/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/KafkaIngestionBaseIT.java index a9ae195332117..cbeb33515a89e 100644 --- a/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/KafkaIngestionBaseIT.java +++ b/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/KafkaIngestionBaseIT.java @@ -93,14 +93,15 @@ private void stopKafka() { } protected void produceData(String id, String name, String age) { - produceData(id, name, age, defaultMessageTimestamp); + produceData(id, name, age, defaultMessageTimestamp, "index"); } - protected void produceData(String id, String name, String age, long timestamp) { + protected void produceData(String id, String name, String age, long timestamp, String opType) { String payload = String.format( Locale.ROOT, - "{\"_id\":\"%s\", \"_op_type:\":\"index\",\"_source\":{\"name\":\"%s\", \"age\": %s}}", + "{\"_id\":\"%s\", \"_op_type\":\"%s\",\"_source\":{\"name\":\"%s\", \"age\": %s}}", id, + opType, name, age ); @@ -159,10 +160,10 @@ protected ResumeIngestionResponse resumeIngestion(String indexName) throws Execu } protected void createIndexWithDefaultSettings(int numShards, int numReplicas) { - createIndexWithDefaultSettings(indexName, numShards, numReplicas); + createIndexWithDefaultSettings(indexName, numShards, numReplicas, 1); } - protected void createIndexWithDefaultSettings(String indexName, int numShards, int numReplicas) { + protected void createIndexWithDefaultSettings(String indexName, int numShards, int numReplicas, int numProcessorThreads) { createIndex( indexName, Settings.builder() @@ -173,6 +174,7 @@ protected void createIndexWithDefaultSettings(String indexName, int numShards, i .put("ingestion_source.param.topic", topicName) .put("ingestion_source.param.bootstrap_servers", kafka.getBootstrapServers()) .put("index.replication.type", "SEGMENT") + .put("ingestion_source.num_processor_threads", numProcessorThreads) // set custom kafka consumer properties .put("ingestion_source.param.fetch.min.bytes", 30000) .put("ingestion_source.param.enable.auto.commit", false) diff --git a/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/RemoteStoreKafkaIT.java b/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/RemoteStoreKafkaIT.java index 54adeaa1396e5..3519ecc35c6f7 100644 --- a/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/RemoteStoreKafkaIT.java +++ b/plugins/ingestion-kafka/src/internalClusterTest/java/org/opensearch/plugin/kafka/RemoteStoreKafkaIT.java @@ -246,8 +246,8 @@ public void testPaginatedGetIngestionState() throws ExecutionException, Interrup internalCluster().startClusterManagerOnlyNode(); internalCluster().startDataOnlyNode(); internalCluster().startDataOnlyNode(); - createIndexWithDefaultSettings("index1", 5, 0); - createIndexWithDefaultSettings("index2", 5, 0); + createIndexWithDefaultSettings("index1", 5, 0, 1); + createIndexWithDefaultSettings("index2", 5, 0, 1); ensureGreen("index1"); ensureGreen("index2"); diff --git a/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java index 9005c830167f9..3b1fa311fda9d 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/IndexMetadata.java @@ -786,6 +786,20 @@ public Iterator> settings() { Property.Dynamic ); + /** + * Defines the number of processor threads that will write to the lucene index. This setting is currently disabled + * and will be allowed once the feature is ready. A default value of 1 will be used. + */ + public static final String SETTING_INGESTION_SOURCE_NUM_PROCESSOR_THREADS = "index.ingestion_source.num_processor_threads"; + public static final Setting INGESTION_SOURCE_NUM_PROCESSOR_THREADS_SETTING = Setting.intSetting( + SETTING_INGESTION_SOURCE_NUM_PROCESSOR_THREADS, + 1, + 1, + 1, + Setting.Property.IndexScope, + Setting.Property.Final + ); + public static final Setting.AffixSetting INGESTION_SOURCE_PARAMS_SETTING = Setting.prefixKeySetting( "index.ingestion_source.param.", key -> new Setting<>(key, "", (value) -> { @@ -1025,8 +1039,9 @@ public IngestionSource getIngestionSource() { ); final IngestionErrorStrategy.ErrorStrategy errorStrategy = INGESTION_SOURCE_ERROR_STRATEGY_SETTING.get(settings); + final int numProcessorThreads = INGESTION_SOURCE_NUM_PROCESSOR_THREADS_SETTING.get(settings); final Map ingestionSourceParams = INGESTION_SOURCE_PARAMS_SETTING.getAsMap(settings); - return new IngestionSource(ingestionSourceType, pointerInitReset, errorStrategy, ingestionSourceParams); + return new IngestionSource(ingestionSourceType, pointerInitReset, errorStrategy, numProcessorThreads, ingestionSourceParams); } return null; } diff --git a/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java b/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java index fd28acf3246ad..7af981669798c 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/IngestionSource.java @@ -23,18 +23,21 @@ public class IngestionSource { private String type; private PointerInitReset pointerInitReset; private IngestionErrorStrategy.ErrorStrategy errorStrategy; + private int numMessageProcessorThreads; private Map params; public IngestionSource( String type, PointerInitReset pointerInitReset, IngestionErrorStrategy.ErrorStrategy errorStrategy, + int numMessageProcessorThreads, Map params ) { this.type = type; this.pointerInitReset = pointerInitReset; this.params = params; this.errorStrategy = errorStrategy; + this.numMessageProcessorThreads = numMessageProcessorThreads; } public String getType() { @@ -53,6 +56,10 @@ public Map params() { return params; } + public int getNumMessageProcessorThreads() { + return numMessageProcessorThreads; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -61,12 +68,13 @@ public boolean equals(Object o) { return Objects.equals(type, ingestionSource.type) && Objects.equals(pointerInitReset, ingestionSource.pointerInitReset) && Objects.equals(errorStrategy, ingestionSource.errorStrategy) - && Objects.equals(params, ingestionSource.params); + && Objects.equals(params, ingestionSource.params) + && Objects.equals(numMessageProcessorThreads, ingestionSource.numMessageProcessorThreads); } @Override public int hashCode() { - return Objects.hash(type, pointerInitReset, params, errorStrategy); + return Objects.hash(type, pointerInitReset, params, errorStrategy, numMessageProcessorThreads); } @Override @@ -81,6 +89,9 @@ public String toString() { + ",error_strategy='" + errorStrategy + '\'' + + ",numMessageProcessorThreads='" + + numMessageProcessorThreads + + '\'' + ", params=" + params + '}'; diff --git a/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java index 3793b9b09e3b2..0cd9523e5b49b 100644 --- a/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/IndexScopedSettings.java @@ -268,6 +268,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexMetadata.INGESTION_SOURCE_POINTER_INIT_RESET_VALUE_SETTING, IndexMetadata.INGESTION_SOURCE_PARAMS_SETTING, IndexMetadata.INGESTION_SOURCE_ERROR_STRATEGY_SETTING, + IndexMetadata.INGESTION_SOURCE_NUM_PROCESSOR_THREADS_SETTING, // validate that built-in similarities don't get redefined Setting.groupSetting("index.similarity.", (s) -> { diff --git a/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java b/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java index 1d5d104394558..009d8e44a3e06 100644 --- a/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java +++ b/server/src/main/java/org/opensearch/index/engine/IngestionEngine.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.Term; import org.apache.lucene.search.IndexSearcher; import org.opensearch.ExceptionsHelper; import org.opensearch.action.admin.indices.streamingingestion.state.ShardIngestionState; @@ -22,6 +23,8 @@ import org.opensearch.index.mapper.DocumentMapperForType; import org.opensearch.index.mapper.IdFieldMapper; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.index.mapper.ParsedDocument; +import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.index.translog.NoOpTranslogManager; import org.opensearch.index.translog.Translog; @@ -43,6 +46,7 @@ import java.util.Set; import java.util.function.BiFunction; +import static org.opensearch.action.index.IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP; import static org.opensearch.index.translog.Translog.EMPTY_TRANSLOG_SNAPSHOT; /** @@ -117,7 +121,8 @@ public void start() { resetState, resetValue, ingestionErrorStrategy, - initialPollerState + initialPollerState, + ingestionSource.getNumMessageProcessorThreads() ); streamPoller.start(); } @@ -153,8 +158,13 @@ public IndexResult index(Index index) throws IOException { } private IndexResult indexIntoLucene(Index index) throws IOException { - // todo: handle updates - addDocs(index.docs(), indexWriter); + if (index.getAutoGeneratedIdTimestamp() != UNSET_AUTO_GENERATED_TIMESTAMP) { + assert index.getAutoGeneratedIdTimestamp() >= 0 : "autoGeneratedIdTimestamp must be positive but was: " + + index.getAutoGeneratedIdTimestamp(); + addDocs(index.docs(), indexWriter); + } else { + updateDocs(index.uid(), index.docs(), indexWriter); + } return new IndexResult(index.version(), index.primaryTerm(), index.seqNo(), true); } @@ -166,9 +176,28 @@ private void addDocs(final List docs, final IndexWriter i } } + private void updateDocs(final Term uid, final List docs, final IndexWriter indexWriter) throws IOException { + if (docs.size() > 1) { + indexWriter.softUpdateDocuments(uid, docs, softDeletesField); + } else { + indexWriter.softUpdateDocument(uid, docs.get(0), softDeletesField); + } + } + @Override public DeleteResult delete(Delete delete) throws IOException { - return null; + assert Objects.equals(delete.uid().field(), IdFieldMapper.NAME) : delete.uid().field(); + ensureOpen(); + final ParsedDocument tombstone = engineConfig.getTombstoneDocSupplier().newDeleteTombstoneDoc(delete.id()); + assert tombstone.docs().size() == 1 : "Tombstone doc should have single doc [" + tombstone + "]"; + final ParseContext.Document doc = tombstone.docs().get(0); + assert doc.getField(SeqNoFieldMapper.TOMBSTONE_NAME) != null : "Delete tombstone document but _tombstone field is not set [" + + doc + + " ]"; + doc.add(softDeletesField); + indexWriter.softUpdateDocument(delete.uid(), doc, softDeletesField); + // delete result is unused in ingestion flow + return new DeleteResult(1, delete.primaryTerm(), -1, true); } @Override diff --git a/server/src/main/java/org/opensearch/index/engine/InternalEngine.java b/server/src/main/java/org/opensearch/index/engine/InternalEngine.java index 064e757c6ebb7..7e171e3f1714c 100644 --- a/server/src/main/java/org/opensearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/opensearch/index/engine/InternalEngine.java @@ -161,6 +161,7 @@ public class InternalEngine extends Engine { protected final AtomicLong maxUnsafeAutoIdTimestamp = new AtomicLong(-1); protected final SoftDeletesPolicy softDeletesPolicy; protected final AtomicBoolean shouldPeriodicallyFlushAfterBigMerge = new AtomicBoolean(false); + protected final NumericDocValuesField softDeletesField = Lucene.newSoftDeletesField(); @Nullable protected final String historyUUID; @@ -197,7 +198,6 @@ public class InternalEngine extends Engine { private final CounterMetric numDocDeletes = new CounterMetric(); private final CounterMetric numDocAppends = new CounterMetric(); private final CounterMetric numDocUpdates = new CounterMetric(); - private final NumericDocValuesField softDeletesField = Lucene.newSoftDeletesField(); private final LastRefreshedCheckpointListener lastRefreshedCheckpointListener; private final CompletionStatsCache completionStatsCache; diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java b/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java index e1a4f7d3b4b7d..470761c6ae2a1 100644 --- a/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java +++ b/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java @@ -21,8 +21,6 @@ import java.util.Locale; import java.util.Objects; import java.util.Set; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -48,8 +46,6 @@ public class DefaultStreamPoller implements StreamPoller { private ExecutorService consumerThread; - private ExecutorService processorThread; - // start of the batch, inclusive private IngestionShardPointer batchStartPointer; private boolean includeBatchStartPointer = false; @@ -59,16 +55,14 @@ public class DefaultStreamPoller implements StreamPoller { private Set persistedPointers; - private BlockingQueue> blockingQueue; - - private MessageProcessorRunnable processorRunnable; - private final CounterMetric totalPolledCount = new CounterMetric(); // A pointer to the max persisted pointer for optimizing the check @Nullable private IngestionShardPointer maxPersistedPointer; + private PartitionedBlockingQueueContainer blockingQueueContainer; + public DefaultStreamPoller( IngestionShardPointer startPointer, Set persistedPointers, @@ -77,13 +71,14 @@ public DefaultStreamPoller( ResetState resetState, String resetValue, IngestionErrorStrategy errorStrategy, - State initialState + State initialState, + int numProcessorThreads ) { this( startPointer, persistedPointers, consumer, - new MessageProcessorRunnable(new ArrayBlockingQueue<>(100), ingestionEngine, errorStrategy), + new PartitionedBlockingQueueContainer(numProcessorThreads, consumer.getShardId(), ingestionEngine, errorStrategy), resetState, resetValue, errorStrategy, @@ -95,7 +90,7 @@ public DefaultStreamPoller( IngestionShardPointer startPointer, Set persistedPointers, IngestionShardConsumer consumer, - MessageProcessorRunnable processorRunnable, + PartitionedBlockingQueueContainer blockingQueueContainer, ResetState resetState, String resetValue, IngestionErrorStrategy errorStrategy, @@ -110,22 +105,13 @@ public DefaultStreamPoller( if (!this.persistedPointers.isEmpty()) { maxPersistedPointer = this.persistedPointers.stream().max(IngestionShardPointer::compareTo).get(); } - this.processorRunnable = processorRunnable; - blockingQueue = processorRunnable.getBlockingQueue(); + this.blockingQueueContainer = blockingQueueContainer; this.consumerThread = Executors.newSingleThreadExecutor( r -> new Thread( r, String.format(Locale.ROOT, "stream-poller-consumer-%d-%d", consumer.getShardId(), System.currentTimeMillis()) ) ); - - // TODO: allow multiple threads for processing the messages in parallel - this.processorThread = Executors.newSingleThreadExecutor( - r -> new Thread( - r, - String.format(Locale.ROOT, "stream-poller-processor-%d-%d", consumer.getShardId(), System.currentTimeMillis()) - ) - ); this.errorStrategy = errorStrategy; } @@ -143,7 +129,7 @@ public void start() { // when we start, we need to include the batch start pointer in the read for the first read includeBatchStartPointer = true; consumerThread.submit(this::startPoll); - processorThread.submit(processorRunnable); + blockingQueueContainer.startProcessorThreads(); } /** @@ -158,6 +144,8 @@ protected void startPoll() { } logger.info("Starting poller for shard {}", consumer.getShardId()); + IngestionShardPointer lastProcessedPointer = null; + boolean encounteredError = false; while (true) { try { if (closed) { @@ -204,15 +192,18 @@ protected void startPoll() { } state = State.POLLING; - List> results; - if (includeBatchStartPointer) { + // todo: handle multi-writer scenarios to provide atleast once semantics + if (encounteredError && lastProcessedPointer != null) { + results = consumer.readNext(lastProcessedPointer, false, MAX_POLL_SIZE, POLL_TIMEOUT); + } else if (includeBatchStartPointer) { results = consumer.readNext(batchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT); } else { results = consumer.readNext(MAX_POLL_SIZE, POLL_TIMEOUT); } + encounteredError = false; if (results.isEmpty()) { // no new records continue; @@ -220,13 +211,8 @@ protected void startPoll() { state = State.PROCESSING; // process the records - boolean firstInBatch = true; for (IngestionShardConsumer.ReadResult result : results) { - if (firstInBatch) { - // update the batch start pointer to the next batch - batchStartPointer = result.getPointer(); - firstInBatch = false; - } + lastProcessedPointer = result.getPointer(); // check if the message is already processed if (isProcessed(result.getPointer())) { @@ -234,7 +220,7 @@ protected void startPoll() { continue; } totalPolledCount.inc(); - blockingQueue.put(result); + blockingQueueContainer.add(result); logger.debug( "Put message {} with pointer {} to the blocking queue", @@ -245,7 +231,8 @@ protected void startPoll() { // for future reads, we do not need to include the batch start pointer, and read from the last successful pointer. includeBatchStartPointer = false; } catch (Throwable e) { - logger.error("Error in polling the shard {}: {}", consumer.getShardId(), e); + encounteredError = true; + logger.error("Error in polling the shard {}, lastProcessedPointer {}: {}", consumer.getShardId(), lastProcessedPointer, e); errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.POLLING); if (!errorStrategy.shouldIgnoreError(e, IngestionErrorStrategy.ErrorStage.POLLING)) { @@ -311,10 +298,10 @@ public void close() { logger.error("Error in closing the poller of shard {}: {}", consumer.getShardId(), e); } } - blockingQueue.clear(); + consumerThread.shutdown(); // interrupts the processor - processorThread.shutdownNow(); + blockingQueueContainer.close(); logger.info("closed the poller of shard {}", consumer.getShardId()); } @@ -337,7 +324,7 @@ public IngestionShardPointer getBatchStartPointer() { public PollingIngestStats getStats() { PollingIngestStats.Builder builder = new PollingIngestStats.Builder(); builder.setTotalPolledCount(totalPolledCount.count()); - builder.setTotalProcessedCount(processorRunnable.getStats().count()); + builder.setTotalProcessedCount(blockingQueueContainer.getTotalProcessedCount()); return builder.build(); } @@ -353,6 +340,6 @@ public IngestionErrorStrategy getErrorStrategy() { @Override public void updateErrorStrategy(IngestionErrorStrategy errorStrategy) { this.errorStrategy = errorStrategy; - processorRunnable.setErrorStrategy(errorStrategy); + blockingQueueContainer.updateErrorStrategy(errorStrategy); } } diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/IngestionUtils.java b/server/src/main/java/org/opensearch/indices/pollingingest/IngestionUtils.java new file mode 100644 index 0000000000000..48701540fd97d --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/pollingingest/IngestionUtils.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.opensearch.common.UUIDs; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaTypeRegistry; + +import java.util.Map; + +/** + * Holds common utilities for streaming ingestion. + */ +public final class IngestionUtils { + + private IngestionUtils() {} + + public static Map getParsedPayloadMap(byte[] payload) { + BytesReference payloadBR = new BytesArray(payload); + Map payloadMap = XContentHelper.convertToMap(payloadBR, false, MediaTypeRegistry.xContentType(payloadBR)).v2(); + return payloadMap; + } + + public static String generateID() { + return UUIDs.base64UUID(); + } +} diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java b/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java index 2066f348243b8..2525895d2fc4a 100644 --- a/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java +++ b/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java @@ -16,11 +16,8 @@ import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.metrics.CounterMetric; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.index.IngestionShardConsumer; import org.opensearch.index.IngestionShardPointer; import org.opensearch.index.Message; import org.opensearch.index.VersionType; @@ -38,6 +35,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; +import static org.opensearch.action.index.IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; /** @@ -52,7 +50,8 @@ public class MessageProcessorRunnable implements Runnable { private static final int WAIT_BEFORE_RETRY_DURATION_MS = 5000; private volatile IngestionErrorStrategy errorStrategy; - private final BlockingQueue> blockingQueue; + private volatile boolean closed = false; + private final BlockingQueue> blockingQueue; private final MessageProcessor messageProcessor; private final CounterMetric stats = new CounterMetric(); @@ -63,7 +62,7 @@ public class MessageProcessorRunnable implements Runnable { * @param engine the ingestion engine */ public MessageProcessorRunnable( - BlockingQueue> blockingQueue, + BlockingQueue> blockingQueue, IngestionEngine engine, IngestionErrorStrategy errorStrategy ) { @@ -76,7 +75,7 @@ public MessageProcessorRunnable( * @param messageProcessor the message processor */ MessageProcessorRunnable( - BlockingQueue> blockingQueue, + BlockingQueue> blockingQueue, MessageProcessor messageProcessor, IngestionErrorStrategy errorStrategy ) { @@ -109,40 +108,48 @@ static class MessageProcessor { * Process the message and create an engine operation. It also records the offset in the document as (1) a point * field used for range search, (2) a stored field for retrieval. * - * @param message the message to process - * @param pointer the pointer to the message + * @param updateMessage the update message to process */ - protected void process(Message message, IngestionShardPointer pointer) { - byte[] payload = (byte[]) message.getPayload(); - + protected void process(ShardUpdateMessage updateMessage) { try { - Engine.Operation operation = getOperation(payload, pointer); + Engine.Operation operation = getOperation(updateMessage); switch (operation.operationType()) { case INDEX: engine.index((Engine.Index) operation); break; case DELETE: - engine.delete((Engine.Delete) operation); + if (updateMessage.autoGeneratedIdTimestamp() == UNSET_AUTO_GENERATED_TIMESTAMP) { + engine.delete((Engine.Delete) operation); + } else { + logger.info("Delete operation dropped since ID was not found in the message"); + } break; default: throw new IllegalArgumentException("Invalid operation: " + operation); } } catch (IOException e) { - logger.error("Failed to process operation from message {} at pointer {}: {}", message, pointer, e); + logger.error( + "Failed to process operation from message {} at pointer {}: {}", + updateMessage.originalMessage(), + updateMessage.pointer(), + e + ); throw new RuntimeException(e); } } /** * Visible for testing. Get the engine operation from the message. - * @param payload the payload of the message - * @param pointer the pointer to the message + * @param updateMessage an update message containing payload and pointer for the update * @return the engine operation */ - protected Engine.Operation getOperation(byte[] payload, IngestionShardPointer pointer) throws IOException { - BytesReference payloadBR = new BytesArray(payload); - Map payloadMap = XContentHelper.convertToMap(payloadBR, false, MediaTypeRegistry.xContentType(payloadBR)).v2(); + protected Engine.Operation getOperation(ShardUpdateMessage updateMessage) throws IOException { + Map payloadMap = updateMessage.parsedPayloadMap(); + if (payloadMap == null) { + payloadMap = IngestionUtils.getParsedPayloadMap((byte[]) updateMessage.originalMessage().getPayload()); + } + IngestionShardPointer pointer = updateMessage.pointer(); String id = (String) payloadMap.getOrDefault(ID, "null"); if (payloadMap.containsKey(OP_TYPE) && !(payloadMap.get(OP_TYPE) instanceof String)) { // TODO: add metric @@ -177,7 +184,7 @@ protected Engine.Operation getOperation(byte[] payload, IngestionShardPointer po document.add(new StoredField(IngestionShardPointer.OFFSET_FIELD, pointer.asString())); operation = new Engine.Index( - new Term("_id", id), + new Term(IdFieldMapper.NAME, Uid.encodeId(id)), doc, 0, 1, @@ -185,7 +192,7 @@ protected Engine.Operation getOperation(byte[] payload, IngestionShardPointer po VersionType.INTERNAL, Engine.Operation.Origin.PRIMARY, System.nanoTime(), - System.currentTimeMillis(), + updateMessage.autoGeneratedIdTimestamp(), false, UNASSIGNED_SEQ_NO, 0 @@ -219,7 +226,7 @@ private static BytesReference convertToBytes(Object object) throws IOException { return BytesReference.bytes(XContentFactory.jsonBuilder().map((Map) object)); } - BlockingQueue> getBlockingQueue() { + BlockingQueue> getBlockingQueue() { return blockingQueue; } @@ -229,27 +236,27 @@ private static BytesReference convertToBytes(Object object) throws IOException { */ @Override public void run() { - IngestionShardConsumer.ReadResult readResult = null; + ShardUpdateMessage updateMessage = null; - while (!(Thread.currentThread().isInterrupted())) { + while (Thread.currentThread().isInterrupted() == false && closed == false) { try { - if (readResult == null) { - readResult = blockingQueue.poll(1000, TimeUnit.MILLISECONDS); + if (updateMessage == null) { + updateMessage = blockingQueue.poll(1000, TimeUnit.MILLISECONDS); } } catch (InterruptedException e) { // TODO: add metric logger.debug("MessageProcessorRunnable poll interruptedException", e); Thread.currentThread().interrupt(); // Restore interrupt status } - if (readResult != null) { + if (updateMessage != null) { try { stats.inc(); - messageProcessor.process(readResult.getMessage(), readResult.getPointer()); - readResult = null; + messageProcessor.process(updateMessage); + updateMessage = null; } catch (Exception e) { errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.PROCESSING); if (errorStrategy.shouldIgnoreError(e, IngestionErrorStrategy.ErrorStage.PROCESSING)) { - readResult = null; + updateMessage = null; } else { waitBeforeRetry(); } @@ -278,4 +285,8 @@ public IngestionErrorStrategy getErrorStrategy() { public void setErrorStrategy(IngestionErrorStrategy errorStrategy) { this.errorStrategy = errorStrategy; } + + public void close() { + closed = true; + } } diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/PartitionedBlockingQueueContainer.java b/server/src/main/java/org/opensearch/indices/pollingingest/PartitionedBlockingQueueContainer.java new file mode 100644 index 0000000000000..ddfa2cbff547f --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/pollingingest/PartitionedBlockingQueueContainer.java @@ -0,0 +1,196 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.core.common.Strings; +import org.opensearch.index.IngestionShardConsumer; +import org.opensearch.index.IngestionShardPointer; +import org.opensearch.index.Message; +import org.opensearch.index.engine.IngestionEngine; + +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import static org.opensearch.action.index.IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP; + +/** + * This class holds a blocking queue per partition. A processor thread is started for each partition to consume updates + * and write to the lucene index. + */ +public class PartitionedBlockingQueueContainer { + private static final Logger logger = LogManager.getLogger(PartitionedBlockingQueueContainer.class); + private static final int defaultQueueSize = 100; + private static final String ID = "_id"; + + private final int numPartitions; + + // partition mappings + private final Map>> partitionToQueueMap; + private final Map partitionToMessageProcessorMap; + private final Map partitionToProcessorExecutorMap; + + /** + * Initialize partitions and processor threads for given number of partitions. + */ + public PartitionedBlockingQueueContainer( + int numPartitions, + int shardId, + IngestionEngine ingestionEngine, + IngestionErrorStrategy errorStrategy + ) { + partitionToQueueMap = new ConcurrentHashMap<>(); + partitionToMessageProcessorMap = new ConcurrentHashMap<>(); + partitionToProcessorExecutorMap = new ConcurrentHashMap<>(); + this.numPartitions = numPartitions; + + logger.info("Initializing processors for shard {} using {} partitions", shardId, numPartitions); + String processorThreadNamePrefix = String.format( + Locale.ROOT, + "stream-poller-processor-shard-%d-%d", + shardId, + System.currentTimeMillis() + ); + + for (int partition = 0; partition < numPartitions; partition++) { + String processorThreadName = String.format(Locale.ROOT, "%s-partition-%d", processorThreadNamePrefix, partition); + ExecutorService executorService = Executors.newSingleThreadExecutor( + r -> new Thread(r, String.format(Locale.ROOT, processorThreadName)) + ); + partitionToProcessorExecutorMap.put(partition, executorService); + partitionToQueueMap.put(partition, new ArrayBlockingQueue<>(defaultQueueSize)); + + MessageProcessorRunnable messageProcessorRunnable = new MessageProcessorRunnable( + partitionToQueueMap.get(partition), + ingestionEngine, + errorStrategy + ); + partitionToMessageProcessorMap.put(partition, messageProcessorRunnable); + } + } + + /** + * Initialize a single partition for the provided messageProcessorRunnable. This method is for testing purpose only. + */ + PartitionedBlockingQueueContainer(MessageProcessorRunnable messageProcessorRunnable, int shardId) { + partitionToQueueMap = new ConcurrentHashMap<>(); + partitionToMessageProcessorMap = new ConcurrentHashMap<>(); + partitionToProcessorExecutorMap = new ConcurrentHashMap<>(); + this.numPartitions = 1; + + partitionToQueueMap.put(0, messageProcessorRunnable.getBlockingQueue()); + partitionToMessageProcessorMap.put(0, messageProcessorRunnable); + ExecutorService executorService = Executors.newSingleThreadExecutor( + r -> new Thread( + r, + String.format( + Locale.ROOT, + String.format(Locale.ROOT, "stream-poller-processor-shard-%d-%d-partition-0", shardId, System.currentTimeMillis()) + ) + ) + ); + partitionToProcessorExecutorMap.put(0, executorService); + } + + /** + * Starts the processor threads to read updates and write to the index. + */ + public void startProcessorThreads() { + for (int partition = 0; partition < numPartitions; partition++) { + ExecutorService executorService = partitionToProcessorExecutorMap.get(partition); + MessageProcessorRunnable messageProcessorRunnable = partitionToMessageProcessorMap.get(partition); + executorService.submit(messageProcessorRunnable); + } + } + + /** + * Add a document into the blocking queue. ID of the document will be used to identify the blocking queue partition. + * If an ID is not present, a new one will be auto generated. + */ + public void add(IngestionShardConsumer.ReadResult readResult) + throws InterruptedException { + Map payloadMap = IngestionUtils.getParsedPayloadMap((byte[]) readResult.getMessage().getPayload()); + String id; + long autoGeneratedIdTimestamp = UNSET_AUTO_GENERATED_TIMESTAMP; + + if (payloadMap.containsKey(ID)) { + id = (String) payloadMap.get(ID); + } else { + id = IngestionUtils.generateID(); + payloadMap.put(ID, id); + autoGeneratedIdTimestamp = System.currentTimeMillis(); + } + + ShardUpdateMessage updateMessage = new ShardUpdateMessage( + readResult.getPointer(), + readResult.getMessage(), + payloadMap, + autoGeneratedIdTimestamp + ); + + int partition = getPartitionFromID(id); + partitionToQueueMap.get(partition).put(updateMessage); + } + + /** + * Stop the processor threads and shutdown the executors. + */ + public void close() { + partitionToQueueMap.values().forEach(queue -> queue.clear()); + partitionToMessageProcessorMap.values().forEach(MessageProcessorRunnable::close); + partitionToProcessorExecutorMap.values().forEach(ExecutorService::shutdown); + partitionToQueueMap.clear(); + partitionToProcessorExecutorMap.clear(); + partitionToMessageProcessorMap.clear(); + } + + /** + * Return total number of processed updates across all partitions. + */ + public long getTotalProcessedCount() { + return partitionToMessageProcessorMap.values() + .stream() + .map(MessageProcessorRunnable::getStats) + .mapToLong(CounterMetric::count) + .sum(); + } + + /** + * Update error strategy in all available message processors. + */ + public void updateErrorStrategy(IngestionErrorStrategy errorStrategy) { + partitionToMessageProcessorMap.values().forEach(messageProcessor -> messageProcessor.setErrorStrategy(errorStrategy)); + } + + private int getPartitionFromID(String id) { + if (Strings.isEmpty(id)) { + return 0; + } + return Math.floorMod(id.hashCode(), numPartitions); + } + + Map getPartitionToMessageProcessorMap() { + return partitionToMessageProcessorMap; + } + + Map getPartitionToProcessorExecutorMap() { + return partitionToProcessorExecutorMap; + } + + Map>> getPartitionToQueueMap() { + return partitionToQueueMap; + } +} diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/ShardUpdateMessage.java b/server/src/main/java/org/opensearch/indices/pollingingest/ShardUpdateMessage.java new file mode 100644 index 0000000000000..fd4af460f33f5 --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/pollingingest/ShardUpdateMessage.java @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.opensearch.index.IngestionShardPointer; +import org.opensearch.index.Message; + +import java.util.Map; + +/** + * Holds the original message consumed from the streaming source, corresponding pointer and parsed payload map. This + * will be used by the pull-based ingestion processor/writer threads to update the index. + */ +public record ShardUpdateMessage(T pointer, M originalMessage, Map< + String, + Object> parsedPayloadMap, long autoGeneratedIdTimestamp) { +} diff --git a/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java b/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java index 05037f33c3965..260ba9b125bf3 100644 --- a/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java +++ b/server/src/test/java/org/opensearch/cluster/metadata/IngestionSourceTests.java @@ -26,7 +26,7 @@ public class IngestionSourceTests extends OpenSearchTestCase { public void testConstructorAndGetters() { Map params = new HashMap<>(); params.put("key", "value"); - IngestionSource source = new IngestionSource("type", pointerInitReset, DROP, params); + IngestionSource source = new IngestionSource("type", pointerInitReset, DROP, 1, params); assertEquals("type", source.getType()); assertEquals(StreamPoller.ResetState.REWIND_BY_OFFSET, source.getPointerInitReset().getType()); @@ -38,38 +38,38 @@ public void testConstructorAndGetters() { public void testEquals() { Map params1 = new HashMap<>(); params1.put("key", "value"); - IngestionSource source1 = new IngestionSource("type", pointerInitReset, DROP, params1); + IngestionSource source1 = new IngestionSource("type", pointerInitReset, DROP, 1, params1); Map params2 = new HashMap<>(); params2.put("key", "value"); - IngestionSource source2 = new IngestionSource("type", pointerInitReset, DROP, params2); + IngestionSource source2 = new IngestionSource("type", pointerInitReset, DROP, 1, params2); assertTrue(source1.equals(source2)); assertTrue(source2.equals(source1)); - IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, DROP, params1); + IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, DROP, 1, params1); assertFalse(source1.equals(source3)); } public void testHashCode() { Map params1 = new HashMap<>(); params1.put("key", "value"); - IngestionSource source1 = new IngestionSource("type", pointerInitReset, DROP, params1); + IngestionSource source1 = new IngestionSource("type", pointerInitReset, DROP, 1, params1); Map params2 = new HashMap<>(); params2.put("key", "value"); - IngestionSource source2 = new IngestionSource("type", pointerInitReset, DROP, params2); + IngestionSource source2 = new IngestionSource("type", pointerInitReset, DROP, 1, params2); assertEquals(source1.hashCode(), source2.hashCode()); - IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, DROP, params1); + IngestionSource source3 = new IngestionSource("differentType", pointerInitReset, DROP, 1, params1); assertNotEquals(source1.hashCode(), source3.hashCode()); } public void testToString() { Map params = new HashMap<>(); params.put("key", "value"); - IngestionSource source = new IngestionSource("type", pointerInitReset, DROP, params); + IngestionSource source = new IngestionSource("type", pointerInitReset, DROP, 1, params); String expected = - "IngestionSource{type='type',pointer_init_reset='PointerInitReset{type='REWIND_BY_OFFSET', value=1000}',error_strategy='DROP', params={key=value}}"; + "IngestionSource{type='type',pointer_init_reset='PointerInitReset{type='REWIND_BY_OFFSET', value=1000}',error_strategy='DROP',numMessageProcessorThreads='1', params={key=value}}"; assertEquals(expected, source.toString()); } } diff --git a/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java b/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java index 6d71a3763fbc9..35e34912528b3 100644 --- a/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java +++ b/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java @@ -45,6 +45,7 @@ public class DefaultStreamPollerTests extends OpenSearchTestCase { private FakeIngestionSource.FakeIngestionConsumer fakeConsumer; private MessageProcessorRunnable processorRunnable; private MessageProcessorRunnable.MessageProcessor processor; + private PartitionedBlockingQueueContainer partitionedBlockingQueueContainer; private List messages; private Set persistedPointers; private final int awaitTime = 300; @@ -61,17 +62,19 @@ public void setUp() throws Exception { processor = mock(MessageProcessorRunnable.MessageProcessor.class); errorStrategy = new DropIngestionErrorStrategy("ingestion_source"); processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(5), processor, errorStrategy); + partitionedBlockingQueueContainer = new PartitionedBlockingQueueContainer(processorRunnable, 0); persistedPointers = new HashSet<>(); poller = new DefaultStreamPoller( new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, fakeConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.NONE, "", errorStrategy, StreamPoller.State.NONE ); + partitionedBlockingQueueContainer.startProcessorThreads(); } @After @@ -79,6 +82,7 @@ public void tearDown() throws Exception { if (!poller.isClosed()) { poller.close(); } + partitionedBlockingQueueContainer.close(); super.tearDown(); } @@ -88,7 +92,7 @@ public void testPauseAndResume() throws InterruptedException { doAnswer(invocation -> { pauseLatch.countDown(); return null; - }).when(processor).process(any(), any()); + }).when(processor).process(any()); poller.pause(); poller.start(); @@ -99,19 +103,19 @@ public void testPauseAndResume() throws InterruptedException { assertFalse("Messages should not be processed while paused", processedWhilePaused); assertEquals(DefaultStreamPoller.State.PAUSED, poller.getState()); assertTrue(poller.isPaused()); - verify(processor, never()).process(any(), any()); + verify(processor, never()).process(any()); CountDownLatch resumeLatch = new CountDownLatch(2); doAnswer(invocation -> { resumeLatch.countDown(); return null; - }).when(processor).process(any(), any()); + }).when(processor).process(any()); poller.resume(); resumeLatch.await(); assertFalse(poller.isPaused()); // 2 messages are processed - verify(processor, times(2)).process(any(), any()); + verify(processor, times(2)).process(any()); } public void testSkipProcessed() throws InterruptedException { @@ -123,7 +127,7 @@ public void testSkipProcessed() throws InterruptedException { new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, fakeConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.NONE, "", errorStrategy, @@ -134,12 +138,12 @@ public void testSkipProcessed() throws InterruptedException { doAnswer(invocation -> { latch.countDown(); return null; - }).when(processor).process(any(), any()); + }).when(processor).process(any()); poller.start(); latch.await(); // 2 messages are processed, 2 messages are skipped - verify(processor, times(2)).process(any(), any()); + verify(processor, times(2)).process(any()); assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(2), poller.getMaxPersistedPointer()); } @@ -161,7 +165,7 @@ public void testResetStateEarliest() throws InterruptedException { new FakeIngestionSource.FakeIngestionShardPointer(1), persistedPointers, fakeConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.EARLIEST, "", errorStrategy, @@ -171,13 +175,13 @@ public void testResetStateEarliest() throws InterruptedException { doAnswer(invocation -> { latch.countDown(); return null; - }).when(processor).process(any(), any()); + }).when(processor).process(any()); poller.start(); latch.await(); // 2 messages are processed - verify(processor, times(2)).process(any(), any()); + verify(processor, times(2)).process(any()); } public void testResetStateLatest() throws InterruptedException { @@ -185,7 +189,7 @@ public void testResetStateLatest() throws InterruptedException { new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, fakeConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.LATEST, "", errorStrategy, @@ -195,7 +199,7 @@ public void testResetStateLatest() throws InterruptedException { poller.start(); waitUntil(() -> poller.getState() == DefaultStreamPoller.State.POLLING, awaitTime, TimeUnit.MILLISECONDS); // no messages processed - verify(processor, never()).process(any(), any()); + verify(processor, never()).process(any()); // reset to the latest assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(2), poller.getBatchStartPointer()); } @@ -205,7 +209,7 @@ public void testResetStateRewindByOffset() throws InterruptedException { new FakeIngestionSource.FakeIngestionShardPointer(2), persistedPointers, fakeConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.REWIND_BY_OFFSET, "1", errorStrategy, @@ -215,12 +219,12 @@ public void testResetStateRewindByOffset() throws InterruptedException { doAnswer(invocation -> { latch.countDown(); return null; - }).when(processor).process(any(), any()); + }).when(processor).process(any()); poller.start(); latch.await(); // 1 message is processed - verify(processor, times(1)).process(any(), any()); + verify(processor, times(1)).process(any()); } public void testStartPollWithoutStart() { @@ -278,7 +282,7 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, mockConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.NONE, "", errorStrategy, @@ -288,7 +292,7 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte Thread.sleep(sleepTime); verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); - verify(processor, times(2)).process(any(), any()); + verify(processor, times(2)).process(any()); } public void testBlockErrorIngestionStrategy() throws TimeoutException, InterruptedException { @@ -324,7 +328,7 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, mockConsumer, - processorRunnable, + partitionedBlockingQueueContainer, StreamPoller.ResetState.NONE, "", errorStrategy, @@ -334,7 +338,7 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt Thread.sleep(sleepTime); verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); - verify(processor, never()).process(any(), any()); + verify(processor, never()).process(any()); assertEquals(DefaultStreamPoller.State.PAUSED, poller.getState()); assertTrue(poller.isPaused()); } @@ -343,15 +347,16 @@ public void testProcessingErrorWithBlockErrorIngestionStrategy() throws TimeoutE messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); - doThrow(new RuntimeException("Error processing update")).when(processor).process(any(), any()); + doThrow(new RuntimeException("Error processing update")).when(processor).process(any()); BlockIngestionErrorStrategy mockErrorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source")); processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(5), processor, mockErrorStrategy); + PartitionedBlockingQueueContainer blockingQueueContainer = new PartitionedBlockingQueueContainer(processorRunnable, 0); poller = new DefaultStreamPoller( new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, fakeConsumer, - processorRunnable, + blockingQueueContainer, StreamPoller.ResetState.NONE, "", mockErrorStrategy, @@ -361,12 +366,43 @@ public void testProcessingErrorWithBlockErrorIngestionStrategy() throws TimeoutE Thread.sleep(sleepTime); verify(mockErrorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.PROCESSING)); - verify(processor, times(1)).process(any(), any()); + verify(processor, times(1)).process(any()); // poller will continue to poll if an error is encountered during message processing but will be blocked by // the write to blockingQueue assertEquals(DefaultStreamPoller.State.POLLING, poller.getState()); } + public void testInitialConsumerReadTransientError() throws TimeoutException, InterruptedException { + messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); + messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); + + DropIngestionErrorStrategy mockErrorStrategy = spy(new DropIngestionErrorStrategy("ingestion_source")); + processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(5), processor, mockErrorStrategy); + PartitionedBlockingQueueContainer blockingQueueContainer = new PartitionedBlockingQueueContainer(processorRunnable, 0); + FakeIngestionSource.FakeIngestionConsumer consumerSpy = spy(fakeConsumer); + + // fail consumer's first poll attempt + doThrow(new RuntimeException("failed to poll messages")).doCallRealMethod() + .when(consumerSpy) + .readNext(any(), anyBoolean(), anyLong(), anyInt()); + + poller = new DefaultStreamPoller( + new FakeIngestionSource.FakeIngestionShardPointer(0), + persistedPointers, + consumerSpy, + blockingQueueContainer, + StreamPoller.ResetState.NONE, + "", + mockErrorStrategy, + StreamPoller.State.NONE + ); + poller.start(); + Thread.sleep(sleepTime); + + verify(mockErrorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); + assertEquals(4, blockingQueueContainer.getTotalProcessedCount()); + } + public void testUpdateErrorStrategy() { assertTrue(poller.getErrorStrategy() instanceof DropIngestionErrorStrategy); assertTrue(processorRunnable.getErrorStrategy() instanceof DropIngestionErrorStrategy); diff --git a/server/src/test/java/org/opensearch/indices/pollingingest/MessageProcessorTests.java b/server/src/test/java/org/opensearch/indices/pollingingest/MessageProcessorTests.java index 273e25c0a5bfc..4f87c41b22672 100644 --- a/server/src/test/java/org/opensearch/indices/pollingingest/MessageProcessorTests.java +++ b/server/src/test/java/org/opensearch/indices/pollingingest/MessageProcessorTests.java @@ -8,6 +8,7 @@ package org.opensearch.indices.pollingingest; +import org.opensearch.index.Message; import org.opensearch.index.engine.Engine; import org.opensearch.index.engine.FakeIngestionSource; import org.opensearch.index.engine.IngestionEngine; @@ -55,7 +56,9 @@ public void testGetIndexOperation() throws IOException { when(documentMapper.parse(any())).thenReturn(parsedDocument); when(parsedDocument.rootDoc()).thenReturn(new ParseContext.Document()); - Engine.Operation operation = processor.getOperation(payload, pointer); + Engine.Operation operation = processor.getOperation( + new ShardUpdateMessage(pointer, mock(Message.class), IngestionUtils.getParsedPayloadMap(payload), 0) + ); assertTrue(operation instanceof Engine.Index); ArgumentCaptor captor = ArgumentCaptor.forClass(SourceToParse.class); @@ -68,7 +71,9 @@ public void testGetDeleteOperation() throws IOException { byte[] payload = "{\"_id\":\"1\",\"_op_type\":\"delete\"}".getBytes(StandardCharsets.UTF_8); FakeIngestionSource.FakeIngestionShardPointer pointer = new FakeIngestionSource.FakeIngestionShardPointer(0); - Engine.Operation operation = processor.getOperation(payload, pointer); + Engine.Operation operation = processor.getOperation( + new ShardUpdateMessage(pointer, mock(Message.class), IngestionUtils.getParsedPayloadMap(payload), 0) + ); assertTrue(operation instanceof Engine.Delete); Engine.Delete deleteOperation = (Engine.Delete) operation; @@ -79,13 +84,17 @@ public void testSkipNoSourceIndexOperation() throws IOException { byte[] payload = "{\"_id\":\"1\"}".getBytes(StandardCharsets.UTF_8); FakeIngestionSource.FakeIngestionShardPointer pointer = new FakeIngestionSource.FakeIngestionShardPointer(0); - Engine.Operation operation = processor.getOperation(payload, pointer); + Engine.Operation operation = processor.getOperation( + new ShardUpdateMessage(pointer, mock(Message.class), IngestionUtils.getParsedPayloadMap(payload), 0) + ); assertNull(operation); // source has wrong type payload = "{\"_id\":\"1\", \"_source\":1}".getBytes(StandardCharsets.UTF_8); - operation = processor.getOperation(payload, pointer); + operation = processor.getOperation( + new ShardUpdateMessage(pointer, mock(Message.class), IngestionUtils.getParsedPayloadMap(payload), 0) + ); assertNull(operation); } @@ -93,7 +102,9 @@ public void testUnsupportedOperation() throws IOException { byte[] payload = "{\"_id\":\"1\", \"_op_tpe\":\"update\"}".getBytes(StandardCharsets.UTF_8); FakeIngestionSource.FakeIngestionShardPointer pointer = new FakeIngestionSource.FakeIngestionShardPointer(0); - Engine.Operation operation = processor.getOperation(payload, pointer); + Engine.Operation operation = processor.getOperation( + new ShardUpdateMessage(pointer, mock(Message.class), IngestionUtils.getParsedPayloadMap(payload), 0) + ); assertNull(operation); } } diff --git a/server/src/test/java/org/opensearch/indices/pollingingest/PartitionedBlockingQueueContainerTests.java b/server/src/test/java/org/opensearch/indices/pollingingest/PartitionedBlockingQueueContainerTests.java new file mode 100644 index 0000000000000..aa8747fdd746f --- /dev/null +++ b/server/src/test/java/org/opensearch/indices/pollingingest/PartitionedBlockingQueueContainerTests.java @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.pollingingest; + +import org.opensearch.core.common.Strings; +import org.opensearch.index.IngestionShardConsumer; +import org.opensearch.index.engine.FakeIngestionSource; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.After; +import org.junit.Before; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeoutException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class PartitionedBlockingQueueContainerTests extends OpenSearchTestCase { + private MessageProcessorRunnable processorRunnable; + private MessageProcessorRunnable.MessageProcessor processor; + private PartitionedBlockingQueueContainer blockingQueueContainer; + private FakeIngestionSource.FakeIngestionConsumer fakeConsumer; + private List messages; + + @Before + public void setUp() throws Exception { + super.setUp(); + messages = new ArrayList<>(); + messages.add("{\"_id\":\"1\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); + messages.add("{\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); + fakeConsumer = new FakeIngestionSource.FakeIngestionConsumer(messages, 0); + processor = mock(MessageProcessorRunnable.MessageProcessor.class); + processorRunnable = new MessageProcessorRunnable( + new ArrayBlockingQueue<>(5), + processor, + new DropIngestionErrorStrategy("ingestion_source") + ); + this.blockingQueueContainer = new PartitionedBlockingQueueContainer(processorRunnable, 0); + } + + @After + public void tearDown() throws Exception { + blockingQueueContainer.close(); + super.tearDown(); + } + + public void testAddMessage() throws TimeoutException, InterruptedException { + assertEquals(1, blockingQueueContainer.getPartitionToQueueMap().size()); + assertEquals(1, blockingQueueContainer.getPartitionToMessageProcessorMap().size()); + assertEquals(1, blockingQueueContainer.getPartitionToProcessorExecutorMap().size()); + + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResults = fakeConsumer.readNext( + fakeConsumer.earliestPointer(), + true, + 5, + 100 + ); + + CountDownLatch updatesLatch = new CountDownLatch(2); + doAnswer(invocation -> { + updatesLatch.countDown(); + return null; + }).when(processor).process(any()); + + for (IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage> readResult : readResults) { + blockingQueueContainer.add(readResult); + } + + // verify ID is present on all messages + blockingQueueContainer.getPartitionToQueueMap() + .get(0) + .forEach(update -> assertTrue(Strings.hasText((String) update.parsedPayloadMap().get("_id")))); + + // start processor threads and verify updates are processed + blockingQueueContainer.startProcessorThreads(); + updatesLatch.await(); + assertEquals(2, blockingQueueContainer.getTotalProcessedCount()); + verify(processor, times(2)).process(any()); + } +}