From 8223fea062d2fd148c3e80ccf576f3f2093ce5c3 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 29 Aug 2025 09:02:35 -0700 Subject: [PATCH] [ty] ensure union normalization really normalizes (#20147) ## Summary Now that we have `Type::TypeAlias`, which can wrap a union, and the possibility of unions including non-unpacked type aliases (which is necessary to support recursive type aliases), we can no longer assume in `UnionType::normalized_impl` that normalizing each element of an existing union will result in a set of elements that we can order and then place raw into `UnionType` to create a normalized union. It's now possible for those elements to themselves include union types (unpacked from an alias). So instead, we need to feed those elements into the full `UnionBuilder` (with alias-unpacking turned on) to flatten/normalize them, and then order them. ## Test Plan Added mdtest. --------- Co-authored-by: Alex Waygood --- .../resources/mdtest/pep695_type_aliases.md | 17 +++++++ crates/ty_python_semantic/src/types.rs | 45 ++++++++++++------- .../ty_python_semantic/src/types/builder.rs | 11 +++++ 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md b/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md index 2d19300989..0beca03adb 100644 --- a/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md @@ -120,6 +120,23 @@ def f(x: IntOrStr, y: str | bytes): reveal_type(z) # revealed: (int & ~AlwaysFalsy) | str | bytes ``` +## Multiple layers of union aliases + +```py +class A: ... +class B: ... +class C: ... +class D: ... + +type W = A | B +type X = C | D +type Y = W | X + +from ty_extensions import is_equivalent_to, static_assert + +static_assert(is_equivalent_to(Y, A | B | C | D)) +``` + ## `TypeAliasType` properties Two `TypeAliasType`s are distinct and disjoint, even if they refer to the same type diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index f75e7beb5e..4afbef3181 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1118,9 +1118,7 @@ impl<'db> Type<'db> { #[must_use] pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - Type::Union(union) => { - visitor.visit(self, || Type::Union(union.normalized_impl(db, visitor))) - } + Type::Union(union) => visitor.visit(self, || union.normalized_impl(db, visitor)), Type::Intersection(intersection) => visitor.visit(self, || { Type::Intersection(intersection.normalized_impl(db, visitor)) }), @@ -1887,14 +1885,14 @@ impl<'db> Type<'db> { } (Type::TypeAlias(self_alias), _) => { - let self_alias_ty = self_alias.value_type(db); + let self_alias_ty = self_alias.value_type(db).normalized(db); visitor.visit((self_alias_ty, other), || { self_alias_ty.is_equivalent_to_impl(db, other, visitor) }) } (_, Type::TypeAlias(other_alias)) => { - let other_alias_ty = other_alias.value_type(db); + let other_alias_ty = other_alias.value_type(db).normalized(db); visitor.visit((self, other_alias_ty), || { self.is_equivalent_to_impl(db, other_alias_ty, visitor) }) @@ -7697,7 +7695,17 @@ impl<'db> TypeVarBoundOrConstraints<'db> { TypeVarBoundOrConstraints::UpperBound(bound.normalized_impl(db, visitor)) } TypeVarBoundOrConstraints::Constraints(constraints) => { - TypeVarBoundOrConstraints::Constraints(constraints.normalized_impl(db, visitor)) + // Constraints are a non-normalized union by design (it's not really a union at + // all, we are just using a union to store the types). Normalize the types but not + // the containing union. + TypeVarBoundOrConstraints::Constraints(UnionType::new( + db, + constraints + .elements(db) + .iter() + .map(|ty| ty.normalized_impl(db, visitor)) + .collect::>(), + )) } } } @@ -9654,18 +9662,25 @@ impl<'db> UnionType<'db> { /// /// See [`Type::normalized`] for more details. #[must_use] - pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { + pub(crate) fn normalized(self, db: &'db dyn Db) -> Type<'db> { self.normalized_impl(db, &NormalizedVisitor::default()) } - pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - let mut new_elements: Vec> = self - .elements(db) + pub(crate) fn normalized_impl( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Type<'db> { + self.elements(db) .iter() - .map(|element| element.normalized_impl(db, visitor)) - .collect(); - new_elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); - UnionType::new(db, new_elements.into_boxed_slice()) + .map(|ty| ty.normalized_impl(db, visitor)) + .fold( + UnionBuilder::new(db) + .order_elements(true) + .unpack_aliases(true), + UnionBuilder::add, + ) + .build() } pub(crate) fn is_equivalent_to_impl>( @@ -9687,7 +9702,7 @@ impl<'db> UnionType<'db> { let sorted_self = self.normalized(db); - if sorted_self == other { + if sorted_self == Type::Union(other) { return C::always_satisfiable(db); } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 1ea07748a8..1bdb85fa50 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -38,6 +38,7 @@ //! unnecessary `is_subtype_of` checks. use crate::types::enums::{enum_member_literals, enum_metadata}; +use crate::types::type_ordering::union_or_intersection_elements_ordering; use crate::types::{ BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType, @@ -211,6 +212,7 @@ pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, unpack_aliases: bool, + order_elements: bool, } impl<'db> UnionBuilder<'db> { @@ -219,6 +221,7 @@ impl<'db> UnionBuilder<'db> { db, elements: vec![], unpack_aliases: true, + order_elements: false, } } @@ -227,6 +230,11 @@ impl<'db> UnionBuilder<'db> { self } + pub(crate) fn order_elements(mut self, val: bool) -> Self { + self.order_elements = val; + self + } + pub(crate) fn is_empty(&self) -> bool { self.elements.is_empty() } @@ -545,6 +553,9 @@ impl<'db> UnionBuilder<'db> { UnionElement::Type(ty) => types.push(ty), } } + if self.order_elements { + types.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(self.db, l, r)); + } match types.len() { 0 => None, 1 => Some(types[0]),