Skip to content

Commit f7d4ff9

Browse files
jnthntatumcopybara-github
authored andcommitted
Optimize bookkeeping in the program builder:
- don't apply incremental cleanup as program steps are rewritten, just delete everything with ProgramBuilder dtor. - hint expected size for program nodes - reserve size for the special function handlder map PiperOrigin-RevId: 742338612
1 parent 76fbb08 commit f7d4ff9

File tree

6 files changed

+210
-144
lines changed

6 files changed

+210
-144
lines changed

eval/compiler/constant_folding.cc

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ using ::google::api::expr::runtime::ProgramOptimizer;
7171
using ::google::api::expr::runtime::ProgramOptimizerFactory;
7272
using ::google::api::expr::runtime::Resolver;
7373

74+
enum class IsConst {
75+
kConditional,
76+
kNonConst,
77+
};
78+
7479
class ConstantFoldingExtension : public ProgramOptimizer {
7580
public:
7681
ConstantFoldingExtension(
@@ -92,10 +97,6 @@ class ConstantFoldingExtension : public ProgramOptimizer {
9297
const Expr& node) override;
9398

9499
private:
95-
enum class IsConst {
96-
kConditional,
97-
kNonConst,
98-
};
99100
// Most constant folding evaluations are simple
100101
// binary operators.
101102
static constexpr size_t kDefaultStackLimit = 4;
@@ -114,51 +115,42 @@ class ConstantFoldingExtension : public ProgramOptimizer {
114115
std::vector<IsConst> is_const_;
115116
};
116117

117-
absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context,
118-
const Expr& node) {
119-
struct IsConstVisitor {
120-
IsConst operator()(const Constant&) { return IsConst::kConditional; }
121-
IsConst operator()(const IdentExpr&) { return IsConst::kNonConst; }
122-
IsConst operator()(const ComprehensionExpr&) {
118+
IsConst IsConstExpr(const Expr& expr, const Resolver& resolver) {
119+
switch (expr.kind_case()) {
120+
case ExprKindCase::kConstant:
121+
return IsConst::kConditional;
122+
case ExprKindCase::kIdentExpr:
123+
return IsConst::kNonConst;
124+
case ExprKindCase::kComprehensionExpr:
123125
// Not yet supported, need to identify whether range and
124126
// iter vars are compatible with const folding.
125127
return IsConst::kNonConst;
126-
}
127-
IsConst operator()(const StructExpr& create_struct) {
128+
case ExprKindCase::kStructExpr:
128129
return IsConst::kNonConst;
129-
}
130-
IsConst operator()(const cel::MapExpr& map_expr) {
131-
// Not yet supported but should be possible in the future.
130+
case ExprKindCase::kMapExpr:
132131
// Empty maps are rare and not currently supported as they may eventually
133132
// have similar issues to empty list when used within comprehensions or
134133
// macros.
135-
if (map_expr.entries().empty()) {
134+
if (expr.map_expr().entries().empty()) {
136135
return IsConst::kNonConst;
137136
}
138137
return IsConst::kConditional;
139-
}
140-
IsConst operator()(const ListExpr& create_list) {
141-
if (create_list.elements().empty()) {
142-
// TODO: Don't fold for empty list to allow comprehension
138+
case ExprKindCase::kListExpr:
139+
if (expr.list_expr().elements().empty()) {
140+
// Don't fold for empty list to allow comprehension
143141
// list append optimization.
144142
return IsConst::kNonConst;
145143
}
146144
return IsConst::kConditional;
147-
}
148-
149-
IsConst operator()(const SelectExpr&) { return IsConst::kConditional; }
150-
151-
IsConst operator()(const cel::UnspecifiedExpr&) {
152-
return IsConst::kNonConst;
153-
}
154-
155-
IsConst operator()(const CallExpr& call) {
145+
case ExprKindCase::kSelectExpr:
146+
return IsConst::kConditional;
147+
case ExprKindCase::kCallExpr: {
148+
const auto& call = expr.call_expr();
156149
// Short Circuiting operators not yet supported.
157150
if (call.function() == kAnd || call.function() == kOr ||
158151
call.function() == kTernary) {
159152
return IsConst::kNonConst;
160153
}
161-
162154
// For now we skip constant folding for cel.@block. We do not yet setup
163155
// slots. When we enable constant folding for comprehensions (like
164156
// cel.bind), we can address cel.@block.
@@ -167,23 +159,24 @@ absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context,
167159
}
168160

169161
int arg_len = call.args().size() + (call.has_target() ? 1 : 0);
170-
std::vector<cel::Kind> arg_matcher(arg_len, cel::Kind::kAny);
171162
// Check for any lazy overloads (activation dependant)
172163
if (!resolver
173-
.FindLazyOverloads(call.function(), call.has_target(),
174-
arg_matcher)
164+
.FindLazyOverloads(call.function(), call.has_target(), arg_len)
175165
.empty()) {
176166
return IsConst::kNonConst;
177167
}
178168

179169
return IsConst::kConditional;
180170
}
171+
case ExprKindCase::kUnspecifiedExpr:
172+
default:
173+
return IsConst::kNonConst;
174+
}
175+
}
181176

182-
const Resolver& resolver;
183-
};
184-
185-
IsConst is_const =
186-
absl::visit(IsConstVisitor{context.resolver()}, node.kind());
177+
absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context,
178+
const Expr& node) {
179+
IsConst is_const = IsConstExpr(node, context.resolver());
187180
is_const_.push_back(is_const);
188181

189182
return absl::OkStatus();

eval/compiler/constant_folding_test.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,37 +352,37 @@ TEST_F(UpdatedConstantFoldingTest, CreatesLargeList) {
352352

353353
ProgramBuilder program_builder;
354354
// Simulate the visitor order.
355-
program_builder.EnterSubexpression(&create_list);
355+
ASSERT_TRUE(program_builder.EnterSubexpression(&create_list) != nullptr);
356356

357357
// 0
358-
program_builder.EnterSubexpression(&elem0);
358+
ASSERT_TRUE(program_builder.EnterSubexpression(&elem0) != nullptr);
359359
ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1));
360360
program_builder.AddStep(std::move(step));
361361
program_builder.ExitSubexpression(&elem0);
362362

363363
// 1
364-
program_builder.EnterSubexpression(&elem1);
364+
ASSERT_TRUE(program_builder.EnterSubexpression(&elem1));
365365
ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2));
366366
program_builder.AddStep(std::move(step));
367367
program_builder.ExitSubexpression(&elem1);
368368

369369
// 2
370-
program_builder.EnterSubexpression(&elem2);
370+
ASSERT_TRUE(program_builder.EnterSubexpression(&elem2) != nullptr);
371371
ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(3L), 3));
372372
program_builder.AddStep(std::move(step));
373373
program_builder.ExitSubexpression(&elem2);
374374

