[ty] Generics: Respect typevar bounds when matching against a union

This commit is contained in:
David Peter 2025-12-10 13:16:31 +01:00
parent ff7086d9ad
commit 7a85fe2556
3 changed files with 82 additions and 18 deletions

View File

@ -337,6 +337,44 @@ reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
```
This also works if the typevar has a bound:
```py
T_str = TypeVar("T_str", bound=str)
def accepts_t_or_int(x: T_str | int) -> T_str:
raise NotImplementedError
reveal_type(accepts_t_or_int("a")) # revealed: Literal["a"]
reveal_type(accepts_t_or_int(1)) # revealed: Unknown
class Unrelated: ...
# error: [invalid-argument-type] "Argument type `Unrelated` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_int(Unrelated())) # revealed: Unknown
```
```py
T_str = TypeVar("T_str", bound=str)
def accepts_t_or_list_of_t(x: T_str | list[T_str]) -> T_str:
raise NotImplementedError
reveal_type(accepts_t_or_list_of_t("a")) # revealed: Literal["a"]
# error: [invalid-argument-type] "Argument type `Literal[1]` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_list_of_t(1)) # revealed: Unknown
def _(list_ofstr: list[str], list_of_int: list[int]):
reveal_type(accepts_t_or_list_of_t(list_ofstr)) # revealed: str
# TODO: the error message here could be improved by referring to the second union element
# error: [invalid-argument-type] "Argument type `list[int]` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_list_of_t(list_of_int)) # revealed: Unknown
```
Here, we make sure that `S` is solved as `Literal[1]` instead of a union of the two literals, which
would also be a valid solution:
```py
S = TypeVar("S")

View File

@ -302,6 +302,38 @@ reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
```
This also works if the typevar has a bound:
```py
def accepts_t_or_int[T_str: str](x: T_str | int) -> T_str:
raise NotImplementedError
reveal_type(accepts_t_or_int("a")) # revealed: Literal["a"]
reveal_type(accepts_t_or_int(1)) # revealed: Unknown
class Unrelated: ...
# error: [invalid-argument-type] "Argument type `Unrelated` does not satisfy upper bound `str` of type variable `T_str`"
reveal_type(accepts_t_or_int(Unrelated())) # revealed: Unknown
def accepts_t_or_list_of_t[T: str](x: T | list[T]) -> T:
raise NotImplementedError
reveal_type(accepts_t_or_list_of_t("a")) # revealed: Literal["a"]
# error: [invalid-argument-type] "Argument type `Literal[1]` does not satisfy upper bound `str` of type variable `T`"
reveal_type(accepts_t_or_list_of_t(1)) # revealed: Unknown
def _(list_ofstr: list[str], list_of_int: list[int]):
reveal_type(accepts_t_or_list_of_t(list_ofstr)) # revealed: str
# TODO: the error message here could be improved by referring to the second union element
# error: [invalid-argument-type] "Argument type `list[int]` does not satisfy upper bound `str` of type variable `T`"
reveal_type(accepts_t_or_list_of_t(list_of_int)) # revealed: Unknown
```
Here, we make sure that `S` is solved as `Literal[1]` instead of a union of the two literals, which
would also be a valid solution:
```py
def tuple_param[T, S](x: T | S, y: tuple[T, S]) -> tuple[T, S]:
return y

View File

@ -1570,21 +1570,9 @@ impl<'db> SpecializationBuilder<'db> {
let mut bound_typevars =
(union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
let first_bound_typevar = bound_typevars.next();
let has_more_than_one_typevar = bound_typevars.next().is_some();
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if let Some(bound_typevar) = first_bound_typevar
&& !has_more_than_one_typevar
{
self.add_type_mapping(bound_typevar, actual, polarity, f);
return Ok(());
}
// TODO:
// Handling more than one bare typevar is something that we can't handle yet.
if has_more_than_one_typevar {
if bound_typevars.nth(1).is_some() {
return Ok(());
}
@ -1599,15 +1587,21 @@ impl<'db> SpecializationBuilder<'db> {
let mut first_error = None;
let mut found_matching_element = false;
for formal_element in union_formal.elements(self.db) {
if !formal_element.is_disjoint_from(self.db, actual) {
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
if let Err(err) = result {
first_error.get_or_insert(err);
} else {
// The recursive call to `infer_map_impl` may succeed even if the actual type is
// not assignable to the formal element.
if !actual
.when_assignable_to(self.db, *formal_element, self.inferable)
.is_never_satisfied(self.db)
{
found_matching_element = true;
}
}
}
if !found_matching_element && let Some(error) = first_error {
return Err(error);
}