diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index f0a8735762..b412148e6a 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -940,17 +940,6 @@ impl<'db> Node<'db> { } } - /// Returns the smallest source_order associated with the given constraint. - fn source_order_for(self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) -> usize { - let mut best: Option = None; - self.for_each_constraint(db, &mut |candidate, order| { - if candidate == constraint { - best = Some(best.map_or(order, |b| b.min(order))); - } - }); - best.unwrap_or(1) - } - /// Returns whether this BDD represent the constant function `true`. fn is_always_satisfied(self, db: &'db dyn Db) -> bool { match self { @@ -1437,7 +1426,9 @@ impl<'db> Node<'db> { self, db: &'db dyn Db, left: ConstraintAssignment<'db>, + left_source_order: usize, right: ConstraintAssignment<'db>, + right_source_order: usize, replacement: Node<'db>, ) -> Self { // We perform a Shannon expansion to find out what the input BDD evaluates to when: @@ -1471,13 +1462,8 @@ impl<'db> Node<'db> { // false // // (Note that the `else` branch shouldn't be reachable, but we have to provide something!) - let left_node = - Node::new_satisfied_constraint(db, left, self.source_order_for(db, left.constraint())); - let right_node = Node::new_satisfied_constraint( - db, - right, - self.source_order_for(db, right.constraint()), - ); + let left_node = Node::new_satisfied_constraint(db, left, left_source_order); + let right_node = Node::new_satisfied_constraint(db, right, right_source_order); let right_result = right_node.ite(db, Node::AlwaysFalse, when_left_but_not_right); let left_result = left_node.ite(db, right_result, when_not_left); let result = replacement.ite(db, when_left_and_right, left_result); @@ -1500,7 +1486,9 @@ impl<'db> Node<'db> { self, db: &'db dyn Db, left: ConstraintAssignment<'db>, + left_source_order: usize, right: ConstraintAssignment<'db>, + right_source_order: usize, replacement: Node<'db>, ) -> Self { // We perform a Shannon expansion to find out what the input BDD evaluates to when: @@ -1540,13 +1528,8 @@ impl<'db> Node<'db> { // Lastly, verify that the result is consistent with the input. (It must produce the same // results when `left ∨ right`.) If it doesn't, the substitution isn't valid, and we should // return the original BDD unmodified. - let left_node = - Node::new_satisfied_constraint(db, left, self.source_order_for(db, left.constraint())); - let right_node = Node::new_satisfied_constraint( - db, - right, - self.source_order_for(db, right.constraint()), - ); + let left_node = Node::new_satisfied_constraint(db, left, left_source_order); + let right_node = Node::new_satisfied_constraint(db, right, right_source_order); let validity = replacement.iff(db, left_node.or(db, right_node)); let constrained_original = self.and(db, validity); let constrained_replacement = result.and(db, validity); @@ -2098,8 +2081,10 @@ impl<'db> InteriorNode<'db> { // visit queue with all pairs of those constraints. (We use "combinations" because we don't // need to compare a constraint against itself, and because ordering doesn't matter.) let mut seen_constraints = FxHashSet::default(); - Node::Interior(self).for_each_constraint(db, &mut |constraint, _| { + let mut source_orders = FxHashMap::default(); + Node::Interior(self).for_each_constraint(db, &mut |constraint, source_order| { seen_constraints.insert(constraint); + source_orders.insert(constraint, source_order); }); let mut to_visit: Vec<(_, _)> = (seen_constraints.iter().copied()) .tuple_combinations() @@ -2108,7 +2093,11 @@ impl<'db> InteriorNode<'db> { // Repeatedly pop constraint pairs off of the visit queue, checking whether each pair can // be simplified. let mut simplified = Node::Interior(self); + let mut next_source_order = self.max_source_order(db) + 1; while let Some((left_constraint, right_constraint)) = to_visit.pop() { + let left_source_order = source_orders[&left_constraint]; + let right_source_order = source_orders[&right_constraint]; + // If the constraints refer to different typevars, the only simplifications we can make // are of the form `S ≤ T ∧ T ≤ int → S ≤ int`. let left_typevar = left_constraint.typevar(db); @@ -2174,17 +2163,17 @@ impl<'db> InteriorNode<'db> { if seen_constraints.contains(&new_constraint) { continue; } - let derived_source_order = simplified.max_source_order(db) + 1; - let new_node = Node::new_constraint(db, new_constraint, derived_source_order); + let new_node = Node::new_constraint(db, new_constraint, next_source_order); + next_source_order += 1; let positive_left_node = Node::new_satisfied_constraint( db, left_constraint.when_true(), - simplified.source_order_for(db, left_constraint), + left_source_order, ); let positive_right_node = Node::new_satisfied_constraint( db, right_constraint.when_true(), - simplified.source_order_for(db, right_constraint), + right_source_order, ); let lhs = positive_left_node.and(db, positive_right_node); let intersection = new_node.ite(db, lhs, Node::AlwaysFalse); @@ -2208,29 +2197,47 @@ impl<'db> InteriorNode<'db> { // Containment: The range of one constraint might completely contain the range of the // other. If so, there are several potential simplifications. let larger_smaller = if left_constraint.implies(db, right_constraint) { - Some((right_constraint, left_constraint)) + Some(( + right_constraint, + right_source_order, + left_constraint, + left_source_order, + )) } else if right_constraint.implies(db, left_constraint) { - Some((left_constraint, right_constraint)) + Some(( + left_constraint, + left_source_order, + right_constraint, + right_source_order, + )) } else { None }; - if let Some((larger_constraint, smaller_constraint)) = larger_smaller { + if let Some(( + larger_constraint, + larger_source_order, + smaller_constraint, + smaller_source_order, + )) = larger_smaller + { let positive_larger_node = Node::new_satisfied_constraint( db, larger_constraint.when_true(), - simplified.source_order_for(db, larger_constraint), + larger_source_order, ); let negative_larger_node = Node::new_satisfied_constraint( db, larger_constraint.when_false(), - simplified.source_order_for(db, larger_constraint), + larger_source_order, ); // larger ∨ smaller = larger simplified = simplified.substitute_union( db, larger_constraint.when_true(), + larger_source_order, smaller_constraint.when_true(), + smaller_source_order, positive_larger_node, ); @@ -2238,7 +2245,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_intersection( db, larger_constraint.when_false(), + larger_source_order, smaller_constraint.when_false(), + smaller_source_order, negative_larger_node, ); @@ -2247,7 +2256,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_intersection( db, larger_constraint.when_false(), + larger_source_order, smaller_constraint.when_true(), + smaller_source_order, Node::AlwaysFalse, ); @@ -2256,7 +2267,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_union( db, larger_constraint.when_true(), + larger_source_order, smaller_constraint.when_false(), + smaller_source_order, Node::AlwaysTrue, ); } @@ -2278,45 +2291,48 @@ impl<'db> InteriorNode<'db> { .map(|seen| (seen, intersection_constraint)), ); } - let derived_source_order = simplified.max_source_order(db) + 1; let positive_intersection_node = Node::new_satisfied_constraint( db, intersection_constraint.when_true(), - derived_source_order, + next_source_order, ); + next_source_order += 1; let negative_intersection_node = Node::new_satisfied_constraint( db, intersection_constraint.when_false(), - derived_source_order, + next_source_order, ); + next_source_order += 1; let positive_left_node = Node::new_satisfied_constraint( db, left_constraint.when_true(), - simplified.source_order_for(db, left_constraint), + left_source_order, ); let negative_left_node = Node::new_satisfied_constraint( db, left_constraint.when_false(), - simplified.source_order_for(db, left_constraint), + left_source_order, ); let positive_right_node = Node::new_satisfied_constraint( db, right_constraint.when_true(), - simplified.source_order_for(db, right_constraint), + right_source_order, ); let negative_right_node = Node::new_satisfied_constraint( db, right_constraint.when_false(), - simplified.source_order_for(db, right_constraint), + right_source_order, ); // left ∧ right = intersection simplified = simplified.substitute_intersection( db, left_constraint.when_true(), + left_source_order, right_constraint.when_true(), + right_source_order, positive_intersection_node, ); @@ -2324,7 +2340,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_union( db, left_constraint.when_false(), + left_source_order, right_constraint.when_false(), + right_source_order, negative_intersection_node, ); @@ -2334,7 +2352,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_intersection( db, left_constraint.when_true(), + left_source_order, right_constraint.when_false(), + right_source_order, positive_left_node.and(db, negative_intersection_node), ); @@ -2343,7 +2363,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_intersection( db, left_constraint.when_false(), + left_source_order, right_constraint.when_true(), + right_source_order, positive_right_node.and(db, negative_intersection_node), ); @@ -2353,7 +2375,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_union( db, left_constraint.when_true(), + left_source_order, right_constraint.when_false(), + right_source_order, negative_right_node.or(db, positive_intersection_node), ); @@ -2362,7 +2386,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_union( db, left_constraint.when_false(), + left_source_order, right_constraint.when_true(), + right_source_order, negative_left_node.or(db, positive_intersection_node), ); } @@ -2378,19 +2404,21 @@ impl<'db> InteriorNode<'db> { let positive_left_node = Node::new_satisfied_constraint( db, left_constraint.when_true(), - simplified.source_order_for(db, left_constraint), + left_source_order, ); let positive_right_node = Node::new_satisfied_constraint( db, right_constraint.when_true(), - simplified.source_order_for(db, right_constraint), + right_source_order, ); // left ∧ right = false simplified = simplified.substitute_intersection( db, left_constraint.when_true(), + left_source_order, right_constraint.when_true(), + right_source_order, Node::AlwaysFalse, ); @@ -2398,7 +2426,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_union( db, left_constraint.when_false(), + left_source_order, right_constraint.when_false(), + right_source_order, Node::AlwaysTrue, ); @@ -2407,7 +2437,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_intersection( db, left_constraint.when_true(), + left_source_order, right_constraint.when_false(), + right_source_order, positive_left_node, ); @@ -2416,7 +2448,9 @@ impl<'db> InteriorNode<'db> { simplified = simplified.substitute_intersection( db, left_constraint.when_false(), + left_source_order, right_constraint.when_true(), + right_source_order, positive_right_node, ); }