Skip to content

Commit 59e866a

Browse files
jnthntatumcopybara-github
authored andcommitted
Refactor and standardize handling for special cased functions.
PiperOrigin-RevId: 715946371
1 parent 88bfbbf commit 59e866a

File tree

2 files changed

+147
-74
lines changed

2 files changed

+147
-74
lines changed

eval/compiler/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ cc_library(
136136
"@com_google_absl//absl/container:flat_hash_map",
137137
"@com_google_absl//absl/container:flat_hash_set",
138138
"@com_google_absl//absl/container:node_hash_map",
139+
"@com_google_absl//absl/functional:any_invocable",
139140
"@com_google_absl//absl/log:absl_check",
140141
"@com_google_absl//absl/log:check",
141142
"@com_google_absl//absl/status",

eval/compiler/flat_expr_builder.cc

Lines changed: 146 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "absl/container/flat_hash_map.h"
3434
#include "absl/container/flat_hash_set.h"
3535
#include "absl/container/node_hash_map.h"
36+
#include "absl/functional/any_invocable.h"
3637
#include "absl/log/absl_check.h"
3738
#include "absl/log/check.h"
3839
#include "absl/status/status.h"
@@ -104,6 +105,7 @@ using ::cel::runtime_internal::IssueCollector;
104105

105106
constexpr absl::string_view kOptionalOrFn = "or";
106107
constexpr absl::string_view kOptionalOrValueFn = "orValue";
108+
constexpr absl::string_view kBlock = "cel.@block";
107109

108110
// Forward declare to resolve circular dependency for short_circuiting visitors.
109111
class FlatExprVisitor;
@@ -375,7 +377,7 @@ bool IsBind(const cel::ast_internal::Comprehension* comprehension) {
375377
}
376378

377379
bool IsBlock(const cel::ast_internal::Call* call) {
378-
return call->function() == "cel.@block";
380+
return call->function() == kBlock;
379381
}
380382

381383
// Visitor for Comprehension expressions.
@@ -464,6 +466,19 @@ absl::flat_hash_set<int32_t> MakeOptionalIndicesSet(
464466

465467
class FlatExprVisitor : public cel::AstVisitor {
466468
public:
469+
enum class CallHandlerResult {
470+
// The call was intercepted, no additional processing is needed.
471+
kIntercepted,
472+
// The call was not intercepted, continue with the default processing.
473+
kNotIntercepted,
474+
};
475+
476+
// Handler for functions with builtin implementations.
477+
// This is used to replace the usual dispatcher step that applies
478+
// the arguments to a candidate function from the function registry.
479+
using CallHandler = absl::AnyInvocable<CallHandlerResult(
480+
const cel::ast_internal::Expr&, const cel::ast_internal::Call&)>;
481+
467482
FlatExprVisitor(
468483
const Resolver& resolver, const cel::RuntimeOptions& options,
469484
std::vector<std::unique_ptr<ProgramOptimizer>> program_optimizers,
@@ -481,7 +496,22 @@ class FlatExprVisitor : public cel::AstVisitor {
481496
issue_collector_(issue_collector),
482497
program_builder_(program_builder),
483498
extension_context_(extension_context),
484-
enable_optional_types_(enable_optional_types) {}
499+
enable_optional_types_(enable_optional_types) {
500+
call_handlers_[cel::builtin::kIndex] =
501+
[this](const cel::ast_internal::Expr& expr,
502+
const cel::ast_internal::Call& call) {
503+
return HandleIndex(expr, call);
504+
};
505+
call_handlers_[kBlock] = [this](const cel::ast_internal::Expr& expr,
506+
const cel::ast_internal::Call& call) {
507+
return HandleBlock(expr, call);
508+
};
509+
call_handlers_[cel::builtin::kAdd] =
510+
[this](const cel::ast_internal::Expr& expr,
511+
const cel::ast_internal::Call& call) {
512+
return HandleListAppend(expr, call);
513+
};
514+
}
485515

486516
void PreVisitExpr(const cel::ast_internal::Expr& expr) override {
487517
ValidateOrError(!absl::holds_alternative<cel::UnspecifiedExpr>(expr.kind()),
@@ -1209,83 +1239,19 @@ class FlatExprVisitor : public cel::AstVisitor {
12091239
if (cond_visitor) {
12101240
cond_visitor->PostVisit(&expr);
12111241
cond_visitor_stack_.pop();
1212-
if (call_expr.function() == cel::builtin::kTernary) {
1213-
MaybeMakeTernaryRecursive(&expr);
1214-
} else if (call_expr.function() == cel::builtin::kOr) {
1215-
MaybeMakeShortcircuitRecursive(&expr, /* is_or= */ true);
1216-
} else if (call_expr.function() == cel::builtin::kAnd) {
1217-
MaybeMakeShortcircuitRecursive(&expr, /* is_or= */ false);
1218-
} else if (enable_optional_types_) {
1219-
if (call_expr.function() == kOptionalOrFn) {
1220-
MaybeMakeOptionalShortcircuitRecursive(&expr,
1221-
/* is_or_value= */ false);
1222-
} else if (call_expr.function() == kOptionalOrValueFn) {
1223-
MaybeMakeOptionalShortcircuitRecursive(&expr,
1224-
/* is_or_value= */ true);
1225-
}
1226-
}
12271242
return;
12281243
}
12291244

1230-
// Special case for "_[_]".
1231-
if (call_expr.function() == cel::builtin::kIndex) {
1232-
auto depth = RecursionEligible();
1233-
if (depth.has_value()) {
1234-
auto args = ExtractRecursiveDependencies();
1235-
if (args.size() != 2) {
1236-
SetProgressStatusError(absl::InvalidArgumentError(
1237-
"unexpected number of args for builtin index operator"));
1238-
return;
1239-
}
1240-
SetRecursiveStep(CreateDirectContainerAccessStep(
1241-
std::move(args[0]), std::move(args[1]),
1242-
enable_optional_types_, expr.id()),
1243-
*depth + 1);
1245+
// Check if the call is intercepted by a custom handler.
1246+
if (auto handler = call_handlers_.find(call_expr.function());
1247+
handler != call_handlers_.end()) {
1248+
CallHandlerResult result = handler->second(expr, call_expr);
1249+
if (result == CallHandlerResult::kIntercepted) {
12441250
return;
1245-
}
1246-
AddStep(CreateContainerAccessStep(call_expr, expr.id(),
1247-
enable_optional_types_));
1248-
return;
1249-
}
1250-
1251-
if (block_.has_value()) {
1252-
BlockInfo& block = *block_;
1253-
if (block.expr == &expr) {
1254-
block.in = false;
1255-
index_manager().ReleaseSlots(block.slot_count);
1256-
AddStep(CreateClearSlotsStep(block.index, block.slot_count, -1));
1257-
return;
1258-
}
1259-
}
1260-
1261-
// Establish the search criteria for a given function.
1262-
absl::string_view function = call_expr.function();
1263-
1264-
// Check to see if this is a special case of add that should really be
1265-
// treated as a list append
1266-
if (!comprehension_stack_.empty() &&
1267-
comprehension_stack_.back().is_optimizable_list_append) {
1268-
// Already checked that this is an optimizeable comprehension,
1269-
// check that this is the correct list append node.
1270-
const cel::ast_internal::Comprehension* comprehension =
1271-
comprehension_stack_.back().comprehension;
1272-
const cel::ast_internal::Expr& loop_step = comprehension->loop_step();
1273-
// Macro loop_step for a map() will contain a list concat operation:
1274-
// accu_var + [elem]
1275-
if (&loop_step == &expr) {
1276-
function = cel::builtin::kRuntimeListAppend;
1277-
}
1278-
// Macro loop_step for a filter() will contain a ternary:
1279-
// filter ? accu_var + [elem] : accu_var
1280-
if (loop_step.has_call_expr() &&
1281-
loop_step.call_expr().function() == cel::builtin::kTernary &&
1282-
loop_step.call_expr().args().size() == 3 &&
1283-
&(loop_step.call_expr().args()[1]) == &expr) {
1284-
function = cel::builtin::kRuntimeListAppend;
1285-
}
1251+
} // otherwise, apply default function handling.
12861252
}
12871253

1288-
AddResolvedFunctionStep(&call_expr, &expr, function);
1254+
AddResolvedFunctionStep(&call_expr, &expr, call_expr.function());
12891255
}
12901256

12911257
void PreVisitComprehension(
@@ -1880,9 +1846,17 @@ class FlatExprVisitor : public cel::AstVisitor {
18801846
return std::make_pair(std::move(resolved_name), std::move(fields));
18811847
}
18821848

1849+
CallHandlerResult HandleIndex(const cel::ast_internal::Expr& expr,
1850+
const cel::ast_internal::Call& call);
1851+
CallHandlerResult HandleBlock(const cel::ast_internal::Expr& expr,
1852+
const cel::ast_internal::Call& call);
1853+
CallHandlerResult HandleListAppend(const cel::ast_internal::Expr& expr,
1854+
const cel::ast_internal::Call& call);
1855+
18831856
const Resolver& resolver_;
18841857
ValueManager& value_factory_;
18851858
absl::Status progress_status_;
1859+
absl::flat_hash_map<std::string, CallHandler> call_handlers_;
18861860

18871861
std::stack<
18881862
std::pair<const cel::ast_internal::Expr*, std::unique_ptr<CondVisitor>>>
@@ -1912,6 +1886,82 @@ class FlatExprVisitor : public cel::AstVisitor {
19121886
absl::optional<BlockInfo> block_;
19131887
};
19141888

1889+
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex(
1890+
const cel::ast_internal::Expr& expr,
1891+
const cel::ast_internal::Call& call_expr) {
1892+
ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex);
1893+
auto depth = RecursionEligible();
1894+
1895+
if (depth.has_value()) {
1896+
auto args = ExtractRecursiveDependencies();
1897+
if (args.size() != 2) {
1898+
SetProgressStatusError(absl::InvalidArgumentError(
1899+
"unexpected number of args for builtin index operator"));
1900+
return CallHandlerResult::kIntercepted;
1901+
}
1902+
SetRecursiveStep(
1903+
CreateDirectContainerAccessStep(std::move(args[0]), std::move(args[1]),
1904+
enable_optional_types_, expr.id()),
1905+
*depth + 1);
1906+
return CallHandlerResult::kIntercepted;
1907+
}
1908+
AddStep(
1909+
CreateContainerAccessStep(call_expr, expr.id(), enable_optional_types_));
1910+
return CallHandlerResult::kIntercepted;
1911+
}
1912+
1913+
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock(
1914+
const cel::ast_internal::Expr& expr,
1915+
const cel::ast_internal::Call& call_expr) {
1916+
ABSL_DCHECK(call_expr.function() == kBlock);
1917+
if (!block_.has_value() || block_->expr != &expr) {
1918+
SetProgressStatusError(absl::InvalidArgumentError(
1919+
"unexpected number call to internal cel.@block"));
1920+
return CallHandlerResult::kIntercepted;
1921+
}
1922+
BlockInfo& block = *block_;
1923+
block.in = false;
1924+
index_manager().ReleaseSlots(block.slot_count);
1925+
AddStep(CreateClearSlotsStep(block.index, block.slot_count, -1));
1926+
return CallHandlerResult::kIntercepted;
1927+
}
1928+
1929+
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend(
1930+
const cel::ast_internal::Expr& expr,
1931+
const cel::ast_internal::Call& call_expr) {
1932+
ABSL_DCHECK(call_expr.function() == cel::builtin::kAdd);
1933+
1934+
// Check to see if this is a special case of add that should really be
1935+
// treated as a list append
1936+
if (!comprehension_stack_.empty() &&
1937+
comprehension_stack_.back().is_optimizable_list_append) {
1938+
// Already checked that this is an optimizeable comprehension,
1939+
// check that this is the correct list append node.
1940+
const cel::ast_internal::Comprehension* comprehension =
1941+
comprehension_stack_.back().comprehension;
1942+
const cel::ast_internal::Expr& loop_step = comprehension->loop_step();
1943+
// Macro loop_step for a map() will contain a list concat operation:
1944+
// accu_var + [elem]
1945+
if (&loop_step == &expr) {
1946+
AddResolvedFunctionStep(&call_expr, &expr,
1947+
cel::builtin::kRuntimeListAppend);
1948+
return CallHandlerResult::kIntercepted;
1949+
}
1950+
// Macro loop_step for a filter() will contain a ternary:
1951+
// filter ? accu_var + [elem] : accu_var
1952+
if (loop_step.has_call_expr() &&
1953+
loop_step.call_expr().function() == cel::builtin::kTernary &&
1954+
loop_step.call_expr().args().size() == 3 &&
1955+
&(loop_step.call_expr().args()[1]) == &expr) {
1956+
AddResolvedFunctionStep(&call_expr, &expr,
1957+
cel::builtin::kRuntimeListAppend);
1958+
return CallHandlerResult::kIntercepted;
1959+
}
1960+
}
1961+
1962+
return CallHandlerResult::kNotIntercepted;
1963+
}
1964+
19151965
void BinaryCondVisitor::PreVisit(const cel::ast_internal::Expr* expr) {
19161966
switch (cond_) {
19171967
case BinaryCond::kAnd:
@@ -2009,6 +2059,26 @@ void BinaryCondVisitor::PostVisit(const cel::ast_internal::Expr* expr) {
20092059
visitor_->SetProgressStatusError(
20102060
jump_step_.set_target(visitor_->GetCurrentIndex()));
20112061
}
2062+
// Handle maybe replacing the subprogram with a recursive version. This needs
2063+
// to happen after the jump step is updated (though it may get overwritten).
2064+
switch (cond_) {
2065+
case BinaryCond::kAnd:
2066+
visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/false);
2067+
break;
2068+
case BinaryCond::kOr:
2069+
visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/true);
2070+
break;
2071+
case BinaryCond::kOptionalOr:
2072+
visitor_->MaybeMakeOptionalShortcircuitRecursive(expr,
2073+
/*is_or_value=*/false);
2074+
break;
2075+
case BinaryCond::kOptionalOrValue:
2076+
visitor_->MaybeMakeOptionalShortcircuitRecursive(expr,
2077+
/*is_or_value=*/true);
2078+
break;
2079+
default:
2080+
ABSL_UNREACHABLE();
2081+
}
20122082
}
20132083

20142084
void TernaryCondVisitor::PreVisit(const cel::ast_internal::Expr* expr) {
@@ -2073,7 +2143,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num,
20732143
// clattered.
20742144
}
20752145

2076-
void TernaryCondVisitor::PostVisit(const cel::ast_internal::Expr*) {
2146+
void TernaryCondVisitor::PostVisit(const cel::ast_internal::Expr* expr) {
20772147
// Determine and set jump offset in jump instruction.
20782148
if (visitor_->ValidateOrError(
20792149
error_jump_.exists(),
@@ -2087,6 +2157,7 @@ void TernaryCondVisitor::PostVisit(const cel::ast_internal::Expr*) {
20872157
visitor_->SetProgressStatusError(
20882158
jump_after_first_.set_target(visitor_->GetCurrentIndex()));
20892159
}
2160+
visitor_->MaybeMakeTernaryRecursive(expr);
20902161
}
20912162

20922163
void ExhaustiveTernaryCondVisitor::PreVisit(
@@ -2099,6 +2170,7 @@ void ExhaustiveTernaryCondVisitor::PreVisit(
20992170
void ExhaustiveTernaryCondVisitor::PostVisit(
21002171
const cel::ast_internal::Expr* expr) {
21012172
visitor_->AddStep(CreateTernaryStep(expr->id()));
2173+
visitor_->MaybeMakeTernaryRecursive(expr);
21022174
}
21032175

21042176
void ComprehensionVisitor::PreVisit(const cel::ast_internal::Expr* expr) {

0 commit comments

Comments
 (0)