[ty] Infer type variables within generic unions (#21862)

## Summary

This PR allows our generics solver to find a solution for `T` in cases
like the following:
```py
def extract_t[T](x: P[T] | Q[T]) -> T:
    raise NotImplementedError

reveal_type(extract_t(P[int]()))  # revealed: int
reveal_type(extract_t(Q[str]()))  # revealed: str
```

closes https://github.com/astral-sh/ty/issues/1772
closes https://github.com/astral-sh/ty/issues/1314

## Ecosystem

The impact here looks very good!

It took me a long time to figure this out, but the new diagnostics on
bokeh are actually true positives. I should have tested with another
type-checker immediately, I guess. All other type checkers also emit
errors on these `__init__` calls. MRE
[here](https://play.ty.dev/5c19d260-65e2-4f70-a75e-1a25780843a2) (no
error on main, diagnostic on this branch)

A lot of false positives on home-assistant go away for calls to
functions like
[`async_listen`](180053fe98/homeassistant/core.py (L1581-L1587))
which take a `event_type: EventType[_DataT] | str` parameter. We can now
solve for `_DataT` here, which was previously falling back to its
default value, and then caused problems because it was used as an
argument to an invariant generic class.

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-12-09 16:22:59 +01:00 committed by GitHub
parent c35bf8f441
commit aea2bc2308
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 321 additions and 46 deletions

View File

@ -107,44 +107,34 @@ We can also specify particular columns to select:
```py ```py
stmt = select(User.id, User.name) stmt = select(User.id, User.name)
# TODO: should be `Select[tuple[int, str]]` reveal_type(stmt) # revealed: Select[tuple[int, str]]
reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]]
ids_and_names = session.execute(stmt).all() ids_and_names = session.execute(stmt).all()
# TODO: should be `Sequence[Row[tuple[int, str]]]` reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]]
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]]
for row in session.execute(stmt): for row in session.execute(stmt):
# TODO: should be `Row[tuple[int, str]]` reveal_type(row) # revealed: Row[tuple[int, str]]
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
for user_id, name in session.execute(stmt).tuples(): for user_id, name in session.execute(stmt).tuples():
# TODO: should be `int` reveal_type(user_id) # revealed: int
reveal_type(user_id) # revealed: Unknown reveal_type(name) # revealed: str
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
result = session.execute(stmt) result = session.execute(stmt)
row = result.one_or_none() row = result.one_or_none()
assert row is not None assert row is not None
(user_id, name) = row._tuple() (user_id, name) = row._tuple()
# TODO: should be `int` reveal_type(user_id) # revealed: int
reveal_type(user_id) # revealed: Unknown reveal_type(name) # revealed: str
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
stmt = select(User.id).where(User.name == "Alice") stmt = select(User.id).where(User.name == "Alice")
# TODO: should be `Select[tuple[int]]` reveal_type(stmt) # revealed: Select[tuple[int]]
reveal_type(stmt) # revealed: Select[tuple[Unknown]]
alice_id = session.scalars(stmt).first() alice_id = session.scalars(stmt).first()
# TODO: should be `int | None` reveal_type(alice_id) # revealed: int | None
reveal_type(alice_id) # revealed: Unknown | None
alice_id = session.scalar(stmt) alice_id = session.scalar(stmt)
# TODO: should be `int | None` reveal_type(alice_id) # revealed: int | None
reveal_type(alice_id) # revealed: Unknown | None
``` ```
Using the legacy `query` API also works: Using the legacy `query` API also works:
@ -166,15 +156,12 @@ And similarly when specifying particular columns:
```py ```py
query = session.query(User.id, User.name) query = session.query(User.id, User.name)
# TODO: should be `RowReturningQuery[tuple[int, str]]` reveal_type(query) # revealed: RowReturningQuery[tuple[int, str]]
reveal_type(query) # revealed: RowReturningQuery[tuple[Unknown, Unknown]]
# TODO: should be `list[Row[tuple[int, str]]]` reveal_type(query.all()) # revealed: list[Row[tuple[int, str]]]
reveal_type(query.all()) # revealed: list[Row[tuple[Unknown, Unknown]]]
for row in query: for row in query:
# TODO: should be `Row[tuple[int, str]]` reveal_type(row) # revealed: Row[tuple[int, str]]
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
``` ```
## Async API ## Async API
@ -203,8 +190,6 @@ async def test_async(session: AsyncSession):
stmt = select(User.id, User.name) stmt = select(User.id, User.name)
result = await session.execute(stmt) result = await session.execute(stmt)
for user_id, name in result.tuples(): for user_id, name in result.tuples():
# TODO: should be `int` reveal_type(user_id) # revealed: int
reveal_type(user_id) # revealed: Unknown reveal_type(name) # revealed: str
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
``` ```

