[ty] Generics: Respect typevar bounds when matching against a union (#21893)

## Summary

Respect typevar bounds and constraints when matching against a union.
For example:

```py
def accepts_t_or_int[T_str: str](x: T_str | int) -> T_str:
    raise NotImplementedError

reveal_type(accepts_t_or_int("a"))  # ok, reveals `Literal["a"]`
reveal_type(accepts_t_or_int(1))  # ok, reveals `Unknown`

class Unrelated: ...

# error: [invalid-argument-type] "Argument type `Unrelated` does not
# satisfy upper bound `str` of type variable `T_str`"
accepts_t_or_int(Unrelated())
```

Previously, the last call succeed without any errors. Worse than that,
we also incorrectly solved `T_str = Unrelated`, which often lead to
downstream errors.

closes https://github.com/astral-sh/ty/issues/1837

## Ecosystem impact

Looks good!

* Lots of removed false positives, often because we previously selected
a wrong overload for a generic function (because we didn't respect the
typevar bound in an earlier overload).
* We now understand calls to functions accepting an argument of type
`GenericPath: TypeAlias = AnyStr | PathLike[AnyStr]`. Previously, we
would incorrectly match a `Path` argument against the `AnyStr` typevar
(violating its constraints), but now we match against `PathLike`.

## Performance

Another regression on `colour`. This package uses `numpy` heavily. And
`numpy` is the codebase that originally lead me to this bug. The fix
here allows us to infer more precise `np.array` types in some cases, so
it's reasonable that we just need to perform more work.

The fix here also requires us to look at more union elements when we
would previously short-circuit incorrectly, so some more work needs to
be done in the solver.

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-12-10 14:58:57 +01:00 committed by GitHub
parent ff7086d9ad
commit 7bf50e70a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 82 additions and 18 deletions

View File

@ -337,6 +337,44 @@ reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
```
This also works if the typevar has a bound:
```py
T_str = TypeVar("T_str", bound=str)
def accepts_t_or_int(x: T_str | int) -> T_str:
raise NotImplementedError
reveal_type(accepts_t_or_int("a")) # revealed: Literal["a"]
reveal_type(accepts_t_or_int(1)) # revealed: Unknown
class Unrelated: ...
# error: [invalid-argument-type] "Argument type `Unrelated` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_int(Unrelated())) # revealed: Unknown
```
```py
T_str = TypeVar("T_str", bound=str)
def accepts_t_or_list_of_t(x: T_str | list[T_str]) -> T_str:
raise NotImplementedError
reveal_type(accepts_t_or_list_of_t("a")) # revealed: Literal["a"]
# error: [invalid-argument-type] "Argument type `Literal[1]` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_list_of_t(1)) # revealed: Unknown
def _(list_ofstr: list[str], list_of_int: list[int]):
reveal_type(accepts_t_or_list_of_t(list_ofstr)) # revealed: str
# TODO: the error message here could be improved by referring to the second union element
# error: [invalid-argument-type] "Argument type `list[int]` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_list_of_t(list_of_int)) # revealed: Unknown
```
Here, we make sure that `S` is solved as `Literal[1]` instead of a union of the two literals, which
would also be a valid solution:
```py
S = TypeVar("S")

View File

@ -302,6 +302,38 @@ reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
```
This also works if the typevar has a bound:
```py
def accepts_t_or_int[T_str: str](x: T_str | int) -> T_str:
raise NotImplementedError
reveal_type(accepts_t_or_int("a")) # revealed: Literal["a"]
reveal_type(accepts_t_or_int(1)) # revealed: Unknown
class Unrelated: ...
# error: [invalid-argument-type] "Argument type `Unrelated` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_int(Unrelated())) # revealed: Unknown
def accepts_t_or_list_of_t[T: str](x: T | list[T]) -> T:
raise NotImplementedError
reveal_type(accepts_t_or_list_of_t("a")) # revealed: Literal["a"]
# error: [invalid-argument-type] "Argument type `Literal[1]` does not satisfy upper bound `str` of type variable `T`"
reveal_type(accepts_t_or_list_of_t(1)) # revealed: Unknown
def _(list_ofstr: list[str], list_of_int: list[int]):
reveal_type(accepts_t_or_list_of_t(list_ofstr)) # revealed: str
# TODO: the error message here could be improved by referring to the second union element
# error: [invalid-argument-type] "Argument type `list[int]` does not satisfy upper bound `str` of type variable `T`"
reveal_type(accepts_t_or_list_of_t(list_of_int)) # revealed: Unknown
```
Here, we make sure that `S` is solved as `Literal[1]` instead of a union of the two literals, which
would also be a valid solution:
```py
def tuple_param[T, S](x: T | S, y: tuple[T, S]) -> tuple[T, S]:
return y

View File

@ -1570,21 +1570,9 @@ impl<'db> SpecializationBuilder<'db> {
let mut bound_typevars =
(union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
let first_bound_typevar = bound_typevars.next();
let has_more_than_one_typevar = bound_typevars.next().is_some();
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if let Some(bound_typevar) = first_bound_typevar
&& !has_more_than_one_typevar
{
self.add_type_mapping(bound_typevar, actual, polarity, f);
return Ok(());
}
// TODO:
// Handling more than one bare typevar is something that we can't handle yet.
if has_more_than_one_typevar {
if bound_typevars.nth(1).is_some() {
return Ok(());
}
@ -1599,15 +1587,21 @@ impl<'db> SpecializationBuilder<'db> {
let mut first_error = None;
let mut found_matching_element = false;
for formal_element in union_formal.elements(self.db) {
if !formal_element.is_disjoint_from(self.db, actual) {
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
if let Err(err) = result {
first_error.get_or_insert(err);
} else {
// The recursive call to `infer_map_impl` may succeed even if the actual type is
// not assignable to the formal element.
if !actual
.when_assignable_to(self.db, *formal_element, self.inferable)
.is_never_satisfied(self.db)
{
found_matching_element = true;
}
}
}
if !found_matching_element && let Some(error) = first_error {
return Err(error);
}