From e4a32ba644db8baf556d0f295e674050a3f0d2fe Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Mon, 17 Nov 2025 13:43:37 -0500 Subject: [PATCH] [ty] Constraint sets compare generic callables correctly (#21392) Constraint sets can now track subtyping/assignability/etc of generic callables correctly. For instance: ```py def identity[T](t: T) -> T: return t constraints = ConstraintSet.always() static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[int], int])) static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[str], str])) ``` A generic callable can be considered an intersection of all of its possible specializations, and an assignability check with an intersection as the lhs side succeeds of _any_ of the intersected types satisfies the check. Put another way, if someone expects to receive any function with a signature of `(int) -> int`, we can give them `identity`. Note that the corresponding check using `is_subtype_of` directly does not yet work, since #20093 has not yet hooked up the core typing relationship logic to use constraint sets: ```py # These currently fail static_assert(is_subtype_of(TypeOf[identity], Callable[[int], int])) static_assert(is_subtype_of(TypeOf[identity], Callable[[str], str])) ``` To do this, we add a new _existential quantification_ operation on constraint sets. This takes in a list of typevars and _removes_ those typevars from the constraint set. Conceptually, we return a new constraint set that evaluates to `true` when there was _any_ assignment of the removed typevars that caused the old constraint set to evaluate to `true`. When comparing a generic constraint set, we add its typevars to the `inferable` set, and figure out whatever constraints would allow any specialization to satisfy the check. We then use the new existential quantification operator to remove those new typevars, since the caller doesn't (and shouldn't) know anything about them. --------- Co-authored-by: David Peter --- .../resources/mdtest/protocols.md | 12 +-- .../type_properties/implies_subtype_of.md | 97 +++++++++++++++++ .../type_properties/is_assignable_to.md | 45 ++++++++ .../mdtest/type_properties/is_subtype_of.md | 48 +++++++++ crates/ty_python_semantic/src/types.rs | 22 ++-- .../src/types/constraints.rs | 90 ++++++++++++++-- .../ty_python_semantic/src/types/generics.rs | 24 ++++- .../src/types/signatures.rs | 100 +++++++++++++++--- 8 files changed, 387 insertions(+), 51 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index f410da25df..8d41073902 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -2099,18 +2099,14 @@ static_assert(is_equivalent_to(LegacyFunctionScoped, NewStyleFunctionScoped)) # static_assert(is_assignable_to(NominalNewStyle, NewStyleFunctionScoped)) static_assert(is_assignable_to(NominalNewStyle, LegacyFunctionScoped)) -# TODO: should pass -static_assert(is_subtype_of(NominalNewStyle, NewStyleFunctionScoped)) # error: [static-assert-error] -# TODO: should pass -static_assert(is_subtype_of(NominalNewStyle, LegacyFunctionScoped)) # error: [static-assert-error] +static_assert(is_subtype_of(NominalNewStyle, NewStyleFunctionScoped)) +static_assert(is_subtype_of(NominalNewStyle, LegacyFunctionScoped)) static_assert(not is_assignable_to(NominalNewStyle, UsesSelf)) static_assert(is_assignable_to(NominalLegacy, NewStyleFunctionScoped)) static_assert(is_assignable_to(NominalLegacy, LegacyFunctionScoped)) -# TODO: should pass -static_assert(is_subtype_of(NominalLegacy, NewStyleFunctionScoped)) # error: [static-assert-error] -# TODO: should pass -static_assert(is_subtype_of(NominalLegacy, LegacyFunctionScoped)) # error: [static-assert-error] +static_assert(is_subtype_of(NominalLegacy, NewStyleFunctionScoped)) +static_assert(is_subtype_of(NominalLegacy, LegacyFunctionScoped)) static_assert(not is_assignable_to(NominalLegacy, UsesSelf)) static_assert(not is_assignable_to(NominalWithSelf, NewStyleFunctionScoped)) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md b/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md index a6c6793de3..35768ef76d 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md @@ -349,4 +349,101 @@ def mutually_constrained[T, U](): static_assert(not given_int.implies_subtype_of(Invariant[str], Invariant[T])) ``` +## Generic callables + +A generic callable can be considered equivalent to an intersection of all of its possible +specializations. That means that a generic callable is a subtype of any particular specialization. +(If someone expects a function that works with a particular specialization, it's fine to hand them +the generic callable.) + +```py +from typing import Callable +from ty_extensions import CallableTypeOf, ConstraintSet, TypeOf, is_subtype_of, static_assert + +def identity[T](t: T) -> T: + return t + +type GenericIdentity[T] = Callable[[T], T] + +constraints = ConstraintSet.always() + +static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[int], int])) +static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[str], str])) +static_assert(not constraints.implies_subtype_of(TypeOf[identity], Callable[[str], int])) + +static_assert(constraints.implies_subtype_of(CallableTypeOf[identity], Callable[[int], int])) +static_assert(constraints.implies_subtype_of(CallableTypeOf[identity], Callable[[str], str])) +static_assert(not constraints.implies_subtype_of(CallableTypeOf[identity], Callable[[str], int])) + +static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[int])) +static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[str])) +# This gives us the default specialization, GenericIdentity[Unknown], which does +# not participate in subtyping. +static_assert(not constraints.implies_subtype_of(TypeOf[identity], GenericIdentity)) +``` + +The reverse is not true — if someone expects a generic function that can be called with any +specialization, we cannot hand them a function that only works with one specialization. + +```py +static_assert(not constraints.implies_subtype_of(Callable[[int], int], TypeOf[identity])) +static_assert(not constraints.implies_subtype_of(Callable[[str], str], TypeOf[identity])) +static_assert(not constraints.implies_subtype_of(Callable[[str], int], TypeOf[identity])) + +static_assert(not constraints.implies_subtype_of(Callable[[int], int], CallableTypeOf[identity])) +static_assert(not constraints.implies_subtype_of(Callable[[str], str], CallableTypeOf[identity])) +static_assert(not constraints.implies_subtype_of(Callable[[str], int], CallableTypeOf[identity])) + +static_assert(not constraints.implies_subtype_of(GenericIdentity[int], TypeOf[identity])) +static_assert(not constraints.implies_subtype_of(GenericIdentity[str], TypeOf[identity])) +# This gives us the default specialization, GenericIdentity[Unknown], which does +# not participate in subtyping. +static_assert(not constraints.implies_subtype_of(GenericIdentity, TypeOf[identity])) +``` + +Unrelated typevars in the constraint set do not affect whether the subtyping check succeeds or +fails. + +```py +def unrelated[T](): + # Note that even though this typevar is also named T, it is not the same typevar as T@identity! + constraints = ConstraintSet.range(bool, T, int) + + static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[int], int])) + static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[str], str])) + static_assert(not constraints.implies_subtype_of(TypeOf[identity], Callable[[str], int])) + static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[int])) + static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[str])) + + static_assert(not constraints.implies_subtype_of(Callable[[int], int], TypeOf[identity])) + static_assert(not constraints.implies_subtype_of(Callable[[str], str], TypeOf[identity])) + static_assert(not constraints.implies_subtype_of(Callable[[str], int], TypeOf[identity])) + static_assert(not constraints.implies_subtype_of(GenericIdentity[int], TypeOf[identity])) + static_assert(not constraints.implies_subtype_of(GenericIdentity[str], TypeOf[identity])) +``` + +The generic callable's typevar _also_ does not affect whether the subtyping check succeeds or fails! + +```py +def identity2[T](t: T) -> T: + # This constraint set refers to the same typevar as the generic function types below! + constraints = ConstraintSet.range(bool, T, int) + + static_assert(constraints.implies_subtype_of(TypeOf[identity2], Callable[[int], int])) + static_assert(constraints.implies_subtype_of(TypeOf[identity2], Callable[[str], str])) + # TODO: no error + # error: [static-assert-error] + static_assert(not constraints.implies_subtype_of(TypeOf[identity2], Callable[[str], int])) + static_assert(constraints.implies_subtype_of(TypeOf[identity2], GenericIdentity[int])) + static_assert(constraints.implies_subtype_of(TypeOf[identity2], GenericIdentity[str])) + + static_assert(not constraints.implies_subtype_of(Callable[[int], int], TypeOf[identity2])) + static_assert(not constraints.implies_subtype_of(Callable[[str], str], TypeOf[identity2])) + static_assert(not constraints.implies_subtype_of(Callable[[str], int], TypeOf[identity2])) + static_assert(not constraints.implies_subtype_of(GenericIdentity[int], TypeOf[identity2])) + static_assert(not constraints.implies_subtype_of(GenericIdentity[str], TypeOf[identity2])) + + return t +``` + [subtyping]: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index 1386a9e158..3ac4f9b652 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -1,5 +1,10 @@ # Assignable-to relation +```toml +[environment] +python-version = "3.12" +``` + The `is_assignable_to(S, T)` relation below checks if type `S` is assignable to type `T` (target). This allows us to check if a type `S` can be used in a context where a type `T` is expected (function arguments, variable assignments). See the [typing documentation] for a precise definition @@ -1227,6 +1232,46 @@ from ty_extensions import static_assert, is_assignable_to static_assert(is_assignable_to(type, Callable[..., Any])) ``` +### Generic callables + +A generic callable can be considered equivalent to an intersection of all of its possible +specializations. That means that a generic callable is assignable to any particular specialization. +(If someone expects a function that works with a particular specialization, it's fine to hand them +the generic callable.) + +```py +from typing import Callable +from ty_extensions import CallableTypeOf, TypeOf, is_assignable_to, static_assert + +def identity[T](t: T) -> T: + return t + +static_assert(is_assignable_to(TypeOf[identity], Callable[[int], int])) +static_assert(is_assignable_to(TypeOf[identity], Callable[[str], str])) +# TODO: no error +# error: [static-assert-error] +static_assert(not is_assignable_to(TypeOf[identity], Callable[[str], int])) + +static_assert(is_assignable_to(CallableTypeOf[identity], Callable[[int], int])) +static_assert(is_assignable_to(CallableTypeOf[identity], Callable[[str], str])) +# TODO: no error +# error: [static-assert-error] +static_assert(not is_assignable_to(CallableTypeOf[identity], Callable[[str], int])) +``` + +The reverse is not true — if someone expects a generic function that can be called with any +specialization, we cannot hand them a function that only works with one specialization. + +```py +static_assert(not is_assignable_to(Callable[[int], int], TypeOf[identity])) +static_assert(not is_assignable_to(Callable[[str], str], TypeOf[identity])) +static_assert(not is_assignable_to(Callable[[str], int], TypeOf[identity])) + +static_assert(not is_assignable_to(Callable[[int], int], CallableTypeOf[identity])) +static_assert(not is_assignable_to(Callable[[str], str], CallableTypeOf[identity])) +static_assert(not is_assignable_to(Callable[[str], int], CallableTypeOf[identity])) +``` + ## Generics ### Assignability of generic types parameterized by gradual types diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index 6034a52529..a2b9ca89d0 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -2207,6 +2207,54 @@ static_assert(is_subtype_of(CallableTypeOf[overload_ab], CallableTypeOf[overload static_assert(is_subtype_of(CallableTypeOf[overload_ba], CallableTypeOf[overload_ab])) ``` +### Generic callables + +A generic callable can be considered equivalent to an intersection of all of its possible +specializations. That means that a generic callable is a subtype of any particular specialization. +(If someone expects a function that works with a particular specialization, it's fine to hand them +the generic callable.) + +```py +from typing import Callable +from ty_extensions import CallableTypeOf, TypeOf, is_subtype_of, static_assert + +def identity[T](t: T) -> T: + return t + +# TODO: Confusingly, these are not the same results as the corresponding checks in +# is_assignable_to.md, even though all of these types are fully static. We have some heuristics that +# currently conflict with each other, that we are in the process of removing with the constraint set +# work. +# TODO: no error +# error: [static-assert-error] +static_assert(is_subtype_of(TypeOf[identity], Callable[[int], int])) +# TODO: no error +# error: [static-assert-error] +static_assert(is_subtype_of(TypeOf[identity], Callable[[str], str])) +static_assert(not is_subtype_of(TypeOf[identity], Callable[[str], int])) + +# TODO: no error +# error: [static-assert-error] +static_assert(is_subtype_of(CallableTypeOf[identity], Callable[[int], int])) +# TODO: no error +# error: [static-assert-error] +static_assert(is_subtype_of(CallableTypeOf[identity], Callable[[str], str])) +static_assert(not is_subtype_of(CallableTypeOf[identity], Callable[[str], int])) +``` + +The reverse is not true — if someone expects a generic function that can be called with any +specialization, we cannot hand them a function that only works with one specialization. + +```py +static_assert(not is_subtype_of(Callable[[int], int], TypeOf[identity])) +static_assert(not is_subtype_of(Callable[[str], str], TypeOf[identity])) +static_assert(not is_subtype_of(Callable[[str], int], TypeOf[identity])) + +static_assert(not is_subtype_of(Callable[[int], int], CallableTypeOf[identity])) +static_assert(not is_subtype_of(Callable[[str], str], CallableTypeOf[identity])) +static_assert(not is_subtype_of(Callable[[str], int], CallableTypeOf[identity])) +``` + [gradual form]: https://typing.python.org/en/latest/spec/glossary.html#term-gradual-form [gradual tuple]: https://typing.python.org/en/latest/spec/tuples.html#tuple-type-form [special case for float and complex]: https://typing.python.org/en/latest/spec/special-types.html#special-cases-for-float-and-complex diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index a79c0c33c5..89ae6ff685 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1944,9 +1944,7 @@ impl<'db> Type<'db> { }) } - (Type::TypeVar(bound_typevar), _) - if bound_typevar.is_inferable(db, inferable) && relation.is_assignability() => - { + (Type::TypeVar(bound_typevar), _) if bound_typevar.is_inferable(db, inferable) => { // The implicit lower bound of a typevar is `Never`, which means // that it is always assignable to any other type. @@ -2086,9 +2084,12 @@ impl<'db> Type<'db> { } // TODO: Infer specializations here - (Type::TypeVar(bound_typevar), _) | (_, Type::TypeVar(bound_typevar)) - if bound_typevar.is_inferable(db, inferable) => - { + (_, Type::TypeVar(bound_typevar)) if bound_typevar.is_inferable(db, inferable) => { + ConstraintSet::from(false) + } + (Type::TypeVar(bound_typevar), _) => { + // All inferable cases should have been handled above + assert!(!bound_typevar.is_inferable(db, inferable)); ConstraintSet::from(false) } @@ -2542,13 +2543,8 @@ impl<'db> Type<'db> { disjointness_visitor, ), - // Other than the special cases enumerated above, nominal-instance types, - // newtype-instance types, and typevars are never subtypes of any other variants - (Type::TypeVar(bound_typevar), _) => { - // All inferable cases should have been handled above - assert!(!bound_typevar.is_inferable(db, inferable)); - ConstraintSet::from(false) - } + // Other than the special cases enumerated above, nominal-instance types, and + // newtype-instance types are never subtypes of any other variants (Type::NominalInstance(_), _) => ConstraintSet::from(false), (Type::NewTypeInstance(_), _) => ConstraintSet::from(false), } diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index a2ad10134d..f8b66c33e4 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -66,8 +66,8 @@ use salsa::plumbing::AsId; use crate::Db; use crate::types::generics::InferableTypeVars; use crate::types::{ - BoundTypeVarInstance, IntersectionType, Type, TypeRelation, TypeVarBoundOrConstraints, - UnionType, + BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation, + TypeVarBoundOrConstraints, UnionType, }; /// An extension trait for building constraint sets from [`Option`] values. @@ -301,6 +301,20 @@ impl<'db> ConstraintSet<'db> { } } + /// Reduces the set of inferable typevars for this constraint set. You provide an iterator of + /// the typevars that were inferable when this constraint set was created, and which should be + /// abstracted away. Those typevars will be removed from the constraint set, and the constraint + /// set will return true whenever there was _any_ specialization of those typevars that + /// returned true before. + pub(crate) fn reduce_inferable( + self, + db: &'db dyn Db, + to_remove: impl IntoIterator>, + ) -> Self { + let node = self.node.exists(db, to_remove); + Self { node } + } + pub(crate) fn range( db: &'db dyn Db, lower: Type<'db>, @@ -803,13 +817,24 @@ impl<'db> Node<'db> { // When checking subtyping involving a typevar, we can turn the subtyping check into a // constraint (i.e, "is `T` a subtype of `int` becomes the constraint `T ≤ int`), and then // check when the BDD implies that constraint. + // + // Note that we are NOT guaranteed that `lhs` and `rhs` will always be fully static, since + // these types are coming in from arbitrary subtyping checks that the caller might want to + // perform. So we have to take the appropriate materialization when translating the check + // into a constraint. let constraint = match (lhs, rhs) { - (Type::TypeVar(bound_typevar), _) => { - ConstrainedTypeVar::new_node(db, bound_typevar, Type::Never, rhs) - } - (_, Type::TypeVar(bound_typevar)) => { - ConstrainedTypeVar::new_node(db, bound_typevar, lhs, Type::object()) - } + (Type::TypeVar(bound_typevar), _) => ConstrainedTypeVar::new_node( + db, + bound_typevar, + Type::Never, + rhs.bottom_materialization(db), + ), + (_, Type::TypeVar(bound_typevar)) => ConstrainedTypeVar::new_node( + db, + bound_typevar, + lhs.top_materialization(db), + Type::object(), + ), _ => panic!("at least one type should be a typevar"), }; @@ -888,6 +913,29 @@ impl<'db> Node<'db> { true } + /// Returns a new BDD that is the _existential abstraction_ of `self` for a set of typevars. + /// The result will return true whenever `self` returns true for _any_ assignment of those + /// typevars. The result will not contain any constraints that mention those typevars. + fn exists( + self, + db: &'db dyn Db, + bound_typevars: impl IntoIterator>, + ) -> Self { + bound_typevars + .into_iter() + .fold(self.simplify(db), |abstracted, bound_typevar| { + abstracted.exists_one(db, bound_typevar) + }) + } + + fn exists_one(self, db: &'db dyn Db, bound_typevar: BoundTypeVarIdentity<'db>) -> Self { + match self { + Node::AlwaysTrue => Node::AlwaysTrue, + Node::AlwaysFalse => Node::AlwaysFalse, + Node::Interior(interior) => interior.exists_one(db, bound_typevar), + } + } + /// Returns a new BDD that returns the same results as `self`, but with some inputs fixed to /// particular values. (Those variables will not be checked when evaluating the result, and /// will not be present in the result.) @@ -1301,6 +1349,32 @@ impl<'db> InteriorNode<'db> { } } + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn exists_one(self, db: &'db dyn Db, bound_typevar: BoundTypeVarIdentity<'db>) -> Node<'db> { + let self_constraint = self.constraint(db); + let self_typevar = self_constraint.typevar(db).identity(db); + match bound_typevar.cmp(&self_typevar) { + // If the typevar that this node checks is "later" than the typevar we're abstracting + // over, then we have reached a point in the BDD where the abstraction can no longer + // affect the result, and we can return early. + Ordering::Less => Node::Interior(self), + // If the typevar that this node checks _is_ the typevar we're abstracting over, then + // we replace this node with the OR of its if_false/if_true edges. That is, the result + // is true if there's any assignment of this node's constraint that is true. + Ordering::Equal => { + let if_true = self.if_true(db).exists_one(db, bound_typevar); + let if_false = self.if_false(db).exists_one(db, bound_typevar); + if_true.or(db, if_false) + } + // Otherwise, we abstract the if_false/if_true edges recursively. + Ordering::Greater => { + let if_true = self.if_true(db).exists_one(db, bound_typevar); + let if_false = self.if_false(db).exists_one(db, bound_typevar); + Node::new(db, self_constraint, if_true, if_false) + } + } + } + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] fn restrict_one( self, diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 8f0a0e0bd6..169b69e496 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::collections::hash_map::Entry; use std::fmt::Display; -use itertools::Itertools; +use itertools::{Either, Itertools}; use ruff_python_ast as ast; use rustc_hash::{FxHashMap, FxHashSet}; @@ -145,10 +145,24 @@ impl<'db> BoundTypeVarInstance<'db> { } impl<'a, 'db> InferableTypeVars<'a, 'db> { - pub(crate) fn merge(&'a self, other: Option<&'a InferableTypeVars<'a, 'db>>) -> Self { - match other { - Some(other) => InferableTypeVars::Two(self, other), - None => *self, + pub(crate) fn merge(&'a self, other: &'a InferableTypeVars<'a, 'db>) -> Self { + match (self, other) { + (InferableTypeVars::None, other) | (other, InferableTypeVars::None) => *other, + _ => InferableTypeVars::Two(self, other), + } + } + + // This is not an IntoIterator implementation because I have no desire to try to name the + // iterator type. + pub(crate) fn iter(self) -> impl Iterator> { + match self { + InferableTypeVars::None => Either::Left(Either::Left(std::iter::empty())), + InferableTypeVars::One(typevars) => Either::Right(typevars.iter().copied()), + InferableTypeVars::Two(left, right) => { + let chained: Box>> = + Box::new(left.iter().chain(right.iter())); + Either::Left(Either::Right(chained)) + } } } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 774a2bf5b7..74fe451e50 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -628,6 +628,13 @@ impl<'db> Signature<'db> { } } + fn inferable_typevars(&self, db: &'db dyn Db) -> InferableTypeVars<'db, 'db> { + match self.generic_context { + Some(generic_context) => generic_context.inferable_typevars(db), + None => InferableTypeVars::None, + } + } + /// Return `true` if `self` has exactly the same set of possible static materializations as /// `other` (if `self` represents the same set of possible sets of possible runtime objects as /// `other`). @@ -638,15 +645,40 @@ impl<'db> Signature<'db> { inferable: InferableTypeVars<'_, 'db>, visitor: &IsEquivalentVisitor<'db>, ) -> ConstraintSet<'db> { - // The typevars in self and other should also be considered inferable when checking whether - // two signatures are equivalent. - let self_inferable = - (self.generic_context).map(|generic_context| generic_context.inferable_typevars(db)); - let other_inferable = - (other.generic_context).map(|generic_context| generic_context.inferable_typevars(db)); - let inferable = inferable.merge(self_inferable.as_ref()); - let inferable = inferable.merge(other_inferable.as_ref()); + // If either signature is generic, their typevars should also be considered inferable when + // checking whether the signatures are equivalent, since we only need to find one + // specialization that causes the check to succeed. + // + // TODO: We should alpha-rename these typevars, too, to correctly handle when a generic + // callable refers to typevars from within the context that defines them. This primarily + // comes up when referring to a generic function recursively from within its body: + // + // def identity[T](t: T) -> T: + // # Here, TypeOf[identity2] is a generic callable that should consider T to be + // # inferable, even though other uses of T in the function body are non-inferable. + // return t + let self_inferable = self.inferable_typevars(db); + let other_inferable = other.inferable_typevars(db); + let inferable = inferable.merge(&self_inferable); + let inferable = inferable.merge(&other_inferable); + // `inner` will create a constraint set that references these newly inferable typevars. + let when = self.is_equivalent_to_inner(db, other, inferable, visitor); + + // But the caller does not need to consider those extra typevars. Whatever constraint set + // we produce, we reduce it back down to the inferable set that the caller asked about. + // If we introduced new inferable typevars, those will be existentially quantified away + // before returning. + when.reduce_inferable(db, self_inferable.iter().chain(other_inferable.iter())) + } + + fn is_equivalent_to_inner( + &self, + db: &'db dyn Db, + other: &Signature<'db>, + inferable: InferableTypeVars<'_, 'db>, + visitor: &IsEquivalentVisitor<'db>, + ) -> ConstraintSet<'db> { let mut result = ConstraintSet::from(true); let mut check_types = |self_type: Option>, other_type: Option>| { let self_type = self_type.unwrap_or(Type::unknown()); @@ -735,6 +767,49 @@ impl<'db> Signature<'db> { relation: TypeRelation<'db>, relation_visitor: &HasRelationToVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>, + ) -> ConstraintSet<'db> { + // If either signature is generic, their typevars should also be considered inferable when + // checking whether one signature is a subtype/etc of the other, since we only need to find + // one specialization that causes the check to succeed. + // + // TODO: We should alpha-rename these typevars, too, to correctly handle when a generic + // callable refers to typevars from within the context that defines them. This primarily + // comes up when referring to a generic function recursively from within its body: + // + // def identity[T](t: T) -> T: + // # Here, TypeOf[identity2] is a generic callable that should consider T to be + // # inferable, even though other uses of T in the function body are non-inferable. + // return t + let self_inferable = self.inferable_typevars(db); + let other_inferable = other.inferable_typevars(db); + let inferable = inferable.merge(&self_inferable); + let inferable = inferable.merge(&other_inferable); + + // `inner` will create a constraint set that references these newly inferable typevars. + let when = self.has_relation_to_inner( + db, + other, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ); + + // But the caller does not need to consider those extra typevars. Whatever constraint set + // we produce, we reduce it back down to the inferable set that the caller asked about. + // If we introduced new inferable typevars, those will be existentially quantified away + // before returning. + when.reduce_inferable(db, self_inferable.iter().chain(other_inferable.iter())) + } + + fn has_relation_to_inner( + &self, + db: &'db dyn Db, + other: &Signature<'db>, + inferable: InferableTypeVars<'_, 'db>, + relation: TypeRelation<'db>, + relation_visitor: &HasRelationToVisitor<'db>, + disjointness_visitor: &IsDisjointVisitor<'db>, ) -> ConstraintSet<'db> { /// A helper struct to zip two slices of parameters together that provides control over the /// two iterators individually. It also keeps track of the current parameter in each @@ -797,15 +872,6 @@ impl<'db> Signature<'db> { } } - // The typevars in self and other should also be considered inferable when checking whether - // two signatures are equivalent. - let self_inferable = - (self.generic_context).map(|generic_context| generic_context.inferable_typevars(db)); - let other_inferable = - (other.generic_context).map(|generic_context| generic_context.inferable_typevars(db)); - let inferable = inferable.merge(self_inferable.as_ref()); - let inferable = inferable.merge(other_inferable.as_ref()); - let mut result = ConstraintSet::from(true); let mut check_types = |type1: Option>, type2: Option>| { let type1 = type1.unwrap_or(Type::unknown());