diff --git a/crates/ty_python_semantic/src/types/unification.rs b/crates/ty_python_semantic/src/types/unification.rs index 9a86072bbd..3e10cf8016 100644 --- a/crates/ty_python_semantic/src/types/unification.rs +++ b/crates/ty_python_semantic/src/types/unification.rs @@ -155,7 +155,7 @@ impl<'db> ConstraintSetSet<'db> { result } - /// union two sets of constraint sets. + /// Union two sets of constraint sets. /// /// This is the ⊔ operator from [[POPL2015][]], Definition 3.5. /// @@ -165,6 +165,21 @@ impl<'db> ConstraintSetSet<'db> { self.add(db, set); } } + + /// Calculate the distributed intersection of an iterator of sets of constraint sets. + fn distributed_intersection(db: &'db dyn Db, sets: impl IntoIterator) -> Self { + sets.into_iter() + .fold(Self::always(), |sets, element| element.intersect(db, &sets)) + } + + /// Calculate the distributed union of an iterator of sets of constraint sets. + fn distributed_union(db: &'db dyn Db, sets: impl IntoIterator) -> Self { + let mut result = Self::never(); + for set in sets { + result.union(db, set); + } + result + } } impl<'db> From> for ConstraintSetSet<'db> { @@ -225,9 +240,10 @@ fn normalized_constraints_from_type_inner<'db>( Type::Union(union) => { // Figure 3, step 6 // A union is a subtype of Never only if every element is. - (union.iter(db)).fold(ConstraintSetSet::always(), |sets, element| { - normalized_constraints_from_type(db, *element).intersect(db, &sets) - }) + ConstraintSetSet::distributed_union( + db, + (union.iter(db)).map(|element| normalized_constraints_from_type(db, *element)), + ) } Type::Intersection(intersection) => {