diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index f96e9d0d16..b29e1364c1 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -1042,11 +1042,11 @@ impl<'db> Node<'db> { db: &'db dyn Db, inferable: InferableTypeVars<'_, 'db>, ) -> bool { - match self { + let interior = match self { Node::AlwaysTrue => return true, Node::AlwaysFalse => return false, - Node::Interior(_) => {} - } + Node::Interior(interior) => interior, + }; let mut typevars = FxHashSet::default(); self.for_each_constraint(db, &mut |constraint| { @@ -1059,8 +1059,6 @@ impl<'db> Node<'db> { } }); - let without_inferable = self.exists(db, inferable.iter()); - // Returns if some specialization satisfies this constraint set. let some_specialization_satisfies = move |specializations: Node<'db>| { let when_satisfied = specializations.implies(db, self).and(db, specializations); @@ -1068,14 +1066,15 @@ impl<'db> Node<'db> { }; // Returns if all specializations satisfy this constraint set. - let all_specializations_satisfy = move |specializations: Node<'db>| { - let when_satisfied = specializations - .implies(db, without_inferable) - .and(db, specializations); - when_satisfied - .iff(db, specializations) - .is_always_satisfied(db) - }; + let all_specializations_satisfy = + move |restricted: Node<'db>, specializations: Node<'db>| { + let when_satisfied = specializations + .implies(db, restricted) + .and(db, specializations); + when_satisfied + .iff(db, specializations) + .is_always_satisfied(db) + }; for typevar in typevars { if typevar.is_inferable(db, inferable) { @@ -1087,10 +1086,19 @@ impl<'db> Node<'db> { } } else { // If the typevar is in non-inferable position, we need to verify that all required - // specializations satisfy the constraint set. Complicating things, the typevar - // might have gradual constraints. For those, we need to know the range of valid - // materializations, but we only need some materialization to satisfy the - // constraint set. + // specializations satisfy the constraint set. If the typevar depends on any other + // typevars that are inferable, we are allowed to choose different specializations + // of those inferable typevars for each specialization of this non-inferable one. + // To handle this, we use the sequent map to find which inferable typevars this + // typevar depends on, and existentially abstract them away. + let inferable_dependencies = interior + .sequent_map(db) + .get_typevar_dependencies(typevar.identity(db)); + let restricted = self.exists(db, inferable_dependencies); + + // Complicating things, the typevar might have gradual constraints. For those, we + // need to know the range of valid materializations, but we only need some + // materialization to satisfy the constraint set. // // NB: We could also model this by introducing a synthetic typevar for the gradual // constraint, treating that synthetic typevar as always inferable (so that we only @@ -1099,7 +1107,7 @@ impl<'db> Node<'db> { // constraint. let (static_specializations, gradual_constraints) = typevar.required_specializations(db); - if !all_specializations_satisfy(static_specializations) { + if !all_specializations_satisfy(restricted, static_specializations) { return false; } for gradual_constraint in gradual_constraints { @@ -2308,6 +2316,12 @@ struct SequentMap<'db> { >, /// Sequents of the form `C → D` single_implications: FxHashMap, FxHashSet>>, + + /// A dependency map recording which typevars depend on each other. (A typevar `T` depends on a + /// typevar `U` if there is any constraint `T ≤ U` or `U ≤ T` in the constraint set.) + typevar_dependencies: + FxHashMap, FxHashSet>>, + /// Constraints that we have already processed processed: FxHashSet>, /// Constraints that enqueued to be processed @@ -2315,6 +2329,18 @@ struct SequentMap<'db> { } impl<'db> SequentMap<'db> { + fn get_typevar_dependencies( + &self, + bound_typevar: BoundTypeVarIdentity<'db>, + ) -> impl Iterator> + '_ { + self.typevar_dependencies + .get(&bound_typevar) + .map(|dependencies| dependencies.into_iter()) + .into_iter() + .flatten() + .copied() + } + fn add(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { self.enqueue_constraint(constraint); @@ -2404,6 +2430,22 @@ impl<'db> SequentMap<'db> { .insert(post); } + fn add_typevar_dependency( + &mut self, + left: BoundTypeVarIdentity<'db>, + right: BoundTypeVarIdentity<'db>, + ) { + // The typevar dependency map is reflexive, so add the dependency in both directions. + self.typevar_dependencies + .entry(left) + .or_default() + .insert(right); + self.typevar_dependencies + .entry(right) + .or_default() + .insert(left); + } + fn add_sequents_for_single(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { // If this constraint binds its typevar to `Never ≤ T ≤ object`, then the typevar can take // on any type, and the constraint is always satisfied. @@ -2424,9 +2466,14 @@ impl<'db> SequentMap<'db> { // Technically, (1) also allows `(S = T) → (S = S)`, but the rhs of that is vacuously true, // so we don't add a sequent for that case. + let typevar = constraint.typevar(db); + let lower = constraint.lower(db); + let upper = constraint.upper(db); let post_constraint = match (lower, upper) { // Case 1 (Type::TypeVar(lower_typevar), Type::TypeVar(upper_typevar)) => { + self.add_typevar_dependency(typevar.identity(db), lower_typevar.identity(db)); + self.add_typevar_dependency(typevar.identity(db), upper_typevar.identity(db)); if !lower_typevar.is_same_typevar_as(db, upper_typevar) { ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper) } else { @@ -2436,11 +2483,13 @@ impl<'db> SequentMap<'db> { // Case 2 (Type::TypeVar(lower_typevar), _) => { + self.add_typevar_dependency(typevar.identity(db), lower_typevar.identity(db)); ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper) } // Case 3 (_, Type::TypeVar(upper_typevar)) => { + self.add_typevar_dependency(typevar.identity(db), upper_typevar.identity(db)); ConstrainedTypeVar::new(db, upper_typevar, lower, Type::object()) }