diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 0819bed33a..5d8f96ba13 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -84,23 +84,18 @@ jobs: echo "::add-mask::$COHERE_KEY" && echo "build and run tests" && ./gradlew build -x spotlessJava && echo "Publish to Maven Local" && ./gradlew publishToMavenLocal -x spotlessJava && - echo "Multi Nodes Integration Testing" && ./gradlew integTest -PnumNodes=3 -x spotlessJava' + echo "Multi Nodes Integration Testing" && ./gradlew integTest -PnumNodes=3 -x spotlessJava + echo "Run Jacoco test coverage" && && ./gradlew jacocoTestReport && cp -v plugin/build/reports/jacoco/test/jacocoTestReport.xml ./jacocoTestReport.xml' plugin=`basename $(ls plugin/build/distributions/*.zip)` echo $plugin mv -v plugin/build/distributions/$plugin ./ echo "build-test-linux=$plugin" >> $GITHUB_OUTPUT - - name: Upload Coverage Report - uses: codecov/codecov-action@v3 - with: - flags: ml-commons - token: ${{ secrets.CODECOV_TOKEN }} - - uses: actions/upload-artifact@v4 + if: ${{ matrix.os }} == "ubuntu-latest" with: - name: ml-plugin-linux-${{ matrix.java }} - path: ${{ steps.step-build-test-linux.outputs.build-test-linux }} - if-no-files-found: error + name: coverage-report-${{ matrix.os }}-${{ matrix.java }} + path: ./jacocoTestReport.xml Test-ml-linux-docker: @@ -200,6 +195,24 @@ jobs: flags: ml-commons token: ${{ secrets.CODECOV_TOKEN }} + Precommit-codecov: + needs: Build-ml-linux + strategy: + matrix: + java: [21, 23] + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/download-artifact@v4 + with: + name: coverage-report-${{ matrix.os }}-${{ matrix.java }} + path: ./ + - name: Upload Coverage Report + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./jacocoTestReport.xml + Build-ml-windows: strategy: matrix: diff --git a/client/build.gradle b/client/build.gradle index de73fbc1a9..92a80a8780 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -17,6 +17,7 @@ dependencies { implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow') implementation project(path: ":${rootProject.name}-common", configuration: 'shadow') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2' diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java similarity index 54% rename from ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java index 3e28366e40..7d21a86609 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java @@ -10,15 +10,20 @@ import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Queue; import java.util.Spliterators; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import org.apache.commons.lang3.math.NumberUtils; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; @@ -30,6 +35,9 @@ import org.opensearch.action.admin.indices.stats.IndexStats; import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.pagination.IndexPaginationStrategy; +import org.opensearch.action.pagination.PageParams; +import org.opensearch.action.pagination.PageToken; import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.IndicesOptions; import org.opensearch.cluster.health.ClusterIndexHealth; @@ -37,6 +45,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Table; import org.opensearch.common.Table.Cell; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; @@ -50,10 +59,15 @@ import lombok.Getter; import lombok.Setter; - -@ToolAnnotation(CatIndexTool.TYPE) -public class CatIndexTool implements Tool { - public static final String TYPE = "CatIndexTool"; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(ListIndexTool.TYPE) +public class ListIndexTool implements Tool { + public static final String TYPE = "ListIndexTool"; + // This needs to be changed once it's changed in opensearch core in RestIndicesListAction. + private static final int MAX_SUPPORTED_LIST_INDICES_PAGE_SIZE = 5000; + public static final int DEFAULT_PAGE_SIZE = 100; private static final String DEFAULT_DESCRIPTION = String .join( " ", @@ -65,7 +79,7 @@ public class CatIndexTool implements Tool { @Setter @Getter - private String name = CatIndexTool.TYPE; + private String name = ListIndexTool.TYPE; @Getter @Setter private String description = DEFAULT_DESCRIPTION; @@ -80,7 +94,7 @@ public class CatIndexTool implements Tool { @SuppressWarnings("unused") private ClusterService clusterService; - public CatIndexTool(Client client, ClusterService clusterService) { + public ListIndexTool(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; @@ -96,9 +110,8 @@ public Object parse(Object o) { @Override public void run(Map parameters, ActionListener listener) { - // TODO: This logic exactly matches the OpenSearch _cat/indices REST action. If code at - // o.o.rest/action/cat/RestIndicesAction.java changes those changes need to be reflected here - // https://github.com/opensearch-project/ml-commons/pull/1582#issuecomment-1796962876 + // TODO: This logic exactly matches the OpenSearch _list/indices REST action. If code at + // o.o.rest/action/list/RestIndicesListAction.java changes those changes need to be reflected here @SuppressWarnings("unchecked") List indexList = parameters.containsKey("indices") ? gson.fromJson(parameters.get("indices"), List.class) @@ -106,13 +119,16 @@ public void run(Map parameters, ActionListener listener) final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); final IndicesOptions indicesOptions = IndicesOptions.strictExpand(); - final boolean local = parameters.containsKey("local") ? Boolean.parseBoolean("local") : false; - final TimeValue clusterManagerNodeTimeout = DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; + final boolean local = parameters.containsKey("local") && Boolean.parseBoolean(parameters.get("local")); final boolean includeUnloadedSegments = Boolean.parseBoolean(parameters.get("include_unloaded_segments")); + final int pageSize = parameters.containsKey("page_size") + ? NumberUtils.toInt(parameters.get("page_size"), DEFAULT_PAGE_SIZE) + : DEFAULT_PAGE_SIZE; + final PageParams pageParams = new PageParams(null, PageParams.PARAM_ASC_SORT_VALUE, pageSize); final ActionListener internalListener = ActionListener.notifyOnce(ActionListener.wrap(table -> { // Handle empty table - if (table.getRows().isEmpty()) { + if (table == null || table.getRows().isEmpty()) { @SuppressWarnings("unchecked") T empty = (T) ("There were no results searching the indices parameter [" + parameters.get("indices") + "]."); listener.onResponse(empty); @@ -131,57 +147,151 @@ public void run(Map parameters, ActionListener listener) listener.onResponse(response); }, listener::onFailure)); - sendGetSettingsRequest( + fetchClusterInfoAndPages( indices, + local, + includeUnloadedSegments, + pageParams, indicesOptions, + new ConcurrentLinkedQueue<>(), + internalListener + ); + } + + private void fetchClusterInfoAndPages( + String[] indices, + boolean local, + boolean includeUnloadedSegments, + PageParams pageParams, + IndicesOptions indicesOptions, + Queue> pageResults, + ActionListener
originalListener + ) { + // First fetch metadata like index setting and cluster states and then fetch index details in batches to save efforts. + sendGetSettingsRequest(indices, indicesOptions, local, client, new ActionListener<>() { + @Override + public void onResponse(final GetSettingsResponse getSettingsResponse) { + // The list of indices that will be returned is determined by the indices returned from the Get Settings call. + // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the + // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so + // force the IndicesOptions for all the sub-requests to be as inclusive as possible. + final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); + // Indices that were successfully resolved during the get settings request might be deleted when the + // subsequent cluster state, cluster health and indices stats requests execute. We have to distinguish two cases: + // 1) the deleted index was explicitly passed as parameter to the /_cat/indices request. In this case we + // want the subsequent requests to fail. + // 2) the deleted index was resolved as part of a wildcard or _all. In this case, we want the subsequent + // requests not to fail on the deleted index (as we want to ignore wildcards that cannot be resolved). + // This behavior can be ensured by letting the cluster state, cluster health and indices stats requests + // re-resolve the index names with the same indices options that we used for the initial cluster state + // request (strictExpand). + sendClusterStateRequest(indices, subRequestIndicesOptions, local, client, new ActionListener<>() { + @Override + public void onResponse(ClusterStateResponse clusterStateResponse) { + // Starts to fetch index details here, if a batch fails build whatever we have and return. + fetchPages( + indices, + local, + DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT, + includeUnloadedSegments, + pageParams, + pageResults, + clusterStateResponse, + getSettingsResponse, + subRequestIndicesOptions, + originalListener + ); + } + + @Override + public void onFailure(final Exception e) { + originalListener.onFailure(e); + } + }); + } + + @Override + public void onFailure(final Exception e) { + originalListener.onFailure(e); + } + }); + } + + private void fetchPages( + String[] indices, + boolean local, + TimeValue clusterManagerNodeTimeout, + boolean includeUnloadedSegments, + PageParams pageParams, + Queue> pageResults, + ClusterStateResponse clusterStateResponse, + GetSettingsResponse getSettingsResponse, + IndicesOptions subRequestIndicesOptions, + ActionListener
originalListener + ) { + final ActionListener iterativeListener = ActionListener.wrap(r -> { + // when previous response returns, build next request with response and invoke again. + PageParams nextPageParams = new PageParams(r.getNextToken(), pageParams.getSort(), pageParams.getSize()); + // when next page doesn't exist or reaches max supported page size, return. + if (r.getNextToken() == null || pageResults.size() >= MAX_SUPPORTED_LIST_INDICES_PAGE_SIZE) { + Table table = buildTable(clusterStateResponse, getSettingsResponse, pageResults); + originalListener.onResponse(table); + } else { + fetchPages( + indices, + local, + clusterManagerNodeTimeout, + includeUnloadedSegments, + nextPageParams, + pageResults, + clusterStateResponse, + getSettingsResponse, + subRequestIndicesOptions, + originalListener + ); + } + }, e -> { + log.error("Failed to fetch index info for page: {}", pageParams.getRequestedToken()); + // Do not throw the exception, just return whatever we have. + originalListener.onResponse(buildTable(clusterStateResponse, getSettingsResponse, pageResults)); + }); + IndexPaginationStrategy paginationStrategy = getPaginationStrategy(pageParams, clusterStateResponse); + // For non-paginated queries, indicesToBeQueried would be same as indices retrieved from + // rest request and unresolved, while for paginated queries, it would be a list of indices + // already resolved by ClusterStateRequest and to be displayed in a page. + final String[] indicesToBeQueried = Objects.isNull(paginationStrategy) + ? indices + : paginationStrategy.getRequestedEntities().toArray(new String[0]); + // After the group listener returns, one page complete and prepare for next page. + // We put the single page result into the pageResults queue for future buildTable. + final GroupedActionListener groupedListener = createGroupedListener( + pageResults, + paginationStrategy.getResponseToken(), + iterativeListener + ); + + sendIndicesStatsRequest( + indicesToBeQueried, + subRequestIndicesOptions, + includeUnloadedSegments, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + + sendClusterHealthRequest( + indicesToBeQueried, + subRequestIndicesOptions, local, clusterManagerNodeTimeout, client, - new ActionListener() { - @Override - public void onResponse(final GetSettingsResponse getSettingsResponse) { - final GroupedActionListener groupedListener = createGroupedListener(4, internalListener); - groupedListener.onResponse(getSettingsResponse); - - // The list of indices that will be returned is determined by the indices returned from the Get Settings call. - // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the - // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so - // force the IndicesOptions for all the sub-requests to be as inclusive as possible. - final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); - - sendIndicesStatsRequest( - indices, - subRequestIndicesOptions, - includeUnloadedSegments, - client, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - sendClusterStateRequest( - indices, - subRequestIndicesOptions, - local, - clusterManagerNodeTimeout, - client, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - sendClusterHealthRequest( - indices, - subRequestIndicesOptions, - local, - clusterManagerNodeTimeout, - client, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - } - - @Override - public void onFailure(final Exception e) { - internalListener.onFailure(e); - } - } + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) ); } + protected IndexPaginationStrategy getPaginationStrategy(PageParams pageParams, ClusterStateResponse clusterStateResponse) { + return new IndexPaginationStrategy(pageParams, clusterStateResponse.getState()); + } + @Override public String getType() { return TYPE; @@ -199,7 +309,6 @@ private void sendGetSettingsRequest( final String[] indices, final IndicesOptions indicesOptions, final boolean local, - final TimeValue clusterManagerNodeTimeout, final Client client, final ActionListener listener ) { @@ -207,7 +316,7 @@ private void sendGetSettingsRequest( request.indices(indices); request.indicesOptions(indicesOptions); request.local(local); - request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); request.names(IndexSettings.INDEX_SEARCH_THROTTLED.getKey()); client.admin().indices().getSettings(request, listener); @@ -217,7 +326,6 @@ private void sendClusterStateRequest( final String[] indices, final IndicesOptions indicesOptions, final boolean local, - final TimeValue clusterManagerNodeTimeout, final Client client, final ActionListener listener ) { @@ -226,7 +334,7 @@ private void sendClusterStateRequest( request.indices(indices); request.indicesOptions(indicesOptions); request.local(local); - request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); client.admin().cluster().state(request, listener); } @@ -266,39 +374,24 @@ private void sendIndicesStatsRequest( client.admin().indices().stats(request, listener); } - private GroupedActionListener createGroupedListener(final int size, final ActionListener
listener) { - return new GroupedActionListener<>(new ActionListener>() { + // group listener only accept two action response: IndicesStatsResponse and ClusterHealthResponse + private GroupedActionListener createGroupedListener( + final Queue> pageResults, + final PageToken pageToken, + final ActionListener listener + ) { + return new GroupedActionListener<>(new ActionListener<>() { @Override public void onResponse(final Collection responses) { - try { - GetSettingsResponse settingsResponse = extractResponse(responses, GetSettingsResponse.class); - Map indicesSettings = StreamSupport - .stream(Spliterators.spliterator(settingsResponse.getIndexToSettings().entrySet(), 0), false) - .collect(Collectors.toMap(cursor -> cursor.getKey(), cursor -> cursor.getValue())); - - ClusterStateResponse stateResponse = extractResponse(responses, ClusterStateResponse.class); - Map indicesStates = StreamSupport - .stream(stateResponse.getState().getMetadata().spliterator(), false) - .collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); - - ClusterHealthResponse healthResponse = extractResponse(responses, ClusterHealthResponse.class); - Map indicesHealths = healthResponse.getIndices(); - - IndicesStatsResponse statsResponse = extractResponse(responses, IndicesStatsResponse.class); - Map indicesStats = statsResponse.getIndices(); - - Table responseTable = buildTable(indicesSettings, indicesHealths, indicesStats, indicesStates); - listener.onResponse(responseTable); - } catch (Exception e) { - onFailure(e); - } + pageResults.add(responses); + listener.onResponse(pageToken); } @Override public void onFailure(final Exception e) { listener.onFailure(e); } - }, size); + }, 2); } @Override @@ -307,9 +400,9 @@ public boolean validate(Map parameters) { } /** - * Factory for the {@link CatIndexTool} + * Factory for the {@link ListIndexTool} */ - public static class Factory implements Tool.Factory { + public static class Factory implements Tool.Factory { private Client client; private ClusterService clusterService; @@ -322,7 +415,7 @@ public static Factory getInstance() { if (INSTANCE != null) { return INSTANCE; } - synchronized (CatIndexTool.class) { + synchronized (ListIndexTool.class) { if (INSTANCE != null) { return INSTANCE; } @@ -342,8 +435,8 @@ public void init(Client client, ClusterService clusterService) { } @Override - public CatIndexTool create(Map map) { - return new CatIndexTool(client, clusterService); + public ListIndexTool create(Map map) { + return new ListIndexTool(client, clusterService); } @Override @@ -396,21 +489,34 @@ private Table getTableWithHeader() { } private Table buildTable( - final Map indicesSettings, - final Map indicesHealths, - final Map indicesStats, - final Map indicesMetadatas + ClusterStateResponse clusterStateResponse, + GetSettingsResponse getSettingsResponse, + Queue> responses ) { + if (responses == null || responses.isEmpty() || responses.peek().isEmpty()) { + return null; + } + Tuple, Map> tuple = aggregateResults(responses); final Table table = getTableWithHeader(); AtomicInteger rowNum = new AtomicInteger(0); + Map indicesSettings = StreamSupport + .stream(Spliterators.spliterator(getSettingsResponse.getIndexToSettings().entrySet(), 0), false) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + Map indicesStates = StreamSupport + .stream(clusterStateResponse.getState().getMetadata().spliterator(), false) + .collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); + + Map indicesHealths = tuple.v2(); + Map indicesStats = tuple.v1(); indicesSettings.forEach((indexName, settings) -> { - if (!indicesMetadatas.containsKey(indexName)) { + if (!indicesStates.containsKey(indexName)) { // the index exists in the Get Indices response but is not present in the cluster state: // it is likely that the index was deleted in the meanwhile, so we ignore it. return; } - final IndexMetadata indexMetadata = indicesMetadatas.get(indexName); + final IndexMetadata indexMetadata = indicesStates.get(indexName); final IndexMetadata.State indexState = indexMetadata.getState(); final IndexStats indexStats = indicesStats.get(indexName); @@ -448,15 +554,28 @@ private Table buildTable( table.addCell(totalStats.getStore() == null ? null : totalStats.getStore().size()); table.addCell(primaryStats.getStore() == null ? null : primaryStats.getStore().size()); - table.endRow(); }); - return table; } - @SuppressWarnings("unchecked") - private static A extractResponse(final Collection responses, Class c) { - return (A) responses.stream().filter(c::isInstance).findFirst().get(); + private Tuple, Map> aggregateResults(Queue> responses) { + // Each batch produces a collection of action response, aggregate them together to build table easier. + Map indexStatsMap = new HashMap<>(); + Map clusterIndexHealthMap = new HashMap<>(); + for (Collection response : responses) { + if (response != null && !response.isEmpty()) { + response.forEach(x -> { + if (x instanceof IndicesStatsResponse) { + indexStatsMap.putAll(((IndicesStatsResponse) x).getIndices()); + } else if (x instanceof ClusterHealthResponse) { + clusterIndexHealthMap.putAll(((ClusterHealthResponse) x).getIndices()); + } else { + throw new IllegalStateException("Unexpected action response type: " + x.getClass().getName()); + } + }); + } + } + return new Tuple<>(indexStatsMap, clusterIndexHealthMap); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java deleted file mode 100644 index 31dd909d1a..0000000000 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.tools; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.Iterator; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.Version; -import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; -import org.opensearch.action.admin.cluster.state.ClusterStateResponse; -import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; -import org.opensearch.action.admin.indices.stats.CommonStats; -import org.opensearch.action.admin.indices.stats.CommonStatsFlags; -import org.opensearch.action.admin.indices.stats.IndexStats; -import org.opensearch.action.admin.indices.stats.IndexStats.IndexStatsBuilder; -import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; -import org.opensearch.action.admin.indices.stats.ShardStats; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.health.ClusterIndexHealth; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.metadata.IndexMetadata.State; -import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.routing.IndexRoutingTable; -import org.opensearch.cluster.routing.IndexShardRoutingTable; -import org.opensearch.cluster.routing.ShardRouting; -import org.opensearch.cluster.routing.ShardRoutingState; -import org.opensearch.cluster.routing.TestShardRouting; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.UUIDs; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.Index; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.index.shard.ShardPath; -import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.engine.tools.CatIndexTool.Factory; -import org.opensearch.transport.client.AdminClient; -import org.opensearch.transport.client.Client; -import org.opensearch.transport.client.ClusterAdminClient; -import org.opensearch.transport.client.IndicesAdminClient; - -public class CatIndexToolTests { - - @Mock - private Client client; - @Mock - private AdminClient adminClient; - @Mock - private IndicesAdminClient indicesAdminClient; - @Mock - private ClusterAdminClient clusterAdminClient; - @Mock - private ClusterService clusterService; - @Mock - private ClusterState clusterState; - @Mock - private Metadata metadata; - @Mock - private GetSettingsResponse getSettingsResponse; - @Mock - private IndicesStatsResponse indicesStatsResponse; - @Mock - private ClusterStateResponse clusterStateResponse; - @Mock - private ClusterHealthResponse clusterHealthResponse; - @Mock - private IndexMetadata indexMetadata; - @Mock - private IndexRoutingTable indexRoutingTable; - - private Map indicesParams; - private Map otherParams; - private Map emptyParams; - - @Before - public void setup() { - MockitoAnnotations.openMocks(this); - - when(adminClient.indices()).thenReturn(indicesAdminClient); - when(adminClient.cluster()).thenReturn(clusterAdminClient); - when(client.admin()).thenReturn(adminClient); - - when(indexMetadata.getState()).thenReturn(State.OPEN); - when(indexMetadata.getCreationVersion()).thenReturn(Version.CURRENT); - - when(metadata.index(any(String.class))).thenReturn(indexMetadata); - when(clusterState.metadata()).thenReturn(metadata); - when(clusterService.state()).thenReturn(clusterState); - - CatIndexTool.Factory.getInstance().init(client, clusterService); - - indicesParams = Map.of("index", "[\"foo\"]"); - otherParams = Map.of("other", "[\"bar\"]"); - emptyParams = Collections.emptyMap(); - } - - @Test - public void testRunAsyncNoIndices() throws Exception { - @SuppressWarnings("unchecked") - ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor - .forClass(ActionListener.class); - doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor - .forClass(ActionListener.class); - doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); - - when(getSettingsResponse.getIndexToSettings()).thenReturn(Collections.emptyMap()); - when(indicesStatsResponse.getIndices()).thenReturn(Collections.emptyMap()); - when(clusterStateResponse.getState()).thenReturn(clusterState); - when(clusterState.getMetadata()).thenReturn(metadata); - when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[0])); - - when(clusterHealthResponse.getIndices()).thenReturn(Collections.emptyMap()); - - Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); - final CompletableFuture future = new CompletableFuture<>(); - ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); - - tool.run(otherParams, listener); - settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); - statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); - clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); - clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); - - future.join(); - assertEquals("There were no results searching the indices parameter [null].", future.get()); - } - - @Test - public void testRunAsyncIndexStats() throws Exception { - String indexName = "foo"; - Index index = new Index(indexName, UUIDs.base64UUID()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor - .forClass(ActionListener.class); - doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); - - @SuppressWarnings("unchecked") - ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor - .forClass(ActionListener.class); - doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); - - when(getSettingsResponse.getIndexToSettings()).thenReturn(Map.of("foo", Settings.EMPTY)); - - int shardId = 0; - ShardId shId = new ShardId(index, shardId); - Path path = Files.createTempDirectory("temp").resolve("indices").resolve(index.getUUID()).resolve(String.valueOf(shardId)); - ShardPath shardPath = new ShardPath(false, path, path, shId); - ShardRouting routing = TestShardRouting.newShardRouting(shId, "node", true, ShardRoutingState.STARTED); - CommonStats commonStats = new CommonStats(CommonStatsFlags.ALL); - IndexStats fooStats = new IndexStatsBuilder(index.getName(), index.getUUID()) - .add(new ShardStats(routing, shardPath, commonStats, null, null, null, null)) - .build(); - when(indicesStatsResponse.getIndices()).thenReturn(Map.of(indexName, fooStats)); - - when(indexMetadata.getIndex()).thenReturn(index); - when(indexMetadata.getNumberOfShards()).thenReturn(5); - when(indexMetadata.getNumberOfReplicas()).thenReturn(1); - when(clusterStateResponse.getState()).thenReturn(clusterState); - when(clusterState.getMetadata()).thenReturn(metadata); - when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[] { indexMetadata })); - @SuppressWarnings("unchecked") - Iterator iterator = (Iterator) mock(Iterator.class); - when(iterator.hasNext()).thenReturn(false); - when(indexRoutingTable.iterator()).thenReturn(iterator); - ClusterIndexHealth fooHealth = new ClusterIndexHealth(indexMetadata, indexRoutingTable); - when(clusterHealthResponse.getIndices()).thenReturn(Map.of(indexName, fooHealth)); - - // Now make the call - Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); - final CompletableFuture future = new CompletableFuture<>(); - ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); - - tool.run(otherParams, listener); - settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); - statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); - clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); - clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); - - future.orTimeout(10, TimeUnit.SECONDS).join(); - String response = future.get(); - String[] responseRows = response.trim().split("\\n"); - - assertEquals(2, responseRows.length); - String header = responseRows[0]; - String fooRow = responseRows[1]; - assertEquals(header.split("\\t").length, fooRow.split("\\t").length); - assertEquals( - "row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)", - header - ); - assertEquals("1,red,open,foo,null,5,1,0,0,0b,0b", fooRow); - } - - @Test - public void testTool() { - Factory instance = CatIndexTool.Factory.getInstance(); - assertEquals(instance, CatIndexTool.Factory.getInstance()); - assertTrue(instance.getDefaultDescription().contains("tool")); - - Tool tool = instance.create(Collections.emptyMap()); - assertEquals(CatIndexTool.TYPE, tool.getType()); - assertTrue(tool.validate(indicesParams)); - assertTrue(tool.validate(otherParams)); - assertFalse(tool.validate(emptyParams)); - } -} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java new file mode 100644 index 0000000000..822812e4fe --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java @@ -0,0 +1,314 @@ +package org.opensearch.ml.engine.tools; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.IndexRoutingTable; +import org.opensearch.cluster.routing.IndexShardRoutingTable; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.core.index.Index; +import org.opensearch.index.shard.DocsStats; +import org.opensearch.index.store.StoreStats; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.ClusterAdminClient; +import org.opensearch.transport.client.IndicesAdminClient; + +import com.google.common.collect.ImmutableMap; + +public class ListIndexToolTests { + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private IndexMetadata indexMetadata; + @Mock + private IndexRoutingTable indexRoutingTable; + @Mock + private Index index; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(client.admin()).thenReturn(adminClient); + + when(indexMetadata.getState()).thenReturn(IndexMetadata.State.OPEN); + when(indexMetadata.getCreationVersion()).thenReturn(Version.CURRENT); + + when(metadata.index(any(String.class))).thenReturn(indexMetadata); + when(indexMetadata.getIndex()).thenReturn(index); + when(indexMetadata.getIndexUUID()).thenReturn(UUIDs.base64UUID()); + when(index.getName()).thenReturn("index-1"); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterState.getMetadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + + ListIndexTool.Factory.getInstance().init(client, clusterService); + } + + @Test + public void test_getType() { + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + assert (tool.getType().equals("ListIndexTool")); + } + + @Test + public void test_run_successful_1() { + mockUp(); + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + verifyResult(tool, createParameters("[\"index-1\"]", "true", "10", "true")); + } + + @Test + public void test_run_successful_2() { + mockUp(); + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + verifyResult(tool, createParameters(null, null, null, null)); + } + + private Map createParameters(String indices, String local, String pageSize, String includeUnloadedSegments) { + Map parameters = new HashMap<>(); + if (indices != null) { + parameters.put("indices", indices); + } + if (local != null) { + parameters.put("local", local); + } + if (pageSize != null) { + parameters.put("page_size", pageSize); + } + if (includeUnloadedSegments != null) { + parameters.put("include_unloaded_segments", includeUnloadedSegments); + } + return parameters; + } + + private void verifyResult(Tool tool, Map parameters) { + ActionListener listener = mock(ActionListener.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + tool.run(parameters, listener); + verify(listener).onResponse(captor.capture()); + System.out.println(captor.getValue()); + assert captor.getValue().contains("1,red,open,index-1"); + assert captor.getValue().contains("5,1,100,10,100kb,100kb"); + } + + private void mockUp() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + GetSettingsResponse response = mock(GetSettingsResponse.class); + Map indexToSettings = new HashMap<>(); + indexToSettings.put("index-1", Settings.EMPTY); + indexToSettings.put("index-2", Settings.EMPTY); + when(response.getIndexToSettings()).thenReturn(indexToSettings); + actionListener.onResponse(response); + return null; + }).when(indicesAdminClient).getSettings(any(GetSettingsRequest.class), isA(ActionListener.class)); + + // clusterStateResponse.getState().getMetadata().spliterator() + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + ClusterStateResponse response = mock(ClusterStateResponse.class); + when(response.getState()).thenReturn(clusterState); + actionListener.onResponse(response); + return null; + }).when(clusterAdminClient).state(any(ClusterStateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + IndicesStatsResponse response = mock(IndicesStatsResponse.class); + Map indicesStats = new HashMap<>(); + IndexStats indexStats = mock(IndexStats.class); + // mock primary stats + CommonStats primaryStats = mock(CommonStats.class); + DocsStats docsStats = mock(DocsStats.class); + when(docsStats.getCount()).thenReturn(100L); + when(docsStats.getDeleted()).thenReturn(10L); + when(primaryStats.getDocs()).thenReturn(docsStats); + StoreStats primaryStoreStats = mock(StoreStats.class); + when(primaryStoreStats.size()).thenReturn(ByteSizeValue.parseBytesSizeValue("100k", "mock_setting_name")); + when(primaryStats.getStore()).thenReturn(primaryStoreStats); + // end mock primary stats + + // mock total stats + CommonStats totalStats = mock(CommonStats.class); + DocsStats totalDocsStats = mock(DocsStats.class); + when(totalDocsStats.getCount()).thenReturn(100L); + when(totalDocsStats.getDeleted()).thenReturn(10L); + StoreStats totalStoreStats = mock(StoreStats.class); + when(totalStoreStats.size()).thenReturn(ByteSizeValue.parseBytesSizeValue("100k", "mock_setting_name")); + when(totalStats.getStore()).thenReturn(totalStoreStats); + // end mock common stats + + when(indexStats.getPrimaries()).thenReturn(primaryStats); + when(indexStats.getTotal()).thenReturn(totalStats); + indicesStats.put("index-1", indexStats); + when(response.getIndices()).thenReturn(indicesStats); + actionListener.onResponse(response); + return null; + }).when(indicesAdminClient).stats(any(IndicesStatsRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + ClusterHealthResponse response = mock(ClusterHealthResponse.class); + Map clusterIndexHealthMap = new HashMap<>(); + when(indexMetadata.getNumberOfShards()).thenReturn(5); + when(indexMetadata.getNumberOfReplicas()).thenReturn(1); + when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[] { indexMetadata })); + Iterator iterator = (Iterator) mock(Iterator.class); + when(iterator.hasNext()).thenReturn(false); + when(indexRoutingTable.iterator()).thenReturn(iterator); + ClusterIndexHealth health = new ClusterIndexHealth(indexMetadata, indexRoutingTable); + clusterIndexHealthMap.put("index-1", health); + when(response.getIndices()).thenReturn(clusterIndexHealthMap); + actionListener.onResponse(response); + return null; + }).when(clusterAdminClient).health(any(ClusterHealthRequest.class), isA(ActionListener.class)); + } + + @Test + public void test_run_withEmptyTableResult() { + Map parameters = createParameters("[\"index-1\"]", "true", "10", "true"); + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + GetSettingsResponse response = mock(GetSettingsResponse.class); + Map indexToSettings = new HashMap<>(); + indexToSettings.put("index-1", Settings.EMPTY); + indexToSettings.put("index-2", Settings.EMPTY); + when(response.getIndexToSettings()).thenReturn(indexToSettings); + actionListener.onResponse(response); + return null; + }).when(indicesAdminClient).getSettings(any(GetSettingsRequest.class), isA(ActionListener.class)); + + // clusterStateResponse.getState().getMetadata().spliterator() + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + ClusterStateResponse response = mock(ClusterStateResponse.class); + when(response.getState()).thenReturn(clusterState); + actionListener.onResponse(response); + return null; + }).when(clusterAdminClient).state(any(ClusterStateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(null); + return null; + }).when(indicesAdminClient).stats(any(IndicesStatsRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(null); + return null; + }).when(clusterAdminClient).health(any(ClusterHealthRequest.class), isA(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + tool.run(parameters, listener); + verify(listener).onResponse(captor.capture()); + System.out.println(captor.getValue()); + assert captor.getValue().contains("There were no results searching the indices parameter"); + } + + @Test + public void test_run_failed() { + Map parameters = new HashMap<>(); + parameters.put("indices", "[\"index-1\"]"); + parameters.put("page_size", "10"); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("failed to get settings")); + return null; + }).when(indicesAdminClient).getSettings(any(GetSettingsRequest.class), isA(ActionListener.class)); + + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + ActionListener listener = mock(ActionListener.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(RuntimeException.class); + tool.run(parameters, listener); + verify(listener).onFailure(captor.capture()); + System.out.println(captor.getValue().getMessage()); + assert (captor.getValue().getMessage().contains("failed to get settings")); + } + + @Test + public void test_validate() { + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + assert tool.validate(ImmutableMap.of("runtimeParameter", "value1")); + assert !tool.validate(null); + assert !tool.validate(Collections.emptyMap()); + } + + @Test + public void test_getDefaultDescription() { + Tool.Factory factory = ListIndexTool.Factory.getInstance(); + System.out.println(factory.getDefaultDescription()); + assert (factory + .getDefaultDescription() + .equals( + "This tool gets index information from the OpenSearch cluster. It takes 2 optional arguments named `index` which is a comma-delimited list of one or more indices to get information from (default is an empty list meaning all indices), and `local` which means whether to return information from the local node only instead of the cluster manager node (default is false). The tool returns the indices information, including `health`, `status`, `index`, `uuid`, `pri`, `rep`, `docs.count`, `docs.deleted`, `store.size`, `pri.store. size `, `pri.store.size`, `pri.store`." + )); + } + + @Test + public void test_getDefaultType() { + Tool.Factory factory = ListIndexTool.Factory.getInstance(); + System.out.println(factory.getDefaultType()); + assert (factory.getDefaultType().equals("ListIndexTool")); + } + + @Test + public void test_getDefaultVersion() { + Tool.Factory factory = ListIndexTool.Factory.getInstance(); + assert factory.getDefaultVersion() == null; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index f9d3421d7f..a74ef178b6 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -190,9 +190,9 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.engine.tools.AgentTool; -import org.opensearch.ml.engine.tools.CatIndexTool; import org.opensearch.ml.engine.tools.ConnectorTool; import org.opensearch.ml.engine.tools.IndexMappingTool; +import org.opensearch.ml.engine.tools.ListIndexTool; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.engine.tools.SearchIndexTool; import org.opensearch.ml.engine.tools.VisualizationsTool; @@ -644,7 +644,7 @@ public Collection createComponents( MLModelTool.Factory.getInstance().init(client); AgentTool.Factory.getInstance().init(client); - CatIndexTool.Factory.getInstance().init(client, clusterService); + ListIndexTool.Factory.getInstance().init(client, clusterService); IndexMappingTool.Factory.getInstance().init(client); SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); VisualizationsTool.Factory.getInstance().init(client); @@ -652,7 +652,7 @@ public Collection createComponents( toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); - toolFactories.put(CatIndexTool.TYPE, CatIndexTool.Factory.getInstance()); + toolFactories.put(ListIndexTool.TYPE, ListIndexTool.Factory.getInstance()); toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance()); toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance()); toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance()); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 8096da1fbc..33fcd4d8ae 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -302,10 +302,6 @@ private void validateOutput(String errorMsg, Map output, String List outputList = (List) output.get("output"); assertEquals(errorMsg, 1, outputList.size()); assertTrue(errorMsg, outputList.get(0) instanceof Map); - String typeErrorMsg = errorMsg - + " first element in the output list is type of: " - + ((Map) outputList.get(0)).get("data").getClass().getName(); - assertTrue(typeErrorMsg, ((Map) outputList.get(0)).get("data") instanceof List); assertEquals(errorMsg, ((Map) outputList.get(0)).get("data_type"), dataType); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java index 239fa9c917..a997f4cb48 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestCohereInferenceIT.java @@ -88,10 +88,6 @@ private void validateOutput(String errorMsg, Map output, String List outputList = (List) output.get("output"); assertEquals(errorMsg, 2, outputList.size()); assertTrue(errorMsg, outputList.get(0) instanceof Map); - String typeErrorMsg = errorMsg - + " first element in the output list is type of: " - + ((Map) outputList.get(0)).get("data").getClass().getName(); - assertTrue(typeErrorMsg, ((Map) outputList.get(0)).get("data") instanceof List); assertTrue(errorMsg, ((Map) outputList.get(0)).get("data_type").equals(dataType)); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLFlowAgentIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLFlowAgentIT.java index 49254f4d1d..9c53e60158 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLFlowAgentIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLFlowAgentIT.java @@ -78,8 +78,8 @@ public static Response registerAgentWithCatIndexTool() throws IOException { + " \"description\": \"this is a test agent for the CatIndexTool\",\n" + " \"tools\": [\n" + " {\n" - + " \"type\": \"CatIndexTool\",\n" - + " \"name\": \"DemoCatIndexTool\",\n" + + " \"type\": \"ListIndexTool\",\n" + + " \"name\": \"DemoListIndexTool\",\n" + " \"parameters\": {\n" + " \"input\": \"${parameters.question}\"\n" + " }\n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java index df6c661641..b627a93a8a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java @@ -31,7 +31,7 @@ import org.opensearch.ml.common.transport.tools.MLGetToolAction; import org.opensearch.ml.common.transport.tools.MLToolGetRequest; import org.opensearch.ml.common.transport.tools.MLToolGetResponse; -import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.ml.engine.tools.ListIndexTool; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -61,7 +61,7 @@ public void setup() { Mockito.when(mockFactory.getDefaultType()).thenReturn("Mocked type"); Mockito.when(mockFactory.getDefaultVersion()).thenReturn("Mocked version"); - Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); Mockito.when(mockFactory.create(Mockito.any())).thenReturn(tool); toolFactories.put("mockTool", mockFactory); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java index 5e52af4f9a..22d14c28af 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java @@ -30,7 +30,7 @@ import org.opensearch.ml.common.transport.tools.MLListToolsAction; import org.opensearch.ml.common.transport.tools.MLToolsListRequest; import org.opensearch.ml.common.transport.tools.MLToolsListResponse; -import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.ml.engine.tools.ListIndexTool; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -59,7 +59,7 @@ public void setup() { Mockito.when(mockFactory.getDefaultType()).thenReturn("Mocked type"); Mockito.when(mockFactory.getDefaultVersion()).thenReturn("Mocked version"); - Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); Mockito.when(mockFactory.create(Mockito.any())).thenReturn(tool); toolFactories.put("mockTool", mockFactory); restMLListToolsAction = new RestMLListToolsAction(toolFactories); diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java b/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java new file mode 100644 index 0000000000..3074fe090c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import org.apache.commons.lang3.StringUtils; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.engine.tools.ListIndexTool; +import org.opensearch.ml.rest.RestBaseAgentToolsIT; +import org.opensearch.ml.utils.TestHelper; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ListIndexToolIT extends RestBaseAgentToolsIT { + private String agentId; + private final String question = "{\"parameters\":{\"question\":\"please help list all the index status in the current cluster?\"}}"; + + @Before + public void setUpCluster() throws Exception { + registerListIndexFlowAgent(); + } + + private List createIndices(int count) throws IOException { + List indices = new ArrayList<>(); + for (int i = 0; i < count; i++) { + String indexName = "test" + StringUtils.toRootLowerCase(randomAlphaOfLength(5)); + createIndex(indexName, Settings.EMPTY); + indices.add(indexName); + } + return indices; + } + + private void registerListIndexFlowAgent() throws Exception { + String requestBody = Files + .readString( + Path.of(this.getClass().getClassLoader().getResource("org/opensearch/ml/tools/ListIndexAgentRegistration.json").toURI()) + ); + registerMLAgent(client(), requestBody, response -> agentId = (String) response.get("agent_id")); + } + + public void testListIndexWithFewIndices() throws IOException { + List indices = createIndices(ListIndexTool.DEFAULT_PAGE_SIZE); + Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null); + String responseStr = TestHelper.httpEntityToString(response.getEntity()); + String toolOutput = extractResult(responseStr); + String[] actualLines = toolOutput.split("\\n"); + long testIndexCount = Arrays.stream(actualLines).filter(x -> x.contains("test")).count(); + assert testIndexCount == indices.size(); + for (String index : indices) { + assert Objects.requireNonNull(toolOutput).contains(index); + } + } + + public void testListIndexWithMoreThan100Indices() throws IOException { + List indices = createIndices(ListIndexTool.DEFAULT_PAGE_SIZE + 1); + Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null); + String responseStr = TestHelper.httpEntityToString(response.getEntity()); + String toolOutput = extractResult(responseStr); + String[] actualLines = toolOutput.split("\\n"); + long testIndexCount = Arrays.stream(actualLines).filter(x -> x.contains("test")).count(); + assert testIndexCount == indices.size(); + for (String index : indices) { + assert Objects.requireNonNull(toolOutput).contains(index); + } + } + + /** + * An example of responseStr: + * { + * "inference_results": [ + * { + * "output": [ + * { + * "name": "response", + * "result": "row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)\n1,yellow,open,test4,6ohWskucQ3u3xV9tMjXCkA,1,1,0,0,208b,208b\n2,yellow,open,test5,5AQLe-Z3QKyyLibbZ3Xcng,1,1,0,0,208b,208b\n3,yellow,open,test2,66Cj3zjlQ-G8I3vWeEONpQ,1,1,0,0,208b,208b\n4,yellow,open,test3,6A-aVxPiTj2U9GnupHQ3BA,1,1,0,0,208b,208b\n5,yellow,open,test8,-WKw-SCET3aTFuWCMMixrw,1,1,0,0,208b,208b" + * } + * ] + * } + * ] + * } + * @param responseStr + * @return + */ + private String extractResult(String responseStr) { + JsonArray output = JsonParser + .parseString(responseStr) + .getAsJsonObject() + .get("inference_results") + .getAsJsonArray() + .get(0) + .getAsJsonObject() + .get("output") + .getAsJsonArray(); + for (JsonElement element : output) { + if ("response".equals(element.getAsJsonObject().get("name").getAsString())) { + return element.getAsJsonObject().get("result").getAsString(); + } + } + return null; + } +} diff --git a/plugin/src/test/resources/org/opensearch/ml/tools/ListIndexAgentRegistration.json b/plugin/src/test/resources/org/opensearch/ml/tools/ListIndexAgentRegistration.json new file mode 100644 index 0000000000..fd2f6e8a07 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/tools/ListIndexAgentRegistration.json @@ -0,0 +1,19 @@ +{ + "name": "list index tool flow agent", + "type": "flow", + "description": "this is a test agent", + "llm": { + "model_id": "dummy_model", + "parameters": { + "max_iteration": 5, + "stop_when_no_tool_found": true + } + }, + "tools": [ + { + "type": "ListIndexTool", + "name": "ListIndexTool" + } + ], + "app_type": "my_app" +}