diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 2c694dd2e..9cad4ae72 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -107,6 +107,10 @@ class TypeCheckEnv { container_ = std::move(container); } + void set_expected_type(const Type& type) { expected_type_ = std::move(type); } + + const absl::optional& expected_type() const { return expected_type_; } + absl::Span> type_providers() const { return type_providers_; } @@ -198,6 +202,8 @@ class TypeCheckEnv { // Type providers for custom types. std::vector> type_providers_; + + absl::optional expected_type_; }; } // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 4c1975bfa..fd665b5d9 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -239,7 +239,7 @@ absl::StatusOr FlattenType(const Type& type) { return absl::InternalError( absl::StrCat("Unsupported type: ", type.DebugString())); } -} // namespace +} class ResolveVisitor : public AstVisitorBase { public: @@ -322,6 +322,13 @@ class ResolveVisitor : public AstVisitorBase { const absl::Status& status() const { return status_; } + void AssertExpectedType(const Expr& expr, const Type& expected_type) { + Type observed = GetTypeOrDyn(&expr); + if (!inference_context_->IsAssignable(observed, expected_type)) { + ReportTypeMismatch(expr.id(), expected_type, observed); + } + } + private: struct ComprehensionScope { const Expr* comprehension_expr; @@ -1231,6 +1238,10 @@ absl::StatusOr TypeCheckerImpl::Check( AstTraverse(ast_impl.root_expr(), visitor, opts); CEL_RETURN_IF_ERROR(visitor.status()); + if (env_.expected_type().has_value()) { + visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); + } + // If any issues are errors, return without an AST. for (const auto& issue : issues) { if (issue.severity() == Severity::kError) { diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index f28621030..1b9062ec1 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -33,7 +33,7 @@ namespace cel::checker_internal { // See cel::TypeCheckerBuilder for constructing instances. class TypeCheckerImpl : public TypeChecker { public: - TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {}) + explicit TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {}) : env_(std::move(env)), options_(options) {} TypeCheckerImpl(const TypeCheckerImpl&) = delete; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 50be6d671..e0ff26ff8 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -1230,6 +1230,48 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { std::make_unique(ast_internal::DynamicType()))))))); } +TEST(TypeCheckerImplTest, ExpectedTypeMatches) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.set_expected_type(MapType(&arena, StringType(), StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + + EXPECT_THAT( + ast_impl.type_map(), + Contains(Pair( + ast_impl.root_expr().id(), + Eq(AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique( + ast_internal::PrimitiveType::kString))))))); +} + +TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + env.set_expected_type(MapType(&arena, StringType(), StringType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'abc': 123}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, + "expected type 'map' but found 'map'"))); +} + TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container("google.protobuf"); diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index bd5eee3f9..c10675bf4 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -32,6 +32,7 @@ #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "common/decl.h" +#include "common/type.h" #include "common/type_introspector.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" @@ -170,4 +171,8 @@ void TypeCheckerBuilder::set_container(absl::string_view container) { env_.set_container(std::string(container)); } +void TypeCheckerBuilder::SetExpectedType(const Type& type) { + env_.set_expected_type(type); +} + } // namespace cel diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index f6eb5aec0..1253c0cae 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -30,6 +30,7 @@ #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "common/decl.h" +#include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/descriptor.h" @@ -91,6 +92,8 @@ class TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl); absl::Status AddFunction(const FunctionDecl& decl); + void SetExpectedType(const Type& type); + // Adds function declaration overloads to the TypeChecker being built. // // Attempts to merge with any existing overloads for a function decl with the