diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/satisfied_by_all_typevars.md b/crates/ty_python_semantic/resources/mdtest/type_properties/satisfied_by_all_typevars.md new file mode 100644 index 0000000000..8d9f563250 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/satisfied_by_all_typevars.md @@ -0,0 +1,220 @@ +# Constraint set satisfaction + +```toml +[environment] +python-version = "3.12" +``` + +Constraint sets exist to help us check assignability and subtyping of types in the presence of +typevars. We construct a constraint set describing the conditions under which assignability holds +between the two types. Then we check whether that constraint set is satisfied for the valid +specializations of the relevant typevars. This file tests that final step. + +## Inferable vs non-inferable typevars + +Typevars can appear in _inferable_ or _non-inferable_ positions. + +When a typevar is in an inferable position, the constraint set only needs to be satisfied for _some_ +valid specialization. The most common inferable position occurs when invoking a generic function: +all of the function's typevars are inferable, because we want to use the argument types to infer +which specialization is being invoked. + +When a typevar is in a non-inferable position, the constraint set must be satisfied for _every_ +valid specialization. The most common non-inferable position occurs in the body of a generic +function or class: here we don't know in advance what type the typevar will be specialized to, and +so we have to ensure that the body is valid for all possible specializations. + +```py +def f[T](t: T) -> T: + # In the function body, T is non-inferable. All assignability checks involving T must be + # satisfied for _all_ valid specializations of T. + return t + +# When invoking the function, T is inferable — we attempt to infer a specialization that is valid +# for the particular arguments that are passed to the function. Assignability checks (in particular, +# that the argument type is assignable to the parameter type) only need to succeed for _at least +# one_ specialization. +f(1) +``` + +In all of the examples below, for ease of reproducibility, we explicitly list the typevars that are +inferable in each `satisfied_by_all_typevars` call; any typevar not listed is assumed to be +non-inferable. + +## Unbounded typevar + +If a typevar has no bound or constraints, then it can specialize to any type. In an inferable +position, that means we just need a single type (any type at all!) that satisfies the constraint +set. In a non-inferable position, that means the constraint set must be satisfied for every possible +type. + +```py +from typing import final, Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +@final +class Unrelated: ... + +def unbounded[T](): + static_assert(ConstraintSet.always().satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(ConstraintSet.always().satisfied_by_all_typevars()) + + static_assert(not ConstraintSet.never().satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(not ConstraintSet.never().satisfied_by_all_typevars()) + + # (T = Never) is a valid specialization, which satisfies (T ≤ Unrelated). + static_assert(ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T ≤ Unrelated). + static_assert(not ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars()) + + # (T = Base) is a valid specialization, which satisfies (T ≤ Super). + static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Super). + static_assert(not ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars()) + + # (T = Base) is a valid specialization, which satisfies (T ≤ Base). + static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Base). + static_assert(not ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars()) + + # (T = Sub) is a valid specialization, which satisfies (T ≤ Sub). + static_assert(ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Sub). + static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars()) +``` + +## Typevar with an upper bound + +If a typevar has an upper bound, then it must specialize to a type that is a subtype of that bound. +For an inferable typevar, that means we need a single type that satisfies both the constraint set +and the upper bound. For a non-inferable typevar, that means the constraint set must be satisfied +for every type that satisfies the upper bound. + +```py +from typing import final, Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +@final +class Unrelated: ... + +def bounded[T: Base](): + static_assert(ConstraintSet.always().satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(ConstraintSet.always().satisfied_by_all_typevars()) + + static_assert(not ConstraintSet.never().satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(not ConstraintSet.never().satisfied_by_all_typevars()) + + # (T = Base) is a valid specialization, which satisfies (T ≤ Super). + static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T])) + # Every valid specialization satisfies (T ≤ Base). Since (Base ≤ Super), every valid + # specialization also satisfies (T ≤ Super). + static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars()) + + # (T = Base) is a valid specialization, which satisfies (T ≤ Base). + static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars(inferable=tuple[T])) + # Every valid specialization satisfies (T ≤ Base). + static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars()) + + # (T = Sub) is a valid specialization, which satisfies (T ≤ Sub). + static_assert(ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T ≤ Sub). + static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars()) + + # (T = Never) is a valid specialization, which satisfies (T ≤ Unrelated). + constraints = ConstraintSet.range(Never, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T ≤ Unrelated). + static_assert(not constraints.satisfied_by_all_typevars()) + + # Never is the only type that satisfies both (T ≤ Base) and (T ≤ Unrelated). So there is no + # valid specialization that satisfies (T ≤ Unrelated ∧ T ≠ Never). + constraints = constraints & ~ConstraintSet.range(Never, T, Never) + static_assert(not constraints.satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(not constraints.satisfied_by_all_typevars()) +``` + +## Constrained typevar + +If a typevar has constraints, then it must specialize to one of those specific types. (Not to a +subtype of one of those types!) For an inferable typevar, that means we need the constraint set to +be satisfied by any one of the constraints. For a non-inferable typevar, that means we need the +constraint set to be satisfied by all of those constraints. + +```py +from typing import final, Never +from ty_extensions import ConstraintSet, static_assert + +class Super: ... +class Base(Super): ... +class Sub(Base): ... + +@final +class Unrelated: ... + +def constrained[T: (Base, Unrelated)](): + static_assert(ConstraintSet.always().satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(ConstraintSet.always().satisfied_by_all_typevars()) + + static_assert(not ConstraintSet.never().satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(not ConstraintSet.never().satisfied_by_all_typevars()) + + # (T = Unrelated) is a valid specialization, which satisfies (T ≤ Unrelated). + static_assert(ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T ≤ Unrelated). + static_assert(not ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars()) + + # (T = Base) is a valid specialization, which satisfies (T ≤ Super). + static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Super). + static_assert(not ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars()) + + # (T = Base) is a valid specialization, which satisfies (T ≤ Base). + static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Base). + static_assert(not ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars()) + + # Neither (T = Base) nor (T = Unrelated) satisfy (T ≤ Sub). + static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars()) + + # (T = Base) and (T = Unrelated) both satisfy (T ≤ Super ∨ T ≤ Unrelated). + constraints = ConstraintSet.range(Never, T, Super) | ConstraintSet.range(Never, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(constraints.satisfied_by_all_typevars()) + + # (T = Base) and (T = Unrelated) both satisfy (T ≤ Base ∨ T ≤ Unrelated). + constraints = ConstraintSet.range(Never, T, Base) | ConstraintSet.range(Never, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(constraints.satisfied_by_all_typevars()) + + # (T = Unrelated) is a valid specialization, which satisfies (T ≤ Sub ∨ T ≤ Unrelated). + constraints = ConstraintSet.range(Never, T, Sub) | ConstraintSet.range(Never, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T ≤ Sub ∨ T ≤ Unrelated). + static_assert(not constraints.satisfied_by_all_typevars()) + + # (T = Unrelated) is a valid specialization, which satisfies (T = Super ∨ T = Unrelated). + constraints = ConstraintSet.range(Super, T, Super) | ConstraintSet.range(Unrelated, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T = Super ∨ T = Unrelated). + static_assert(not constraints.satisfied_by_all_typevars()) + + # (T = Base) and (T = Unrelated) both satisfy (T = Base ∨ T = Unrelated). + constraints = ConstraintSet.range(Base, T, Base) | ConstraintSet.range(Unrelated, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + static_assert(constraints.satisfied_by_all_typevars()) + + # (T = Unrelated) is a valid specialization, which satisfies (T = Sub ∨ T = Unrelated). + constraints = ConstraintSet.range(Sub, T, Sub) | ConstraintSet.range(Unrelated, T, Unrelated) + static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T])) + # (T = Base) is a valid specialization, which does not satisfy (T = Sub ∨ T = Unrelated). + static_assert(not constraints.satisfied_by_all_typevars()) +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index be3816ac12..a4eb563e6a 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4161,6 +4161,14 @@ impl<'db> Type<'db> { )) .into() } + Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) + if name == "satisfied_by_all_typevars" => + { + Place::bound(Type::KnownBoundMethod( + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(tracked), + )) + .into() + } Type::ClassLiteral(class) if name == "__get__" && class.is_known(db, KnownClass::FunctionType) => @@ -6923,6 +6931,7 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) ) | Type::DataclassDecorator(_) | Type::DataclassTransformer(_) @@ -7074,7 +7083,8 @@ impl<'db> Type<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), ) | Type::DataclassDecorator(_) | Type::DataclassTransformer(_) @@ -10339,6 +10349,7 @@ pub enum KnownBoundMethodType<'db> { ConstraintSetAlways, ConstraintSetNever, ConstraintSetImpliesSubtypeOf(TrackedConstraintSet<'db>), + ConstraintSetSatisfiedByAllTypeVars(TrackedConstraintSet<'db>), } pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( @@ -10366,7 +10377,8 @@ pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Size | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => {} + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => {} } } @@ -10434,6 +10446,10 @@ impl<'db> KnownBoundMethodType<'db> { | ( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), + ) + | ( + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), ) => ConstraintSet::from(true), ( @@ -10446,7 +10462,8 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), KnownBoundMethodType::FunctionTypeDunderGet(_) | KnownBoundMethodType::FunctionTypeDunderCall(_) | KnownBoundMethodType::PropertyDunderGet(_) @@ -10456,7 +10473,8 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), ) => ConstraintSet::from(false), } } @@ -10509,6 +10527,10 @@ impl<'db> KnownBoundMethodType<'db> { ( KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(left_constraints), KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(right_constraints), + ) + | ( + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(left_constraints), + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(right_constraints), ) => left_constraints .constraints(db) .iff(db, right_constraints.constraints(db)), @@ -10523,7 +10545,8 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), KnownBoundMethodType::FunctionTypeDunderGet(_) | KnownBoundMethodType::FunctionTypeDunderCall(_) | KnownBoundMethodType::PropertyDunderGet(_) @@ -10533,7 +10556,8 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_), + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_), ) => ConstraintSet::from(false), } } @@ -10557,7 +10581,8 @@ impl<'db> KnownBoundMethodType<'db> { | KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => self, + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => self, } } @@ -10573,7 +10598,10 @@ impl<'db> KnownBoundMethodType<'db> { KnownBoundMethodType::ConstraintSetRange | KnownBoundMethodType::ConstraintSetAlways | KnownBoundMethodType::ConstraintSetNever - | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => KnownClass::ConstraintSet, + | KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) + | KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => { + KnownClass::ConstraintSet + } } } @@ -10712,6 +10740,19 @@ impl<'db> KnownBoundMethodType<'db> { Some(KnownClass::ConstraintSet.to_instance(db)), ))) } + + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => { + Either::Right(std::iter::once(Signature::new( + Parameters::new([Parameter::keyword_only(Name::new_static("inferable")) + .type_form() + .with_annotated_type(UnionType::from_elements( + db, + [Type::homogeneous_tuple(db, Type::any()), Type::none(db)], + )) + .with_default_type(Type::none(db))]), + Some(KnownClass::Bool.to_instance(db)), + ))) + } } } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 1b4629b301..b0a5cc1b91 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -9,6 +9,7 @@ use std::fmt; use itertools::{Either, Itertools}; use ruff_db::parsed::parsed_module; use ruff_python_ast::name::Name; +use rustc_hash::FxHashSet; use smallvec::{SmallVec, smallvec, smallvec_inline}; use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Signature, Type}; @@ -35,9 +36,10 @@ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Paramete use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance, - KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, - SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType, - WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type, + KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, + PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, + UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, + todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; @@ -1174,6 +1176,42 @@ impl<'db> Bindings<'db> { )); } + Type::KnownBoundMethod( + KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(tracked), + ) => { + let extract_inferable = |instance: &NominalInstanceType<'db>| { + if instance.has_known_class(db, KnownClass::NoneType) { + // Caller explicitly passed None, so no typevars are inferable. + return Some(FxHashSet::default()); + } + instance + .tuple_spec(db)? + .fixed_elements() + .map(|ty| { + ty.as_typevar() + .map(|bound_typevar| bound_typevar.identity(db)) + }) + .collect() + }; + + let inferable = match overload.parameter_types() { + // Caller did not provide argument, so no typevars are inferable. + [None] => FxHashSet::default(), + [Some(Type::NominalInstance(instance))] => { + match extract_inferable(instance) { + Some(inferable) => inferable, + None => continue, + } + } + _ => continue, + }; + + let result = tracked + .constraints(db) + .satisfied_by_all_typevars(db, InferableTypeVars::One(&inferable)); + overload.set_return_type(Type::BooleanLiteral(result)); + } + Type::ClassLiteral(class) => match class.known(db) { Some(KnownClass::Bool) => match overload.parameter_types() { [Some(arg)] => overload.set_return_type(arg.bool(db).into_type(db)), diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index ef7632ff2e..ee66cd85f3 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -65,7 +65,10 @@ use salsa::plumbing::AsId; use crate::Db; use crate::types::generics::InferableTypeVars; -use crate::types::{BoundTypeVarInstance, IntersectionType, Type, TypeRelation, UnionType}; +use crate::types::{ + BoundTypeVarInstance, IntersectionType, Type, TypeRelation, TypeVarBoundOrConstraints, + UnionType, +}; /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { @@ -256,6 +259,28 @@ impl<'db> ConstraintSet<'db> { } } + /// Returns whether this constraint set is satisfied by all of the typevars that it mentions. + /// + /// Each typevar has a set of _valid specializations_, which is defined by any upper bound or + /// constraints that the typevar has. + /// + /// Each typevar is also either _inferable_ or _non-inferable_. (You provide a list of the + /// `inferable` typevars; all others are considered non-inferable.) For an inferable typevar, + /// then there must be _some_ valid specialization that satisfies the constraint set. For a + /// non-inferable typevar, then _all_ valid specializations must satisfy it. + /// + /// Note that we don't have to consider typevars that aren't mentioned in the constraint set, + /// since the constraint set cannot be affected by any typevars that it does not mention. That + /// means that those additional typevars trivially satisfy the constraint set, regardless of + /// whether they are inferable or not. + pub(crate) fn satisfied_by_all_typevars( + self, + db: &'db dyn Db, + inferable: InferableTypeVars<'_, 'db>, + ) -> bool { + self.node.satisfied_by_all_typevars(db, inferable) + } + /// Updates this constraint set to hold the union of itself and another constraint set. pub(crate) fn union(&mut self, db: &'db dyn Db, other: Self) -> Self { self.node = self.node.or(db, other.node); @@ -746,6 +771,13 @@ impl<'db> Node<'db> { .or(db, self.negate(db).and(db, else_node)) } + fn satisfies(self, db: &'db dyn Db, other: Self) -> Self { + let simplified_self = self.simplify(db); + let implication = simplified_self.implies(db, other); + let (simplified, domain) = implication.simplify_and_domain(db); + simplified.and(db, domain) + } + fn when_subtype_of_given( self, db: &'db dyn Db, @@ -767,10 +799,48 @@ impl<'db> Node<'db> { _ => return lhs.when_subtype_of(db, rhs, inferable).node, }; - let simplified_self = self.simplify(db); - let implication = simplified_self.implies(db, constraint); - let (simplified, domain) = implication.simplify_and_domain(db); - simplified.and(db, domain) + self.satisfies(db, constraint) + } + + fn satisfied_by_all_typevars( + self, + db: &'db dyn Db, + inferable: InferableTypeVars<'_, 'db>, + ) -> bool { + match self { + Node::AlwaysTrue => return true, + Node::AlwaysFalse => return false, + Node::Interior(_) => {} + } + + let mut typevars = FxHashSet::default(); + self.for_each_constraint(db, &mut |constraint| { + typevars.insert(constraint.typevar(db)); + }); + + for typevar in typevars { + // Determine which valid specializations of this typevar satisfy the constraint set. + let valid_specializations = typevar.valid_specializations(db).node; + let when_satisfied = valid_specializations + .satisfies(db, self) + .and(db, valid_specializations); + let satisfied = if typevar.is_inferable(db, inferable) { + // If the typevar is inferable, then we only need one valid specialization to + // satisfy the constraint set. + !when_satisfied.is_never_satisfied() + } else { + // If the typevar is non-inferable, then we need _all_ valid specializations to + // satisfy the constraint set. + when_satisfied + .iff(db, valid_specializations) + .is_always_satisfied(db) + }; + if !satisfied { + return false; + } + } + + true } /// Returns a new BDD that returns the same results as `self`, but with some inputs fixed to @@ -1861,6 +1931,33 @@ impl<'db> SatisfiedClauses<'db> { } } +/// Returns a constraint set describing the valid specializations of a typevar. +impl<'db> BoundTypeVarInstance<'db> { + pub(crate) fn valid_specializations(self, db: &'db dyn Db) -> ConstraintSet<'db> { + match self.typevar(db).bound_or_constraints(db) { + None => ConstraintSet::from(true), + Some(TypeVarBoundOrConstraints::UpperBound(bound)) => ConstraintSet::constrain_typevar( + db, + self, + Type::Never, + bound, + TypeRelation::Assignability, + ), + Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { + constraints.elements(db).iter().when_any(db, |constraint| { + ConstraintSet::constrain_typevar( + db, + self, + *constraint, + *constraint, + TypeRelation::Assignability, + ) + }) + } + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 7748dd3ab5..8500c142e8 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -535,6 +535,9 @@ impl Display for DisplayRepresentation<'_> { Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)) => { f.write_str("bound method `ConstraintSet.implies_subtype_of`") } + Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars( + _, + )) => f.write_str("bound method `ConstraintSet.satisfied_by_all_typevars`"), Type::WrapperDescriptor(kind) => { let (method, object) = match kind { WrapperDescriptorKind::FunctionTypeDunderGet => ("__get__", "function"), diff --git a/crates/ty_vendored/ty_extensions/ty_extensions.pyi b/crates/ty_vendored/ty_extensions/ty_extensions.pyi index 79cda64bef..d23554f0ae 100644 --- a/crates/ty_vendored/ty_extensions/ty_extensions.pyi +++ b/crates/ty_vendored/ty_extensions/ty_extensions.pyi @@ -67,6 +67,16 @@ class ConstraintSet: .. _subtype: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence """ + def satisfied_by_all_typevars( + self, *, inferable: tuple[Any, ...] | None = None + ) -> bool: + """ + Returns whether this constraint set is satisfied by all of the typevars + that it mentions. You must provide a tuple of the typevars that should + be considered `inferable`. All other typevars mentioned in the + constraint set will be considered non-inferable. + """ + def __bool__(self) -> bool: ... def __eq__(self, other: ConstraintSet) -> bool: ... def __ne__(self, other: ConstraintSet) -> bool: ...