[ty] Infer type variables within generic unions (#21862)

## Summary

This PR allows our generics solver to find a solution for `T` in cases
like the following:
```py
def extract_t[T](x: P[T] | Q[T]) -> T:
    raise NotImplementedError

reveal_type(extract_t(P[int]()))  # revealed: int
reveal_type(extract_t(Q[str]()))  # revealed: str
```

closes https://github.com/astral-sh/ty/issues/1772
closes https://github.com/astral-sh/ty/issues/1314

## Ecosystem

The impact here looks very good!

It took me a long time to figure this out, but the new diagnostics on
bokeh are actually true positives. I should have tested with another
type-checker immediately, I guess. All other type checkers also emit
errors on these `__init__` calls. MRE
[here](https://play.ty.dev/5c19d260-65e2-4f70-a75e-1a25780843a2) (no
error on main, diagnostic on this branch)

A lot of false positives on home-assistant go away for calls to
functions like
[`async_listen`](180053fe98/homeassistant/core.py (L1581-L1587))
which take a `event_type: EventType[_DataT] | str` parameter. We can now
solve for `_DataT` here, which was previously falling back to its
default value, and then caused problems because it was used as an
argument to an invariant generic class.

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-12-09 16:22:59 +01:00 committed by GitHub
parent c35bf8f441
commit aea2bc2308
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 321 additions and 46 deletions

View File

@ -107,44 +107,34 @@ We can also specify particular columns to select:
```py
stmt = select(User.id, User.name)
# TODO: should be `Select[tuple[int, str]]`
reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]]
reveal_type(stmt) # revealed: Select[tuple[int, str]]
ids_and_names = session.execute(stmt).all()
# TODO: should be `Sequence[Row[tuple[int, str]]]`
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]]
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]]
for row in session.execute(stmt):
# TODO: should be `Row[tuple[int, str]]`
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
reveal_type(row) # revealed: Row[tuple[int, str]]
for user_id, name in session.execute(stmt).tuples():
# TODO: should be `int`
reveal_type(user_id) # revealed: Unknown
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
reveal_type(user_id) # revealed: int
reveal_type(name) # revealed: str
result = session.execute(stmt)
row = result.one_or_none()
assert row is not None
(user_id, name) = row._tuple()
# TODO: should be `int`
reveal_type(user_id) # revealed: Unknown
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
reveal_type(user_id) # revealed: int
reveal_type(name) # revealed: str
stmt = select(User.id).where(User.name == "Alice")
# TODO: should be `Select[tuple[int]]`
reveal_type(stmt) # revealed: Select[tuple[Unknown]]
reveal_type(stmt) # revealed: Select[tuple[int]]
alice_id = session.scalars(stmt).first()
# TODO: should be `int | None`
reveal_type(alice_id) # revealed: Unknown | None
reveal_type(alice_id) # revealed: int | None
alice_id = session.scalar(stmt)
# TODO: should be `int | None`
reveal_type(alice_id) # revealed: Unknown | None
reveal_type(alice_id) # revealed: int | None
```
Using the legacy `query` API also works:
@ -166,15 +156,12 @@ And similarly when specifying particular columns:
```py
query = session.query(User.id, User.name)
# TODO: should be `RowReturningQuery[tuple[int, str]]`
reveal_type(query) # revealed: RowReturningQuery[tuple[Unknown, Unknown]]
reveal_type(query) # revealed: RowReturningQuery[tuple[int, str]]
# TODO: should be `list[Row[tuple[int, str]]]`
reveal_type(query.all()) # revealed: list[Row[tuple[Unknown, Unknown]]]
reveal_type(query.all()) # revealed: list[Row[tuple[int, str]]]
for row in query:
# TODO: should be `Row[tuple[int, str]]`
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
reveal_type(row) # revealed: Row[tuple[int, str]]
```
## Async API
@ -203,8 +190,6 @@ async def test_async(session: AsyncSession):
stmt = select(User.id, User.name)
result = await session.execute(stmt)
for user_id, name in result.tuples():
# TODO: should be `int`
reveal_type(user_id) # revealed: Unknown
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
reveal_type(user_id) # revealed: int
reveal_type(name) # revealed: str
```

View File

@ -347,6 +347,138 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
```
When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from
the actual argument even for non-final classes.
```py
from typing import TypeVar, Generic
T = TypeVar("T")
class P(Generic[T]):
x: T
class Q(Generic[T]):
x: T
def extract_t(x: P[T] | Q[T]) -> T:
raise NotImplementedError
reveal_type(extract_t(P[int]())) # revealed: int
reveal_type(extract_t(Q[str]())) # revealed: str
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_t([1, 2])) # revealed: Unknown
```
This also works when different union elements have different typevars:
```py
S = TypeVar("S")
def extract_both(x: P[T] | Q[S]) -> tuple[T, S]:
raise NotImplementedError
reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str]
```
Inference also works when passing subclasses of the generic classes in the union.
```py
class SubP(P[T]):
pass
class SubQ(Q[T]):
pass
reveal_type(extract_t(SubP[int]())) # revealed: int
reveal_type(extract_t(SubQ[str]())) # revealed: str
reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str]
```
When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a
single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer
both types in a call to `extract_both`:
```py
class PandQ(P[int], Q[str]):
pass
# TODO: Ideally, we would return `Unknown` here.
# error: [invalid-argument-type]
reveal_type(extract_t(PandQ())) # revealed: int | str
reveal_type(extract_both(PandQ())) # revealed: tuple[int, str]
```
When non-generic types are part of the union, we can still infer typevars for the remaining generic
types:
```py
def extract_optional_t(x: None | P[T]) -> T:
raise NotImplementedError
reveal_type(extract_optional_t(None)) # revealed: Unknown
reveal_type(extract_optional_t(P[int]())) # revealed: int
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
```
If the union contains contains parent and child of a generic class, we ideally pick the union
element that is more precise:
```py
class Base(Generic[T]):
x: T
class Sub(Base[T]): ...
def f(t: Base[T] | Sub[T | None]) -> T:
raise NotImplementedError
reveal_type(f(Base[int]())) # revealed: int
# TODO: Should ideally be `str`
reveal_type(f(Sub[str | None]())) # revealed: str | None
```
If we have a case like the following, where only one of the union elements matches due to the
typevar bound, we do not emit a specialization error:
```py
from typing import TypeVar
I_int = TypeVar("I_int", bound=int)
S_str = TypeVar("S_str", bound=str)
class P(Generic[T]):
value: T
def f(t: P[I_int] | P[S_str]) -> tuple[I_int, S_str]:
raise NotImplementedError
reveal_type(f(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(f(P[str]())) # revealed: tuple[Unknown, str]
```
However, if we pass something that does not match _any_ union element, we do emit an error:
```py
# error: [invalid-argument-type]
reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown]
```
## Inferring nested generic function calls
We can infer type assignments in nested calls to multiple generic functions. If they use the same

View File

@ -310,6 +310,127 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
```
When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from
the actual argument even for non-final classes.
```py
class P[T]:
x: T # invariant
class Q[T]:
x: T # invariant
def extract_t[T](x: P[T] | Q[T]) -> T:
raise NotImplementedError
reveal_type(extract_t(P[int]())) # revealed: int
reveal_type(extract_t(Q[str]())) # revealed: str
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_t([1, 2])) # revealed: Unknown
```
This also works when different union elements have different typevars:
```py
def extract_both[T, S](x: P[T] | Q[S]) -> tuple[T, S]:
raise NotImplementedError
reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str]
```
Inference also works when passing subclasses of the generic classes in the union.
```py
class SubP[T](P[T]):
pass
class SubQ[T](Q[T]):
pass
reveal_type(extract_t(SubP[int]())) # revealed: int
reveal_type(extract_t(SubQ[str]())) # revealed: str
reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown]
reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str]
```
When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a
single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer
both types in a call to `extract_both`:
```py
class PandQ(P[int], Q[str]):
pass
# TODO: Ideally, we would return `Unknown` here.
# error: [invalid-argument-type]
reveal_type(extract_t(PandQ())) # revealed: int | str
reveal_type(extract_both(PandQ())) # revealed: tuple[int, str]
```
When non-generic types are part of the union, we can still infer typevars for the remaining generic
types:
```py
def extract_optional_t[T](x: None | P[T]) -> T:
raise NotImplementedError
reveal_type(extract_optional_t(None)) # revealed: Unknown
reveal_type(extract_optional_t(P[int]())) # revealed: int
```
Passing anything else results in an error:
```py
# error: [invalid-argument-type]
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
```
If the union contains contains parent and child of a generic class, we ideally pick the union
element that is more precise:
```py
class Base[T]:
x: T
class Sub[T](Base[T]): ...
def f[T](t: Base[T] | Sub[T | None]) -> T:
raise NotImplementedError
reveal_type(f(Base[int]())) # revealed: int
# TODO: Should ideally be `str`
reveal_type(f(Sub[str | None]())) # revealed: str | None
```
If we have a case like the following, where only one of the union elements matches due to the
typevar bound, we do not emit a specialization error:
```py
class P[T]:
value: T
def f[I: int, S: str](t: P[I] | P[S]) -> tuple[I, S]:
raise NotImplementedError
reveal_type(f(P[int]())) # revealed: tuple[int, Unknown]
reveal_type(f(P[str]())) # revealed: tuple[Unknown, str]
```
However, if we pass something that does not match _any_ union element, we do emit an error:
```py
# error: [invalid-argument-type]
reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown]
```
## Inferring nested generic function calls
We can infer type assignments in nested calls to multiple generic functions. If they use the same

View File

@ -1545,22 +1545,20 @@ impl<'db> SpecializationBuilder<'db> {
}
self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f);
}
(Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element is assignable
// from the actual type, then we don't add any type mapping. This handles a case like
(Type::Union(union_formal), _) => {
// Second, if the formal is a union, and the actual type is assignable to precisely
// one union element, then we don't add any type mapping. This handles a case like
//
// ```py
// def f[T](t: T | None): ...
// def f[T](t: T | None) -> T: ...
//
// f(None)
// reveal_type(f(None)) # revealed: Unknown
// ```
//
// without specializing `T` to `None`.
//
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if !actual.is_never() {
let assignable_elements = (formal.elements(self.db).iter()).filter(|ty| {
let assignable_elements =
(union_formal.elements(self.db).iter()).filter(|ty| {
actual
.when_subtype_of(self.db, **ty, self.inferable)
.is_always_satisfied(self.db)
@ -1570,10 +1568,49 @@ impl<'db> SpecializationBuilder<'db> {
}
}
let bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
let mut bound_typevars =
(union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
let first_bound_typevar = bound_typevars.next();
let has_more_than_one_typevar = bound_typevars.next().is_some();
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if let Some(bound_typevar) = first_bound_typevar
&& !has_more_than_one_typevar
{
self.add_type_mapping(bound_typevar, actual, polarity, f);
return Ok(());
}
// TODO:
// Handling more than one bare typevar is something that we can't handle yet.
if has_more_than_one_typevar {
return Ok(());
}
// Finally, if there are no bare typevars, we try to infer type mappings by
// checking against each union element. This handles cases like
// ```py
// def f[T](t: P[T] | Q[T]) -> T: ...
//
// reveal_type(f(P[str]())) # revealed: str
// reveal_type(f(Q[int]())) # revealed: int
// ```
let mut first_error = None;
let mut found_matching_element = false;
for formal_element in union_formal.elements(self.db) {
if !formal_element.is_disjoint_from(self.db, actual) {
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
if let Err(err) = result {
first_error.get_or_insert(err);
} else {
found_matching_element = true;
}
}
}
if !found_matching_element && let Some(error) = first_error {
return Err(error);
}
}