Skip to content

Commit d812134

Browse files
committed
Star Tree Search request/response changes
Signed-off-by: Sandesh Kumar <[email protected]>
1 parent b99c73a commit d812134

26 files changed

+1195
-22
lines changed

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/builder/BaseStarTreeBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ public List<SequentialDocValuesIterator> getMetricReaders(SegmentWriteState stat
198198
continue;
199199
}
200200
FieldInfo metricFieldInfo = state.fieldInfos.fieldInfo(metric.getField());
201-
// if (metricStat != MetricStat.COUNT) {
201+
// if (metricStat != MetricStat.VALUE_COUNT) {
202202
if (metricFieldInfo == null) {
203203
metricFieldInfo = StarTreeUtils.getFieldInfo(metric.getField(), 1);
204204
}

server/src/main/java/org/opensearch/index/query/QueryShardContext.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@
5656
import org.opensearch.index.IndexSortConfig;
5757
import org.opensearch.index.analysis.IndexAnalyzers;
5858
import org.opensearch.index.cache.bitset.BitsetFilterCache;
59+
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
60+
import org.opensearch.index.compositeindex.datacube.Dimension;
61+
import org.opensearch.index.compositeindex.datacube.Metric;
62+
import org.opensearch.index.compositeindex.datacube.MetricStat;
5963
import org.opensearch.index.fielddata.IndexFieldData;
64+
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
6065
import org.opensearch.index.mapper.ContentPath;
6166
import org.opensearch.index.mapper.DerivedFieldResolver;
6267
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
@@ -73,12 +78,22 @@
7378
import org.opensearch.script.ScriptContext;
7479
import org.opensearch.script.ScriptFactory;
7580
import org.opensearch.script.ScriptService;
81+
import org.opensearch.search.aggregations.AggregatorFactory;
82+
import org.opensearch.search.aggregations.metrics.AvgAggregatorFactory;
83+
import org.opensearch.search.aggregations.metrics.MaxAggregatorFactory;
84+
import org.opensearch.search.aggregations.metrics.MinAggregatorFactory;
85+
import org.opensearch.search.aggregations.metrics.SumAggregatorFactory;
86+
import org.opensearch.search.aggregations.metrics.ValueCountAggregatorFactory;
7687
import org.opensearch.search.aggregations.support.AggregationUsageService;
88+
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
7789
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
7890
import org.opensearch.search.lookup.SearchLookup;
91+
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
92+
import org.opensearch.search.startree.StarTreeQuery;
7993
import org.opensearch.transport.RemoteClusterAware;
8094

8195
import java.io.IOException;
96+
import java.util.ArrayList;
8297
import java.util.HashMap;
8398
import java.util.HashSet;
8499
import java.util.List;
@@ -89,6 +104,7 @@
89104
import java.util.function.LongSupplier;
90105
import java.util.function.Predicate;
91106
import java.util.function.Supplier;
107+
import java.util.stream.Collectors;
92108

93109
import static java.util.Collections.emptyList;
94110
import static java.util.Collections.emptyMap;
@@ -522,6 +538,80 @@ private ParsedQuery toQuery(QueryBuilder queryBuilder, CheckedFunction<QueryBuil
522538
}
523539
}
524540

