Skip to content

Commit 833abb1

Browse files
jnthntatumcopybara-github
authored andcommitted
Add support for declaring an overall expected type to the type checker.
PiperOrigin-RevId: 689936845
1 parent 13e252f commit 833abb1

File tree

6 files changed

+69
-2
lines changed

6 files changed

+69
-2
lines changed

checker/internal/type_check_env.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ class TypeCheckEnv {
107107
container_ = std::move(container);
108108
}
109109

110+
void set_expected_type(const Type& type) { expected_type_ = std::move(type); }
111+
112+
const absl::optional<Type>& expected_type() const { return expected_type_; }
113+
110114
absl::Span<const std::unique_ptr<TypeIntrospector>> type_providers() const {
111115
return type_providers_;
112116
}
@@ -198,6 +202,8 @@ class TypeCheckEnv {
198202

199203
// Type providers for custom types.
200204
std::vector<std::unique_ptr<TypeIntrospector>> type_providers_;
205+
206+
absl::optional<Type> expected_type_;
201207
};
202208

203209
} // namespace cel::checker_internal

checker/internal/type_checker_impl.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ absl::StatusOr<AstType> FlattenType(const Type& type) {
239239
return absl::InternalError(
240240
absl::StrCat("Unsupported type: ", type.DebugString()));
241241
}
242-
} // namespace
242+
}
243243

244244
class ResolveVisitor : public AstVisitorBase {
245245
public:
@@ -322,6 +322,13 @@ class ResolveVisitor : public AstVisitorBase {
322322

323323
const absl::Status& status() const { return status_; }
324324

325+
void AssertExpectedType(const Expr& expr, const Type& expected_type) {
326+
Type observed = GetTypeOrDyn(&expr);
327+
if (!inference_context_->IsAssignable(observed, expected_type)) {
328+
ReportTypeMismatch(expr.id(), expected_type, observed);
329+
}
330+
}
331+
325332
private:
326333
struct ComprehensionScope {
327334
const Expr* comprehension_expr;
@@ -1231,6 +1238,10 @@ absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(
12311238
AstTraverse(ast_impl.root_expr(), visitor, opts);
12321239
CEL_RETURN_IF_ERROR(visitor.status());
12331240

1241+
if (env_.expected_type().has_value()) {
1242+
visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type());
1243+
}
1244+
12341245
// If any issues are errors, return without an AST.
12351246
for (const auto& issue : issues) {
12361247
if (issue.severity() == Severity::kError) {

checker/internal/type_checker_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace cel::checker_internal {
3333
// See cel::TypeCheckerBuilder for constructing instances.
3434
class TypeCheckerImpl : public TypeChecker {
3535
public:
36-
TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {})
36+
explicit TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {})
3737
: env_(std::move(env)), options_(options) {}
3838

3939
TypeCheckerImpl(const TypeCheckerImpl&) = delete;

checker/internal/type_checker_impl_test.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,48 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) {
12301230
std::make_unique<AstType>(ast_internal::DynamicType())))))));
12311231
}
12321232

1233+
TEST(TypeCheckerImplTest, ExpectedTypeMatches) {
1234+
google::protobuf::Arena arena;
1235+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
1236+
1237+
env.set_expected_type(MapType(&arena, StringType(), StringType()));
1238+
1239+
TypeCheckerImpl impl(std::move(env));
1240+
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}"));
1241+
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
1242+
1243+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> checked_ast, result.ReleaseAst());
1244+
1245+
const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
1246+
1247+
EXPECT_THAT(
1248+
ast_impl.type_map(),
1249+
Contains(Pair(
1250+
ast_impl.root_expr().id(),
1251+
Eq(AstType(ast_internal::MapType(
1252+
std::make_unique<AstType>(ast_internal::PrimitiveType::kString),
1253+
std::make_unique<AstType>(
1254+
ast_internal::PrimitiveType::kString)))))));
1255+
}
1256+
1257+
TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) {
1258+
google::protobuf::Arena arena;
1259+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
1260+
1261+
env.set_expected_type(MapType(&arena, StringType(), StringType()));
1262+
1263+
TypeCheckerImpl impl(std::move(env));
1264+
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'abc': 123}"));
1265+
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
1266+
1267+
EXPECT_FALSE(result.IsValid());
1268+
EXPECT_THAT(
1269+
result.GetIssues(),
1270+
Contains(IsIssueWithSubstring(
1271+
Severity::kError,
1272+
"expected type 'map<string, string>' but found 'map<string, int>'")));
1273+
}
1274+
12331275
TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) {
12341276
TypeCheckEnv env(GetSharedTestingDescriptorPool());
12351277
env.set_container("google.protobuf");

checker/type_checker_builder.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "checker/internal/type_checker_impl.h"
3333
#include "checker/type_checker.h"
3434
#include "common/decl.h"
35+
#include "common/type.h"
3536
#include "common/type_introspector.h"
3637
#include "internal/status_macros.h"
3738
#include "internal/well_known_types.h"
@@ -170,4 +171,8 @@ void TypeCheckerBuilder::set_container(absl::string_view container) {
170171
env_.set_container(std::string(container));
171172
}
172173

174+
void TypeCheckerBuilder::SetExpectedType(const Type& type) {
175+
env_.set_expected_type(type);
176+
}
177+
173178
} // namespace cel

checker/type_checker_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "checker/internal/type_check_env.h"
3131
#include "checker/type_checker.h"
3232
#include "common/decl.h"
33+
#include "common/type.h"
3334
#include "common/type_introspector.h"
3435
#include "google/protobuf/descriptor.h"
3536

@@ -91,6 +92,8 @@ class TypeCheckerBuilder {
9192
absl::Status AddVariable(const VariableDecl& decl);
9293
absl::Status AddFunction(const FunctionDecl& decl);
9394

95+
void SetExpectedType(const Type& type);
96+
9497
// Adds function declaration overloads to the TypeChecker being built.
9598
//
9699
// Attempts to merge with any existing overloads for a function decl with the

0 commit comments

Comments
 (0)