diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index dfc932bc66..23c38444c8 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1252,7 +1252,7 @@ impl<'db> Type<'db> { } } - pub(crate) const fn is_union(&self) -> bool { + pub(crate) const fn is_union(self) -> bool { matches!(self, Type::Union(_)) } @@ -1268,6 +1268,10 @@ impl<'db> Type<'db> { self.as_union().expect("Expected a Type::Union variant") } + pub(crate) const fn is_intersection(self) -> bool { + matches!(self, Type::Intersection(_)) + } + pub(crate) const fn as_function_literal(self) -> Option> { match self { Type::FunctionLiteral(function_type) => Some(function_type), @@ -14109,6 +14113,10 @@ impl<'db> IntersectionType<'db> { self.positive(db).iter().copied() } + pub fn iter_negative(self, db: &'db dyn Db) -> impl Iterator> { + self.negative(db).iter().copied() + } + pub(crate) fn has_one_element(self, db: &'db dyn Db) -> bool { (self.positive(db).len() + self.negative(db).len()) == 1 } diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 9ec4cfd25f..7a727f3285 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -458,6 +458,19 @@ impl<'db> BoundTypeVarInstance<'db> { } } +#[derive(Clone, Copy, Debug)] +enum IntersectionResult<'db> { + Simplified(ConstrainedTypeVar<'db>), + CannotSimplify, + Disjoint, +} + +impl IntersectionResult<'_> { + fn is_disjoint(self) -> bool { + matches!(self, IntersectionResult::Disjoint) + } +} + /// An individual constraint in a constraint set. This restricts a single typevar to be within a /// lower and upper bound. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] @@ -484,6 +497,39 @@ impl<'db> ConstrainedTypeVar<'db> { debug_assert_eq!(lower, lower.bottom_materialization(db)); debug_assert_eq!(upper, upper.top_materialization(db)); + // It's not useful for an upper bound to be an intersection type, or for a lower bound to + // be a union type. Both of those can be rewritten as simpler BDDs: + // + // T ≤ α & β ⇒ (T ≤ α) ∧ (T ≤ β) + // T ≤ α & ¬β ⇒ (T ≤ α) ∧ ¬(T ≤ β) + // α | β ≤ T ⇒ (α ≤ T) ∧ (β ≤ T) + if let Type::Union(lower_union) = lower { + let mut result = Node::AlwaysTrue; + for lower_element in lower_union.elements(db) { + result = result.and( + db, + ConstrainedTypeVar::new_node(db, typevar, *lower_element, upper), + ); + } + return result; + } + if let Type::Intersection(upper_intersection) = upper { + let mut result = Node::AlwaysTrue; + for upper_element in upper_intersection.iter_positive(db) { + result = result.and( + db, + ConstrainedTypeVar::new_node(db, typevar, lower, upper_element), + ); + } + for upper_element in upper_intersection.iter_negative(db) { + result = result.and( + db, + ConstrainedTypeVar::new_node(db, typevar, lower, upper_element).negate(db), + ); + } + return result; + } + // Two identical typevars must always solve to the same type, so it is not useful to have // an upper or lower bound that is the typevar being constrained. match lower { @@ -659,7 +705,7 @@ impl<'db> ConstrainedTypeVar<'db> { } /// Returns the intersection of two range constraints, or `None` if the intersection is empty. - fn intersect(self, db: &'db dyn Db, other: Self) -> Option { + fn intersect(self, db: &'db dyn Db, other: Self) -> IntersectionResult<'db> { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); let upper = IntersectionType::from_elements(db, [self.upper(db), other.upper(db)]); @@ -667,10 +713,14 @@ impl<'db> ConstrainedTypeVar<'db> { // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. if !lower.is_subtype_of(db, upper) { - return None; + return IntersectionResult::Disjoint; } - Some(Self::new(db, self.typevar(db), lower, upper)) + if lower.is_union() || upper.is_intersection() { + return IntersectionResult::CannotSimplify; + } + + IntersectionResult::Simplified(Self::new(db, self.typevar(db), lower, upper)) } fn display(self, db: &'db dyn Db) -> impl Display { @@ -2037,7 +2087,7 @@ impl<'db> InteriorNode<'db> { // constraints is empty, and others that we can make when the intersection is // non-empty. match left_constraint.intersect(db, right_constraint) { - Some(intersection_constraint) => { + IntersectionResult::Simplified(intersection_constraint) => { let intersection_constraint = intersection_constraint.normalized(db); // If the intersection is non-empty, we need to create a new constraint to @@ -2120,7 +2170,11 @@ impl<'db> InteriorNode<'db> { ); } - None => { + // If the intersection doesn't simplify to a single clause, we shouldn't update the + // BDD. + IntersectionResult::CannotSimplify => {} + + IntersectionResult::Disjoint => { // All of the below hold because we just proved that the intersection of left // and right is empty. @@ -2245,7 +2299,9 @@ impl<'db> ConstraintAssignment<'db> { ( ConstraintAssignment::Positive(self_constraint), ConstraintAssignment::Negative(other_constraint), - ) => self_constraint.intersect(db, other_constraint).is_none(), + ) => self_constraint + .intersect(db, other_constraint) + .is_disjoint(), // It's theoretically possible for a negative constraint to imply a positive constraint // if the positive constraint is always satisfied (`Never ≤ T ≤ object`). But we never @@ -2689,7 +2745,7 @@ impl<'db> SequentMap<'db> { } match left_constraint.intersect(db, right_constraint) { - Some(intersection_constraint) => { + IntersectionResult::Simplified(intersection_constraint) => { tracing::debug!( target: "ty_python_semantic::types::constraints::SequentMap", left = %left_constraint.display(db), @@ -2707,7 +2763,13 @@ impl<'db> SequentMap<'db> { self.add_single_implication(db, intersection_constraint, right_constraint); self.enqueue_constraint(intersection_constraint); } - None => { + + // The sequent map only needs to include constraints that might appear in a BDD. If the + // intersection does not collapse to a single constraint, then there's no new + // constraint that we need to add to the sequent map. + IntersectionResult::CannotSimplify => {} + + IntersectionResult::Disjoint => { tracing::debug!( target: "ty_python_semantic::types::constraints::SequentMap", left = %left_constraint.display(db),