mirror of https://github.com/astral-sh/ruff
[ty] Generics: Respect typevar bounds when matching against a union
This commit is contained in:
parent
ff7086d9ad
commit
7a85fe2556
|
|
@ -337,6 +337,44 @@ reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
|
|||
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
|
||||
```
|
||||
|
||||
This also works if the typevar has a bound:
|
||||
|
||||
```py
|
||||
T_str = TypeVar("T_str", bound=str)
|
||||
|
||||
def accepts_t_or_int(x: T_str | int) -> T_str:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(accepts_t_or_int("a")) # revealed: Literal["a"]
|
||||
reveal_type(accepts_t_or_int(1)) # revealed: Unknown
|
||||
|
||||
class Unrelated: ...
|
||||
|
||||
# error: [invalid-argument-type] "Argument type `Unrelated` does not satisfy upper bound `str` of type variable `T_str`"
|
||||
reveal_type(accepts_t_or_int(Unrelated())) # revealed: Unknown
|
||||
```
|
||||
|
||||
```py
|
||||
T_str = TypeVar("T_str", bound=str)
|
||||
|
||||
def accepts_t_or_list_of_t(x: T_str | list[T_str]) -> T_str:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(accepts_t_or_list_of_t("a")) # revealed: Literal["a"]
|
||||
# error: [invalid-argument-type] "Argument type `Literal[1]` does not satisfy upper bound `str` of type variable `T_str`"
|
||||
reveal_type(accepts_t_or_list_of_t(1)) # revealed: Unknown
|
||||
|
||||
def _(list_ofstr: list[str], list_of_int: list[int]):
|
||||
reveal_type(accepts_t_or_list_of_t(list_ofstr)) # revealed: str
|
||||
|
||||
# TODO: the error message here could be improved by referring to the second union element
|
||||
# error: [invalid-argument-type] "Argument type `list[int]` does not satisfy upper bound `str` of type variable `T_str`"
|
||||
reveal_type(accepts_t_or_list_of_t(list_of_int)) # revealed: Unknown
|
||||
```
|
||||
|
||||
Here, we make sure that `S` is solved as `Literal[1]` instead of a union of the two literals, which
|
||||
would also be a valid solution:
|
||||
|
||||
```py
|
||||
S = TypeVar("S")
|
||||
|
||||
|
|
|
|||
|
|
@ -302,6 +302,38 @@ reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
|
|||
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
|
||||
```
|
||||
|
||||
This also works if the typevar has a bound:
|
||||
|
||||
```py
|
||||
def accepts_t_or_int[T_str: str](x: T_str | int) -> T_str:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(accepts_t_or_int("a")) # revealed: Literal["a"]
|
||||
reveal_type(accepts_t_or_int(1)) # revealed: Unknown
|
||||
|
||||
class Unrelated: ...
|
||||
|
||||
# error: [invalid-argument-type] "Argument type `Unrelated` does not satisfy upper bound `str` of type variable `T_str`"
|
||||
reveal_type(accepts_t_or_int(Unrelated())) # revealed: Unknown
|
||||
|
||||
def accepts_t_or_list_of_t[T: str](x: T | list[T]) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
reveal_type(accepts_t_or_list_of_t("a")) # revealed: Literal["a"]
|
||||
# error: [invalid-argument-type] "Argument type `Literal[1]` does not satisfy upper bound `str` of type variable `T`"
|
||||
reveal_type(accepts_t_or_list_of_t(1)) # revealed: Unknown
|
||||
|
||||
def _(list_ofstr: list[str], list_of_int: list[int]):
|
||||
reveal_type(accepts_t_or_list_of_t(list_ofstr)) # revealed: str
|
||||
|
||||
# TODO: the error message here could be improved by referring to the second union element
|
||||
# error: [invalid-argument-type] "Argument type `list[int]` does not satisfy upper bound `str` of type variable `T`"
|
||||
reveal_type(accepts_t_or_list_of_t(list_of_int)) # revealed: Unknown
|
||||
```
|
||||
|
||||
Here, we make sure that `S` is solved as `Literal[1]` instead of a union of the two literals, which
|
||||
would also be a valid solution:
|
||||
|
||||
```py
|
||||
def tuple_param[T, S](x: T | S, y: tuple[T, S]) -> tuple[T, S]:
|
||||
return y
|
||||
|
|
|
|||
|
|
@ -1570,21 +1570,9 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let mut bound_typevars =
|
||||
(union_formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
|
||||
|
||||
let first_bound_typevar = bound_typevars.next();
|
||||
let has_more_than_one_typevar = bound_typevars.next().is_some();
|
||||
|
||||
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
|
||||
// typevar), then we add a mapping between that typevar and the actual type.
|
||||
if let Some(bound_typevar) = first_bound_typevar
|
||||
&& !has_more_than_one_typevar
|
||||
{
|
||||
self.add_type_mapping(bound_typevar, actual, polarity, f);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Handling more than one bare typevar is something that we can't handle yet.
|
||||
if has_more_than_one_typevar {
|
||||
if bound_typevars.nth(1).is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
|
@ -1599,15 +1587,21 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let mut first_error = None;
|
||||
let mut found_matching_element = false;
|
||||
for formal_element in union_formal.elements(self.db) {
|
||||
if !formal_element.is_disjoint_from(self.db, actual) {
|
||||
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
|
||||
if let Err(err) = result {
|
||||
first_error.get_or_insert(err);
|
||||
} else {
|
||||
// The recursive call to `infer_map_impl` may succeed even if the actual type is
|
||||
// not assignable to the formal element.
|
||||
if !actual
|
||||
.when_assignable_to(self.db, *formal_element, self.inferable)
|
||||
.is_never_satisfied(self.db)
|
||||
{
|
||||
found_matching_element = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found_matching_element && let Some(error) = first_error {
|
||||
return Err(error);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue