From c02bd11b9390ba1b116e3769edbf0bb848d26a07 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Tue, 16 Dec 2025 12:16:49 -0500 Subject: [PATCH] [ty] Infer typevar specializations for `Callable` types (#21551) This is a first stab at solving https://github.com/astral-sh/ty/issues/500, at least in part, with the old solver. We add a new `TypeRelation` that lets us opt into using constraint sets to describe when a typevar is assignability to some type, and then use that to calculate a constraint set that describes when two callable types are assignable. If the callable types contain typevars, that constraint set will describe their valid specializations. We can then walk through all of the ways the constraint set can be satisfied, and record a type mapping in the old solver for each one. --------- Co-authored-by: Carl Meyer Co-authored-by: Alex Waygood --- crates/ruff_benchmark/benches/ty_walltime.rs | 2 +- .../resources/mdtest/annotations/self.md | 2 +- .../resources/mdtest/async.md | 4 +- .../resources/mdtest/dataclasses/fields.md | 3 +- .../resources/mdtest/decorators.md | 9 +- .../resources/mdtest/deprecated.md | 4 +- .../mdtest/generics/legacy/functions.md | 3 +- .../mdtest/generics/legacy/variables.md | 3 +- .../mdtest/generics/pep695/functions.md | 159 ++++++++- .../mdtest/generics/pep695/paramspec.md | 21 +- .../mdtest/generics/pep695/variables.md | 2 +- .../mdtest/generics/specialize_constrained.md | 3 +- .../resources/mdtest/literal_promotion.md | 2 +- ...pr…_-_Introduction_(cff2724f4c9d28c4).snap | 8 +- crates/ty_python_semantic/src/types.rs | 144 +++++++- crates/ty_python_semantic/src/types/class.rs | 4 +- .../src/types/constraints.rs | 328 +++++++++++++----- .../ty_python_semantic/src/types/generics.rs | 187 +++++++--- .../src/types/signatures.rs | 281 ++++++++++++++- 19 files changed, 978 insertions(+), 191 deletions(-) diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index 490cafc650..5826da6073 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -223,7 +223,7 @@ static STATIC_FRAME: Benchmark = Benchmark::new( max_dep_date: "2025-08-09", python_version: PythonVersion::PY311, }, - 950, + 1100, ); #[track_caller] diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/self.md b/crates/ty_python_semantic/resources/mdtest/annotations/self.md index 016cc848b8..b3563e8935 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/self.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/self.md @@ -194,7 +194,7 @@ reveal_type(B().name_does_not_matter()) # revealed: B reveal_type(B().positional_only(1)) # revealed: B reveal_type(B().keyword_only(x=1)) # revealed: B # TODO: This should deally be `B` -reveal_type(B().decorated_method()) # revealed: Unknown +reveal_type(B().decorated_method()) # revealed: Self@decorated_method reveal_type(B().a_property) # revealed: B diff --git a/crates/ty_python_semantic/resources/mdtest/async.md b/crates/ty_python_semantic/resources/mdtest/async.md index 416c88b09c..9fad1a2506 100644 --- a/crates/ty_python_semantic/resources/mdtest/async.md +++ b/crates/ty_python_semantic/resources/mdtest/async.md @@ -43,9 +43,7 @@ async def main(): loop = asyncio.get_event_loop() with concurrent.futures.ThreadPoolExecutor() as pool: result = await loop.run_in_executor(pool, blocking_function) - - # TODO: should be `int` - reveal_type(result) # revealed: Unknown + reveal_type(result) # revealed: int ``` ### `asyncio.Task` diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md index 28a69081e5..2912adb049 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md @@ -82,8 +82,7 @@ def get_default() -> str: reveal_type(field(default=1)) # revealed: dataclasses.Field[Literal[1]] reveal_type(field(default=None)) # revealed: dataclasses.Field[None] -# TODO: this could ideally be `dataclasses.Field[str]` with a better generics solver -reveal_type(field(default_factory=get_default)) # revealed: dataclasses.Field[Unknown] +reveal_type(field(default_factory=get_default)) # revealed: dataclasses.Field[str] ``` ## dataclass_transform field_specifiers diff --git a/crates/ty_python_semantic/resources/mdtest/decorators.md b/crates/ty_python_semantic/resources/mdtest/decorators.md index f92eca8003..124e2d9a82 100644 --- a/crates/ty_python_semantic/resources/mdtest/decorators.md +++ b/crates/ty_python_semantic/resources/mdtest/decorators.md @@ -144,11 +144,10 @@ from functools import cache def f(x: int) -> int: return x**2 -# TODO: Should be `_lru_cache_wrapper[int]` -reveal_type(f) # revealed: _lru_cache_wrapper[Unknown] - -# TODO: Should be `int` -reveal_type(f(1)) # revealed: Unknown +# revealed: _lru_cache_wrapper[int] +reveal_type(f) +# revealed: int +reveal_type(f(1)) ``` ## Lambdas as decorators diff --git a/crates/ty_python_semantic/resources/mdtest/deprecated.md b/crates/ty_python_semantic/resources/mdtest/deprecated.md index 80d5108508..9497c77fe4 100644 --- a/crates/ty_python_semantic/resources/mdtest/deprecated.md +++ b/crates/ty_python_semantic/resources/mdtest/deprecated.md @@ -11,9 +11,9 @@ classes. Uses of these items should subsequently produce a warning. from typing_extensions import deprecated @deprecated("use OtherClass") -def myfunc(): ... +def myfunc(x: int): ... -myfunc() # error: [deprecated] "use OtherClass" +myfunc(1) # error: [deprecated] "use OtherClass" ``` ```py diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 77e3c23c78..c674f7a9a1 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -555,8 +555,7 @@ def identity(x: T) -> T: def head(xs: list[T]) -> T: return xs[0] -# TODO: this should be `Literal[1]` -reveal_type(invoke(identity, 1)) # revealed: Unknown +reveal_type(invoke(identity, 1)) # revealed: Literal[1] # TODO: this should be `Unknown | int` reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/variables.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/variables.md index b419b61e71..04dfdf648f 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/variables.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/variables.md @@ -518,8 +518,7 @@ V = TypeVar("V", default="V") class D(Generic[V]): x: V -# TODO: we shouldn't leak a typevar like this in type inference -reveal_type(D().x) # revealed: V@D +reveal_type(D().x) # revealed: Unknown ``` ## Regression diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 9d3cca6b57..8121ce5d26 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -493,8 +493,7 @@ def identity[T](x: T) -> T: def head[T](xs: list[T]) -> T: return xs[0] -# TODO: this should be `Literal[1]` -reveal_type(invoke(identity, 1)) # revealed: Unknown +reveal_type(invoke(identity, 1)) # revealed: Literal[1] # TODO: this should be `Unknown | int` reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown @@ -736,3 +735,159 @@ def f[T](x: T, y: Not[T]) -> T: y = x # error: [invalid-assignment] return x ``` + +## `Callable` parameters + +We can recurse into the parameters and return values of `Callable` parameters to infer +specializations of a generic function. + +```py +from typing import Any, Callable, NoReturn, overload, Self +from ty_extensions import generic_context, into_callable + +def accepts_callable[**P, R](callable: Callable[P, R]) -> Callable[P, R]: + return callable + +def returns_int() -> int: + raise NotImplementedError + +# revealed: () -> int +reveal_type(into_callable(returns_int)) +# revealed: () -> int +reveal_type(accepts_callable(returns_int)) +# revealed: int +reveal_type(accepts_callable(returns_int)()) + +class ClassWithoutConstructor: ... + +# revealed: () -> ClassWithoutConstructor +reveal_type(into_callable(ClassWithoutConstructor)) +# revealed: () -> ClassWithoutConstructor +reveal_type(accepts_callable(ClassWithoutConstructor)) +# revealed: ClassWithoutConstructor +reveal_type(accepts_callable(ClassWithoutConstructor)()) + +class ClassWithNew: + def __new__(cls, *args, **kwargs) -> Self: + raise NotImplementedError + +# revealed: (...) -> ClassWithNew +reveal_type(into_callable(ClassWithNew)) +# revealed: (...) -> ClassWithNew +reveal_type(accepts_callable(ClassWithNew)) +# revealed: ClassWithNew +reveal_type(accepts_callable(ClassWithNew)()) + +class ClassWithInit: + def __init__(self) -> None: ... + +# revealed: () -> ClassWithInit +reveal_type(into_callable(ClassWithInit)) +# revealed: () -> ClassWithInit +reveal_type(accepts_callable(ClassWithInit)) +# revealed: ClassWithInit +reveal_type(accepts_callable(ClassWithInit)()) + +class ClassWithNewAndInit: + def __new__(cls, *args, **kwargs) -> Self: + raise NotImplementedError + + def __init__(self, x: int) -> None: ... + +# TODO: We do not currently solve a common behavioral supertype for the two solutions of P. +# revealed: ((...) -> ClassWithNewAndInit) | ((x: int) -> ClassWithNewAndInit) +reveal_type(into_callable(ClassWithNewAndInit)) +# TODO: revealed: ((...) -> ClassWithNewAndInit) | ((x: int) -> ClassWithNewAndInit) +# revealed: (...) -> ClassWithNewAndInit +reveal_type(accepts_callable(ClassWithNewAndInit)) +# revealed: ClassWithNewAndInit +reveal_type(accepts_callable(ClassWithNewAndInit)()) + +class Meta(type): + def __call__(cls, *args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError + +class ClassWithNoReturnMetatype(metaclass=Meta): + def __new__(cls, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError + +# TODO: The return types here are wrong, because we end up creating a constraint (Never ≤ R), which +# we confuse with "R has no lower bound". +# revealed: (...) -> Never +reveal_type(into_callable(ClassWithNoReturnMetatype)) +# TODO: revealed: (...) -> Never +# revealed: (...) -> Unknown +reveal_type(accepts_callable(ClassWithNoReturnMetatype)) +# TODO: revealed: Never +# revealed: Unknown +reveal_type(accepts_callable(ClassWithNoReturnMetatype)()) + +class Proxy: ... + +class ClassWithIgnoredInit: + def __new__(cls) -> Proxy: + return Proxy() + + def __init__(self, x: int) -> None: ... + +# revealed: () -> Proxy +reveal_type(into_callable(ClassWithIgnoredInit)) +# revealed: () -> Proxy +reveal_type(accepts_callable(ClassWithIgnoredInit)) +# revealed: Proxy +reveal_type(accepts_callable(ClassWithIgnoredInit)()) + +class ClassWithOverloadedInit[T]: + t: T # invariant + + @overload + def __init__(self: "ClassWithOverloadedInit[int]", x: int) -> None: ... + @overload + def __init__(self: "ClassWithOverloadedInit[str]", x: str) -> None: ... + def __init__(self, x: int | str) -> None: ... + +# TODO: The old solver cannot handle this overloaded constructor. The ideal solution is that we +# would solve **P once, and map it to the entire overloaded signature of the constructor. This +# mapping would have to include the return types, since there are different return types for each +# overload. We would then also have to determine that R must be equal to the return type of **P's +# solution. + +# revealed: Overload[(x: int) -> ClassWithOverloadedInit[int], (x: str) -> ClassWithOverloadedInit[str]] +reveal_type(into_callable(ClassWithOverloadedInit)) +# TODO: revealed: Overload[(x: int) -> ClassWithOverloadedInit[int], (x: str) -> ClassWithOverloadedInit[str]] +# revealed: Overload[(x: int) -> ClassWithOverloadedInit[int] | ClassWithOverloadedInit[str], (x: str) -> ClassWithOverloadedInit[int] | ClassWithOverloadedInit[str]] +reveal_type(accepts_callable(ClassWithOverloadedInit)) +# TODO: revealed: ClassWithOverloadedInit[int] +# revealed: ClassWithOverloadedInit[int] | ClassWithOverloadedInit[str] +reveal_type(accepts_callable(ClassWithOverloadedInit)(0)) +# TODO: revealed: ClassWithOverloadedInit[str] +# revealed: ClassWithOverloadedInit[int] | ClassWithOverloadedInit[str] +reveal_type(accepts_callable(ClassWithOverloadedInit)("")) + +class GenericClass[T]: + t: T # invariant + + def __new__(cls, x: list[T], y: list[T]) -> Self: + raise NotImplementedError + +def _(x: list[str]): + # TODO: This fails because we are not propagating GenericClass's generic context into the + # Callable that we create for it. + # revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass] + reveal_type(into_callable(GenericClass)) + # revealed: ty_extensions.GenericContext[T@GenericClass] + reveal_type(generic_context(into_callable(GenericClass))) + + # revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass] + reveal_type(accepts_callable(GenericClass)) + # TODO: revealed: ty_extensions.GenericContext[T@GenericClass] + # revealed: None + reveal_type(generic_context(accepts_callable(GenericClass))) + + # TODO: revealed: GenericClass[str] + # TODO: no errors + # revealed: GenericClass[T@GenericClass] + # error: [invalid-argument-type] + # error: [invalid-argument-type] + reveal_type(accepts_callable(GenericClass)(x, x)) +``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md index e6e6acd35c..df2508744a 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md @@ -503,7 +503,8 @@ class C[**P]: def __init__(self, f: Callable[P, int]) -> None: self.f = f -def f(x: int, y: str) -> bool: +# Note that the return type must match exactly, since C is invariant on the return type of C.f. +def f(x: int, y: str) -> int: return True c = C(f) @@ -618,6 +619,22 @@ reveal_type(foo.method) # revealed: bound method Foo[(int, str, /)].method(int, reveal_type(foo.method(1, "a")) # revealed: str ``` +### Gradual types propagate through `ParamSpec` inference + +```py +from typing import Callable + +def callable_identity[**P, R](func: Callable[P, R]) -> Callable[P, R]: + return func + +@callable_identity +def f(env: dict) -> None: + pass + +# revealed: (env: dict[Unknown, Unknown]) -> None +reveal_type(f) +``` + ### Overloads `overloaded.pyi`: @@ -662,7 +679,7 @@ reveal_type(change_return_type(int_int)) # revealed: Overload[(x: int) -> str, reveal_type(change_return_type(int_str)) # revealed: Overload[(x: int) -> str, (x: str) -> str] # error: [invalid-argument-type] -reveal_type(change_return_type(str_str)) # revealed: Overload[(x: int) -> str, (x: str) -> str] +reveal_type(change_return_type(str_str)) # revealed: (...) -> str # TODO: Both of these shouldn't raise an error # error: [invalid-argument-type] diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variables.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variables.md index d70c130649..cd24eb5f06 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variables.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variables.md @@ -883,7 +883,7 @@ reveal_type(C[int]().y) # revealed: int class D[T = T]: x: T -reveal_type(D().x) # revealed: T@D +reveal_type(D().x) # revealed: Unknown ``` [pep 695]: https://peps.python.org/pep-0695/ diff --git a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md index 0cf656ccc3..32956cdfa8 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md @@ -426,7 +426,8 @@ from ty_extensions import ConstraintSet, generic_context def mentions[T, U](): # (T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]) constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T]) - # revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]] + # TODO: revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]] + # revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = Unknown] reveal_type(generic_context(mentions).specialize_constrained(constraints)) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md index eb79c44b6c..65fc1c1602 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md +++ b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md @@ -304,7 +304,7 @@ x11: list[Literal[1] | Literal[2] | Literal[3]] = [1, 2, 3] reveal_type(x11) # revealed: list[Literal[1, 2, 3]] x12: Y[Y[Literal[1]]] = [[1]] -reveal_type(x12) # revealed: list[Y[Literal[1]]] +reveal_type(x12) # revealed: list[list[Literal[1]]] x13: list[tuple[Literal[1], Literal[2], Literal[3]]] = [(1, 2, 3)] reveal_type(x13) # revealed: list[tuple[Literal[1], Literal[2], Literal[3]]] diff --git a/crates/ty_python_semantic/resources/mdtest/snapshots/deprecated.md_-_Tests_for_the_`@depr…_-_Introduction_(cff2724f4c9d28c4).snap b/crates/ty_python_semantic/resources/mdtest/snapshots/deprecated.md_-_Tests_for_the_`@depr…_-_Introduction_(cff2724f4c9d28c4).snap index 89eb99e534..4e4d8f0ae7 100644 --- a/crates/ty_python_semantic/resources/mdtest/snapshots/deprecated.md_-_Tests_for_the_`@depr…_-_Introduction_(cff2724f4c9d28c4).snap +++ b/crates/ty_python_semantic/resources/mdtest/snapshots/deprecated.md_-_Tests_for_the_`@depr…_-_Introduction_(cff2724f4c9d28c4).snap @@ -15,9 +15,9 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/deprecated.md 1 | from typing_extensions import deprecated 2 | 3 | @deprecated("use OtherClass") - 4 | def myfunc(): ... + 4 | def myfunc(x: int): ... 5 | - 6 | myfunc() # error: [deprecated] "use OtherClass" + 6 | myfunc(1) # error: [deprecated] "use OtherClass" 7 | from typing_extensions import deprecated 8 | 9 | @deprecated("use BetterClass") @@ -42,9 +42,9 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/deprecated.md warning[deprecated]: The function `myfunc` is deprecated --> src/mdtest_snippet.py:6:1 | -4 | def myfunc(): ... +4 | def myfunc(x: int): ... 5 | -6 | myfunc() # error: [deprecated] "use OtherClass" +6 | myfunc(1) # error: [deprecated] "use OtherClass" | ^^^^^^ use OtherClass 7 | from typing_extensions import deprecated | diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 2632888aaf..e4a0228cd4 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -918,16 +918,33 @@ impl<'db> Type<'db> { previous: Self, cycle: &salsa::Cycle, ) -> Self { - // Avoid unioning two generic aliases of the same class together; this union will never - // simplify and is likely to cause downstream problems. This introduces the theoretical - // possibility of cycle oscillation involving such types (because we are not strictly - // widening the type on each iteration), but so far we have not seen an example of that. + // When we encounter a salsa cycle, we want to avoid oscillating between two or more types + // without converging on a fixed-point result. Most of the time, we union together the + // types from each cycle iteration to ensure that our result is monotonic, even if we + // encounter oscillation. + // + // However, there are a couple of cases where we don't want to do that, and want to use the + // later cycle iteration's result directly. This introduces the theoretical possibility of + // cycle oscillation involving such types (because we are not strictly widening the type on + // each iteration), but so far we have not seen an example of that. match (previous, self) { + // Avoid unioning two generic aliases of the same class together; this union will never + // simplify and is likely to cause downstream problems. (Type::GenericAlias(prev_alias), Type::GenericAlias(curr_alias)) if prev_alias.origin(db) == curr_alias.origin(db) => { self } + + // Similarly, don't union together two function literals, since there are several parts + // of our type inference machinery that assume that we infer a single FunctionLiteral + // type for each overload of each function definition. + (Type::FunctionLiteral(prev_function), Type::FunctionLiteral(curr_function)) + if prev_function.definition(db) == curr_function.definition(db) => + { + self + } + _ => { // Also avoid unioning in a previous type which contains a Divergent from the // current cycle, if the most-recent type does not. This cannot cause an @@ -1843,7 +1860,7 @@ impl<'db> Type<'db> { } } Type::ClassLiteral(class_literal) => { - Some(class_literal.default_specialization(db).into_callable(db)) + Some(class_literal.identity_specialization(db).into_callable(db)) } Type::GenericAlias(alias) => Some(ClassType::Generic(alias).into_callable(db)), @@ -1975,6 +1992,16 @@ impl<'db> Type<'db> { .is_always_satisfied(db) } + /// Return true if this type is assignable to type `target` using constraint-set assignability. + /// + /// This uses `TypeRelation::ConstraintSetAssignability`, which encodes typevar relations into + /// a constraint set and lets `satisfied_by_all_typevars` perform existential vs universal + /// reasoning depending on inferable typevars. + pub fn is_constraint_set_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool { + self.when_constraint_set_assignable_to(db, target, InferableTypeVars::None) + .is_always_satisfied(db) + } + fn when_assignable_to( self, db: &'db dyn Db, @@ -1984,6 +2011,20 @@ impl<'db> Type<'db> { self.has_relation_to(db, target, inferable, TypeRelation::Assignability) } + fn when_constraint_set_assignable_to( + self, + db: &'db dyn Db, + target: Type<'db>, + inferable: InferableTypeVars<'_, 'db>, + ) -> ConstraintSet<'db> { + self.has_relation_to( + db, + target, + inferable, + TypeRelation::ConstraintSetAssignability, + ) + } + /// Return `true` if it would be redundant to add `self` to a union that already contains `other`. /// /// See [`TypeRelation::Redundancy`] for more details. @@ -2049,6 +2090,21 @@ impl<'db> Type<'db> { return constraints.implies_subtype_of(db, self, target); } + // Handle the new constraint-set-based assignability relation next. Comparisons with a + // typevar are translated directly into a constraint set. + if relation.is_constraint_set_assignability() { + // A typevar satisfies a relation when...it satisfies the relation. Yes that's a + // tautology! We're moving the caller's subtyping/assignability requirement into a + // constraint set. If the typevar has an upper bound or constraints, then the relation + // only has to hold when the typevar has a valid specialization (i.e., one that + // satisfies the upper bound/constraints). + if let Type::TypeVar(bound_typevar) = self { + return ConstraintSet::constrain_typevar(db, bound_typevar, Type::Never, target); + } else if let Type::TypeVar(bound_typevar) = target { + return ConstraintSet::constrain_typevar(db, bound_typevar, self, Type::object()); + } + } + match (self, target) { // Everything is a subtype of `object`. (_, Type::NominalInstance(instance)) if instance.is_object() => { @@ -2129,7 +2185,7 @@ impl<'db> Type<'db> { ); ConstraintSet::from(match relation { TypeRelation::Subtyping | TypeRelation::SubtypingAssuming(_) => false, - TypeRelation::Assignability => true, + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => true, TypeRelation::Redundancy => match target { Type::Dynamic(_) => true, Type::Union(union) => union.elements(db).iter().any(Type::is_dynamic), @@ -2139,7 +2195,7 @@ impl<'db> Type<'db> { } (_, Type::Dynamic(_)) => ConstraintSet::from(match relation { TypeRelation::Subtyping | TypeRelation::SubtypingAssuming(_) => false, - TypeRelation::Assignability => true, + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => true, TypeRelation::Redundancy => match self { Type::Dynamic(_) => true, Type::Intersection(intersection) => { @@ -2403,14 +2459,19 @@ impl<'db> Type<'db> { TypeRelation::Subtyping | TypeRelation::Redundancy | TypeRelation::SubtypingAssuming(_) => self, - TypeRelation::Assignability => self.bottom_materialization(db), + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => { + self.bottom_materialization(db) + } }; intersection.negative(db).iter().when_all(db, |&neg_ty| { let neg_ty = match relation { TypeRelation::Subtyping | TypeRelation::Redundancy | TypeRelation::SubtypingAssuming(_) => neg_ty, - TypeRelation::Assignability => neg_ty.bottom_materialization(db), + TypeRelation::Assignability + | TypeRelation::ConstraintSetAssignability => { + neg_ty.bottom_materialization(db) + } }; self_ty.is_disjoint_from_impl( db, @@ -9780,6 +9841,22 @@ impl<'db> TypeVarInstance<'db> { )) } + fn type_is_self_referential(self, db: &'db dyn Db, ty: Type<'db>) -> bool { + let identity = self.identity(db); + any_over_type( + db, + ty, + &|ty| match ty { + Type::TypeVar(bound_typevar) => identity == bound_typevar.typevar(db).identity(db), + Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) => { + identity == typevar.identity(db) + } + _ => false, + }, + false, + ) + } + #[salsa::tracked( cycle_fn=lazy_bound_or_constraints_cycle_recover, cycle_initial=lazy_bound_or_constraints_cycle_initial, @@ -9802,6 +9879,11 @@ impl<'db> TypeVarInstance<'db> { } _ => return None, }; + + if self.type_is_self_referential(db, ty) { + return None; + } + Some(TypeVarBoundOrConstraints::UpperBound(ty)) } @@ -9849,6 +9931,15 @@ impl<'db> TypeVarInstance<'db> { } _ => return None, }; + + if constraints + .elements(db) + .iter() + .any(|ty| self.type_is_self_referential(db, *ty)) + { + return None; + } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) } @@ -9895,15 +9986,11 @@ impl<'db> TypeVarInstance<'db> { let definition = self.definition(db)?; let module = parsed_module(db, definition.file(db)).load(db); - match definition.kind(db) { + let ty = match definition.kind(db) { // PEP 695 typevar DefinitionKind::TypeVar(typevar) => { let typevar_node = typevar.node(&module); - Some(definition_expression_type( - db, - definition, - typevar_node.default.as_ref()?, - )) + definition_expression_type(db, definition, typevar_node.default.as_ref()?) } // legacy typevar / ParamSpec DefinitionKind::Assignment(assignment) => { @@ -9913,9 +10000,9 @@ impl<'db> TypeVarInstance<'db> { let expr = &call_expr.arguments.find_keyword("default")?.value; let default_type = definition_expression_type(db, definition, expr); if known_class == Some(KnownClass::ParamSpec) { - Some(convert_type_to_paramspec_value(db, default_type)) + convert_type_to_paramspec_value(db, default_type) } else { - Some(default_type) + default_type } } // PEP 695 ParamSpec @@ -9923,10 +10010,16 @@ impl<'db> TypeVarInstance<'db> { let paramspec_node = paramspec.node(&module); let default_ty = definition_expression_type(db, definition, paramspec_node.default.as_ref()?); - Some(convert_type_to_paramspec_value(db, default_ty)) + convert_type_to_paramspec_value(db, default_ty) } - _ => None, + _ => return None, + }; + + if self.type_is_self_referential(db, ty) { + return None; } + + Some(ty) } pub fn bind_pep695(self, db: &'db dyn Db) -> Option> { @@ -12003,6 +12096,11 @@ pub(crate) enum TypeRelation<'db> { /// are not actually subtypes of each other. (That is, `implies_subtype_of(false, int, str)` /// will return true!) SubtypingAssuming(ConstraintSet<'db>), + + /// A placeholder for the new assignability relation that uses constraint sets to encode + /// relationships with a typevar. This will eventually replace `Assignability`, but allows us + /// to start using the new relation in a controlled manner in some places. + ConstraintSetAssignability, } impl TypeRelation<'_> { @@ -12010,6 +12108,10 @@ impl TypeRelation<'_> { matches!(self, TypeRelation::Assignability) } + pub(crate) const fn is_constraint_set_assignability(self) -> bool { + matches!(self, TypeRelation::ConstraintSetAssignability) + } + pub(crate) const fn is_subtyping(self) -> bool { matches!(self, TypeRelation::Subtyping) } @@ -12503,6 +12605,10 @@ impl<'db> CallableTypes<'db> { } } + fn as_slice(&self) -> &[CallableType<'db>] { + &self.0 + } + fn into_inner(self) -> SmallVec<[CallableType<'db>; 1]> { self.0 } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 67d12c7d69..d5fd42ace7 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -637,7 +637,9 @@ impl<'db> ClassType<'db> { | TypeRelation::SubtypingAssuming(_) => { ConstraintSet::from(other.is_object(db)) } - TypeRelation::Assignability => ConstraintSet::from(!other.is_final(db)), + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => { + ConstraintSet::from(!other.is_final(db)) + } }, // Protocol, Generic, and TypedDict are not represented by a ClassType. diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 82d6c3b0ad..77b96bd74b 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -76,12 +76,14 @@ use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; use crate::types::generics::{GenericContext, InferableTypeVars, Specialization}; -use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; +use crate::types::visitor::{ + TypeCollector, TypeVisitor, any_over_type, walk_type_with_recursion_guard, +}; use crate::types::{ BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type, }; -use crate::{Db, FxOrderSet}; +use crate::{Db, FxOrderMap}; /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { @@ -349,6 +351,18 @@ impl<'db> ConstraintSet<'db> { self.node.satisfied_by_all_typevars(db, inferable) } + pub(crate) fn limit_to_valid_specializations(self, db: &'db dyn Db) -> Self { + let mut result = self.node; + let mut seen = FxHashSet::default(); + self.node.for_each_constraint(db, &mut |constraint, _| { + let bound_typevar = constraint.typevar(db); + if seen.insert(bound_typevar) { + result = result.and_with_offset(db, bound_typevar.valid_specializations(db)); + } + }); + Self { node: result } + } + /// Updates this constraint set to hold the union of itself and another constraint set. /// /// In the result, `self` will appear before `other` according to the `source_order` of the BDD @@ -432,6 +446,10 @@ impl<'db> ConstraintSet<'db> { Self { node } } + pub(crate) fn for_each_path(self, db: &'db dyn Db, f: impl FnMut(&PathAssignments<'db>)) { + self.node.for_each_path(db, f); + } + pub(crate) fn range( db: &'db dyn Db, lower: Type<'db>, @@ -490,9 +508,9 @@ impl IntersectionResult<'_> { /// lower and upper bound. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub(crate) struct ConstrainedTypeVar<'db> { - typevar: BoundTypeVarInstance<'db>, - lower: Type<'db>, - upper: Type<'db>, + pub(crate) typevar: BoundTypeVarInstance<'db>, + pub(crate) lower: Type<'db>, + pub(crate) upper: Type<'db>, } // The Salsa heap is tracked separately. @@ -602,7 +620,7 @@ impl<'db> ConstrainedTypeVar<'db> { // If `lower ≰ upper`, then the constraint cannot be satisfied, since there is no type that // is both greater than `lower`, and less than `upper`. - if !lower.is_assignable_to(db, upper) { + if !lower.is_constraint_set_assignable_to(db, upper) { return Node::AlwaysFalse; } @@ -713,7 +731,11 @@ impl<'db> ConstrainedTypeVar<'db> { /// simplifications that we perform that operate on constraints with the same typevar, and this /// ensures that we can find all candidate simplifications more easily. fn ordering(self, db: &'db dyn Db) -> impl Ord { - (self.typevar(db).identity(db), self.as_id()) + ( + self.typevar(db).binding_context(db), + self.typevar(db).identity(db), + self.as_id(), + ) } /// Returns whether this constraint implies another — i.e., whether every type that @@ -725,8 +747,12 @@ impl<'db> ConstrainedTypeVar<'db> { if !self.typevar(db).is_same_typevar_as(db, other.typevar(db)) { return false; } - other.lower(db).is_assignable_to(db, self.lower(db)) - && self.upper(db).is_assignable_to(db, other.upper(db)) + other + .lower(db) + .is_constraint_set_assignable_to(db, self.lower(db)) + && self + .upper(db) + .is_constraint_set_assignable_to(db, other.upper(db)) } /// Returns the intersection of two range constraints, or `None` if the intersection is empty. @@ -737,7 +763,7 @@ impl<'db> ConstrainedTypeVar<'db> { // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. - if !lower.is_assignable_to(db, upper) { + if !lower.is_constraint_set_assignable_to(db, upper) { return IntersectionResult::Disjoint; } @@ -748,7 +774,7 @@ impl<'db> ConstrainedTypeVar<'db> { IntersectionResult::Simplified(Self::new(db, self.typevar(db), lower, upper)) } - fn display(self, db: &'db dyn Db) -> impl Display { + pub(crate) fn display(self, db: &'db dyn Db) -> impl Display { self.display_inner(db, false) } @@ -973,6 +999,41 @@ impl<'db> Node<'db> { } } + fn for_each_path(self, db: &'db dyn Db, mut f: impl FnMut(&PathAssignments<'db>)) { + match self { + Node::AlwaysTrue => {} + Node::AlwaysFalse => {} + Node::Interior(interior) => { + let map = interior.sequent_map(db); + let mut path = PathAssignments::default(); + self.for_each_path_inner(db, &mut f, map, &mut path); + } + } + } + + fn for_each_path_inner( + self, + db: &'db dyn Db, + f: &mut dyn FnMut(&PathAssignments<'db>), + map: &SequentMap<'db>, + path: &mut PathAssignments<'db>, + ) { + match self { + Node::AlwaysTrue => f(path), + Node::AlwaysFalse => {} + Node::Interior(interior) => { + let constraint = interior.constraint(db); + let source_order = interior.source_order(db); + path.walk_edge(db, map, constraint.when_true(), source_order, |path, _| { + interior.if_true(db).for_each_path_inner(db, f, map, path); + }); + path.walk_edge(db, map, constraint.when_false(), source_order, |path, _| { + interior.if_false(db).for_each_path_inner(db, f, map, path); + }); + } + } + } + /// Returns whether this BDD represent the constant function `true`. fn is_always_satisfied(self, db: &'db dyn Db) -> bool { match self { @@ -1000,8 +1061,9 @@ impl<'db> Node<'db> { // from it) causes the if_true edge to become impossible. We want to ignore // impossible paths, and so we treat them as passing the "always satisfied" check. let constraint = interior.constraint(db); + let source_order = interior.source_order(db); let true_always_satisfied = path - .walk_edge(db, map, constraint.when_true(), |path, _| { + .walk_edge(db, map, constraint.when_true(), source_order, |path, _| { interior .if_true(db) .is_always_satisfied_inner(db, map, path) @@ -1012,7 +1074,7 @@ impl<'db> Node<'db> { } // Ditto for the if_false branch - path.walk_edge(db, map, constraint.when_false(), |path, _| { + path.walk_edge(db, map, constraint.when_false(), source_order, |path, _| { interior .if_false(db) .is_always_satisfied_inner(db, map, path) @@ -1049,8 +1111,9 @@ impl<'db> Node<'db> { // from it) causes the if_true edge to become impossible. We want to ignore // impossible paths, and so we treat them as passing the "never satisfied" check. let constraint = interior.constraint(db); + let source_order = interior.source_order(db); let true_never_satisfied = path - .walk_edge(db, map, constraint.when_true(), |path, _| { + .walk_edge(db, map, constraint.when_true(), source_order, |path, _| { interior.if_true(db).is_never_satisfied_inner(db, map, path) }) .unwrap_or(true); @@ -1059,7 +1122,7 @@ impl<'db> Node<'db> { } // Ditto for the if_false branch - path.walk_edge(db, map, constraint.when_false(), |path, _| { + path.walk_edge(db, map, constraint.when_false(), source_order, |path, _| { interior .if_false(db) .is_never_satisfied_inner(db, map, path) @@ -1408,7 +1471,7 @@ impl<'db> Node<'db> { db, current_bounds.iter().map(|bounds| bounds.upper), ); - greatest_lower_bound.is_assignable_to(db, least_upper_bound) + greatest_lower_bound.is_constraint_set_assignable_to(db, least_upper_bound) }); // We've been tracking the lower and upper bound that the types for this path must @@ -1946,22 +2009,27 @@ impl<'db> InteriorNode<'db> { fn exists_one(self, db: &'db dyn Db, bound_typevar: BoundTypeVarIdentity<'db>) -> Node<'db> { let map = self.sequent_map(db); let mut path = PathAssignments::default(); + let mentions_typevar = |ty: Type<'db>| match ty { + Type::TypeVar(haystack) => haystack.identity(db) == bound_typevar, + _ => false, + }; self.abstract_one_inner( db, - // Remove any node that constrains `bound_typevar`, or that has a lower/upper bound of - // `bound_typevar`. + // Remove any node that constrains `bound_typevar`, or that has a lower/upper bound + // that mentions `bound_typevar`. + // TODO: This will currently remove constraints that mention a typevar, but the sequent + // map is not yet propagating all derived facts about those constraints. For instance, + // removing `T` from `T ≤ int ∧ U ≤ Sequence[T]` should produce `U ≤ Sequence[int]`. + // But that requires `T ≤ int ∧ U ≤ Sequence[T] → U ≤ Sequence[int]` to exist in the + // sequent map. It doesn't, and so we currently produce `U ≤ Unknown` in this case. &mut |constraint| { if constraint.typevar(db).identity(db) == bound_typevar { return true; } - if let Type::TypeVar(lower_bound_typevar) = constraint.lower(db) - && lower_bound_typevar.identity(db) == bound_typevar - { + if any_over_type(db, constraint.lower(db), &mentions_typevar, false) { return true; } - if let Type::TypeVar(upper_bound_typevar) = constraint.upper(db) - && upper_bound_typevar.identity(db) == bound_typevar - { + if any_over_type(db, constraint.upper(db), &mentions_typevar, false) { return true; } false @@ -1985,9 +2053,7 @@ impl<'db> InteriorNode<'db> { if constraint.typevar(db).identity(db) != bound_typevar { return true; } - if matches!(constraint.lower(db), Type::TypeVar(_)) - || matches!(constraint.upper(db), Type::TypeVar(_)) - { + if constraint.lower(db).has_typevar(db) || constraint.upper(db).has_typevar(db) { return true; } false @@ -2005,6 +2071,7 @@ impl<'db> InteriorNode<'db> { path: &mut PathAssignments<'db>, ) -> Node<'db> { let self_constraint = self.constraint(db); + let self_source_order = self.source_order(db); if should_remove(self_constraint) { // If we should remove constraints involving this typevar, then we replace this node // with the OR of its if_false/if_true edges. That is, the result is true if there's @@ -2020,59 +2087,83 @@ impl<'db> InteriorNode<'db> { // way of tracking source order for derived facts. let self_source_order = self.source_order(db); let if_true = path - .walk_edge(db, map, self_constraint.when_true(), |path, new_range| { - let branch = self - .if_true(db) - .abstract_one_inner(db, should_remove, map, path); - path.assignments[new_range] - .iter() - .filter(|assignment| { - // Don't add back any derived facts if they are ones that we would have - // removed! - !should_remove(assignment.constraint()) - }) - .fold(branch, |branch, assignment| { - branch.and( - db, - Node::new_satisfied_constraint(db, *assignment, self_source_order), - ) - }) - }) + .walk_edge( + db, + map, + self_constraint.when_true(), + self_source_order, + |path, new_range| { + let branch = + self.if_true(db) + .abstract_one_inner(db, should_remove, map, path); + path.assignments[new_range] + .iter() + .filter(|(assignment, _)| { + // Don't add back any derived facts if they are ones that we would have + // removed! + !should_remove(assignment.constraint()) + }) + .fold(branch, |branch, (assignment, source_order)| { + branch.and( + db, + Node::new_satisfied_constraint(db, *assignment, *source_order), + ) + }) + }, + ) .unwrap_or(Node::AlwaysFalse); let if_false = path - .walk_edge(db, map, self_constraint.when_false(), |path, new_range| { - let branch = self - .if_false(db) - .abstract_one_inner(db, should_remove, map, path); - path.assignments[new_range] - .iter() - .filter(|assignment| { - // Don't add back any derived facts if they are ones that we would have - // removed! - !should_remove(assignment.constraint()) - }) - .fold(branch, |branch, assignment| { - branch.and( - db, - Node::new_satisfied_constraint(db, *assignment, self_source_order), - ) - }) - }) + .walk_edge( + db, + map, + self_constraint.when_false(), + self_source_order, + |path, new_range| { + let branch = + self.if_false(db) + .abstract_one_inner(db, should_remove, map, path); + path.assignments[new_range] + .iter() + .filter(|(assignment, _)| { + // Don't add back any derived facts if they are ones that we would have + // removed! + !should_remove(assignment.constraint()) + }) + .fold(branch, |branch, (assignment, source_order)| { + branch.and( + db, + Node::new_satisfied_constraint(db, *assignment, *source_order), + ) + }) + }, + ) .unwrap_or(Node::AlwaysFalse); if_true.or(db, if_false) } else { // Otherwise, we abstract the if_false/if_true edges recursively. let if_true = path - .walk_edge(db, map, self_constraint.when_true(), |path, _| { - self.if_true(db) - .abstract_one_inner(db, should_remove, map, path) - }) + .walk_edge( + db, + map, + self_constraint.when_true(), + self_source_order, + |path, _| { + self.if_true(db) + .abstract_one_inner(db, should_remove, map, path) + }, + ) .unwrap_or(Node::AlwaysFalse); let if_false = path - .walk_edge(db, map, self_constraint.when_false(), |path, _| { - self.if_false(db) - .abstract_one_inner(db, should_remove, map, path) - }) + .walk_edge( + db, + map, + self_constraint.when_false(), + self_source_order, + |path, _| { + self.if_false(db) + .abstract_one_inner(db, should_remove, map, path) + }, + ) .unwrap_or(Node::AlwaysFalse); // NB: We cannot use `Node::new` here, because the recursive calls might introduce new // derived constraints into the result, and those constraints might appear before this @@ -2555,7 +2646,7 @@ fn sequent_map_cycle_initial<'db>( /// An assignment of one BDD variable to either `true` or `false`. (When evaluating a BDD, we /// must provide an assignment for each variable present in the BDD.) #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -enum ConstraintAssignment<'db> { +pub(crate) enum ConstraintAssignment<'db> { Positive(ConstrainedTypeVar<'db>), Negative(ConstrainedTypeVar<'db>), } @@ -2990,6 +3081,24 @@ impl<'db> SequentMap<'db> { (bound_constraint.lower(db), constrained_upper) } + // (CL ≤ C ≤ pivot) ∧ (pivot ≤ B ≤ BU) → (CL ≤ C ≤ B) + (constrained_lower, constrained_upper) + if constrained_upper == bound_constraint.lower(db) + && !constrained_upper.is_never() + && !constrained_upper.is_object() => + { + (constrained_lower, Type::TypeVar(bound_typevar)) + } + + // (pivot ≤ C ≤ CU) ∧ (BL ≤ B ≤ pivot) → (B ≤ C ≤ CU) + (constrained_lower, constrained_upper) + if constrained_lower == bound_constraint.upper(db) + && !constrained_lower.is_never() + && !constrained_lower.is_object() => + { + (Type::TypeVar(bound_typevar), constrained_upper) + } + _ => return, }; @@ -3012,17 +3121,36 @@ impl<'db> SequentMap<'db> { let left_upper = left_constraint.upper(db); let right_lower = right_constraint.lower(db); let right_upper = right_constraint.upper(db); + let new_constraint = |bound_typevar: BoundTypeVarInstance<'db>, + right_lower: Type<'db>, + right_upper: Type<'db>| { + let right_lower = if let Type::TypeVar(other_bound_typevar) = right_lower + && bound_typevar.is_same_typevar_as(db, other_bound_typevar) + { + Type::Never + } else { + right_lower + }; + let right_upper = if let Type::TypeVar(other_bound_typevar) = right_upper + && bound_typevar.is_same_typevar_as(db, other_bound_typevar) + { + Type::object() + } else { + right_upper + }; + ConstrainedTypeVar::new(db, bound_typevar, right_lower, right_upper) + }; let post_constraint = match (left_lower, left_upper) { (Type::TypeVar(bound_typevar), Type::TypeVar(other_bound_typevar)) if bound_typevar.is_same_typevar_as(db, other_bound_typevar) => { - ConstrainedTypeVar::new(db, bound_typevar, right_lower, right_upper) + new_constraint(bound_typevar, right_lower, right_upper) } (Type::TypeVar(bound_typevar), _) => { - ConstrainedTypeVar::new(db, bound_typevar, Type::Never, right_upper) + new_constraint(bound_typevar, Type::Never, right_upper) } (_, Type::TypeVar(bound_typevar)) => { - ConstrainedTypeVar::new(db, bound_typevar, right_lower, Type::object()) + new_constraint(bound_typevar, right_lower, Type::object()) } _ => return, }; @@ -3169,8 +3297,8 @@ impl<'db> SequentMap<'db> { /// The collection of constraints that we know to be true or false at a certain point when /// traversing a BDD. #[derive(Debug, Default)] -struct PathAssignments<'db> { - assignments: FxOrderSet>, +pub(crate) struct PathAssignments<'db> { + assignments: FxOrderMap, usize>, } impl<'db> PathAssignments<'db> { @@ -3201,6 +3329,7 @@ impl<'db> PathAssignments<'db> { db: &'db dyn Db, map: &SequentMap<'db>, assignment: ConstraintAssignment<'db>, + source_order: usize, f: impl FnOnce(&mut Self, Range) -> R, ) -> Option { // Record a snapshot of the assignments that we already knew held — both so that we can @@ -3213,12 +3342,12 @@ impl<'db> PathAssignments<'db> { target: "ty_python_semantic::types::constraints::PathAssignment", before = %format_args!( "[{}]", - self.assignments[..start].iter().map(|assignment| assignment.display(db)).format(", "), + self.assignments[..start].iter().map(|(assignment, _)| assignment.display(db)).format(", "), ), edge = %assignment.display(db), "walk edge", ); - let found_conflict = self.add_assignment(db, map, assignment); + let found_conflict = self.add_assignment(db, map, assignment, source_order); let result = if found_conflict.is_err() { // If that results in the path now being impossible due to a contradiction, return // without invoking the callback. @@ -3233,7 +3362,7 @@ impl<'db> PathAssignments<'db> { target: "ty_python_semantic::types::constraints::PathAssignment", new = %format_args!( "[{}]", - self.assignments[start..].iter().map(|assignment| assignment.display(db)).format(", "), + self.assignments[start..].iter().map(|(assignment, _)| assignment.display(db)).format(", "), ), "new assignments", ); @@ -3247,8 +3376,19 @@ impl<'db> PathAssignments<'db> { result } + pub(crate) fn positive_constraints( + &self, + ) -> impl Iterator, usize)> + '_ { + self.assignments + .iter() + .filter_map(|(assignment, source_order)| match assignment { + ConstraintAssignment::Positive(constraint) => Some((*constraint, *source_order)), + ConstraintAssignment::Negative(_) => None, + }) + } + fn assignment_holds(&self, assignment: ConstraintAssignment<'db>) -> bool { - self.assignments.contains(&assignment) + self.assignments.contains_key(&assignment) } /// Adds a new assignment, along with any derived information that we can infer from the new @@ -3259,26 +3399,34 @@ impl<'db> PathAssignments<'db> { db: &'db dyn Db, map: &SequentMap<'db>, assignment: ConstraintAssignment<'db>, + source_order: usize, ) -> Result<(), PathAssignmentConflict> { // First add this assignment. If it causes a conflict, return that as an error. If we've // already know this assignment holds, just return. - if self.assignments.contains(&assignment.negated()) { + if self.assignments.contains_key(&assignment.negated()) { tracing::trace!( target: "ty_python_semantic::types::constraints::PathAssignment", assignment = %assignment.display(db), facts = %format_args!( "[{}]", - self.assignments.iter().map(|assignment| assignment.display(db)).format(", "), + self.assignments.iter().map(|(assignment, _)| assignment.display(db)).format(", "), ), "found contradiction", ); return Err(PathAssignmentConflict); } - if !self.assignments.insert(assignment) { + if self.assignments.insert(assignment, source_order).is_some() { return Ok(()); } - // Then use our sequents to add additional facts that we know to be true. + // Then use our sequents to add additional facts that we know to be true. We currently + // reuse the `source_order` of the "real" constraint passed into `walk_edge` when we add + // these derived facts. + // + // TODO: This might not be stable enough, if we add more than one derived fact for this + // constraint. If we still see inconsistent test output, we might need a more complex + // way of tracking source order for derived facts. + // // TODO: This is very naive at the moment, partly for expediency, and partly because we // don't anticipate the sequent maps to be very large. We might consider avoiding the // brute-force search. @@ -3292,7 +3440,7 @@ impl<'db> PathAssignments<'db> { ante = %ante.display(db), facts = %format_args!( "[{}]", - self.assignments.iter().map(|assignment| assignment.display(db)).format(", "), + self.assignments.iter().map(|(assignment, _)| assignment.display(db)).format(", "), ), "found contradiction", ); @@ -3311,7 +3459,7 @@ impl<'db> PathAssignments<'db> { ante2 = %ante2.display(db), facts = %format_args!( "[{}]", - self.assignments.iter().map(|assignment| assignment.display(db)).format(", "), + self.assignments.iter().map(|(assignment, _)| assignment.display(db)).format(", "), ), "found contradiction", ); @@ -3324,7 +3472,7 @@ impl<'db> PathAssignments<'db> { if self.assignment_holds(ante1.when_true()) && self.assignment_holds(ante2.when_true()) { - self.add_assignment(db, map, post.when_true())?; + self.add_assignment(db, map, post.when_true(), source_order)?; } } } @@ -3332,7 +3480,7 @@ impl<'db> PathAssignments<'db> { for (ante, posts) in &map.single_implications { for post in posts { if self.assignment_holds(ante.when_true()) { - self.add_assignment(db, map, post.when_true())?; + self.add_assignment(db, map, post.when_true(), source_order)?; } } } @@ -3765,7 +3913,7 @@ impl<'db> GenericContext<'db> { // If `lower ≰ upper`, then there is no type that satisfies all of the paths in the // BDD. That's an ambiguous specialization, as described above. - if !greatest_lower_bound.is_assignable_to(db, least_upper_bound) { + if !greatest_lower_bound.is_constraint_set_assignable_to(db, least_upper_bound) { tracing::debug!( target: "ty_python_semantic::types::constraints::specialize_constrained", bound_typevar = %identity.display(db), diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index d4e6c7a446..c012ab09f6 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -13,16 +13,16 @@ use crate::types::class::ClassType; use crate::types::class_base::ClassBase; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::instance::{Protocol, ProtocolInstanceType}; -use crate::types::signatures::{Parameters, ParametersKind}; +use crate::types::signatures::Parameters; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; +use crate::types::variance::VarianceInferable; use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarIdentity, BoundTypeVarInstance, - CallableSignature, CallableType, CallableTypeKind, CallableTypes, ClassLiteral, - FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, - KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, Signature, Type, - TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarIdentity, - TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, declaration_type, + ClassLiteral, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IntersectionType, + IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, + NormalizedVisitor, Type, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, + TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, declaration_type, walk_type_var_bounds, }; use crate::{Db, FxOrderMap, FxOrderSet}; @@ -571,6 +571,14 @@ impl<'db> GenericContext<'db> { let partial = PartialSpecialization { generic_context: self, types: &types, + // Don't recursively substitute type[i] in itself. Ideally, we could instead + // check if the result is self-referential after we're done applying the + // partial specialization. But when we apply a paramspec, we don't use the + // callable that it maps to directly; we create a new callable that reuses + // parts of it. That means we can't look for the previous type directly. + // Instead we use this to skip specializing the type in itself in the first + // place. + skip: Some(i), }; let updated = types[i].apply_type_mapping( db, @@ -641,6 +649,7 @@ impl<'db> GenericContext<'db> { let partial = PartialSpecialization { generic_context: self, types: &expanded[0..idx], + skip: None, }; let default = default.apply_type_mapping( db, @@ -917,7 +926,11 @@ fn has_relation_in_invariant_position<'db>( disjointness_visitor, ), // And A <~ B (assignability) is Bottom[A] <: Top[B] - (None, Some(base_mat), TypeRelation::Assignability) => is_subtype_in_invariant_position( + ( + None, + Some(base_mat), + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability, + ) => is_subtype_in_invariant_position( db, derived_type, MaterializationKind::Bottom, @@ -927,7 +940,11 @@ fn has_relation_in_invariant_position<'db>( relation_visitor, disjointness_visitor, ), - (Some(derived_mat), None, TypeRelation::Assignability) => is_subtype_in_invariant_position( + ( + Some(derived_mat), + None, + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability, + ) => is_subtype_in_invariant_position( db, derived_type, derived_mat, @@ -1438,6 +1455,9 @@ impl<'db> Specialization<'db> { pub struct PartialSpecialization<'a, 'db> { generic_context: GenericContext<'db>, types: &'a [Type<'db>], + /// An optional typevar to _not_ substitute when applying the specialization. We use this to + /// avoid recursively substituting a type inside of itself. + skip: Option, } impl<'db> PartialSpecialization<'_, 'db> { @@ -1452,6 +1472,9 @@ impl<'db> PartialSpecialization<'_, 'db> { .generic_context .variables_inner(db) .get_index_of(&bound_typevar.identity(db))?; + if self.skip.is_some_and(|skip| skip == index) { + return Some(Type::Never); + } self.types.get(index).copied() } } @@ -1509,7 +1532,7 @@ impl<'db> SpecializationBuilder<'db> { .map(|(identity, _)| self.types.get(identity).copied()); // TODO Infer the tuple spec for a tuple type - generic_context.specialize_partial(self.db, types) + generic_context.specialize_recursive(self.db, types) } fn add_type_mapping( @@ -1543,6 +1566,80 @@ impl<'db> SpecializationBuilder<'db> { } } + /// Finds all of the valid specializations of a constraint set, and adds their type mappings to + /// the specialization that this builder is building up. + /// + /// `formal` should be the top-level formal parameter type that we are inferring. This is used + /// by our literal promotion logic, which needs to know which typevars are affected by each + /// argument, and the variance of those typevars in the corresponding parameter. + /// + /// TODO: This is a stopgap! Eventually, the builder will maintain a single constraint set for + /// the main specialization that we are building, and [`build`][Self::build] will build the + /// specialization directly from that constraint set. This method lets us migrate to that brave + /// new world incrementally, by using the new constraint set mechanism piecemeal for certain + /// type comparisons. + fn add_type_mappings_from_constraint_set( + &mut self, + formal: Type<'db>, + constraints: ConstraintSet<'db>, + mut f: impl FnMut(TypeVarAssignment<'db>) -> Option>, + ) { + #[derive(Default)] + struct Bounds<'db> { + lower: Vec>, + upper: Vec>, + } + + let constraints = constraints.limit_to_valid_specializations(self.db); + let mut sorted_paths = Vec::new(); + constraints.for_each_path(self.db, |path| { + let mut path: Vec<_> = path.positive_constraints().collect(); + path.sort_unstable_by_key(|(_, source_order)| *source_order); + sorted_paths.push(path); + }); + sorted_paths.sort_unstable_by(|path1, path2| { + let source_orders1 = path1.iter().map(|(_, source_order)| *source_order); + let source_orders2 = path2.iter().map(|(_, source_order)| *source_order); + source_orders1.cmp(source_orders2) + }); + + let mut mappings: FxHashMap, Bounds<'db>> = FxHashMap::default(); + for path in sorted_paths { + mappings.clear(); + for (constraint, _) in path { + let typevar = constraint.typevar(self.db); + let lower = constraint.lower(self.db); + let upper = constraint.upper(self.db); + let bounds = mappings.entry(typevar).or_default(); + bounds.lower.push(lower); + bounds.upper.push(upper); + + if let Type::TypeVar(lower_bound_typevar) = lower { + let bounds = mappings.entry(lower_bound_typevar).or_default(); + bounds.upper.push(Type::TypeVar(typevar)); + } + + if let Type::TypeVar(upper_bound_typevar) = upper { + let bounds = mappings.entry(upper_bound_typevar).or_default(); + bounds.lower.push(Type::TypeVar(typevar)); + } + } + + for (bound_typevar, bounds) in mappings.drain() { + let variance = formal.variance_of(self.db, bound_typevar); + let upper = IntersectionType::from_elements(self.db, bounds.upper); + if !upper.is_object() { + self.add_type_mapping(bound_typevar, upper, variance, &mut f); + continue; + } + let lower = UnionType::from_elements(self.db, bounds.lower); + if !lower.is_never() { + self.add_type_mapping(bound_typevar, lower, variance, &mut f); + } + } + } + } + /// Infer type mappings for the specialization based on a given type and its declared type. pub(crate) fn infer( &mut self, @@ -1572,6 +1669,15 @@ impl<'db> SpecializationBuilder<'db> { polarity: TypeVarVariance, mut f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, ) -> Result<(), SpecializationError<'db>> { + // TODO: Eventually, the builder will maintain a constraint set, instead of a hash-map of + // type mappings, to represent the specialization that we are building up. At that point, + // this method will just need to compare `actual ≤ formal`, using constraint set + // assignability, and AND the result into the constraint set we are building. + // + // To make progress on that migration, we use constraint set assignability whenever + // possible when adding any new heuristics here. See the `Callable` clause below for an + // example. + if formal == actual { return Ok(()); } @@ -1827,43 +1933,34 @@ impl<'db> SpecializationBuilder<'db> { } (Type::Callable(formal_callable), _) => { - if let Some(actual_callable) = actual - .try_upcast_to_callable(self.db) - .and_then(CallableTypes::exactly_one) - { - // We're only interested in a formal callable of the form `Callable[P, ...]` for - // now where `P` is a `ParamSpec`. - // TODO: This would need to be updated once we support `Concatenate` - // TODO: What to do for overloaded callables? - let [signature] = formal_callable.signatures(self.db).as_slice() else { - return Ok(()); - }; - let formal_parameters = signature.parameters(); - let ParametersKind::ParamSpec(typevar) = formal_parameters.kind() else { - return Ok(()); - }; - let paramspec_value = match actual_callable.signatures(self.db).as_slice() { - [] => return Ok(()), - [actual_signature] => match actual_signature.parameters().kind() { - ParametersKind::ParamSpec(typevar) => Type::TypeVar(typevar), - _ => Type::Callable(CallableType::new( + let Some(actual_callables) = actual.try_upcast_to_callable(self.db) else { + return Ok(()); + }; + + let formal_callable = formal_callable.signatures(self.db); + let formal_is_single_paramspec = formal_callable.is_single_paramspec().is_some(); + + for actual_callable in actual_callables.as_slice() { + if formal_is_single_paramspec { + let when = actual_callable + .signatures(self.db) + .when_constraint_set_assignable_to( self.db, - CallableSignature::single(Signature::new( - actual_signature.parameters().clone(), - None, - )), - CallableTypeKind::ParamSpecValue, - )), - }, - actual_signatures => Type::Callable(CallableType::new( - self.db, - CallableSignature::from_overloads(actual_signatures.iter().map( - |signature| Signature::new(signature.parameters().clone(), None), - )), - CallableTypeKind::ParamSpecValue, - )), - }; - self.add_type_mapping(typevar, paramspec_value, polarity, &mut f); + formal_callable, + self.inferable, + ); + self.add_type_mappings_from_constraint_set(formal, when, &mut f); + } else { + for actual_signature in &actual_callable.signatures(self.db).overloads { + let when = actual_signature + .when_constraint_set_assignable_to_signatures( + self.db, + formal_callable, + self.inferable, + ); + self.add_type_mappings_from_constraint_set(formal, when, &mut f); + } + } } } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index dd6fe02334..0a70337930 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -18,11 +18,13 @@ use smallvec::{SmallVec, smallvec_inline}; use super::{DynamicType, Type, TypeVarVariance, definition_expression_type, semantic_index}; use crate::semantic_index::definition::Definition; -use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; +use crate::types::constraints::{ + ConstraintSet, IteratorConstraintsExtension, OptionConstraintsExtension, +}; use crate::types::generics::{GenericContext, InferableTypeVars, walk_generic_context}; use crate::types::infer::{infer_deferred_types, infer_scope_types}; use crate::types::{ - ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, CallableTypeKind, + ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, CallableType, CallableTypeKind, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, ParamSpecAttrKind, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type, @@ -91,10 +93,6 @@ impl<'db> CallableSignature<'db> { self.overloads.iter() } - pub(crate) fn as_slice(&self) -> &[Signature<'db>] { - &self.overloads - } - pub(crate) fn with_inherited_generic_context( &self, db: &'db dyn Db, @@ -338,6 +336,44 @@ impl<'db> CallableSignature<'db> { ) } + pub(crate) fn is_single_paramspec( + &self, + ) -> Option<(BoundTypeVarInstance<'db>, Option>)> { + Self::signatures_is_single_paramspec(&self.overloads) + } + + /// Checks whether the given slice contains a single signature, and that signature is a + /// `ParamSpec` signature. If so, returns the [`BoundTypeVarInstance`] for the `ParamSpec`, + /// along with the return type of the signature. + fn signatures_is_single_paramspec( + signatures: &[Signature<'db>], + ) -> Option<(BoundTypeVarInstance<'db>, Option>)> { + // TODO: This might need updating once we support `Concatenate` + let [signature] = signatures else { + return None; + }; + signature + .parameters + .as_paramspec() + .map(|bound_typevar| (bound_typevar, signature.return_ty)) + } + + pub(crate) fn when_constraint_set_assignable_to( + &self, + db: &'db dyn Db, + other: &Self, + inferable: InferableTypeVars<'_, 'db>, + ) -> ConstraintSet<'db> { + self.has_relation_to_impl( + db, + other, + inferable, + TypeRelation::ConstraintSetAssignability, + &HasRelationToVisitor::default(), + &IsDisjointVisitor::default(), + ) + } + /// Implementation of subtyping and assignability between two, possible overloaded, callable /// types. fn has_relation_to_inner( @@ -349,6 +385,111 @@ impl<'db> CallableSignature<'db> { relation_visitor: &HasRelationToVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>, ) -> ConstraintSet<'db> { + if relation.is_constraint_set_assignability() { + // TODO: Oof, maybe ParamSpec needs to live at CallableSignature, not Signature? + let self_is_single_paramspec = Self::signatures_is_single_paramspec(self_signatures); + let other_is_single_paramspec = Self::signatures_is_single_paramspec(other_signatures); + + // If either callable is a ParamSpec, the constraint set should bind the ParamSpec to + // the other callable's signature. We also need to compare the return types — for + // instance, to verify in `Callable[P, int]` that the return type is assignable to + // `int`, or in `Callable[P, T]` to bind `T` to the return type of the other callable. + match (self_is_single_paramspec, other_is_single_paramspec) { + ( + Some((self_bound_typevar, self_return_type)), + Some((other_bound_typevar, other_return_type)), + ) => { + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self_bound_typevar, + Type::TypeVar(other_bound_typevar), + Type::TypeVar(other_bound_typevar), + ); + let return_types_match = self_return_type.zip(other_return_type).when_some_and( + |(self_return_type, other_return_type)| { + self_return_type.has_relation_to_impl( + db, + other_return_type, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + }, + ); + return param_spec_matches.and(db, || return_types_match); + } + + (Some((self_bound_typevar, self_return_type)), None) => { + let upper = + Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads(other_signatures.iter().map( + |signature| Signature::new(signature.parameters().clone(), None), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self_bound_typevar, + Type::Never, + upper, + ); + let return_types_match = self_return_type.when_some_and(|self_return_type| { + other_signatures + .iter() + .filter_map(|signature| signature.return_ty) + .when_any(db, |other_return_type| { + self_return_type.has_relation_to_impl( + db, + other_return_type, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + }) + }); + return param_spec_matches.and(db, || return_types_match); + } + + (None, Some((other_bound_typevar, other_return_type))) => { + let lower = + Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads(self_signatures.iter().map( + |signature| Signature::new(signature.parameters().clone(), None), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + other_bound_typevar, + lower, + Type::object(), + ); + let return_types_match = other_return_type.when_some_and(|other_return_type| { + self_signatures + .iter() + .filter_map(|signature| signature.return_ty) + .when_any(db, |self_return_type| { + self_return_type.has_relation_to_impl( + db, + other_return_type, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + }) + }); + return param_spec_matches.and(db, || return_types_match); + } + + (None, None) => {} + } + } + match (self_signatures, other_signatures) { ([self_signature], [other_signature]) => { // Base case: both callable types contain a single signature. @@ -955,6 +1096,65 @@ impl<'db> Signature<'db> { result } + pub(crate) fn when_constraint_set_assignable_to_signatures( + &self, + db: &'db dyn Db, + other: &CallableSignature<'db>, + inferable: InferableTypeVars<'_, 'db>, + ) -> ConstraintSet<'db> { + // If this signature is a paramspec, bind it to the entire overloaded other callable. + if let Some(self_bound_typevar) = self.parameters.as_paramspec() + && other.is_single_paramspec().is_none() + { + let upper = Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads( + other + .overloads + .iter() + .map(|signature| Signature::new(signature.parameters().clone(), None)), + ), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = + ConstraintSet::constrain_typevar(db, self_bound_typevar, Type::Never, upper); + let return_types_match = self.return_ty.when_some_and(|self_return_type| { + other + .overloads + .iter() + .filter_map(|signature| signature.return_ty) + .when_any(db, |other_return_type| { + self_return_type.when_constraint_set_assignable_to( + db, + other_return_type, + inferable, + ) + }) + }); + return param_spec_matches.and(db, || return_types_match); + } + + other.overloads.iter().when_all(db, |other_signature| { + self.when_constraint_set_assignable_to(db, other_signature, inferable) + }) + } + + fn when_constraint_set_assignable_to( + &self, + db: &'db dyn Db, + other: &Self, + inferable: InferableTypeVars<'_, 'db>, + ) -> ConstraintSet<'db> { + self.has_relation_to_impl( + db, + other, + inferable, + TypeRelation::ConstraintSetAssignability, + &HasRelationToVisitor::default(), + &IsDisjointVisitor::default(), + ) + } + /// Implementation of subtyping and assignability for signature. fn has_relation_to_impl( &self, @@ -1134,7 +1334,67 @@ impl<'db> Signature<'db> { // If either of the parameter lists is gradual (`...`), then it is assignable to and from // any other parameter list, but not a subtype or supertype of any other parameter list. if self.parameters.is_gradual() || other.parameters.is_gradual() { - return ConstraintSet::from(relation.is_assignability()); + result.intersect( + db, + ConstraintSet::from( + relation.is_assignability() || relation.is_constraint_set_assignability(), + ), + ); + return result; + } + + if relation.is_constraint_set_assignability() { + let self_is_paramspec = self.parameters.as_paramspec(); + let other_is_paramspec = other.parameters.as_paramspec(); + + // If either signature is a ParamSpec, the constraint set should bind the ParamSpec to + // the other signature. + match (self_is_paramspec, other_is_paramspec) { + (Some(self_bound_typevar), Some(other_bound_typevar)) => { + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self_bound_typevar, + Type::TypeVar(other_bound_typevar), + Type::TypeVar(other_bound_typevar), + ); + result.intersect(db, param_spec_matches); + return result; + } + + (Some(self_bound_typevar), None) => { + let upper = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new(other.parameters.clone(), None)), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self_bound_typevar, + Type::Never, + upper, + ); + result.intersect(db, param_spec_matches); + return result; + } + + (None, Some(other_bound_typevar)) => { + let lower = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new(self.parameters.clone(), None)), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + other_bound_typevar, + lower, + Type::object(), + ); + result.intersect(db, param_spec_matches); + return result; + } + + (None, None) => {} + } } let mut parameters = ParametersZip { @@ -1554,6 +1814,13 @@ impl<'db> Parameters<'db> { matches!(self.kind, ParametersKind::Gradual) } + pub(crate) const fn as_paramspec(&self) -> Option> { + match self.kind { + ParametersKind::ParamSpec(bound_typevar) => Some(bound_typevar), + _ => None, + } + } + /// Return todo parameters: (*args: Todo, **kwargs: Todo) pub(crate) fn todo() -> Self { Self {