From c343e94ac537099de623a25127a1e62f38f4ae58 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Tue, 9 Dec 2025 19:49:17 -0500 Subject: [PATCH] [ty] Simplify union lower bounds and intersection upper bounds in constraint sets (#21871) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In a constraint set, it's not useful for an upper bound to be an intersection type, or for a lower bound to be a union type. Both of those can be rewritten as simpler BDDs: ``` T ≤ α & β ⇒ (T ≤ α) ∧ (T ≤ β) T ≤ α & ¬β ⇒ (T ≤ α) ∧ ¬(T ≤ β) α | β ≤ T ⇒ (α ≤ T) ∧ (β ≤ T) ``` We were seeing performance issues on #21551 when _not_ performing this simplification. For instance, `pandas` was producing some constraint sets involving intersections of 8-9 different types. Our sequent map calculation was timing out calculating all of the different permutations of those types: ``` t1 & t2 & t3 → t1 t1 & t2 & t3 → t2 t1 & t2 & t3 → t3 t1 & t2 & t3 → t1 & t2 t1 & t2 & t3 → t1 & t3 t1 & t2 & t3 → t2 & t3 ``` (and then imagine what that looks like for 9 types instead of 3...) With this change, all of those permutations are now encoded in the BDD structure itself, which is very good at simplifying that kind of thing. Pulling this out of #21551 for separate review. --- crates/ty_python_semantic/src/types.rs | 10 ++- .../src/types/constraints.rs | 78 +++++++++++++++++-- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index dfc932bc66..23c38444c8 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1252,7 +1252,7 @@ impl<'db> Type<'db> { } } - pub(crate) const fn is_union(&self) -> bool { + pub(crate) const fn is_union(self) -> bool { matches!(self, Type::Union(_)) } @@ -1268,6 +1268,10 @@ impl<'db> Type<'db> { self.as_union().expect("Expected a Type::Union variant") } + pub(crate) const fn is_intersection(self) -> bool { + matches!(self, Type::Intersection(_)) + } + pub(crate) const fn as_function_literal(self) -> Option> { match self { Type::FunctionLiteral(function_type) => Some(function_type), @@ -14109,6 +14113,10 @@ impl<'db> IntersectionType<'db> { self.positive(db).iter().copied() } + pub fn iter_negative(self, db: &'db dyn Db) -> impl Iterator> { + self.negative(db).iter().copied() + } + pub(crate) fn has_one_element(self, db: &'db dyn Db) -> bool { (self.positive(db).len() + self.negative(db).len()) == 1 } diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 9ec4cfd25f..7a727f3285 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -458,6 +458,19 @@ impl<'db> BoundTypeVarInstance<'db> { } } +#[derive(Clone, Copy, Debug)] +enum IntersectionResult<'db> { + Simplified(ConstrainedTypeVar<'db>), + CannotSimplify, + Disjoint, +} + +impl IntersectionResult<'_> { + fn is_disjoint(self) -> bool { + matches!(self, IntersectionResult::Disjoint) + } +} + /// An individual constraint in a constraint set. This restricts a single typevar to be within a /// lower and upper bound. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] @@ -484,6 +497,39 @@ impl<'db> ConstrainedTypeVar<'db> { debug_assert_eq!(lower, lower.bottom_materialization(db)); debug_assert_eq!(upper, upper.top_materialization(db)); + // It's not useful for an upper bound to be an intersection type, or for a lower bound to + // be a union type. Both of those can be rewritten as simpler BDDs: + // + // T ≤ α & β ⇒ (T ≤ α) ∧ (T ≤ β) + // T ≤ α & ¬β ⇒ (T ≤ α) ∧ ¬(T ≤ β) + // α | β ≤ T ⇒ (α ≤ T) ∧ (β ≤ T) + if let Type::Union(lower_union) = lower { + let mut result = Node::AlwaysTrue; + for lower_element in lower_union.elements(db) { + result = result.and( + db, + ConstrainedTypeVar::new_node(db, typevar, *lower_element, upper), + ); + } + return result; + } + if let Type::Intersection(upper_intersection) = upper { + let mut result = Node::AlwaysTrue; + for upper_element in upper_intersection.iter_positive(db) { + result = result.and( + db, + ConstrainedTypeVar::new_node(db, typevar, lower, upper_element), + ); + } + for upper_element in upper_intersection.iter_negative(db) { + result = result.and( + db, + ConstrainedTypeVar::new_node(db, typevar, lower, upper_element).negate(db), + ); + } + return result; + } + // Two identical typevars must always solve to the same type, so it is not useful to have // an upper or lower bound that is the typevar being constrained. match lower { @@ -659,7 +705,7 @@ impl<'db> ConstrainedTypeVar<'db> { } /// Returns the intersection of two range constraints, or `None` if the intersection is empty. - fn intersect(self, db: &'db dyn Db, other: Self) -> Option { + fn intersect(self, db: &'db dyn Db, other: Self) -> IntersectionResult<'db> { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]); @@ -667,10 +713,14 @@ impl<'db> ConstrainedTypeVar<'db> { // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. if !lower.is_subtype_of(db, upper) { - return None; + return IntersectionResult::Disjoint; } - Some(Self::new(db, self.typevar(db), lower, upper)) + if lower.is_union() || upper.is_intersection() { + return IntersectionResult::CannotSimplify; + } + + IntersectionResult::Simplified(Self::new(db, self.typevar(db), lower, upper)) } fn display(self, db: &'db dyn Db) -> impl Display { @@ -2037,7 +2087,7 @@ impl<'db> InteriorNode<'db> { // constraints is empty, and others that we can make when the intersection is // non-empty. match left_constraint.intersect(db, right_constraint) { - Some(intersection_constraint) => { + IntersectionResult::Simplified(intersection_constraint) => { let intersection_constraint = intersection_constraint.normalized(db); // If the intersection is non-empty, we need to create a new constraint to @@ -2120,7 +2170,11 @@ impl<'db> InteriorNode<'db> { ); } - None => { + // If the intersection doesn't simplify to a single clause, we shouldn't update the + // BDD. + IntersectionResult::CannotSimplify => {} + + IntersectionResult::Disjoint => { // All of the below hold because we just proved that the intersection of left // and right is empty. @@ -2245,7 +2299,9 @@ impl<'db> ConstraintAssignment<'db> { ( ConstraintAssignment::Positive(self_constraint), ConstraintAssignment::Negative(other_constraint), - ) => self_constraint.intersect(db, other_constraint).is_none(), + ) => self_constraint + .intersect(db, other_constraint) + .is_disjoint(), // It's theoretically possible for a negative constraint to imply a positive constraint // if the positive constraint is always satisfied (`Never ≤ T ≤ object`). But we never @@ -2689,7 +2745,7 @@ impl<'db> SequentMap<'db> { } match left_constraint.intersect(db, right_constraint) { - Some(intersection_constraint) => { + IntersectionResult::Simplified(intersection_constraint) => { tracing::debug!( target: "ty_python_semantic::types::constraints::SequentMap", left = %left_constraint.display(db), @@ -2707,7 +2763,13 @@ impl<'db> SequentMap<'db> { self.add_single_implication(db, intersection_constraint, right_constraint); self.enqueue_constraint(intersection_constraint); } - None => { + + // The sequent map only needs to include constraints that might appear in a BDD. If the + // intersection does not collapse to a single constraint, then there's no new + // constraint that we need to add to the sequent map. + IntersectionResult::CannotSimplify => {} + + IntersectionResult::Disjoint => { tracing::debug!( target: "ty_python_semantic::types::constraints::SequentMap", left = %left_constraint.display(db),