From ef155f8fd4f4dc46e7ea8d0c87f2774eccbf48c8 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 19 Nov 2024 11:31:31 -0800 Subject: [PATCH] Add option to set input expression size limit in type checker. If exceeded, type checking fails early instead of fully visiting the input AST. PiperOrigin-RevId: 698088256 --- checker/checker_options.h | 6 ++++++ checker/internal/type_checker_impl.cc | 21 ++++++++++++++++++--- checker/internal/type_checker_impl_test.cc | 18 ++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/checker/checker_options.h b/checker/checker_options.h index 839446180..91fdad3e0 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -42,6 +42,12 @@ struct CheckerOptions { // Enabled by default, but can be disabled to preserve the original type name // as parsed. bool update_struct_type_names = true; + + // Maximum number (inclusive) of expression nodes to check for an input + // expression. + // + // If exceeded, the checker should return a status with code InvalidArgument. + int max_expression_node_count = 100000; }; } // namespace cel diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index bae1c43d7..c1a8ab4aa 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -458,9 +458,9 @@ class ResolveVisitor : public AstVisitorBase { // These are handled separately to disambiguate between namespaces and field // accesses absl::flat_hash_set deferred_select_operations_; - absl::Status status_; std::vector> comprehension_vars_; std::vector comprehension_scopes_; + absl::Status status_; // References that were resolved and may require AST rewrites. absl::flat_hash_map functions_; @@ -1252,8 +1252,23 @@ absl::StatusOr TypeCheckerImpl::Check( TraversalOptions opts; opts.use_comprehension_callbacks = true; - AstTraverse(ast_impl.root_expr(), visitor, opts); - CEL_RETURN_IF_ERROR(visitor.status()); + + auto traversal = AstTraversal::Create(ast_impl.root_expr(), opts); + for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { + bool has_next = traversal.Step(visitor); + if (!visitor.status().ok()) { + return visitor.status(); + } + if (!has_next) { + break; + } + } + + if (!traversal.IsDone()) { + return absl::InvalidArgumentError( + absl::StrCat("Max expression node count exceeded: ", + options_.max_expression_node_count)); + } if (env_.expected_type().has_value()) { visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index a2b6fdade..c72e57a12 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -57,6 +57,7 @@ namespace checker_internal { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Reference; using ::cel::expr::conformance::proto3::TestAllTypes; @@ -65,6 +66,7 @@ using ::testing::_; using ::testing::Contains; using ::testing::ElementsAre; using ::testing::Eq; +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::Property; @@ -1013,6 +1015,22 @@ TEST(TypeCheckerImplTest, NullLiteral) { EXPECT_TRUE(ast_impl.type_map()[1].has_null()); } +TEST(TypeCheckerImplTest, ExpressionLimitInclusive) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + CheckerOptions options; + options.max_expression_node_count = 2; + TypeCheckerImpl impl(std::move(env), options); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}.foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("{}.foo.bar")); + EXPECT_THAT(impl.Check(std::move(ast)), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expression node count exceeded: 2"))); +} + TEST(TypeCheckerImplTest, ComprehensionUnsupportedRange) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena;