3333#include " absl/container/flat_hash_map.h"
3434#include " absl/container/flat_hash_set.h"
3535#include " absl/container/node_hash_map.h"
36+ #include " absl/functional/any_invocable.h"
3637#include " absl/log/absl_check.h"
3738#include " absl/log/check.h"
3839#include " absl/status/status.h"
@@ -104,6 +105,7 @@ using ::cel::runtime_internal::IssueCollector;
104105
105106constexpr absl::string_view kOptionalOrFn = " or" ;
106107constexpr absl::string_view kOptionalOrValueFn = " orValue" ;
108+ constexpr absl::string_view kBlock = " cel.@block" ;
107109
108110// Forward declare to resolve circular dependency for short_circuiting visitors.
109111class FlatExprVisitor ;
@@ -375,7 +377,7 @@ bool IsBind(const cel::ast_internal::Comprehension* comprehension) {
375377}
376378
377379bool IsBlock (const cel::ast_internal::Call* call) {
378- return call->function () == " cel.@block " ;
380+ return call->function () == kBlock ;
379381}
380382
381383// Visitor for Comprehension expressions.
@@ -464,6 +466,19 @@ absl::flat_hash_set<int32_t> MakeOptionalIndicesSet(
464466
465467class FlatExprVisitor : public cel ::AstVisitor {
466468 public:
469+ enum class CallHandlerResult {
470+ // The call was intercepted, no additional processing is needed.
471+ kIntercepted ,
472+ // The call was not intercepted, continue with the default processing.
473+ kNotIntercepted ,
474+ };
475+
476+ // Handler for functions with builtin implementations.
477+ // This is used to replace the usual dispatcher step that applies
478+ // the arguments to a candidate function from the function registry.
479+ using CallHandler = absl::AnyInvocable<CallHandlerResult(
480+ const cel::ast_internal::Expr&, const cel::ast_internal::Call&)>;
481+
467482 FlatExprVisitor (
468483 const Resolver& resolver, const cel::RuntimeOptions& options,
469484 std::vector<std::unique_ptr<ProgramOptimizer>> program_optimizers,
@@ -481,7 +496,22 @@ class FlatExprVisitor : public cel::AstVisitor {
481496 issue_collector_(issue_collector),
482497 program_builder_(program_builder),
483498 extension_context_(extension_context),
484- enable_optional_types_(enable_optional_types) {}
499+ enable_optional_types_(enable_optional_types) {
500+ call_handlers_[cel::builtin::kIndex ] =
501+ [this ](const cel::ast_internal::Expr& expr,
502+ const cel::ast_internal::Call& call) {
503+ return HandleIndex (expr, call);
504+ };
505+ call_handlers_[kBlock ] = [this ](const cel::ast_internal::Expr& expr,
506+ const cel::ast_internal::Call& call) {
507+ return HandleBlock (expr, call);
508+ };
509+ call_handlers_[cel::builtin::kAdd ] =
510+ [this ](const cel::ast_internal::Expr& expr,
511+ const cel::ast_internal::Call& call) {
512+ return HandleListAppend (expr, call);
513+ };
514+ }
485515
486516 void PreVisitExpr (const cel::ast_internal::Expr& expr) override {
487517 ValidateOrError (!absl::holds_alternative<cel::UnspecifiedExpr>(expr.kind ()),
@@ -1209,83 +1239,19 @@ class FlatExprVisitor : public cel::AstVisitor {
12091239 if (cond_visitor) {
12101240 cond_visitor->PostVisit (&expr);
12111241 cond_visitor_stack_.pop ();
1212- if (call_expr.function () == cel::builtin::kTernary ) {
1213- MaybeMakeTernaryRecursive (&expr);
1214- } else if (call_expr.function () == cel::builtin::kOr ) {
1215- MaybeMakeShortcircuitRecursive (&expr, /* is_or= */ true );
1216- } else if (call_expr.function () == cel::builtin::kAnd ) {
1217- MaybeMakeShortcircuitRecursive (&expr, /* is_or= */ false );
1218- } else if (enable_optional_types_) {
1219- if (call_expr.function () == kOptionalOrFn ) {
1220- MaybeMakeOptionalShortcircuitRecursive (&expr,
1221- /* is_or_value= */ false );
1222- } else if (call_expr.function () == kOptionalOrValueFn ) {
1223- MaybeMakeOptionalShortcircuitRecursive (&expr,
1224- /* is_or_value= */ true );
1225- }
1226- }
12271242 return ;
12281243 }
12291244
1230- // Special case for "_[_]".
1231- if (call_expr.function () == cel::builtin::kIndex ) {
1232- auto depth = RecursionEligible ();
1233- if (depth.has_value ()) {
1234- auto args = ExtractRecursiveDependencies ();
1235- if (args.size () != 2 ) {
1236- SetProgressStatusError (absl::InvalidArgumentError (
1237- " unexpected number of args for builtin index operator" ));
1238- return ;
1239- }
1240- SetRecursiveStep (CreateDirectContainerAccessStep (
1241- std::move (args[0 ]), std::move (args[1 ]),
1242- enable_optional_types_, expr.id ()),
1243- *depth + 1 );
1245+ // Check if the call is intercepted by a custom handler.
1246+ if (auto handler = call_handlers_.find (call_expr.function ());
1247+ handler != call_handlers_.end ()) {
1248+ CallHandlerResult result = handler->second (expr, call_expr);
1249+ if (result == CallHandlerResult::kIntercepted ) {
12441250 return ;
1245- }
1246- AddStep (CreateContainerAccessStep (call_expr, expr.id (),
1247- enable_optional_types_));
1248- return ;
1249- }
1250-
1251- if (block_.has_value ()) {
1252- BlockInfo& block = *block_;
1253- if (block.expr == &expr) {
1254- block.in = false ;
1255- index_manager ().ReleaseSlots (block.slot_count );
1256- AddStep (CreateClearSlotsStep (block.index , block.slot_count , -1 ));
1257- return ;
1258- }
1259- }
1260-
1261- // Establish the search criteria for a given function.
1262- absl::string_view function = call_expr.function ();
1263-
1264- // Check to see if this is a special case of add that should really be
1265- // treated as a list append
1266- if (!comprehension_stack_.empty () &&
1267- comprehension_stack_.back ().is_optimizable_list_append ) {
1268- // Already checked that this is an optimizeable comprehension,
1269- // check that this is the correct list append node.
1270- const cel::ast_internal::Comprehension* comprehension =
1271- comprehension_stack_.back ().comprehension ;
1272- const cel::ast_internal::Expr& loop_step = comprehension->loop_step ();
1273- // Macro loop_step for a map() will contain a list concat operation:
1274- // accu_var + [elem]
1275- if (&loop_step == &expr) {
1276- function = cel::builtin::kRuntimeListAppend ;
1277- }
1278- // Macro loop_step for a filter() will contain a ternary:
1279- // filter ? accu_var + [elem] : accu_var
1280- if (loop_step.has_call_expr () &&
1281- loop_step.call_expr ().function () == cel::builtin::kTernary &&
1282- loop_step.call_expr ().args ().size () == 3 &&
1283- &(loop_step.call_expr ().args ()[1 ]) == &expr) {
1284- function = cel::builtin::kRuntimeListAppend ;
1285- }
1251+ } // otherwise, apply default function handling.
12861252 }
12871253
1288- AddResolvedFunctionStep (&call_expr, &expr, function);
1254+ AddResolvedFunctionStep (&call_expr, &expr, call_expr. function () );
12891255 }
12901256
12911257 void PreVisitComprehension (
@@ -1880,9 +1846,17 @@ class FlatExprVisitor : public cel::AstVisitor {
18801846 return std::make_pair (std::move (resolved_name), std::move (fields));
18811847 }
18821848
1849+ CallHandlerResult HandleIndex (const cel::ast_internal::Expr& expr,
1850+ const cel::ast_internal::Call& call);
1851+ CallHandlerResult HandleBlock (const cel::ast_internal::Expr& expr,
1852+ const cel::ast_internal::Call& call);
1853+ CallHandlerResult HandleListAppend (const cel::ast_internal::Expr& expr,
1854+ const cel::ast_internal::Call& call);
1855+
18831856 const Resolver& resolver_;
18841857 ValueManager& value_factory_;
18851858 absl::Status progress_status_;
1859+ absl::flat_hash_map<std::string, CallHandler> call_handlers_;
18861860
18871861 std::stack<
18881862 std::pair<const cel::ast_internal::Expr*, std::unique_ptr<CondVisitor>>>
@@ -1912,6 +1886,82 @@ class FlatExprVisitor : public cel::AstVisitor {
19121886 absl::optional<BlockInfo> block_;
19131887};
19141888
1889+ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex (
1890+ const cel::ast_internal::Expr& expr,
1891+ const cel::ast_internal::Call& call_expr) {
1892+ ABSL_DCHECK (call_expr.function () == cel::builtin::kIndex );
1893+ auto depth = RecursionEligible ();
1894+
1895+ if (depth.has_value ()) {
1896+ auto args = ExtractRecursiveDependencies ();
1897+ if (args.size () != 2 ) {
1898+ SetProgressStatusError (absl::InvalidArgumentError (
1899+ " unexpected number of args for builtin index operator" ));
1900+ return CallHandlerResult::kIntercepted ;
1901+ }
1902+ SetRecursiveStep (
1903+ CreateDirectContainerAccessStep (std::move (args[0 ]), std::move (args[1 ]),
1904+ enable_optional_types_, expr.id ()),
1905+ *depth + 1 );
1906+ return CallHandlerResult::kIntercepted ;
1907+ }
1908+ AddStep (
1909+ CreateContainerAccessStep (call_expr, expr.id (), enable_optional_types_));
1910+ return CallHandlerResult::kIntercepted ;
1911+ }
1912+
1913+ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock (
1914+ const cel::ast_internal::Expr& expr,
1915+ const cel::ast_internal::Call& call_expr) {
1916+ ABSL_DCHECK (call_expr.function () == kBlock );
1917+ if (!block_.has_value () || block_->expr != &expr) {
1918+ SetProgressStatusError (absl::InvalidArgumentError (
1919+ " unexpected number call to internal cel.@block" ));
1920+ return CallHandlerResult::kIntercepted ;
1921+ }
1922+ BlockInfo& block = *block_;
1923+ block.in = false ;
1924+ index_manager ().ReleaseSlots (block.slot_count );
1925+ AddStep (CreateClearSlotsStep (block.index , block.slot_count , -1 ));
1926+ return CallHandlerResult::kIntercepted ;
1927+ }
1928+
1929+ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend (
1930+ const cel::ast_internal::Expr& expr,
1931+ const cel::ast_internal::Call& call_expr) {
1932+ ABSL_DCHECK (call_expr.function () == cel::builtin::kAdd );
1933+
1934+ // Check to see if this is a special case of add that should really be
1935+ // treated as a list append
1936+ if (!comprehension_stack_.empty () &&
1937+ comprehension_stack_.back ().is_optimizable_list_append ) {
1938+ // Already checked that this is an optimizeable comprehension,
1939+ // check that this is the correct list append node.
1940+ const cel::ast_internal::Comprehension* comprehension =
1941+ comprehension_stack_.back ().comprehension ;
1942+ const cel::ast_internal::Expr& loop_step = comprehension->loop_step ();
1943+ // Macro loop_step for a map() will contain a list concat operation:
1944+ // accu_var + [elem]
1945+ if (&loop_step == &expr) {
1946+ AddResolvedFunctionStep (&call_expr, &expr,
1947+ cel::builtin::kRuntimeListAppend );
1948+ return CallHandlerResult::kIntercepted ;
1949+ }
1950+ // Macro loop_step for a filter() will contain a ternary:
1951+ // filter ? accu_var + [elem] : accu_var
1952+ if (loop_step.has_call_expr () &&
1953+ loop_step.call_expr ().function () == cel::builtin::kTernary &&
1954+ loop_step.call_expr ().args ().size () == 3 &&
1955+ &(loop_step.call_expr ().args ()[1 ]) == &expr) {
1956+ AddResolvedFunctionStep (&call_expr, &expr,
1957+ cel::builtin::kRuntimeListAppend );
1958+ return CallHandlerResult::kIntercepted ;
1959+ }
1960+ }
1961+
1962+ return CallHandlerResult::kNotIntercepted ;
1963+ }
1964+
19151965void BinaryCondVisitor::PreVisit (const cel::ast_internal::Expr* expr) {
19161966 switch (cond_) {
19171967 case BinaryCond::kAnd :
@@ -2009,6 +2059,26 @@ void BinaryCondVisitor::PostVisit(const cel::ast_internal::Expr* expr) {
20092059 visitor_->SetProgressStatusError (
20102060 jump_step_.set_target (visitor_->GetCurrentIndex ()));
20112061 }
2062+ // Handle maybe replacing the subprogram with a recursive version. This needs
2063+ // to happen after the jump step is updated (though it may get overwritten).
2064+ switch (cond_) {
2065+ case BinaryCond::kAnd :
2066+ visitor_->MaybeMakeShortcircuitRecursive (expr, /* is_or=*/ false );
2067+ break ;
2068+ case BinaryCond::kOr :
2069+ visitor_->MaybeMakeShortcircuitRecursive (expr, /* is_or=*/ true );
2070+ break ;
2071+ case BinaryCond::kOptionalOr :
2072+ visitor_->MaybeMakeOptionalShortcircuitRecursive (expr,
2073+ /* is_or_value=*/ false );
2074+ break ;
2075+ case BinaryCond::kOptionalOrValue :
2076+ visitor_->MaybeMakeOptionalShortcircuitRecursive (expr,
2077+ /* is_or_value=*/ true );
2078+ break ;
2079+ default :
2080+ ABSL_UNREACHABLE ();
2081+ }
20122082}
20132083
20142084void TernaryCondVisitor::PreVisit (const cel::ast_internal::Expr* expr) {
@@ -2073,7 +2143,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num,
20732143 // clattered.
20742144}
20752145
2076- void TernaryCondVisitor::PostVisit (const cel::ast_internal::Expr*) {
2146+ void TernaryCondVisitor::PostVisit (const cel::ast_internal::Expr* expr ) {
20772147 // Determine and set jump offset in jump instruction.
20782148 if (visitor_->ValidateOrError (
20792149 error_jump_.exists (),
@@ -2087,6 +2157,7 @@ void TernaryCondVisitor::PostVisit(const cel::ast_internal::Expr*) {
20872157 visitor_->SetProgressStatusError (
20882158 jump_after_first_.set_target (visitor_->GetCurrentIndex ()));
20892159 }
2160+ visitor_->MaybeMakeTernaryRecursive (expr);
20902161}
20912162
20922163void ExhaustiveTernaryCondVisitor::PreVisit (
@@ -2099,6 +2170,7 @@ void ExhaustiveTernaryCondVisitor::PreVisit(
20992170void ExhaustiveTernaryCondVisitor::PostVisit (
21002171 const cel::ast_internal::Expr* expr) {
21012172 visitor_->AddStep (CreateTernaryStep (expr->id ()));
2173+ visitor_->MaybeMakeTernaryRecursive (expr);
21022174}
21032175
21042176void ComprehensionVisitor::PreVisit (const cel::ast_internal::Expr* expr) {
0 commit comments