From 536d8fb000d353cc2516c59fb566d7eeeef86b87 Mon Sep 17 00:00:00 2001 From: David Peter Date: Mon, 26 May 2025 17:40:17 +0200 Subject: [PATCH] [ty] Normalize tuples of unions as unions of tuples --- .../type_properties/is_equivalent_to.md | 11 +++++ crates/ty_python_semantic/src/types.rs | 48 ++++++++++++++++--- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 558c455c4e..94544b3a47 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -106,6 +106,17 @@ static_assert( ) ``` +## Tuples containing unions, unions containing tuples + +```py +from ty_extensions import is_equivalent_to, static_assert + +class A: ... +class B: ... + +static_assert(is_equivalent_to(tuple[A | B], tuple[A] | tuple[B])) +``` + ## Intersections containing tuples containing unions ```py diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index fcab31b7a7..819d363c3f 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1013,7 +1013,7 @@ impl<'db> Type<'db> { match self { Type::Union(union) => Type::Union(union.normalized(db)), Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)), - Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)), + Type::Tuple(tuple) => tuple.normalized(db), Type::Callable(callable) => Type::Callable(callable.normalized(db)), Type::ProtocolInstance(protocol) => protocol.normalized(db), Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)), @@ -1709,7 +1709,7 @@ impl<'db> Type<'db> { pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { // TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc. - match (self, other) { + match (self.normalized(db), other.normalized(db)) { (Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right), (Type::Intersection(left), Type::Intersection(right)) => { left.is_equivalent_to(db, right) @@ -1756,7 +1756,7 @@ impl<'db> Type<'db> { return true; } - match (self, other) { + match (self.normalized(db), other.normalized(db)) { (Type::Dynamic(_), Type::Dynamic(_)) => true, (Type::SubclassOf(first), Type::SubclassOf(second)) => { @@ -8712,13 +8712,47 @@ impl<'db> TupleType<'db> { /// /// See [`Type::normalized`] for more details. #[must_use] - pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { - let elements: Box<[Type<'db>]> = self + pub(crate) fn normalized(self, db: &'db dyn Db) -> Type<'db> { + // Collect the normalized elements for each tuple slot. + let normalized_elements: Vec>> = self .elements(db) .iter() - .map(|ty| ty.normalized(db)) + .map(|ty| { + let norm = ty.normalized(db); + if let Type::Union(union) = norm { + union.elements(db).to_vec() + } else { + vec![norm] + } + }) .collect(); - TupleType::new(db, elements) + + // Compute the cartesian product of all element choices. + let mut product: Vec>> = vec![vec![]]; + for slot in &normalized_elements { + let mut next = Vec::with_capacity(product.len() * slot.len()); + for prefix in &product { + for elem in slot { + let mut new_tuple = prefix.clone(); + new_tuple.push(*elem); + next.push(new_tuple); + } + } + product = next; + } + + // If only one combination, return a single tuple type. + if product.len() == 1 { + return TupleType::from_elements(db, product.pop().unwrap().into_boxed_slice()); + } + + // Otherwise, return a union of all possible tuple combinations. + UnionType::from_elements( + db, + product + .into_iter() + .map(|elems| TupleType::from_elements(db, elems.into_boxed_slice())), + ) } pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {