mirror of https://github.com/astral-sh/ruff
[ty] Infer type variables within generic unions (#21862)
## Summary
This PR allows our generics solver to find a solution for `T` in cases
like the following:
```py
def extract_t[T](x: P[T] | Q[T]) -> T:
raise NotImplementedError
reveal_type(extract_t(P[int]())) # revealed: int
reveal_type(extract_t(Q[str]())) # revealed: str
```
closes https://github.com/astral-sh/ty/issues/1772
closes https://github.com/astral-sh/ty/issues/1314
## Ecosystem
The impact here looks very good!
It took me a long time to figure this out, but the new diagnostics on
bokeh are actually true positives. I should have tested with another
type-checker immediately, I guess. All other type checkers also emit
errors on these `__init__` calls. MRE
[here](https://play.ty.dev/5c19d260-65e2-4f70-a75e-1a25780843a2) (no
error on main, diagnostic on this branch)
A lot of false positives on home-assistant go away for calls to
functions like
[`async_listen`](180053fe98/homeassistant/core.py (L1581-L1587))
which take a `event_type: EventType[_DataT] | str` parameter. We can now
solve for `_DataT` here, which was previously falling back to its
default value, and then caused problems because it was used as an
argument to an invariant generic class.
## Test Plan
New Markdown tests
This commit is contained in:
parent
c35bf8f441
commit
aea2bc2308
|
|
@ -107,44 +107,34 @@ We can also specify particular columns to select:
|
|||
|
||||
```py
|
||||
stmt = select(User.id, User.name)
|
||||
# TODO: should be `Select[tuple[int, str]]`
|
||||
reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]]
|
||||
reveal_type(stmt) # revealed: Select[tuple[int, str]]
|
||||
|
||||
ids_and_names = session.execute(stmt).all()
|
||||
# TODO: should be `Sequence[Row[tuple[int, str]]]`
|
||||
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]]
|
||||
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]]
|
||||
|
||||
for row in session.execute(stmt):
|
||||
# TODO: should be `Row[tuple[int, str]]`
|
||||
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
|
||||
reveal_type(row) # revealed: Row[tuple[int, str]]
|
||||
|
||||
for user_id, name in session.execute(stmt).tuples():
|
||||
# TODO: should be `int`
|
||||
reveal_type(user_id) # revealed: Unknown
|
||||
# TODO: should be `str`
|
||||
reveal_type(name) # revealed: Unknown
|
||||
reveal_type(user_id) # revealed: int
|
||||
reveal_type(name) # revealed: str
|
||||
|
||||
result = session.execute(stmt)
|
||||
row = result.one_or_none()
|
||||
assert row is not None
|
||||
(user_id, name) = row._tuple()
|
||||
# TODO: should be `int`
|
||||
reveal_type(user_id) # revealed: Unknown
|
||||
# TODO: should be `str`
|
||||
reveal_type(name) # revealed: Unknown
|
||||
reveal_type(user_id) # revealed: int
|
||||
reveal_type(name) # revealed: str
|
||||
|
||||
stmt = select(User.id).where(User.name == "Alice")
|
||||
|
||||
# TODO: should be `Select[tuple[int]]`
|
||||
reveal_type(stmt) # revealed: Select[tuple[Unknown]]
|
||||
reveal_type(stmt) # revealed: Select[tuple[int]]
|
||||
|
||||
alice_id = session.scalars(stmt).first()
|
||||
# TODO: should be `int | None`
|
||||
reveal_type(alice_id) # revealed: Unknown | None
|
||||
reveal_type(alice_id) # revealed: int | None
|
||||
|
||||
alice_id = session.scalar(stmt)
|
||||
# TODO: should be `int | None`
|
||||
reveal_type(alice_id) # revealed: Unknown | None
|
||||
reveal_type(alice_id) # revealed: int | None
|
||||
```
|
||||
|
||||
Using the legacy `query` API also works:
|
||||
|
|
@ -166,15 +156,12 @@ And similarly when specifying particular columns:
|
|||
|
||||
```py
|
||||
query = session.query(User.id, User.name)
|
||||
# TODO: should be `RowReturningQuery[tuple[int, str]]`
|
||||
reveal_type(query) # revealed: RowReturningQuery[tuple[Unknown, Unknown]]
|
||||
reveal_type(query) # revealed: RowReturningQuery[tuple[int, str]]
|
||||
|
||||
# TODO: should be `list[Row[tuple[int, str]]]`
|
||||
reveal_type(query.all()) # revealed: list[Row[tuple[Unknown, Unknown]]]
|
||||
reveal_type(query.all()) # revealed: list[Row[tuple[int, str]]]
|
||||
|
||||
for row in query:
|
||||
# TODO: should be `Row[tuple[int, str]]`
|
||||
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
|
||||
reveal_type(row) # revealed: Row[tuple[int, str]]
|
||||
```
|
||||
|
||||
## Async API
|
||||
|
|
@ -203,8 +190,6 @@ async def test_async(session: AsyncSession):
|
|||
stmt = select(User.id, User.name)
|
||||
result = await session.execute(stmt)
|
||||
for user_id, name in result.tuples():
|
||||
# TODO: should be `int`
|
||||
reveal_type(user_id) # revealed: Unknown
|
||||
# TODO: should be `str`
|
||||
reveal_type(name) # revealed: Unknown
|
||||
reveal_type(user_id) # revealed: int
|
||||
reveal_type(name) # revealed: str
|
||||
```
|
||||
|
|
|
|||
|
|
@ -347,6 +347,138 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal
|
|||
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
|
||||
```
|
||||
|
||||
When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from
|
||||
the actual argument even for non-final classes.
|
||||
|
||||
```py
|
||||
from typing import TypeVar, Generic
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class P(Generic[T]):
|
||||
x: T
|
||||
|
||||
class Q(Generic[T]):
|
||||
x: T
|
||||
|
||||
def extract_t(x: P[T] | Q[T]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(extract_t(P[int]())) # revealed: int
|
||||
reveal_type(extract_t(Q[str]())) # revealed: str
|
||||
```
|
||||
|
||||
Passing anything else results in an error:
|
||||
|
||||
```py
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(extract_t([1, 2])) # revealed: Unknown
|
||||
```
|
||||
|
||||
This also works when different union elements have different typevars:
|
||||
|
||||
```py
|
||||
S = TypeVar("S")
|
||||
|
||||
def extract_both(x: P[T] | Q[S]) -> tuple[T, S]:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown]
|
||||
reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str]
|
||||
```
|
||||
|
||||
Inference also works when passing subclasses of the generic classes in the union.
|
||||
|
||||
```py
|
||||
class SubP(P[T]):
|
||||
pass
|
||||
|
||||
class SubQ(Q[T]):
|
||||
pass
|
||||
|
||||
reveal_type(extract_t(SubP[int]())) # revealed: int
|
||||
reveal_type(extract_t(SubQ[str]())) # revealed: str
|
||||
|
||||
reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown]
|
||||
reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str]
|
||||
```
|
||||
|
||||
When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a
|
||||
single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer
|
||||
both types in a call to `extract_both`:
|
||||
|
||||
```py
|
||||
class PandQ(P[int], Q[str]):
|
||||
pass
|
||||
|
||||
# TODO: Ideally, we would return `Unknown` here.
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(extract_t(PandQ())) # revealed: int | str
|
||||
|
||||
reveal_type(extract_both(PandQ())) # revealed: tuple[int, str]
|
||||
```
|
||||
|
||||
When non-generic types are part of the union, we can still infer typevars for the remaining generic
|
||||
types:
|
||||
|
||||
```py
|
||||
def extract_optional_t(x: None | P[T]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(extract_optional_t(None)) # revealed: Unknown
|
||||
reveal_type(extract_optional_t(P[int]())) # revealed: int
|
||||
```
|
||||
|
||||
Passing anything else results in an error:
|
||||
|
||||
```py
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
|
||||
```
|
||||
|
||||
If the union contains contains parent and child of a generic class, we ideally pick the union
|
||||
element that is more precise:
|
||||
|
||||
```py
|
||||
class Base(Generic[T]):
|
||||
x: T
|
||||
|
||||
class Sub(Base[T]): ...
|
||||
|
||||
def f(t: Base[T] | Sub[T | None]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(f(Base[int]())) # revealed: int
|
||||
# TODO: Should ideally be `str`
|
||||
reveal_type(f(Sub[str | None]())) # revealed: str | None
|
||||
```
|
||||
|
||||
If we have a case like the following, where only one of the union elements matches due to the
|
||||
typevar bound, we do not emit a specialization error:
|
||||
|
||||
```py
|
||||
from typing import TypeVar
|
||||
|
||||
I_int = TypeVar("I_int", bound=int)
|
||||
S_str = TypeVar("S_str", bound=str)
|
||||
|
||||
class P(Generic[T]):
|
||||
value: T
|
||||
|
||||
def f(t: P[I_int] | P[S_str]) -> tuple[I_int, S_str]:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(f(P[int]())) # revealed: tuple[int, Unknown]
|
||||
reveal_type(f(P[str]())) # revealed: tuple[Unknown, str]
|
||||
```
|
||||
|
||||
However, if we pass something that does not match _any_ union element, we do emit an error:
|
||||
|
||||
```py
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown]
|
||||
```
|
||||
|
||||
## Inferring nested generic function calls
|
||||
|
||||
We can infer type assignments in nested calls to multiple generic functions. If they use the same
|
||||
|
|
|
|||
|
|
@ -310,6 +310,127 @@ reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal
|
|||
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
|
||||
```
|
||||
|
||||
When a union parameter contains generic classes like `P[T] | Q[T]`, we can infer the typevar from
|
||||
the actual argument even for non-final classes.
|
||||
|
||||
```py
|
||||
class P[T]:
|
||||
x: T # invariant
|
||||
|
||||
class Q[T]:
|
||||
x: T # invariant
|
||||
|
||||
def extract_t[T](x: P[T] | Q[T]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(extract_t(P[int]())) # revealed: int
|
||||
reveal_type(extract_t(Q[str]())) # revealed: str
|
||||
```
|
||||
|
||||
Passing anything else results in an error:
|
||||
|
||||
```py
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(extract_t([1, 2])) # revealed: Unknown
|
||||
```
|
||||
|
||||
This also works when different union elements have different typevars:
|
||||
|
||||
```py
|
||||
def extract_both[T, S](x: P[T] | Q[S]) -> tuple[T, S]:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(extract_both(P[int]())) # revealed: tuple[int, Unknown]
|
||||
reveal_type(extract_both(Q[str]())) # revealed: tuple[Unknown, str]
|
||||
```
|
||||
|
||||
Inference also works when passing subclasses of the generic classes in the union.
|
||||
|
||||
```py
|
||||
class SubP[T](P[T]):
|
||||
pass
|
||||
|
||||
class SubQ[T](Q[T]):
|
||||
pass
|
||||
|
||||
reveal_type(extract_t(SubP[int]())) # revealed: int
|
||||
reveal_type(extract_t(SubQ[str]())) # revealed: str
|
||||
|
||||
reveal_type(extract_both(SubP[int]())) # revealed: tuple[int, Unknown]
|
||||
reveal_type(extract_both(SubQ[str]())) # revealed: tuple[Unknown, str]
|
||||
```
|
||||
|
||||
When a type is a subclass of both `P` and `Q` with different specializations, we cannot infer a
|
||||
single type for `T` in `extract_t`, because `P` and `Q` are invariant. However, we can still infer
|
||||
both types in a call to `extract_both`:
|
||||
|
||||
```py
|
||||
class PandQ(P[int], Q[str]):
|
||||
pass
|
||||
|
||||
# TODO: Ideally, we would return `Unknown` here.
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(extract_t(PandQ())) # revealed: int | str
|
||||
|
||||
reveal_type(extract_both(PandQ())) # revealed: tuple[int, str]
|
||||
```
|
||||
|
||||
When non-generic types are part of the union, we can still infer typevars for the remaining generic
|
||||
types:
|
||||
|
||||
```py
|
||||
def extract_optional_t[T](x: None | P[T]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(extract_optional_t(None)) # revealed: Unknown
|
||||
reveal_type(extract_optional_t(P[int]())) # revealed: int
|
||||
```
|
||||
|
||||
Passing anything else results in an error:
|
||||
|
||||
```py
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
|
||||
```
|
||||
|
||||
If the union contains contains parent and child of a generic class, we ideally pick the union
|
||||
element that is more precise:
|
||||
|
||||
```py
|
||||
class Base[T]:
|
||||
x: T
|
||||
|
||||
class Sub[T](Base[T]): ...
|
||||
|
||||
def f[T](t: Base[T] | Sub[T | None]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(f(Base[int]())) # revealed: int
|
||||
# TODO: Should ideally be `str`
|
||||
reveal_type(f(Sub[str | None]())) # revealed: str | None
|
||||
```
|
||||
|
||||
If we have a case like the following, where only one of the union elements matches due to the
|
||||
typevar bound, we do not emit a specialization error:
|
||||
|
||||
```py
|
||||
class P[T]:
|
||||
value: T
|
||||
|
||||
def f[I: int, S: str](t: P[I] | P[S]) -> tuple[I, S]:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(f(P[int]())) # revealed: tuple[int, Unknown]
|
||||
reveal_type(f(P[str]())) # revealed: tuple[Unknown, str]
|
||||
```
|
||||
|
||||
However, if we pass something that does not match _any_ union element, we do emit an error:
|
||||
|
||||
```py
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(f(P[bytes]())) # revealed: tuple[Unknown, Unknown]
|
||||
```
|
||||
|
||||
## Inferring nested generic function calls
|
||||
|
||||
We can infer type assignments in nested calls to multiple generic functions. If they use the same
|
||||
|
|
|
|||
|
|
@ -1545,35 +1545,72 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
}
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f);
|
||||
}
|
||||
(Type::Union(formal), _) => {
|
||||
// Second, if the formal is a union, and precisely one union element is assignable
|
||||
// from the actual type, then we don't add any type mapping. This handles a case like
|
||||
(Type::Union(union_formal), _) => {
|
||||
// Second, if the formal is a union, and the actual type is assignable to precisely
|
||||
// one union element, then we don't add any type mapping. This handles a case like
|
||||
//
|
||||
// ```py
|
||||
// def f[T](t: T | None): ...
|
||||
// def f[T](t: T | None) -> T: ...
|
||||
//
|
||||
// f(None)
|
||||
// reveal_type(f(None)) # revealed: Unknown
|
||||
// ```
|
||||
//
|
||||
// without specializing `T` to `None`.
|
||||
//
|
||||
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
|
||||
// typevar), then we add a mapping between that typevar and the actual type.
|
||||
if !actual.is_never() {
|
||||
let assignable_elements = (formal.elements(self.db).iter()).filter(|ty| {
|
||||
actual
|
||||
.when_subtype_of(self.db, **ty, self.inferable)
|
||||
.is_always_satisfied(self.db)
|
||||
});
|
||||
let assignable_elements =
|
||||
(union_formal.elements(self.db).iter()).filter(|ty| {
|
||||
actual
|
||||
.when_subtype_of(self.db, **ty, self.inferable)
|
||||
.is_always_satisfied(self.db)
|
||||
});
|
||||
if assignable_elements.exactly_one().is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let bound_typevars =
|
||||
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
|
||||
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
|
||||
let mut bound_typevars =
|
||||
(union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
|
||||
|
||||
let first_bound_typevar = bound_typevars.next();
|
||||
let has_more_than_one_typevar = bound_typevars.next().is_some();
|
||||
|
||||
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
|
||||
// typevar), then we add a mapping between that typevar and the actual type.
|
||||
if let Some(bound_typevar) = first_bound_typevar
|
||||
&& !has_more_than_one_typevar
|
||||
{
|
||||
self.add_type_mapping(bound_typevar, actual, polarity, f);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Handling more than one bare typevar is something that we can't handle yet.
|
||||
if has_more_than_one_typevar {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Finally, if there are no bare typevars, we try to infer type mappings by
|
||||
// checking against each union element. This handles cases like
|
||||
// ```py
|
||||
// def f[T](t: P[T] | Q[T]) -> T: ...
|
||||
//
|
||||
// reveal_type(f(P[str]())) # revealed: str
|
||||
// reveal_type(f(Q[int]())) # revealed: int
|
||||
// ```
|
||||
let mut first_error = None;
|
||||
let mut found_matching_element = false;
|
||||
for formal_element in union_formal.elements(self.db) {
|
||||
if !formal_element.is_disjoint_from(self.db, actual) {
|
||||
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
|
||||
if let Err(err) = result {
|
||||
first_error.get_or_insert(err);
|
||||
} else {
|
||||
found_matching_element = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found_matching_element && let Some(error) = first_error {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue