From aea2bc23086e3036ef1820cba2204f9dec586729 Mon Sep 17 00:00:00 2001 From: David Peter Date: Tue, 9 Dec 2025 16:22:59 +0100 Subject: [PATCH] [ty] Infer type variables within generic unions (#21862) ## Summary This PR allows our generics solver to find a solution for `T` in cases like the following: ```py def extract_t[T](x: P[T] | Q[T]) -> T: raise NotImplementedError reveal_type(extract_t(P[int]())) # revealed: int reveal_type(extract_t(Q[str]())) # revealed: str ``` closes https://github.com/astral-sh/ty/issues/1772 closes https://github.com/astral-sh/ty/issues/1314 ## Ecosystem The impact here looks very good! It took me a long time to figure this out, but the new diagnostics on bokeh are actually true positives. I should have tested with another type-checker immediately, I guess. All other type checkers also emit errors on these `__init__` calls. MRE [here](https://play.ty.dev/5c19d260-65e2-4f70-a75e-1a25780843a2) (no error on main, diagnostic on this branch) A lot of false positives on home-assistant go away for calls to functions like [`async_listen`](https://github.com/home-assistant/core/blob/180053fe9859f2a201ed2c33375db5316b50b7b5/homeassistant/core.py#L1581-L1587) which take a `event_type: EventType[_DataT] | str` parameter. We can now solve for `_DataT` here, which was previously falling back to its default value, and then caused problems because it was used as an argument to an invariant generic class. ## Test Plan New Markdown tests --- .../resources/mdtest/external/sqlalchemy.md | 45 ++---- .../mdtest/generics/legacy/functions.md | 132 ++++++++++++++++++ .../mdtest/generics/pep695/functions.md | 121 ++++++++++++++++ .../ty_python_semantic/src/types/generics.rs | 69 ++++++--- 4 files changed, 321 insertions(+), 46 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md b/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md index 61e6668de1..43fff45058 100644 --- a/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md +++ b/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md @@ -107,44 +107,34 @@ We can also specify particular columns to select: ```py stmt = select(User.id, User.name) -# TODO: should be `Select[tuple[int, str]]` -reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]] +reveal_type(stmt) # revealed: Select[tuple[int, str]] ids_and_names = session.execute(stmt).all() -# TODO: should be `Sequence[Row[tuple[int, str]]]` -reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]] +reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]] for row in session.execute(stmt): - # TODO: should be `Row[tuple[int, str]]` - reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]] + reveal_type(row) # revealed: Row[tuple[int, str]] for user_id, name in session.execute(stmt).tuples(): - # TODO: should be `int` - reveal_type(user_id) # revealed: Unknown - # TODO: should be `str` - reveal_type(name) # revealed: Unknown + reveal_type(user_id) # revealed: int + reveal_type(name) # revealed: str result = session.execute(stmt) row = result.one_or_none() assert row is not None (user_id, name) = row._tuple() -# TODO: should be `int` -reveal_type(user_id) # revealed: Unknown -# TODO: should be `str` -reveal_type(name) # revealed: Unknown +reveal_type(user_id) # revealed: int +reveal_type(name) # revealed: str stmt = select(User.id).where(User.name == "Alice") -# TODO: should be `Select[tuple[int]]` -reveal_type(stmt) # revealed: Select[tuple[Unknown]] +reveal_type(stmt) # revealed: Select[tuple[int]] alice_id = session.scalars(stmt).first() -# TODO: should be `int | None` -reveal_type(alice_id) # revealed: Unknown | None +reveal_type(alice_id) # revealed: int | None alice_id = session.scalar(stmt) -# TODO: should be `int | None` -reveal_type(alice_id) # revealed: Unknown | None +reveal_type(alice_id) # revealed: int | None ``` Using the legacy `query` API also works: @@ -166,15 +156,12 @@ And similarly when specifying particular columns: ```py query = session.query(User.id, User.name) -# TODO: should be `RowReturningQuery[tuple[int, str]]` -reveal_type(query) # revealed: RowReturningQuery[tuple[Unknown, Unknown]] +reveal_type(query) # revealed: RowReturningQuery[tuple[int, str]] -# TODO: should be `list[Row[tuple[int, str]]]` -reveal_type(query.all()) # revealed: list[Row[tuple[Unknown, Unknown]]] +reveal_type(query.all()) # revealed: list[Row[tuple[int, str]]] for row in query: - # TODO: should be `Row[tuple[int, str]]` - reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]] + reveal_type(row) # revealed: Row[tuple[int, str]] ``` ## Async API @@ -203,8 +190,6 @@ async def test_async(session: AsyncSession): stmt = select(User.id, User.name) result = await session.execute(stmt) for user_id, name in result.tuples(): - # TODO: should be `int` - reveal_type(user_id) # revealed: Unknown - # TODO: should be `str` - reveal_type(name) # revealed: Unknown + reveal_type(user_id) # revealed: int + reveal_type(name) # revealed: str ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 2fce911026..6e89253bd0 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -347,6 +347,138 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]] ``` +When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from +the actual argument even for non-final classes. + +```py +from typing import TypeVar, Generic + +T = TypeVar("T") + +class P(Generic[T]): + x: T + +class Q(Generic[T]): + x: T + +def extract_t(x: P[T] | Q[T]) -> T: + raise NotImplementedError + +reveal_type(extract_t(P[int]())) # revealed: int +reveal_type(extract_t(Q[str]())) # revealed: str +``` + +Passing anything else results in an error: + +```py +# error: [invalid-argument-type] +reveal_type(extract_t([1, 2])) # revealed: Unknown +``` + +This also works when different union elements have different typevars: + +```py +S = TypeVar("S") + +def extract_both(x: P[T] | Q[S]) -> tuple[T, S]: + raise NotImplementedError + +reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown] +reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str] +``` + +Inference also works when passing subclasses of the generic classes in the union. + +```py +class SubP(P[T]): + pass + +class SubQ(Q[T]): + pass + +reveal_type(extract_t(SubP[int]())) # revealed: int +reveal_type(extract_t(SubQ[str]())) # revealed: str + +reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown] +reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str] +``` + +When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a +single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer +both types in a call to `extract_both`: + +```py +class PandQ(P[int], Q[str]): + pass + +# TODO: Ideally, we would return `Unknown` here. +# error: [invalid-argument-type] +reveal_type(extract_t(PandQ())) # revealed: int | str + +reveal_type(extract_both(PandQ())) # revealed: tuple[int, str] +``` + +When non-generic types are part of the union, we can still infer typevars for the remaining generic +types: + +```py +def extract_optional_t(x: None | P[T]) -> T: + raise NotImplementedError + +reveal_type(extract_optional_t(None)) # revealed: Unknown +reveal_type(extract_optional_t(P[int]())) # revealed: int +``` + +Passing anything else results in an error: + +```py +# error: [invalid-argument-type] +reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown +``` + +If the union contains contains parent and child of a generic class, we ideally pick the union +element that is more precise: + +```py +class Base(Generic[T]): + x: T + +class Sub(Base[T]): ... + +def f(t: Base[T] | Sub[T | None]) -> T: + raise NotImplementedError + +reveal_type(f(Base[int]())) # revealed: int +# TODO: Should ideally be `str` +reveal_type(f(Sub[str | None]())) # revealed: str | None +``` + +If we have a case like the following, where only one of the union elements matches due to the +typevar bound, we do not emit a specialization error: + +```py +from typing import TypeVar + +I_int = TypeVar("I_int", bound=int) +S_str = TypeVar("S_str", bound=str) + +class P(Generic[T]): + value: T + +def f(t: P[I_int] | P[S_str]) -> tuple[I_int, S_str]: + raise NotImplementedError + +reveal_type(f(P[int]())) # revealed: tuple[int, Unknown] +reveal_type(f(P[str]())) # revealed: tuple[Unknown, str] +``` + +However, if we pass something that does not match _any_ union element, we do emit an error: + +```py +# error: [invalid-argument-type] +reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown] +``` + ## Inferring nested generic function calls We can infer type assignments in nested calls to multiple generic functions. If they use the same diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 8af1b948ee..843bd60d21 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -310,6 +310,127 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]] ``` +When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from +the actual argument even for non-final classes. + +```py +class P[T]: + x: T # invariant + +class Q[T]: + x: T # invariant + +def extract_t[T](x: P[T] | Q[T]) -> T: + raise NotImplementedError + +reveal_type(extract_t(P[int]())) # revealed: int +reveal_type(extract_t(Q[str]())) # revealed: str +``` + +Passing anything else results in an error: + +```py +# error: [invalid-argument-type] +reveal_type(extract_t([1, 2])) # revealed: Unknown +``` + +This also works when different union elements have different typevars: + +```py +def extract_both[T, S](x: P[T] | Q[S]) -> tuple[T, S]: + raise NotImplementedError + +reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown] +reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str] +``` + +Inference also works when passing subclasses of the generic classes in the union. + +```py +class SubP[T](P[T]): + pass + +class SubQ[T](Q[T]): + pass + +reveal_type(extract_t(SubP[int]())) # revealed: int +reveal_type(extract_t(SubQ[str]())) # revealed: str + +reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown] +reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str] +``` + +When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a +single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer +both types in a call to `extract_both`: + +```py +class PandQ(P[int], Q[str]): + pass + +# TODO: Ideally, we would return `Unknown` here. +# error: [invalid-argument-type] +reveal_type(extract_t(PandQ())) # revealed: int | str + +reveal_type(extract_both(PandQ())) # revealed: tuple[int, str] +``` + +When non-generic types are part of the union, we can still infer typevars for the remaining generic +types: + +```py +def extract_optional_t[T](x: None | P[T]) -> T: + raise NotImplementedError + +reveal_type(extract_optional_t(None)) # revealed: Unknown +reveal_type(extract_optional_t(P[int]())) # revealed: int +``` + +Passing anything else results in an error: + +```py +# error: [invalid-argument-type] +reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown +``` + +If the union contains contains parent and child of a generic class, we ideally pick the union +element that is more precise: + +```py +class Base[T]: + x: T + +class Sub[T](Base[T]): ... + +def f[T](t: Base[T] | Sub[T | None]) -> T: + raise NotImplementedError + +reveal_type(f(Base[int]())) # revealed: int +# TODO: Should ideally be `str` +reveal_type(f(Sub[str | None]())) # revealed: str | None +``` + +If we have a case like the following, where only one of the union elements matches due to the +typevar bound, we do not emit a specialization error: + +```py +class P[T]: + value: T + +def f[I: int, S: str](t: P[I] | P[S]) -> tuple[I, S]: + raise NotImplementedError + +reveal_type(f(P[int]())) # revealed: tuple[int, Unknown] +reveal_type(f(P[str]())) # revealed: tuple[Unknown, str] +``` + +However, if we pass something that does not match _any_ union element, we do emit an error: + +```py +# error: [invalid-argument-type] +reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown] +``` + ## Inferring nested generic function calls We can infer type assignments in nested calls to multiple generic functions. If they use the same diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 7db5f7e7a2..32488389bb 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1545,35 +1545,72 @@ impl<'db> SpecializationBuilder<'db> { } self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f); } - (Type::Union(formal), _) => { - // Second, if the formal is a union, and precisely one union element is assignable - // from the actual type, then we don't add any type mapping. This handles a case like + (Type::Union(union_formal), _) => { + // Second, if the formal is a union, and the actual type is assignable to precisely + // one union element, then we don't add any type mapping. This handles a case like // // ```py - // def f[T](t: T | None): ... + // def f[T](t: T | None) -> T: ... // - // f(None) + // reveal_type(f(None)) # revealed: Unknown // ``` // // without specializing `T` to `None`. - // - // Otherwise, if precisely one union element _is_ a typevar (not _contains_ a - // typevar), then we add a mapping between that typevar and the actual type. if !actual.is_never() { - let assignable_elements = (formal.elements(self.db).iter()).filter(|ty| { - actual - .when_subtype_of(self.db, **ty, self.inferable) - .is_always_satisfied(self.db) - }); + let assignable_elements = + (union_formal.elements(self.db).iter()).filter(|ty| { + actual + .when_subtype_of(self.db, **ty, self.inferable) + .is_always_satisfied(self.db) + }); if assignable_elements.exactly_one().is_ok() { return Ok(()); } } - let bound_typevars = - (formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); - if let Ok(bound_typevar) = bound_typevars.exactly_one() { + let mut bound_typevars = + (union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); + + let first_bound_typevar = bound_typevars.next(); + let has_more_than_one_typevar = bound_typevars.next().is_some(); + + // Otherwise, if precisely one union element _is_ a typevar (not _contains_ a + // typevar), then we add a mapping between that typevar and the actual type. + if let Some(bound_typevar) = first_bound_typevar + && !has_more_than_one_typevar + { self.add_type_mapping(bound_typevar, actual, polarity, f); + return Ok(()); + } + + // TODO: + // Handling more than one bare typevar is something that we can't handle yet. + if has_more_than_one_typevar { + return Ok(()); + } + + // Finally, if there are no bare typevars, we try to infer type mappings by + // checking against each union element. This handles cases like + // ```py + // def f[T](t: P[T] | Q[T]) -> T: ... + // + // reveal_type(f(P[str]())) # revealed: str + // reveal_type(f(Q[int]())) # revealed: int + // ``` + let mut first_error = None; + let mut found_matching_element = false; + for formal_element in union_formal.elements(self.db) { + if !formal_element.is_disjoint_from(self.db, actual) { + let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f); + if let Err(err) = result { + first_error.get_or_insert(err); + } else { + found_matching_element = true; + } + } + } + if !found_matching_element && let Some(error) = first_error { + return Err(error); } }