View File

@ -347,6 +347,138 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]] reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
``` ```
When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from
the actual argument even for non-final classes.
```py
from typing import TypeVar, Generic
T = TypeVar("T")
class P(Generic[T]):
x: T
class Q(Generic[T]):
x: T
def extract_t(x: P[T] | Q[T]) -> T:
raise NotImplementedError
reveal_type(extract_t(P[int]())) # revealed: int
reveal_type(extract_t(Q[str]())) # revealed: str
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_t([1, 2])) # revealed: Unknown
```
This also works when different union elements have different typevars:
```py
S = TypeVar("S")
def extract_both(x: P[T] | Q[S]) -> tuple[T, S]:
raise NotImplementedError
reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str]
```
Inference also works when passing subclasses of the generic classes in the union.
```py
class SubP(P[T]):
pass
class SubQ(Q[T]):
pass
reveal_type(extract_t(SubP[int]())) # revealed: int
reveal_type(extract_t(SubQ[str]())) # revealed: str
reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str]
```
When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a
single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer
both types in a call to `extract_both`:
```py
class PandQ(P[int], Q[str]):
pass
# TODO: Ideally, we would return `Unknown` here.
# error: [invalid-argument-type]
reveal_type(extract_t(PandQ())) # revealed: int | str
reveal_type(extract_both(PandQ())) # revealed: tuple[int, str]
```
When non-generic types are part of the union, we can still infer typevars for the remaining generic
types:
```py
def extract_optional_t(x: None | P[T]) -> T:
raise NotImplementedError
reveal_type(extract_optional_t(None)) # revealed: Unknown
reveal_type(extract_optional_t(P[int]())) # revealed: int
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
```
If the union contains contains parent and child of a generic class, we ideally pick the union
element that is more precise:
```py
class Base(Generic[T]):
x: T
class Sub(Base[T]): ...
def f(t: Base[T] | Sub[T | None]) -> T:
raise NotImplementedError
reveal_type(f(Base[int]())) # revealed: int
# TODO: Should ideally be `str`
reveal_type(f(Sub[str | None]())) # revealed: str | None
```
If we have a case like the following, where only one of the union elements matches due to the
typevar bound, we do not emit a specialization error:
```py
from typing import TypeVar
I_int = TypeVar("I_int", bound=int)
S_str = TypeVar("S_str", bound=str)
class P(Generic[T]):
value: T
def f(t: P[I_int] | P[S_str]) -> tuple[I_int, S_str]:
raise NotImplementedError
reveal_type(f(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(f(P[str]())) # revealed: tuple[Unknown, str]
```
However, if we pass something that does not match _any_ union element, we do emit an error:
```py
# error: [invalid-argument-type]
reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown]
```
## Inferring nested generic function calls ## Inferring nested generic function calls
We can infer type assignments in nested calls to multiple generic functions. If they use the same We can infer type assignments in nested calls to multiple generic functions. If they use the same

View File