541+
public ParsedQuery toStarTreeQuery(
542+
CompositeIndexFieldInfo starTree,
543+
CompositeDataCubeFieldType compositeIndexFieldInfo,
544+
QueryBuilder queryBuilder,
545+
Query query
546+
) {
547+
Map<String, List<Predicate<Long>>> predicateMap;
548+
549+
if (queryBuilder == null) {
550+
predicateMap = null;
551+
} else if (queryBuilder instanceof TermQueryBuilder) {
552+
List<String> supportedDimensions = compositeIndexFieldInfo.getDimensions()
553+
.stream()
554+
.map(Dimension::getField)
555+
.collect(Collectors.toList());
556+
predicateMap = getStarTreePredicates(queryBuilder, supportedDimensions);
557+
} else {
558+
return null;
559+
}
560+
561+
StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, predicateMap);
562+
OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, query);
563+
return new ParsedQuery(originalOrStarTreeQuery);
564+
}
565+
566+
/**
567+
* Parse query body to star-tree predicates
568+
* @param queryBuilder
569+
* @return predicates to match
570+
*/
571+
private Map<String, List<Predicate<Long>>> getStarTreePredicates(QueryBuilder queryBuilder, List<String> supportedDimensions) {
572+
TermQueryBuilder tq = (TermQueryBuilder) queryBuilder;
573+
String field = tq.fieldName();
574+
if (supportedDimensions.contains(field) == false) {
575+
throw new IllegalArgumentException("unsupported field in star-tree");
576+
}
577+
long inputQueryVal = Long.parseLong(tq.value().toString());
578+
579+
// Get or create the list of predicates for the given field
580+
Map<String, List<Predicate<Long>>> predicateMap = new HashMap<>();
581+
List<Predicate<Long>> predicates = predicateMap.getOrDefault(field, new ArrayList<>());
582+
583+
// Create a predicate to match the input query value
584+
Predicate<Long> predicate = dimVal -> dimVal == inputQueryVal;
585+
predicates.add(predicate);
586+
587+
// Put the predicates list back into the map
588+
predicateMap.put(field, predicates);
589+
return predicateMap;
590+
}
591+
592+
public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) {
593+
String field;
594+
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
595+
.stream()
596+
.collect(Collectors.toMap(Metric::getField, Metric::getMetrics));
597+
598+
// Map to associate supported AggregatorFactory classes with their corresponding MetricStat
599+
Map<Class<? extends ValuesSourceAggregatorFactory>, MetricStat> aggregatorStatMap = Map.of(
600+
SumAggregatorFactory.class, MetricStat.SUM,
601+
MaxAggregatorFactory.class, MetricStat.MAX,
602+
MinAggregatorFactory.class, MetricStat.MIN,
603+
ValueCountAggregatorFactory.class, MetricStat.VALUE_COUNT,
604+
AvgAggregatorFactory.class, MetricStat.AVG
605+
);
606+
607+
MetricStat metricStat = aggregatorStatMap.get(aggregatorFactory.getClass());
608+
if (metricStat != null) {
609+
field = ((ValuesSourceAggregatorFactory)aggregatorFactory).getField();
610+
return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(metricStat);
611+
}
612+
return false;
613+
}
614+
525615
public Index index() {
526616
return indexSettings.getIndex();
527617
}

server/src/main/java/org/opensearch/search/SearchService.java

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,16 @@
7777
import org.opensearch.index.IndexNotFoundException;
7878
import org.opensearch.index.IndexService;
7979
import org.opensearch.index.IndexSettings;
80+
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
8081
import org.opensearch.index.engine.Engine;
82+
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
8183
import org.opensearch.index.mapper.DerivedFieldResolver;
8284
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
85+
import org.opensearch.index.mapper.StarTreeMapper;
8386
import org.opensearch.index.query.InnerHitContextBuilder;
8487
import org.opensearch.index.query.MatchAllQueryBuilder;
8588
import org.opensearch.index.query.MatchNoneQueryBuilder;
89+
import org.opensearch.index.query.ParsedQuery;
8690
import org.opensearch.index.query.QueryBuilder;
8791
import org.opensearch.index.query.QueryRewriteContext;
8892
import org.opensearch.index.query.QueryShardContext;
@@ -97,11 +101,13 @@
97101
import org.opensearch.script.ScriptService;
98102
import org.opensearch.search.aggregations.AggregationInitializationException;
99103
import org.opensearch.search.aggregations.AggregatorFactories;
104+
import org.opensearch.search.aggregations.AggregatorFactory;
100105
import org.opensearch.search.aggregations.InternalAggregation;
101106
import org.opensearch.search.aggregations.InternalAggregation.ReduceContext;
102107
import org.opensearch.search.aggregations.MultiBucketConsumerService;
103108
import org.opensearch.search.aggregations.SearchContextAggregations;
104109
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
110+
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
105111
import org.opensearch.search.builder.SearchSourceBuilder;
106112
import org.opensearch.search.collapse.CollapseContext;
107113
import org.opensearch.search.dfs.DfsPhase;
@@ -162,6 +168,7 @@
162168
import static org.opensearch.common.unit.TimeValue.timeValueHours;
163169
import static org.opensearch.common.unit.TimeValue.timeValueMillis;
164170
import static org.opensearch.common.unit.TimeValue.timeValueMinutes;
171+
import static org.opensearch.search.internal.SearchContext.TRACK_TOTAL_HITS_DISABLED;
165172

166173
/**
167174
* The main search service
@@ -1314,6 +1321,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
13141321
context.evaluateRequestShouldUseConcurrentSearch();
13151322
return;
13161323
}
1324+
// Can be marked false for majority cases for which star-tree cannot be used
1325+
// As we increment the cases where star-tree can be used, this can be set back to true
1326+
boolean canUseStarTree = context.mapperService().isCompositeIndexPresent();
1327+
13171328
SearchShardTarget shardTarget = context.shardTarget();
13181329
QueryShardContext queryShardContext = context.getQueryShardContext();
13191330
context.from(source.from());
@@ -1324,10 +1335,12 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
13241335
context.parsedQuery(queryShardContext.toQuery(source.query()));
13251336
}
13261337
if (source.postFilter() != null) {
1338+
canUseStarTree = false;
13271339
InnerHitContextBuilder.extractInnerHits(source.postFilter(), innerHitBuilders);
13281340
context.parsedPostFilter(queryShardContext.toQuery(source.postFilter()));
13291341
}
1330-
if (innerHitBuilders.size() > 0) {
1342+
if (!innerHitBuilders.isEmpty()) {
1343+
canUseStarTree = false;
13311344
for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
13321345
try {
13331346
entry.getValue().build(context, context.innerHits());
@@ -1337,11 +1350,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
13371350
}
13381351
}
13391352
if (source.sorts() != null) {
1353+
canUseStarTree = false;
13401354
try {
13411355
Optional<SortAndFormats> optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext());
1342-
if (optionalSort.isPresent()) {
1343-
context.sort(optionalSort.get());
1344-
}
1356+
optionalSort.ifPresent(context::sort);
13451357
} catch (IOException e) {
13461358
throw new SearchException(shardTarget, "failed to create sort elements", e);
13471359
}
@@ -1355,8 +1367,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
13551367
}
13561368
if (source.trackTotalHitsUpTo() != null) {
13571369
context.trackTotalHitsUpTo(source.trackTotalHitsUpTo());
1370+
canUseStarTree = canUseStarTree && (source.trackTotalHitsUpTo() == TRACK_TOTAL_HITS_DISABLED);
13581371
}
13591372
if (source.minScore() != null) {
1373+
canUseStarTree = false;
13601374
context.minimumScore(source.minScore());
13611375
}
13621376
if (source.timeout() != null) {
@@ -1496,6 +1510,50 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
14961510
if (source.profile()) {
14971511
context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch()));
14981512
}
1513+
1514+
if (canUseStarTree) {
1515+
try {
1516+
setStarTreeQuery(context, queryShardContext, source);
1517+
logger.debug("can use star tree");
1518+
} catch (IOException e) {
1519+
logger.debug("not using star tree");
1520+
}
1521+
}
1522+
}
1523+
1524+
private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source)
1525+
throws IOException {
1526+
1527+
if (source.aggregations() == null) {
1528+
return false;
1529+
}
1530+
1531+
// Current implementation assumes only single star-tree is supported
1532+
CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService()
1533+
.getCompositeFieldTypes()
1534+
.iterator()
1535+
.next();
1536+
CompositeIndexFieldInfo starTree = new CompositeIndexFieldInfo(
1537+
compositeMappedFieldType.name(),
1538+
compositeMappedFieldType.getCompositeIndexType()
1539+
);
1540+
1541+
ParsedQuery newParsedQuery = queryShardContext.toStarTreeQuery(starTree, compositeMappedFieldType, source.query(), context.query());
1542+
if (newParsedQuery == null) {
1543+
return false;
1544+
}
1545+
1546+
for (AggregatorFactory aggregatorFactory : context.aggregations().factories().getFactories()) {
1547+
if (!(aggregatorFactory instanceof ValuesSourceAggregatorFactory
1548+
&& aggregatorFactory.getSubFactories().getFactories().length == 0)) {
1549+
return false;
1550+
}
1551+
if (queryShardContext.validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory) == false) {
1552+
return false;
1553+
}
1554+
}
1555+
context.parsedQuery(newParsedQuery);
1556+
return true;
14991557
}
15001558

15011559
/**
@@ -1655,7 +1713,7 @@ public static boolean canMatchSearchAfter(
16551713
&& minMax != null
16561714
&& primarySortField != null
16571715
&& primarySortField.missing() == null
1658-
&& Objects.equals(trackTotalHitsUpto, SearchContext.TRACK_TOTAL_HITS_DISABLED)) {
1716+
&& Objects.equals(trackTotalHitsUpto, TRACK_TOTAL_HITS_DISABLED)) {
16591717
final Object searchAfterPrimary = searchAfter.fields[0];
16601718
if (primarySortField.order() == SortOrder.DESC) {
16611719
if (minMax.compareMin(searchAfterPrimary) > 0) {

server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ public static Builder builder() {
255255
return new Builder();
256256
}
257257

258-
private AggregatorFactories(AggregatorFactory[] factories) {
258+
public AggregatorFactories(AggregatorFactory[] factories) {
259259
this.factories = factories;
260260
}
261261

@@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() {
661661
return new PipelineTree(subTrees, aggregators);
662662
}
663663
}
664+
665+
public AggregatorFactory[] getFactories() {
666+
return factories;
667+
}
664668
}

server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() {
127127
public boolean evaluateChildFactories() {
128128
return factories.allFactoriesSupportConcurrentSearch();
129129
}
130+
131+
public AggregatorFactories getSubFactories() {
132+
return factories;
133+
}
130134
}

0 commit comments

Comments
 (0)