Skip to content

Commit 033986e

Browse files
jnthntatumcopybara-github
authored andcommitted
[checker] Fix bug where recursive type could be inferred
The cycle check during inference had a bug where certain recursive definitions could be inferred. This would lead to the checker eventually crashing when it tried to realize a concrete type from the substitution map. ex. `[optional.none()].map(x, [?x, null, x])` PiperOrigin-RevId: 817307359
1 parent c9b86b1 commit 033986e

File tree

4 files changed

+64
-31
lines changed

4 files changed

+64
-31
lines changed

checker/internal/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ cc_library(
218218
srcs = ["type_inference_context.cc"],
219219
hdrs = ["type_inference_context.h"],
220220
deps = [
221+
":format_type_name",
221222
"//common:decl",
222223
"//common:type",
223224
"//common:type_kind",

checker/internal/type_inference_context.cc

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
#include "absl/log/absl_check.h"
2424
#include "absl/log/absl_log.h"
2525
#include "absl/strings/match.h"
26+
#include "absl/strings/str_cat.h"
27+
#include "absl/strings/str_join.h"
2628
#include "absl/strings/string_view.h"
2729
#include "absl/types/optional.h"
2830
#include "absl/types/span.h"
31+
#include "checker/internal/format_type_name.h"
2932
#include "common/decl.h"
3033
#include "common/type.h"
3134
#include "common/type_kind.h"
@@ -267,14 +270,15 @@ bool TypeInferenceContext::IsAssignableInternal(
267270
// Checking assignability to a specific type var
268271
// that has a prospective type assignment.
269272
to.kind() == TypeKind::kTypeParam &&
270-
prospective_substitutions.contains(to.AsTypeParam()->name())) {
271-
auto prospective_subs_cpy(prospective_substitutions);
273+
prospective_substitutions.contains(to.GetTypeParam().name())) {
274+
SubstitutionMap prospective_subs_cpy = prospective_substitutions;
272275
if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) ==
273276
RelativeGenerality::kMoreGeneral) {
274277
if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) &&
275-
!OccursWithin(to.name(), from_subs, prospective_subs_cpy)) {
276-
prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs;
277-
prospective_substitutions = prospective_subs_cpy;
278+
!OccursWithin(to.GetTypeParam().name(), from_subs,
279+
prospective_subs_cpy)) {
280+
prospective_subs_cpy[to.GetTypeParam().name()] = from_subs;
281+
prospective_substitutions = std::move(prospective_subs_cpy);
278282
return true;
279283
// otherwise, continue with normal assignability check.
280284
}
@@ -454,17 +458,35 @@ bool TypeInferenceContext::OccursWithin(
454458
//
455459
// This check guarantees that we don't introduce a recursive type definition
456460
// (a cycle in the substitution map).
457-
if (type.kind() == TypeKind::kTypeParam) {
458-
if (type.AsTypeParam()->name() == var_name) {
461+
//
462+
// We can't reuse Substitute here because it does the pointer chasing and
463+
// might hide a cycle.
464+
//
465+
// E.g.
466+
// T2 in T3 when
467+
// T3 -> T2 -> null_type;
468+
Type substitution = type;
469+
while (substitution.kind() == TypeKind::kTypeParam) {
470+
absl::string_view param_name = substitution.AsTypeParam()->name();
471+
if (param_name == var_name) {
459472
return true;
460473
}
461-
auto typeSubs = Substitute(type, substitutions);
462-
if (typeSubs != type && OccursWithin(var_name, typeSubs, substitutions)) {
463-
return true;
474+
475+
if (auto it = substitutions.find(param_name); it != substitutions.end()) {
476+
substitution = it->second;
477+
continue;
478+
}
479+
if (auto it = type_parameter_bindings_.find(param_name);
480+
it != type_parameter_bindings_.end() && it->second.type.has_value()) {
481+
substitution = it->second.type.value();
482+
continue;
464483
}
484+
485+
// Type parameter is free.
486+
return false;
465487
}
466488

467-
for (const auto& param : type.GetParameters()) {
489+
for (const auto& param : substitution.GetParameters()) {
468490
if (OccursWithin(var_name, param, substitutions)) {
469491
return true;
470492
}
@@ -526,19 +548,18 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl,
526548
ABSL_DCHECK_EQ(argument_types.size(),
527549
call_type_instance.param_types.size());
528550
bool is_match = true;
529-
SubstitutionMap prospective_substitutions;
551+
AssignabilityContext assignability_context = CreateAssignabilityContext();
530552
for (int i = 0; i < argument_types.size(); ++i) {
531-
if (!IsAssignableInternal(argument_types[i],
532-
call_type_instance.param_types[i],
533-
prospective_substitutions)) {
553+
if (!assignability_context.IsAssignable(
554+
argument_types[i], call_type_instance.param_types[i])) {
534555
is_match = false;
535556
break;
536557
}
537558
}
538559

539560
if (is_match) {
540561
matching_overloads.push_back(ovl);
541-
UpdateTypeParameterBindings(prospective_substitutions);
562+
assignability_context.UpdateInferredTypeAssignments();
542563
if (!result_type.has_value()) {
543564
result_type = call_type_instance.result_type;
544565
} else {
@@ -625,10 +646,23 @@ bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from,
625646
prospective_substitutions_);
626647
}
627648

649+
std::string TypeInferenceContext::DebugString() const {
650+
return absl::StrCat(
651+
"type_parameter_bindings: ",
652+
absl::StrJoin(
653+
type_parameter_bindings_, "\n ",
654+
[](std::string* out, const auto& binding) {
655+
absl::StrAppend(
656+
out, binding.first, " (", binding.second.name, ") -> ",
657+
checker_internal::FormatTypeName(
658+
binding.second.type.value_or(Type(TypeParamType("none")))));
659+
}));
660+
}
661+
628662
void TypeInferenceContext::AssignabilityContext::
629663
UpdateInferredTypeAssignments() {
630-
inference_context_.UpdateTypeParameterBindings(
631-
std::move(prospective_substitutions_));
664+
inference_context_.UpdateTypeParameterBindings(prospective_substitutions_);
665+
prospective_substitutions_.clear();
632666
}
633667

634668
void TypeInferenceContext::AssignabilityContext::Reset() {

checker/internal/type_inference_context.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "absl/container/node_hash_map.h"
2424
#include "absl/log/absl_check.h"
2525
#include "absl/strings/str_cat.h"
26-
#include "absl/strings/str_join.h"
2726
#include "absl/strings/string_view.h"
2827
#include "absl/types/optional.h"
2928
#include "absl/types/span.h"
@@ -141,18 +140,7 @@ class TypeInferenceContext {
141140
// Checks if `from` is assignable to `to`.
142141
bool IsAssignable(const Type& from, const Type& to);
143142

144-
std::string DebugString() const {
145-
return absl::StrCat(
146-
"type_parameter_bindings: ",
147-
absl::StrJoin(
148-
type_parameter_bindings_, "\n ",
149-
[](std::string* out, const auto& binding) {
150-
absl::StrAppend(
151-
out, binding.first, " (", binding.second.name, ") -> ",
152-
binding.second.type.value_or(Type(TypeParamType("none")))
153-
.DebugString());
154-
}));
155-
}
143+
std::string DebugString() const;
156144

157145
private:
158146
struct TypeVar {

checker/optional_test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ INSTANTIATE_TEST_SUITE_P(
158158
"optional.none()",
159159
IsOptionalType(TypeSpec(DynTypeSpec())),
160160
},
161+
// Odd case -- the correct result might be a bespoke recursively-defined
162+
// type but CEL doesn't support that. Null is used because it is
163+
// implicitly assignable to optional types. This allows for a recursive
164+
// type to be non-trivial and verify the checker is actually avoiding
165+
// introducing a cyclic type.
166+
TestCase{
167+
"[optional.none()].map(x, [?x, null, x])",
168+
Eq(TypeSpec(ListTypeSpec(std::make_unique<TypeSpec>(
169+
ListTypeSpec(std::make_unique<TypeSpec>(NullTypeSpec())))))),
170+
},
161171
TestCase{
162172
"optional.of('abc').hasValue()",
163173
Eq(TypeSpec(PrimitiveType::kBool)),

0 commit comments

Comments
 (0)