Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class TypeCheckEnv {
container_ = std::move(container);
}

void set_expected_type(const Type& type) { expected_type_ = std::move(type); }

const absl::optional<Type>& expected_type() const { return expected_type_; }

absl::Span<const std::unique_ptr<TypeIntrospector>> type_providers() const {
return type_providers_;
}
Expand Down Expand Up @@ -198,6 +202,8 @@ class TypeCheckEnv {

// Type providers for custom types.
std::vector<std::unique_ptr<TypeIntrospector>> type_providers_;

absl::optional<Type> expected_type_;
};

} // namespace cel::checker_internal
Expand Down
13 changes: 12 additions & 1 deletion checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ absl::StatusOr<AstType> FlattenType(const Type& type) {
return absl::InternalError(
absl::StrCat("Unsupported type: ", type.DebugString()));
}
} // namespace
}

class ResolveVisitor : public AstVisitorBase {
public:
Expand Down Expand Up @@ -322,6 +322,13 @@ class ResolveVisitor : public AstVisitorBase {

const absl::Status& status() const { return status_; }

void AssertExpectedType(const Expr& expr, const Type& expected_type) {
Type observed = GetTypeOrDyn(&expr);
if (!inference_context_->IsAssignable(observed, expected_type)) {
ReportTypeMismatch(expr.id(), expected_type, observed);
}
}

private:
struct ComprehensionScope {
const Expr* comprehension_expr;
Expand Down Expand Up @@ -1231,6 +1238,10 @@ absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(
AstTraverse(ast_impl.root_expr(), visitor, opts);
CEL_RETURN_IF_ERROR(visitor.status());

if (env_.expected_type().has_value()) {
visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type());
}

// If any issues are errors, return without an AST.
for (const auto& issue : issues) {
if (issue.severity() == Severity::kError) {
Expand Down
2 changes: 1 addition & 1 deletion checker/internal/type_checker_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace cel::checker_internal {
// See cel::TypeCheckerBuilder for constructing instances.
class TypeCheckerImpl : public TypeChecker {
public:
TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {})
explicit TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {})
: env_(std::move(env)), options_(options) {}

TypeCheckerImpl(const TypeCheckerImpl&) = delete;
Expand Down
42 changes: 42 additions & 0 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,48 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) {
std::make_unique<AstType>(ast_internal::DynamicType())))))));
}

TEST(TypeCheckerImplTest, ExpectedTypeMatches) {
google::protobuf::Arena arena;
TypeCheckEnv env(GetSharedTestingDescriptorPool());

env.set_expected_type(MapType(&arena, StringType(), StringType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> checked_ast, result.ReleaseAst());

const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);

EXPECT_THAT(
ast_impl.type_map(),
Contains(Pair(
ast_impl.root_expr().id(),
Eq(AstType(ast_internal::MapType(
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
std::make_unique<AstType>(
ast_internal::PrimitiveType::kString)))))));
}

TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) {
google::protobuf::Arena arena;
TypeCheckEnv env(GetSharedTestingDescriptorPool());

env.set_expected_type(MapType(&arena, StringType(), StringType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'abc': 123}"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_FALSE(result.IsValid());
EXPECT_THAT(
result.GetIssues(),
Contains(IsIssueWithSubstring(
Severity::kError,
"expected type 'map<string, string>' but found 'map<string, int>'")));
}

TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
env.set_container("google.protobuf");
Expand Down
5 changes: 5 additions & 0 deletions checker/type_checker_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "checker/internal/type_checker_impl.h"
#include "checker/type_checker.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_introspector.h"
#include "internal/status_macros.h"
#include "internal/well_known_types.h"
Expand Down Expand Up @@ -170,4 +171,8 @@ void TypeCheckerBuilder::set_container(absl::string_view container) {
env_.set_container(std::string(container));
}

void TypeCheckerBuilder::SetExpectedType(const Type& type) {
env_.set_expected_type(type);
}

} // namespace cel
3 changes: 3 additions & 0 deletions checker/type_checker_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "checker/internal/type_check_env.h"
#include "checker/type_checker.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_introspector.h"
#include "google/protobuf/descriptor.h"

Expand Down Expand Up @@ -91,6 +92,8 @@ class TypeCheckerBuilder {
absl::Status AddVariable(const VariableDecl& decl);
absl::Status AddFunction(const FunctionDecl& decl);

void SetExpectedType(const Type& type);

// Adds function declaration overloads to the TypeChecker being built.
//
// Attempts to merge with any existing overloads for a function decl with the
Expand Down