Skip to content

Commit 4d416db

Browse files
authored
Add per_second function support for timechart command (#4464)
Add per_second() support to the timechart command by implementing Option 3 (Eval Transformation). --------- Signed-off-by: Chen Dai <[email protected]>
1 parent ef783f1 commit 4d416db

File tree

17 files changed

+622
-14
lines changed

17 files changed

+622
-14
lines changed

core/src/main/java/org/opensearch/sql/ast/tree/Timechart.java

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,43 @@
55

66
package org.opensearch.sql.ast.tree;
77

8+
import static org.opensearch.sql.ast.dsl.AstDSL.aggregate;
9+
import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral;
10+
import static org.opensearch.sql.ast.dsl.AstDSL.eval;
11+
import static org.opensearch.sql.ast.dsl.AstDSL.function;
12+
import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral;
13+
import static org.opensearch.sql.ast.expression.IntervalUnit.SECOND;
14+
import static org.opensearch.sql.ast.tree.Timechart.PerFunctionRateExprBuilder.sum;
15+
import static org.opensearch.sql.ast.tree.Timechart.PerFunctionRateExprBuilder.timestampadd;
16+
import static org.opensearch.sql.ast.tree.Timechart.PerFunctionRateExprBuilder.timestampdiff;
17+
import static org.opensearch.sql.calcite.plan.OpenSearchConstants.IMPLICIT_FIELD_TIMESTAMP;
18+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE;
19+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY;
20+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUM;
21+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPADD;
22+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMPDIFF;
23+
824
import com.google.common.collect.ImmutableList;
925
import java.util.List;
26+
import java.util.Locale;
27+
import java.util.Map;
28+
import java.util.Optional;
1029
import lombok.AllArgsConstructor;
1130
import lombok.EqualsAndHashCode;
1231
import lombok.Getter;
32+
import lombok.RequiredArgsConstructor;
1333
import lombok.ToString;
1434
import org.opensearch.sql.ast.AbstractNodeVisitor;
35+
import org.opensearch.sql.ast.dsl.AstDSL;
36+
import org.opensearch.sql.ast.expression.AggregateFunction;
37+
import org.opensearch.sql.ast.expression.Field;
38+
import org.opensearch.sql.ast.expression.Function;
39+
import org.opensearch.sql.ast.expression.IntervalUnit;
40+
import org.opensearch.sql.ast.expression.Let;
41+
import org.opensearch.sql.ast.expression.Span;
42+
import org.opensearch.sql.ast.expression.SpanUnit;
1543
import org.opensearch.sql.ast.expression.UnresolvedExpression;
44+
import org.opensearch.sql.calcite.utils.PlanUtils;
1645

1746
/** AST node represent Timechart operation. */
1847
@Getter
@@ -49,8 +78,9 @@ public Timechart useOther(Boolean useOther) {
4978
}
5079

5180
@Override
52-
public Timechart attach(UnresolvedPlan child) {
53-
return toBuilder().child(child).build();
81+
public UnresolvedPlan attach(UnresolvedPlan child) {
82+
// Transform after child attached to avoid unintentionally overriding it
83+
return toBuilder().child(child).build().transformPerFunction();
5484
}
5585

5686
@Override
@@ -62,4 +92,112 @@ public List<UnresolvedPlan> getChild() {
6292
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
6393
return nodeVisitor.visitTimechart(this, context);
6494
}
95+
96+
/**
97+
* Transform per function to eval-based post-processing on sum result by timechart. Specifically,
98+
* calculate how many seconds are in the time bucket based on the span option dynamically, then
99+
* divide the aggregated sum value by the number of seconds to get the per-second rate.
100+
*
101+
* <p>For example, with span=5m per_second(field): per second rate = sum(field) / 300 seconds
102+
*
103+
* @return eval+timechart if per function present, or the original timechart otherwise.
104+
*/
105+
private UnresolvedPlan transformPerFunction() {
106+
Optional<PerFunction> perFuncOpt = PerFunction.from(aggregateFunction);
107+
if (perFuncOpt.isEmpty()) {
108+
return this;
109+
}
110+
111+
PerFunction perFunc = perFuncOpt.get();
112+
Span span = (Span) this.binExpression;
113+
Field spanStartTime = AstDSL.field(IMPLICIT_FIELD_TIMESTAMP);
114+
Function spanEndTime = timestampadd(span.getUnit(), span.getValue(), spanStartTime);
115+
Function spanSeconds = timestampdiff(SECOND, spanStartTime, spanEndTime);
116+
117+
return eval(
118+
timechart(AstDSL.alias(perFunc.aggName, sum(perFunc.aggArg))),
119+
let(perFunc.aggName).multiply(perFunc.seconds).dividedBy(spanSeconds));
120+
}
121+
122+
private Timechart timechart(UnresolvedExpression newAggregateFunction) {
123+
return this.toBuilder().aggregateFunction(newAggregateFunction).build();
124+
}
125+
126+
/** TODO: extend to support additional per_* functions */
127+
@RequiredArgsConstructor
128+
static class PerFunction {
129+
private static final Map<String, Integer> UNIT_SECONDS = Map.of("per_second", 1);
130+
private final String aggName;
131+
private final UnresolvedExpression aggArg;
132+
private final int seconds;
133+
134+
static Optional<PerFunction> from(UnresolvedExpression aggExpr) {
135+
if (!(aggExpr instanceof AggregateFunction)) {
136+
return Optional.empty();
137+
}
138+
139+
AggregateFunction aggFunc = (AggregateFunction) aggExpr;
140+
String aggFuncName = aggFunc.getFuncName().toLowerCase(Locale.ROOT);
141+
if (!UNIT_SECONDS.containsKey(aggFuncName)) {
142+
return Optional.empty();
143+
}
144+
145+
String aggName = toAggName(aggFunc);
146+
return Optional.of(
147+
new PerFunction(aggName, aggFunc.getField(), UNIT_SECONDS.get(aggFuncName)));
148+
}
149+
150+
private static String toAggName(AggregateFunction aggFunc) {
151+
String fieldName =
152+
(aggFunc.getField() instanceof Field)
153+
? ((Field) aggFunc.getField()).getField().toString()
154+
: aggFunc.getField().toString();
155+
return String.format(Locale.ROOT, "%s(%s)", aggFunc.getFuncName(), fieldName);
156+
}
157+
}
158+
159+
private PerFunctionRateExprBuilder let(String fieldName) {
160+
return new PerFunctionRateExprBuilder(AstDSL.field(fieldName));
161+
}
162+
163+
/** Fluent builder for creating Let expressions with mathematical operations. */
164+
static class PerFunctionRateExprBuilder {
165+
private final Field field;
166+
private UnresolvedExpression expr;
167+
168+
PerFunctionRateExprBuilder(Field field) {
169+
this.field = field;
170+
this.expr = field;
171+
}
172+
173+
PerFunctionRateExprBuilder multiply(Integer multiplier) {
174+
// Promote to double literal to avoid integer division in downstream
175+
this.expr =
176+
function(
177+
MULTIPLY.getName().getFunctionName(), expr, doubleLiteral(multiplier.doubleValue()));
178+
return this;
179+
}
180+
181+
Let dividedBy(UnresolvedExpression divisor) {
182+
return AstDSL.let(field, function(DIVIDE.getName().getFunctionName(), expr, divisor));
183+
}
184+
185+
static UnresolvedExpression sum(UnresolvedExpression field) {
186+
return aggregate(SUM.getName().getFunctionName(), field);
187+
}
188+
189+
static Function timestampadd(
190+
SpanUnit unit, UnresolvedExpression value, UnresolvedExpression timestampField) {
191+
UnresolvedExpression intervalUnit =
192+
stringLiteral(PlanUtils.spanUnitToIntervalUnit(unit).toString());
193+
return function(
194+
TIMESTAMPADD.getName().getFunctionName(), intervalUnit, value, timestampField);
195+
}
196+
197+
static Function timestampdiff(
198+
IntervalUnit unit, UnresolvedExpression start, UnresolvedExpression end) {
199+
return function(
200+
TIMESTAMPDIFF.getName().getFunctionName(), stringLiteral(unit.toString()), start, end);
201+
}
202+
}
65203
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,9 @@ public RelNode visitFlatten(Flatten node, CalcitePlanContext context) {
19141914

19151915
/** Helper method to get the function name for proper column naming */
19161916
private String getValueFunctionName(UnresolvedExpression aggregateFunction) {
1917+
if (aggregateFunction instanceof Alias) {
1918+
return ((Alias) aggregateFunction).getName();
1919+
}
19171920
if (!(aggregateFunction instanceof AggregateFunction)) {
19181921
return "value";
19191922
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,59 @@ static SpanUnit intervalUnitToSpanUnit(IntervalUnit unit) {
7474
};
7575
}
7676

77+
static IntervalUnit spanUnitToIntervalUnit(SpanUnit unit) {
78+
switch (unit) {
79+
case MILLISECOND:
80+
case MS:
81+
return IntervalUnit.MICROSECOND;
82+
case SECOND:
83+
case SECONDS:
84+
case SEC:
85+
case SECS:
86+
case S:
87+
return IntervalUnit.SECOND;
88+
case MINUTE:
89+
case MINUTES:
90+
case MIN:
91+
case MINS:
92+
case m:
93+
return IntervalUnit.MINUTE;
94+
case HOUR:
95+
case HOURS:
96+
case HR:
97+
case HRS:
98+
case H:
99+
return IntervalUnit.HOUR;
100+
case DAY:
101+
case DAYS:
102+
case D:
103+
return IntervalUnit.DAY;
104+
case WEEK:
105+
case WEEKS:
106+
case W:
107+
return IntervalUnit.WEEK;
108+
case MONTH:
109+
case MONTHS:
110+
case MON:
111+
case M:
112+
return IntervalUnit.MONTH;
113+
case QUARTER:
114+
case QUARTERS:
115+
case QTR:
116+
case QTRS:
117+
case Q:
118+
return IntervalUnit.QUARTER;
119+
case YEAR:
120+
case YEARS:
121+
case Y:
122+
return IntervalUnit.YEAR;
123+
case UNKNOWN:
124+
return IntervalUnit.UNKNOWN;
125+
default:
126+
throw new UnsupportedOperationException("Unsupported span unit: " + unit);
127+
}
128+
}
129+
77130
static RexNode makeOver(
78131
CalcitePlanContext context,
79132
BuiltinFunctionName functionName,
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.ast.tree;
7+
8+
import static org.junit.jupiter.api.Assertions.assertEquals;
9+
import static org.opensearch.sql.ast.dsl.AstDSL.aggregate;
10+
import static org.opensearch.sql.ast.dsl.AstDSL.alias;
11+
import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral;
12+
import static org.opensearch.sql.ast.dsl.AstDSL.field;
13+
import static org.opensearch.sql.ast.dsl.AstDSL.function;
14+
import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral;
15+
import static org.opensearch.sql.ast.dsl.AstDSL.relation;
16+
17+
import org.junit.jupiter.api.Test;
18+
import org.junit.jupiter.params.ParameterizedTest;
19+
import org.junit.jupiter.params.provider.CsvSource;
20+
import org.opensearch.sql.ast.dsl.AstDSL;
21+
import org.opensearch.sql.ast.expression.AggregateFunction;
22+
import org.opensearch.sql.ast.expression.Let;
23+
import org.opensearch.sql.ast.expression.Span;
24+
import org.opensearch.sql.ast.expression.SpanUnit;
25+
import org.opensearch.sql.ast.expression.UnresolvedExpression;
26+
27+
class TimechartTest {
28+
29+
@ParameterizedTest
30+
@CsvSource({"1, m, MINUTE", "30, s, SECOND", "5, m, MINUTE", "2, h, HOUR", "1, d, DAY"})
31+
void should_transform_per_second_for_different_spans(
32+
int spanValue, String spanUnit, String expectedIntervalUnit) {
33+
withTimechart(span(spanValue, spanUnit), perSecond("bytes"))
34+
.whenTransformingPerFunction()
35+
.thenExpect(
36+
eval(
37+
let(
38+
"per_second(bytes)",
39+
divide(
40+
multiply("per_second(bytes)", 1.0),
41+
timestampdiff(
42+
"SECOND",
43+
"@timestamp",
44+
timestampadd(expectedIntervalUnit, spanValue, "@timestamp")))),
45+
timechart(span(spanValue, spanUnit), alias("per_second(bytes)", sum("bytes")))));
46+
}
47+
48+
@Test
49+
void should_not_transform_non_per_functions() {
50+
withTimechart(span(1, "m"), sum("bytes"))
51+
.whenTransformingPerFunction()
52+
.thenExpect(timechart(span(1, "m"), sum("bytes")));
53+
}
54+
55+
@Test
56+
void should_preserve_all_fields_during_per_function_transformation() {
57+
Timechart original =
58+
new Timechart(relation("logs"), perSecond("bytes"))
59+
.span(span(5, "m"))
60+
.by(field("status"))
61+
.limit(20)
62+
.useOther(false);
63+
64+
Timechart expected =
65+
new Timechart(relation("logs"), alias("per_second(bytes)", sum("bytes")))
66+
.span(span(5, "m"))
67+
.by(field("status"))
68+
.limit(20)
69+
.useOther(false);
70+
71+
withTimechart(original)
72+
.whenTransformingPerFunction()
73+
.thenExpect(
74+
eval(
75+
let(
76+
"per_second(bytes)",
77+
divide(
78+
multiply("per_second(bytes)", 1.0),
79+
timestampdiff(
80+
"SECOND", "@timestamp", timestampadd("MINUTE", 5, "@timestamp")))),
81+
expected));
82+
}
83+
84+
// Fluent API for readable test assertions
85+
86+
private static TransformationAssertion withTimechart(Span spanExpr, AggregateFunction aggFunc) {
87+
return new TransformationAssertion(timechart(spanExpr, aggFunc));
88+
}
89+
90+
private static TransformationAssertion withTimechart(Timechart timechart) {
91+
return new TransformationAssertion(timechart);
92+
}
93+
94+
private static Timechart timechart(Span spanExpr, UnresolvedExpression aggExpr) {
95+
// Set child here because expected object won't call attach below
96+
return new Timechart(relation("t"), aggExpr).span(spanExpr).limit(10).useOther(true);
97+
}
98+
99+
private static Span span(int value, String unit) {
100+
return AstDSL.span(field("@timestamp"), intLiteral(value), SpanUnit.of(unit));
101+
}
102+
103+
private static AggregateFunction perSecond(String fieldName) {
104+
return (AggregateFunction) aggregate("per_second", field(fieldName));
105+
}
106+
107+
private static AggregateFunction sum(String fieldName) {
108+
return (AggregateFunction) aggregate("sum", field(fieldName));
109+
}
110+
111+
private static Let let(String fieldName, UnresolvedExpression expression) {
112+
return AstDSL.let(field(fieldName), expression);
113+
}
114+
115+
private static UnresolvedExpression multiply(String fieldName, double right) {
116+
return function("*", field(fieldName), doubleLiteral(right));
117+
}
118+
119+
private static UnresolvedExpression divide(
120+
UnresolvedExpression left, UnresolvedExpression right) {
121+
return function("/", left, right);
122+
}
123+
124+
private static UnresolvedExpression timestampadd(String unit, int value, String timestampField) {
125+
return function(
126+
"timestampadd", AstDSL.stringLiteral(unit), intLiteral(value), field(timestampField));
127+
}
128+
129+
private static UnresolvedExpression timestampdiff(
130+
String unit, String startField, UnresolvedExpression end) {
131+
return function("timestampdiff", AstDSL.stringLiteral(unit), field(startField), end);
132+
}
133+
134+
private static UnresolvedPlan eval(Let letExpr, Timechart timechartExpr) {
135+
return AstDSL.eval(timechartExpr, letExpr);
136+
}
137+
138+
private static class TransformationAssertion {
139+
private final Timechart timechart;
140+
private UnresolvedPlan result;
141+
142+
TransformationAssertion(Timechart timechart) {
143+
this.timechart = timechart;
144+
}
145+
146+
public TransformationAssertion whenTransformingPerFunction() {
147+
this.result = timechart.attach(timechart.getChild().get(0));
148+
return this;
149+
}
150+
151+
public void thenExpect(UnresolvedPlan expected) {
152+
assertEquals(expected, result);
153+
}
154+
}
155+
}

0 commit comments

Comments
 (0)