From 336d01957da07c22eb476d457d8a17ed10c7d359 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Tue, 18 Nov 2025 14:53:27 -0500 Subject: [PATCH] start using constraint set specs --- .../ty_python_semantic/src/types/call/bind.rs | 85 ++++++++++++++----- .../src/types/constraints.rs | 36 +++++--- 2 files changed, 86 insertions(+), 35 deletions(-) diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 890bdb8570..95f1db711f 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2809,15 +2809,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { .zip(self.call_expression_tcx.annotation); self.inferable_typevars = generic_context.inferable_typevars(self.db); + let valid_specializations = generic_context.valid_specializations(self.db); + let mut constraints = ConstraintSet::from(true); let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars); // Prefer the declared type of generic classes. let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| { tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) .class_specialization(self.db)?; - - builder.infer(return_ty, tcx).ok()?; - Some(builder.type_mappings().clone()) + let return_type_mappings = + return_ty.when_assignable_to(self.db, tcx, self.inferable_typevars); + if return_type_mappings.is_never_satisfied(self.db) { + return None; + } + constraints.intersect(self.db, return_type_mappings); + Some(return_type_mappings) }); // For a given type variable, we track the variance of any assignments to that type variable @@ -2837,22 +2843,11 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { continue; }; + let argument_type = variadic_argument_type.unwrap_or(argument_type); let specialization_result = builder.infer_map( expected_type, - variadic_argument_type.unwrap_or(argument_type), + argument_type, |(identity, variance, inferred_ty)| { - // Avoid widening the inferred type if it is already assignable to the - // preferred declared type. - if preferred_type_mappings - .as_ref() - .and_then(|types| types.get(&identity)) - .is_some_and(|preferred_ty| { - inferred_ty.is_assignable_to(self.db, *preferred_ty) - }) - { - return None; - } - variance_in_arguments .entry(identity) .and_modify(|current| *current = current.join(variance)) @@ -2868,6 +2863,35 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_index: adjusted_argument_index, }); } + + let argument_constraints = expected_type.when_assignable_to( + self.db, + variadic_argument_type.unwrap_or(argument_type), + self.inferable_typevars, + ); + if argument_constraints.is_never_satisfied(self.db) { + // This argument is never assignable to its parameter, without considering any + // typevars. This will be caught by `check_argument_types` later. + continue; + } + + let valid_argument_constraints = + argument_constraints.and(self.db, || valid_specializations); + if valid_argument_constraints.is_never_satisfied(self.db) { + // There are specializations that make this argument assignable to its + // parameter, but none of them are _valid_ specializations. + // XXX: Figure out which typevars are violated and create a nice + // SpecializationError. + continue; + } + + // Avoid widening the inferred type if it is already assignable to the + // preferred declared type. + // XXX: Because constraint sets are ANDed together this might not be needed? AND + // should prefer the tighter specialization. + + // XXX: Determine typevar variance per argument + constraints.intersect(self.db, valid_argument_constraints); } } @@ -2921,9 +2945,13 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { }; // Build the specialization first without inferring the complete type context. - let isolated_specialization = builder - .mapped(generic_context, maybe_promote) - .build(generic_context); + let Ok(isolated_specialization) = + generic_context.specialize_constrained_mapped(self.db, constraints, maybe_promote) + else { + // XXX: better error + return; + }; + // XXX: maybe_promote let isolated_return_ty = self .return_ty .apply_specialization(self.db, isolated_specialization); @@ -2945,12 +2973,23 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an // annotated assignment, to closer match the order of any unions written in the type annotation. - builder.infer(return_ty, call_expression_tcx).ok()?; + let return_constraints = return_ty + .when_assignable_to(self.db, call_expression_tcx, self.inferable_typevars) + .and(self.db, || valid_specializations); + if return_constraints.is_never_satisfied(self.db) { + return None; + } // Otherwise, build the specialization again after inferring the complete type context. - let specialization = builder - .mapped(generic_context, maybe_promote) - .build(generic_context); + let Ok(specialization) = generic_context.specialize_constrained_mapped( + self.db, + constraints.and(self.db, || return_constraints), + maybe_promote, + ) else { + // XXX: better return + return None; + }; + // XXX: maybe_promote let return_ty = return_ty.apply_specialization(self.db, specialization); Some((Some(specialization), return_ty)) diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index a6079ccfb9..4413441edb 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -3060,10 +3060,25 @@ impl<'db> BoundTypeVarInstance<'db> { } impl<'db> GenericContext<'db> { + /// Returns the valid specializations of all of the typevars in this generic context. + pub(crate) fn valid_specializations(self, db: &'db dyn Db) -> ConstraintSet<'db> { + self.variables(db) + .when_all(db, |bound_typevar| bound_typevar.valid_specializations(db)) + } + pub(crate) fn specialize_constrained( self, db: &'db dyn Db, constraints: ConstraintSet<'db>, + ) -> Result, ()> { + self.specialize_constrained_mapped(db, constraints, |_, _, ty| ty) + } + + pub(crate) fn specialize_constrained_mapped( + self, + db: &'db dyn Db, + constraints: ConstraintSet<'db>, + f: impl Fn(BoundTypeVarIdentity<'db>, BoundTypeVarInstance<'db>, Type<'db>) -> Type<'db>, ) -> Result, ()> { // If the constraint set is cyclic, don't even try to construct a specialization. if constraints.is_cyclic(db) { @@ -3096,17 +3111,14 @@ impl<'db> GenericContext<'db> { let mut satisfied = false; let mut greatest_lower_bound = Type::Never; let mut least_upper_bound = Type::object(); - abstracted.find_representative_types( - db, - bound_typevar.identity(db), - |lower_bound, upper_bound| { - satisfied = true; - greatest_lower_bound = - UnionType::from_elements(db, [greatest_lower_bound, lower_bound]); - least_upper_bound = - IntersectionType::from_elements(db, [least_upper_bound, upper_bound]); - }, - ); + let identity = bound_typevar.identity(db); + abstracted.find_representative_types(db, identity, |lower_bound, upper_bound| { + satisfied = true; + greatest_lower_bound = + UnionType::from_elements(db, [greatest_lower_bound, lower_bound]); + least_upper_bound = + IntersectionType::from_elements(db, [least_upper_bound, upper_bound]); + }); // If there are no satisfiable paths in the BDD, then there is no valid specialization // for this constraint set. @@ -3124,7 +3136,7 @@ impl<'db> GenericContext<'db> { // Of all of the types that satisfy all of the paths in the BDD, we choose the // "largest" one (i.e., "closest to `object`") as the specialization. - types[i] = least_upper_bound; + types[i] = f(identity, bound_typevar, least_upper_bound); } Ok(self.specialize_recursive(db, types.into_boxed_slice()))