diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index dd75b2c56a..c9ee5359ce 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -441,7 +441,23 @@ def g[T: A](b: B[T]): return f(b.x) # Fine ``` -## Constrained TypeVar in a union +## Typevars in a union + +```py +def takes_in_union[T](t: T | None) -> T: + raise NotImplementedError + +def takes_in_bigger_union[T](t: T | int | None) -> T: + raise NotImplementedError + +def _(x: str | None) -> None: + reveal_type(takes_in_union(x)) # revealed: str + reveal_type(takes_in_bigger_union(x)) # revealed: str + +def _(x: str | int | None) -> None: + reveal_type(takes_in_union(x)) # revealed: str | int + reveal_type(takes_in_bigger_union(x)) # revealed: str +``` This is a regression test for an issue that surfaced in the primer report of an early version of , where we failed to solve the `TypeVar` here due to diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9d0e04bbba..f9430370a9 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -943,6 +943,17 @@ impl<'db> Type<'db> { self.apply_type_mapping_impl(db, &TypeMapping::Materialize(materialization_kind), visitor) } + pub(crate) const fn is_type_var(self) -> bool { + matches!(self, Type::TypeVar(_)) + } + + pub(crate) const fn into_type_var(self) -> Option> { + match self { + Type::TypeVar(bound_typevar) => Some(bound_typevar), + _ => None, + } + } + pub(crate) const fn into_class_literal(self) -> Option> { match self { Type::ClassLiteral(class_type) => Some(class_type), diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 55b8a26638..55a0d32398 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1138,25 +1138,53 @@ impl<'db> SpecializationBuilder<'db> { } match (formal, actual) { - (Type::Union(formal), _) => { - // TODO: We haven't implemented a full unification solver yet. If typevars appear - // in multiple union elements, we ideally want to express that _only one_ of them - // needs to match, and that we should infer the smallest type mapping that allows - // that. + // TODO: We haven't implemented a full unification solver yet. If typevars appear in + // multiple union elements, we ideally want to express that _only one_ of them needs to + // match, and that we should infer the smallest type mapping that allows that. + // + // For now, we punt on fully handling multiple typevar elements. Instead, we handle two + // common cases specially: + (Type::Union(formal_union), Type::Union(actual_union)) => { + // First, if both formal and actual are unions, and precisely one formal union + // element _is_ a typevar (not _contains_ a typevar), then we remove any actual + // union elements that are a subtype of the formal (as a whole), and map the formal + // typevar to any remaining actual union elements. // - // For now, we punt on handling multiple typevar elements. Instead, if _precisely - // one_ union element _is_ a typevar (not _contains_ a typevar), then we go ahead - // and add a mapping between that typevar and the actual type. (Note that we've - // already handled above the case where the actual is assignable to a _non-typevar_ - // union element.) - let mut bound_typevars = - formal.elements(self.db).iter().filter_map(|ty| match ty { - Type::TypeVar(bound_typevar) => Some(*bound_typevar), - _ => None, - }); - let bound_typevar = bound_typevars.next(); - let additional_bound_typevars = bound_typevars.next(); - if let (Some(bound_typevar), None) = (bound_typevar, additional_bound_typevars) { + // In particular, this handles cases like + // + // ```py + // def f[T](t: T | None) -> T: ... + // def g[T](t: T | int | None) -> T | int: ... + // + // def _(x: str | None): + // reveal_type(f(x)) # revealed: str + // + // def _(y: str | int | None): + // reveal_type(g(x)) # revealed: str | int + // ``` + let formal_bound_typevars = + (formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var()); + let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else { + return Ok(()); + }; + if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) { + return Ok(()); + } + let remaining_actual = + actual_union.filter(self.db, |ty| !ty.is_subtype_of(self.db, formal)); + if remaining_actual.is_never() { + return Ok(()); + } + self.add_type_mapping(formal_bound_typevar, remaining_actual); + } + (Type::Union(formal), _) => { + // Second, if the formal is a union, and precisely one union element _is_ a typevar (not + // _contains_ a typevar), then we add a mapping between that typevar and the actual + // type. (Note that we've already handled above the case where the actual is + // assignable to any _non-typevar_ union element.) + let bound_typevars = + (formal.elements(self.db).iter()).filter_map(|ty| ty.into_type_var()); + if let Ok(bound_typevar) = bound_typevars.exactly_one() { self.add_type_mapping(bound_typevar, actual); } }