1717import org .opensearch .common .lucene .Lucene ;
1818import org .opensearch .index .codec .composite .CompositeIndexFieldInfo ;
1919import org .opensearch .index .codec .composite .CompositeIndexReader ;
20+ import org .opensearch .index .compositeindex .datacube .DateDimension ;
2021import org .opensearch .index .compositeindex .datacube .Dimension ;
2122import org .opensearch .index .compositeindex .datacube .Metric ;
2223import org .opensearch .index .compositeindex .datacube .MetricStat ;
2324import org .opensearch .index .compositeindex .datacube .startree .index .StarTreeValues ;
25+ import org .opensearch .index .compositeindex .datacube .startree .utils .date .DateTimeUnitAdapter ;
26+ import org .opensearch .index .compositeindex .datacube .startree .utils .date .DateTimeUnitRounding ;
2427import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .SortedNumericStarTreeValuesIterator ;
2528import org .opensearch .index .mapper .CompositeDataCubeFieldType ;
2629import org .opensearch .index .query .MatchAllQueryBuilder ;
2730import org .opensearch .index .query .QueryBuilder ;
2831import org .opensearch .index .query .TermQueryBuilder ;
2932import org .opensearch .search .aggregations .AggregatorFactory ;
3033import org .opensearch .search .aggregations .LeafBucketCollector ;
31- import org .opensearch .search .aggregations .LeafBucketCollectorBase ;
34+ import org .opensearch .search .aggregations .StarTreeBucketCollector ;
35+ import org .opensearch .search .aggregations .bucket .histogram .DateHistogramAggregatorFactory ;
3236import org .opensearch .search .aggregations .metrics .MetricAggregatorFactory ;
3337import org .opensearch .search .aggregations .support .ValuesSource ;
3438import org .opensearch .search .builder .SearchSourceBuilder ;
3741import org .opensearch .search .startree .StarTreeQueryContext ;
3842
3943import java .io .IOException ;
40- import java .util .HashMap ;
4144import java .util .List ;
4245import java .util .Map ;
46+ import java .util .Set ;
47+ import java .util .function .BiConsumer ;
4348import java .util .function .Consumer ;
4449import java .util .stream .Collectors ;
4550
@@ -74,10 +79,16 @@ public static StarTreeQueryContext getStarTreeQueryContext(SearchContext context
7479 );
7580
7681 for (AggregatorFactory aggregatorFactory : context .aggregations ().factories ().getFactories ()) {
77- MetricStat metricStat = validateStarTreeMetricSupport (compositeMappedFieldType , aggregatorFactory );
78- if (metricStat == null ) {
79- return null ;
82+ // first check for aggregation is a metric aggregation
83+ if (validateStarTreeMetricSupport (compositeMappedFieldType , aggregatorFactory )) {
84+ continue ;
85+ }
86+
87+ // if not a metric aggregation, check for applicable date histogram shape
88+ if (validateDateHistogramSupport (compositeMappedFieldType , aggregatorFactory )) {
89+ continue ;
8090 }
91+ return null ;
8192 }
8293
8394 // need to cache star tree values only for multiple aggregations
@@ -99,64 +110,85 @@ private static StarTreeQueryContext tryCreateStarTreeQueryContext(
99110 Map <String , Long > queryMap ;
100111 if (queryBuilder == null || queryBuilder instanceof MatchAllQueryBuilder ) {
101112 queryMap = null ;
102- } else if (queryBuilder instanceof TermQueryBuilder ) {
113+ } else if (queryBuilder instanceof TermQueryBuilder termQueryBuilder ) {
103114 // TODO: Add support for keyword fields
104- if (compositeFieldType .getDimensions ().stream ().anyMatch (d -> d .getDocValuesType () != DocValuesType .SORTED_NUMERIC )) {
105- // return null for non-numeric fields
106- return null ;
107- }
108-
109- List <String > supportedDimensions = compositeFieldType .getDimensions ()
115+ Dimension matchedDimension = compositeFieldType .getDimensions ()
110116 .stream ()
111- .map ( Dimension :: getField )
112- .collect ( Collectors . toList ());
113- queryMap = getStarTreePredicates ( queryBuilder , supportedDimensions );
114- if (queryMap == null ) {
117+ .filter ( d -> ( d . getField (). equals ( termQueryBuilder . fieldName ()) && d . getDocValuesType () == DocValuesType . SORTED_NUMERIC ) )
118+ .findFirst ()
119+ . orElse ( null );
120+ if (matchedDimension == null ) {
115121 return null ;
116122 }
123+ queryMap = Map .of (termQueryBuilder .fieldName (), Long .parseLong (termQueryBuilder .value ().toString ()));
117124 } else {
118125 return null ;
119126 }
120127 return new StarTreeQueryContext (compositeIndexFieldInfo , queryMap , cacheStarTreeValuesSize );
121128 }
122129
123- /**
124- * Parse query body to star-tree predicates
125- * @param queryBuilder to match star-tree supported query shape
126- * @return predicates to match
127- */
128- private static Map <String , Long > getStarTreePredicates (QueryBuilder queryBuilder , List <String > supportedDimensions ) {
129- TermQueryBuilder tq = (TermQueryBuilder ) queryBuilder ;
130- String field = tq .fieldName ();
131- if (!supportedDimensions .contains (field )) {
132- return null ;
133- }
134- long inputQueryVal = Long .parseLong (tq .value ().toString ());
135-
136- // Create a map with the field and the value
137- Map <String , Long > predicateMap = new HashMap <>();
138- predicateMap .put (field , inputQueryVal );
139- return predicateMap ;
140- }
141-
142- private static MetricStat validateStarTreeMetricSupport (
130+ private static boolean validateStarTreeMetricSupport (
143131 CompositeDataCubeFieldType compositeIndexFieldInfo ,
144132 AggregatorFactory aggregatorFactory
145133 ) {
146- if (aggregatorFactory instanceof MetricAggregatorFactory && aggregatorFactory .getSubFactories ().getFactories ().length == 0 ) {
134+ if (aggregatorFactory instanceof MetricAggregatorFactory metricAggregatorFactory
135+ && metricAggregatorFactory .getSubFactories ().getFactories ().length == 0 ) {
147136 String field ;
148137 Map <String , List <MetricStat >> supportedMetrics = compositeIndexFieldInfo .getMetrics ()
149138 .stream ()
150139 .collect (Collectors .toMap (Metric ::getField , Metric ::getMetrics ));
151140
152- MetricStat metricStat = ((MetricAggregatorFactory ) aggregatorFactory ).getMetricStat ();
153- field = ((MetricAggregatorFactory ) aggregatorFactory ).getField ();
141+ MetricStat metricStat = metricAggregatorFactory .getMetricStat ();
142+ field = metricAggregatorFactory .getField ();
143+
144+ return supportedMetrics .containsKey (field ) && supportedMetrics .get (field ).contains (metricStat );
145+ }
146+ return false ;
147+ }
148+
149+ private static boolean validateDateHistogramSupport (
150+ CompositeDataCubeFieldType compositeIndexFieldInfo ,
151+ AggregatorFactory aggregatorFactory
152+ ) {
153+ if (!(aggregatorFactory instanceof DateHistogramAggregatorFactory dateHistogramAggregatorFactory )
154+ || aggregatorFactory .getSubFactories ().getFactories ().length < 1 ) {
155+ return false ;
156+ }
157+
158+ // Find the DateDimension in the dimensions list
159+ DateDimension starTreeDateDimension = null ;
160+ for (Dimension dimension : compositeIndexFieldInfo .getDimensions ()) {
161+ if (dimension instanceof DateDimension ) {
162+ starTreeDateDimension = (DateDimension ) dimension ;
163+ break ;
164+ }
165+ }
166+
167+ // If no DateDimension is found, validation fails
168+ if (starTreeDateDimension == null ) {
169+ return false ;
170+ }
171+
172+ // Ensure the rounding is not null
173+ if (dateHistogramAggregatorFactory .getRounding () == null ) {
174+ return false ;
175+ }
176+
177+ // Find the closest valid interval in the DateTimeUnitRounding class associated with star tree
178+ DateTimeUnitRounding rounding = starTreeDateDimension .findClosestValidInterval (
179+ new DateTimeUnitAdapter (dateHistogramAggregatorFactory .getRounding ())
180+ );
181+ if (rounding == null ) {
182+ return false ;
183+ }
154184
155- if (field != null && supportedMetrics .containsKey (field ) && supportedMetrics .get (field ).contains (metricStat )) {
156- return metricStat ;
185+ // Validate all sub-factories
186+ for (AggregatorFactory subFactory : aggregatorFactory .getSubFactories ().getFactories ()) {
187+ if (!validateStarTreeMetricSupport (compositeIndexFieldInfo , subFactory )) {
188+ return false ;
157189 }
158190 }
159- return null ;
191+ return true ;
160192 }
161193
162194 public static CompositeIndexFieldInfo getSupportedStarTree (SearchContext context ) {
@@ -222,11 +254,37 @@ public static LeafBucketCollector getStarTreeLeafCollector(
222254 // Call the final consumer after processing all entries
223255 finalConsumer .run ();
224256
225- // Return a LeafBucketCollector that terminates collection
226- return new LeafBucketCollectorBase (sub , valuesSource .doubleValues (ctx )) {
257+ // Terminate after pre-computing aggregation
258+ throw new CollectionTerminatedException ();
259+ }
260+
261+ public static StarTreeBucketCollector getStarTreeBucketMetricCollector (
262+ CompositeIndexFieldInfo starTree ,
263+ String metric ,
264+ ValuesSource .Numeric valuesSource ,
265+ StarTreeBucketCollector parentCollector ,
266+ Consumer <Long > growArrays ,
267+ BiConsumer <Long , Long > updateBucket
268+ ) throws IOException {
269+ assert parentCollector != null ;
270+ return new StarTreeBucketCollector (parentCollector ) {
271+ String metricName = StarTreeUtils .fullyQualifiedFieldNameForStarTreeMetricsDocValues (
272+ starTree .getField (),
273+ ((ValuesSource .Numeric .FieldData ) valuesSource ).getIndexFieldName (),
274+ metric
275+ );
276+ SortedNumericStarTreeValuesIterator metricValuesIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
277+ .getMetricValuesIterator (metricName );
278+
227279 @ Override
228- public void collect (int doc , long bucket ) {
229- throw new CollectionTerminatedException ();
280+ public void collectStarTreeEntry (int starTreeEntryBit , long bucket ) throws IOException {
281+ growArrays .accept (bucket );
282+ // Advance the valuesIterator to the current bit
283+ if (!metricValuesIterator .advanceExact (starTreeEntryBit )) {
284+ return ; // Skip if no entries for this document
285+ }
286+ long metricValue = metricValuesIterator .nextValue ();
287+ updateBucket .accept (bucket , metricValue );
230288 }
231289 };
232290 }
@@ -240,7 +298,7 @@ public static FixedBitSet getStarTreeFilteredValues(SearchContext context, LeafR
240298 throws IOException {
241299 FixedBitSet result = context .getStarTreeQueryContext ().getStarTreeValues (ctx );
242300 if (result == null ) {
243- result = StarTreeFilter .getStarTreeResult (starTreeValues , context .getStarTreeQueryContext ().getQueryMap ());
301+ result = StarTreeFilter .getStarTreeResult (starTreeValues , context .getStarTreeQueryContext ().getQueryMap (), Set . of () );
244302 context .getStarTreeQueryContext ().setStarTreeValues (ctx , result );
245303 }
246304 return result ;
0 commit comments