@ -310,6 +310,127 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]] reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
``` ```
When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from
the actual argument even for non-final classes.
```py
class P[T]:
x: T # invariant
class Q[T]:
x: T # invariant
def extract_t[T](x: P[T] | Q[T]) -> T:
raise NotImplementedError
reveal_type(extract_t(P[int]())) # revealed: int
reveal_type(extract_t(Q[str]())) # revealed: str
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_t([1, 2])) # revealed: Unknown
```
This also works when different union elements have different typevars:
```py
def extract_both[T, S](x: P[T] | Q[S]) -> tuple[T, S]:
raise NotImplementedError
reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str]
```
Inference also works when passing subclasses of the generic classes in the union.
```py
class SubP[T](P[T]):
pass
class SubQ[T](Q[T]):
pass
reveal_type(extract_t(SubP[int]())) # revealed: int
reveal_type(extract_t(SubQ[str]())) # revealed: str
reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str]
```
When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a
single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer
both types in a call to `extract_both`:
```py
class PandQ(P[int], Q[str]):
pass
# TODO: Ideally, we would return `Unknown` here.
# error: [invalid-argument-type]
reveal_type(extract_t(PandQ())) # revealed: int | str
reveal_type(extract_both(PandQ())) # revealed: tuple[int, str]
```
When non-generic types are part of the union, we can still infer typevars for the remaining generic
types:
```py
def extract_optional_t[T](x: None | P[T]) -> T:
raise NotImplementedError
reveal_type(extract_optional_t(None)) # revealed: Unknown
reveal_type(extract_optional_t(P[int]())) # revealed: int
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
```
If the union contains contains parent and child of a generic class, we ideally pick the union
element that is more precise:
```py
class Base[T]:
x: T
class Sub[T](Base[T]): ...
def f[T](t: Base[T] | Sub[T | None]) -> T:
raise NotImplementedError
reveal_type(f(Base[int]())) # revealed: int
# TODO: Should ideally be `str`
reveal_type(f(Sub[str | None]())) # revealed: str | None
```
If we have a case like the following, where only one of the union elements matches due to the
typevar bound, we do not emit a specialization error:
```py
class P[T]:
value: T
def f[I: int, S: str](t: P[I] | P[S]) -> tuple[I, S]:
raise NotImplementedError
reveal_type(f(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(f(P[str]())) # revealed: tuple[Unknown, str]
```
However, if we pass something that does not match _any_ union element, we do emit an error:
```py
# error: [invalid-argument-type]
reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown]
```
## Inferring nested generic function calls ## Inferring nested generic function calls
We can infer type assignments in nested calls to multiple generic functions. If they use the same We can infer type assignments in nested calls to multiple generic functions. If they use the same

View File

@ -1545,35 +1545,72 @@ impl<'db> SpecializationBuilder<'db> {
} }
self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f); self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f);
} }
(Type::Union(formal), _) => { (Type::Union(union_formal), _) => {
// Second, if the formal is a union, and precisely one union element is assignable // Second, if the formal is a union, and the actual type is assignable to precisely
// from the actual type, then we don't add any type mapping. This handles a case like // one union element, then we don't add any type mapping. This handles a case like
// //
// ```py // ```py
// def f[T](t: T | None): ... // def f[T](t: T | None) -> T: ...
// //
// f(None) // reveal_type(f(None)) # revealed: Unknown
// ``` // ```
// //
// without specializing `T` to `None`. // without specializing `T` to `None`.
//
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if !actual.is_never() { if !actual.is_never() {
let assignable_elements = (formal.elements(self.db).iter()).filter(|ty| { let assignable_elements =
actual (union_formal.elements(self.db).iter()).filter(|ty| {
.when_subtype_of(self.db, **ty, self.inferable) actual
.is_always_satisfied(self.db) .when_subtype_of(self.db, **ty, self.inferable)
}); .is_always_satisfied(self.db)
});
if assignable_elements.exactly_one().is_ok() { if assignable_elements.exactly_one().is_ok() {
return Ok(()); return Ok(());
} }
} }
let bound_typevars = let mut bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); (union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
let first_bound_typevar = bound_typevars.next();
let has_more_than_one_typevar = bound_typevars.next().is_some();
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if let Some(bound_typevar) = first_bound_typevar
&& !has_more_than_one_typevar
{
self.add_type_mapping(bound_typevar, actual, polarity, f); self.add_type_mapping(bound_typevar, actual, polarity, f);
return Ok(());
}
// TODO:
// Handling more than one bare typevar is something that we can't handle yet.
if has_more_than_one_typevar {
return Ok(());
}
// Finally, if there are no bare typevars, we try to infer type mappings by
// checking against each union element. This handles cases like
// ```py
// def f[T](t: P[T] | Q[T]) -> T: ...
//
// reveal_type(f(P[str]())) # revealed: str
// reveal_type(f(Q[int]())) # revealed: int
// ```
let mut first_error = None;
let mut found_matching_element = false;
for formal_element in union_formal.elements(self.db) {
if !formal_element.is_disjoint_from(self.db, actual) {
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
if let Err(err) = result {
first_error.get_or_insert(err);
} else {
found_matching_element = true;
}
}
}
if !found_matching_element && let Some(error) = first_error {
return Err(error);
} }
} }