375375
// 3
376-
program_builder.EnterSubexpression(&elem2);
376+
ASSERT_TRUE(program_builder.EnterSubexpression(&elem3) != nullptr);
377377
ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(4L), 4));
378378
program_builder.AddStep(std::move(step));
379-
program_builder.ExitSubexpression(&elem2);
379+
program_builder.ExitSubexpression(&elem3);
380380

381381
// 4
382-
program_builder.EnterSubexpression(&elem2);
382+
ASSERT_TRUE(program_builder.EnterSubexpression(&elem4) != nullptr);
383383
ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(5L), 5));
384384
program_builder.AddStep(std::move(step));
385-
program_builder.ExitSubexpression(&elem2);
385+
program_builder.ExitSubexpression(&elem4);
386386

387387
// createlist
388388
ASSERT_OK_AND_ASSIGN(step, CreateCreateListStep(create_list.list_expr(), 6));

eval/compiler/flat_expr_builder.cc

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,31 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor {
257257
FlatExprVisitor* visitor_;
258258
};
259259

260+
// Returns a hint for the number of program nodes (steps or subexpressions) that
261+
// will be created for this expr.
262+
size_t SizeHint(const cel::Expr& expr) {
263+
switch (expr.kind_case()) {
264+
case cel::ExprKindCase::kConstant:
265+
return 1;
266+
case cel::ExprKindCase::kIdentExpr:
267+
return 1;
268+
case cel::ExprKindCase::kSelectExpr:
269+
return 2;
270+
case cel::ExprKindCase::kCallExpr:
271+
return expr.call_expr().args().size() +
272+
(expr.call_expr().has_target() ? 2 : 1);
273+
case cel::ExprKindCase::kListExpr:
274+
return expr.list_expr().elements().size() + 1;
275+
case cel::ExprKindCase::kStructExpr:
276+
return expr.struct_expr().fields().size() + 1;
277+
case cel::ExprKindCase::kMapExpr:
278+
return 2 * expr.struct_expr().fields().size() + 1;
279+
default:
280+
return 1;
281+
}
282+
return 0;
283+
}
284+
260285
// Returns whether this comprehension appears to be a standard map/filter
261286
// macro implementation. It is not exhaustive, so it is unsafe to use with
262287
// custom comprehensions outside of the standard macros or hand crafted ASTs.
@@ -491,6 +516,8 @@ class FlatExprVisitor : public cel::AstVisitor {
491516
program_builder_(program_builder),
492517
extension_context_(extension_context),
493518
enable_optional_types_(enable_optional_types) {
519+
constexpr size_t kCallHandlerSizeHint = 11;
520+
call_handlers_.reserve(kCallHandlerSizeHint);
494521
call_handlers_[cel::builtin::kIndex] = [this](const cel::Expr& expr,
495522
const cel::CallExpr& call) {
496523
return HandleIndex(expr, call);
@@ -565,7 +592,13 @@ class FlatExprVisitor : public cel::AstVisitor {
565592
}
566593
}
567594

568-
program_builder_.EnterSubexpression(&expr);
595+
auto* subexpression =
596+
program_builder_.EnterSubexpression(&expr, SizeHint(expr));
597+
if (subexpression == nullptr) {
598+
progress_status_.Update(
599+
absl::InternalError("same CEL expr visited twice"));
600+
return;
601+
}
569602

570603
for (const std::unique_ptr<ProgramOptimizer>& optimizer :
571604
program_optimizers_) {

0 commit comments

Comments
 (0)