Skip to content

Commit d50f3d3

Browse files
jnthntatumcopybara-github
authored andcommitted
Update type checker to fail (return a status) if it fails to deduce the type of subexpression.
PiperOrigin-RevId: 698538634
1 parent ceb592c commit d50f3d3

File tree

2 files changed

+76
-31
lines changed

2 files changed

+76
-31
lines changed

checker/internal/type_checker_impl.cc

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ class ResolveVisitor : public AstVisitorBase {
318318
int error_count() const { return error_count_; }
319319

320320
void AssertExpectedType(const Expr& expr, const Type& expected_type) {
321-
Type observed = GetTypeOrDyn(&expr);
321+
Type observed = GetDeducedType(&expr);
322322
if (!inference_context_->IsAssignable(observed, expected_type)) {
323323
ReportTypeMismatch(expr.id(), expected_type, observed);
324324
}
@@ -405,7 +405,7 @@ class ResolveVisitor : public AstVisitorBase {
405405
absl::string_view resolved_name) {
406406
for (const auto& field : create_struct.fields()) {
407407
const Expr* value = &field.value();
408-
Type value_type = GetTypeOrDyn(value);
408+
Type value_type = GetDeducedType(value);
409409

410410
// Lookup message type by name to support WellKnownType creation.
411411
CEL_ASSIGN_OR_RETURN(
@@ -441,12 +441,22 @@ class ResolveVisitor : public AstVisitorBase {
441441

442442
void HandleOptSelect(const Expr& expr);
443443

444-
// TODO: This should switch to a failing check once all core
445-
// features are supported. For now, we allow dyn for implementing the
446-
// typechecker behaviors in isolation.
447-
Type GetTypeOrDyn(const Expr* expr) {
444+
// Get the assigned type of the given subexpression. Should only be called if
445+
// the given subexpression is expected to have already been checked.
446+
//
447+
// If unknown, returns DynType as a placeholder and reports an error.
448+
// Whether or not the subexpression is valid for the checker configuration,
449+
// the type checker should have assigned a type (possibly ErrorType). If there
450+
// is no assigned type, the type checker failed to handle the subexpression
451+
// and should not attempt to continue type checking.
452+
Type GetDeducedType(const Expr* expr) {
448453
auto iter = types_.find(expr);
449-
return iter != types_.end() ? iter->second : DynType();
454+
if (iter != types_.end()) {
455+
return iter->second;
456+
}
457+
status_.Update(absl::InvalidArgumentError(
458+
absl::StrCat("Could not deduce type for expression id: ", expr->id())));
459+
return DynType();
450460
}
451461

452462
absl::string_view container_;
@@ -560,6 +570,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr,
560570
ComputeSourceLocation(*ast_, expr.id()),
561571
absl::StrCat("unsupported constant type: ",
562572
constant.kind().index())));
573+
types_[&expr] = ErrorType();
563574
break;
564575
}
565576
}
@@ -599,7 +610,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
599610
auto assignability_context = inference_context_->CreateAssignabilityContext();
600611
for (const auto& entry : map.entries()) {
601612
const Expr* key = &entry.key();
602-
Type key_type = GetTypeOrDyn(key);
613+
Type key_type = GetDeducedType(key);
603614
if (!IsSupportedKeyType(key_type)) {
604615
// The Go type checker implementation can allow any type as a map key, but
605616
// per the spec this should be limited to the types listed in
@@ -626,7 +637,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) {
626637
assignability_context.Reset();
627638
for (const auto& entry : map.entries()) {
628639
const Expr* value = &entry.value();
629-
Type value_type = GetTypeOrDyn(value);
640+
Type value_type = GetDeducedType(value);
630641
if (entry.optional()) {
631642
if (value_type.IsOptional()) {
632643
value_type = value_type.GetOptional().GetParameter();
@@ -657,7 +668,7 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) {
657668
auto assignability_context = inference_context_->CreateAssignabilityContext();
658669
for (const auto& element : list.elements()) {
659670
const Expr* value = &element.expr();
660-
Type value_type = GetTypeOrDyn(value);
671+
Type value_type = GetDeducedType(value);
661672
if (element.optional()) {
662673
if (value_type.IsOptional()) {
663674
value_type = value_type.GetOptional().GetParameter();
@@ -707,6 +718,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr,
707718

708719
if (resolved_name.empty()) {
709720
ReportMissingReference(expr, create_struct.name());
721+
types_[&expr] = ErrorType();
710722
return;
711723
}
712724

@@ -716,6 +728,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr,
716728
ComputeSourceLocation(*ast_, expr.id()),
717729
absl::StrCat("type '", resolved_name,
718730
"' does not support message creation")));
731+
types_[&expr] = ErrorType();
719732
return;
720733
}
721734

@@ -758,13 +771,14 @@ void ResolveVisitor::PostVisitCall(const Expr& expr, const CallExpr& call) {
758771
const FunctionDecl* decl = ResolveFunctionCallShape(
759772
expr, call.function(), arg_count, call.has_target());
760773

761-
if (decl != nullptr) {
762-
ResolveFunctionOverloads(expr, *decl, arg_count, call.has_target(),
763-
/* is_namespaced= */ false);
774+
if (decl == nullptr) {
775+
ReportMissingReference(expr, call.function());
776+
types_[&expr] = ErrorType();
764777
return;
765778
}
766779

767-
ReportMissingReference(expr, call.function());
780+
ResolveFunctionOverloads(expr, *decl, arg_count, call.has_target(),
781+
/* is_namespaced= */ false);
768782
}
769783

770784
void ResolveVisitor::PreVisitComprehension(
@@ -786,7 +800,7 @@ void ResolveVisitor::PreVisitComprehension(
786800
void ResolveVisitor::PostVisitComprehension(
787801
const Expr& expr, const ComprehensionExpr& comprehension) {
788802
comprehension_scopes_.pop_back();
789-
types_[&expr] = GetTypeOrDyn(&comprehension.result());
803+
types_[&expr] = GetDeducedType(&comprehension.result());
790804
}
791805

792806
void ResolveVisitor::PreVisitComprehensionSubexpression(
@@ -839,11 +853,12 @@ void ResolveVisitor::PostVisitComprehensionSubexpression(
839853
// the corresponding variables can be referenced.
840854
switch (comprehension_arg) {
841855
case ComprehensionArg::ACCU_INIT:
842-
scope.accu_scope->InsertVariableIfAbsent(MakeVariableDecl(
843-
comprehension.accu_var(), GetTypeOrDyn(&comprehension.accu_init())));
856+
scope.accu_scope->InsertVariableIfAbsent(
857+
MakeVariableDecl(comprehension.accu_var(),
858+
GetDeducedType(&comprehension.accu_init())));
844859
break;
845860
case ComprehensionArg::ITER_RANGE: {
846-
Type range_type = GetTypeOrDyn(&comprehension.iter_range());
861+
Type range_type = GetDeducedType(&comprehension.iter_range());
847862
Type iter_type = DynType(); // iter_var for non comprehensions v2.
848863
Type iter_type1 = DynType(); // iter_var for comprehensions v2.
849864
Type iter_type2 = DynType(); // iter_var2 for comprehensions v2.
@@ -879,9 +894,6 @@ void ResolveVisitor::PostVisitComprehensionSubexpression(
879894
}
880895
break;
881896
}
882-
case ComprehensionArg::RESULT:
883-
types_[&expr] = types_[&expr];
884-
break;
885897
default:
886898
break;
887899
}
@@ -923,10 +935,10 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr,
923935
std::vector<Type> arg_types;
924936
arg_types.reserve(arg_count);
925937
if (is_receiver) {
926-
arg_types.push_back(GetTypeOrDyn(&expr.call_expr().target()));
938+
arg_types.push_back(GetDeducedType(&expr.call_expr().target()));
927939
}
928940
for (int i = 0; i < expr.call_expr().args().size(); ++i) {
929-
arg_types.push_back(GetTypeOrDyn(&expr.call_expr().args()[i]));
941+
arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i]));
930942
}
931943

932944
absl::optional<TypeInferenceContext::OverloadResolution> resolution =
@@ -942,6 +954,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr,
942954
out->append(type.DebugString());
943955
}),
944956
")'")));
957+
types_[&expr] = ErrorType();
945958
return;
946959
}
947960

@@ -1000,6 +1013,7 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr,
10001013

10011014
if (decl == nullptr) {
10021015
ReportMissingReference(expr, name);
1016+
types_[&expr] = ErrorType();
10031017
return;
10041018
}
10051019

@@ -1029,6 +1043,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier(
10291043

10301044
if (decl == nullptr) {
10311045
ReportMissingReference(expr, FormatCandidate(qualifiers));
1046+
types_[&expr] = ErrorType();
10321047
return;
10331048
}
10341049

@@ -1106,7 +1121,7 @@ absl::optional<Type> ResolveVisitor::CheckFieldType(int64_t id,
11061121
void ResolveVisitor::ResolveSelectOperation(const Expr& expr,
11071122
absl::string_view field,
11081123
const Expr& operand) {
1109-
const Type& operand_type = GetTypeOrDyn(&operand);
1124+
const Type& operand_type = GetDeducedType(&operand);
11101125

11111126
absl::optional<Type> result_type;
11121127
int64_t id = expr.id();
@@ -1122,12 +1137,15 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr,
11221137
result_type = CheckFieldType(id, operand_type, field);
11231138
}
11241139

1125-
if (result_type.has_value()) {
1126-
if (expr.select_expr().test_only()) {
1127-
types_[&expr] = BoolType();
1128-
} else {
1129-
types_[&expr] = *result_type;
1130-
}
1140+
if (!result_type.has_value()) {
1141+
types_[&expr] = ErrorType();
1142+
return;
1143+
}
1144+
1145+
if (expr.select_expr().test_only()) {
1146+
types_[&expr] = BoolType();
1147+
} else {
1148+
types_[&expr] = *result_type;
11311149
}
11321150
}
11331151

@@ -1147,14 +1165,15 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) {
11471165
return;
11481166
}
11491167

1150-
Type operand_type = GetTypeOrDyn(operand);
1168+
Type operand_type = GetDeducedType(operand);
11511169
if (operand_type.IsOptional()) {
11521170
operand_type = operand_type.GetOptional().GetParameter();
11531171
}
11541172

11551173
absl::optional<Type> field_type = CheckFieldType(
11561174
expr.id(), operand_type, field->const_expr().string_value());
11571175
if (!field_type.has_value()) {
1176+
types_[&expr] = ErrorType();
11581177
return;
11591178
}
11601179
const FunctionDecl* select_decl = env_->LookupFunction(kOptionalSelect);

checker/internal/type_checker_impl_test.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,6 +1348,32 @@ TEST(TypeCheckerImplTest, BadSourcePosition) {
13481348
"ERROR: <input>:-1:-1: undeclared reference to 'foo' (in container '')");
13491349
}
13501350

1351+
// Check that the TypeChecker will fail if no type is deduced for a
1352+
// subexpression. This is meant to be a guard against failing to account for new
1353+
// types of expressions in the type checker logic.
1354+
TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) {
1355+
google::protobuf::Arena arena;
1356+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
1357+
1358+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
1359+
env.InsertVariableIfAbsent(MakeVariableDecl("a", BoolType()));
1360+
env.InsertVariableIfAbsent(MakeVariableDecl("b", BoolType()));
1361+
1362+
TypeCheckerImpl impl(std::move(env));
1363+
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("a || b"));
1364+
auto& ast_impl = AstImpl::CastFromPublicAst(*ast);
1365+
1366+
// Assume that an unspecified expr kind is not deducible.
1367+
Expr unspecified_expr;
1368+
unspecified_expr.set_id(3);
1369+
ast_impl.root_expr().mutable_call_expr().mutable_args()[1] =
1370+
std::move(unspecified_expr);
1371+
1372+
ASSERT_THAT(impl.Check(std::move(ast)),
1373+
StatusIs(absl::StatusCode::kInvalidArgument,
1374+
"Could not deduce type for expression id: 3"));
1375+
}
1376+
13511377
TEST(TypeCheckerImplTest, BadLineOffsets) {
13521378
google::protobuf::Arena arena;
13531379
TypeCheckEnv env(GetSharedTestingDescriptorPool());

0 commit comments

Comments
 (0)