[ty] Handle nested types when creating specializations from constraint sets (#21530)

#21414 added the ability to create a specialization from a constraint
set. It handled mutually constrained typevars just fine, e.g. given `T ≤
int ∧ U = T` we can infer `T = int, U = int`.

But it didn't handle _nested_ constraints correctly, e.g. `T ≤ int ∧ U =
list[T]`. Now we do! This requires doing a fixed-point "apply the
specialization to itself" step to propagate the assignments of any
nested typevars, and then a cycle detection check to make sure we don't
have an infinite expansion in the specialization.

This gets at an interesting nuance in our constraint set structure that
@sharkdp has asked about before. Constraint sets are BDDs, and each
internal node represents an _individual constraint_, of the form `lower
≤ T ≤ upper`. `lower` and `upper` are allowed to be other typevars, but
only if they appear "later" in the arbitary ordering that we establish
over typevars. The main purpose of this is to avoid infinite expansion
for mutually constrained typevars.

However, that restriction doesn't help us here, because only applies
when `lower` and `upper` _are_ typevars, not when they _contain_
typevars. That distinction is important, since it means the restriction
does not affect our expressiveness: we can always rewrite `Never ≤ T ≤
U` (a constraint on `T`) into `T ≤ U ≤ object` (a constraint on `U`).
The same is not true of `Never ≤ T ≤ list[U]` — there is no "inverse" of
`list` that we could apply to both sides to transform this into a
constraint on a bare `U`.
This commit is contained in:
Douglas Creager 2025-11-19 17:37:16 -05:00 committed by GitHub
parent 0d47334f3b
commit 83134fb380
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 179 additions and 5 deletions

View File

@ -303,3 +303,33 @@ def mutually_bound[T: Base, U]():
# revealed: ty_extensions.Specialization[T@mutually_bound = Base, U@mutually_bound = Sub]
reveal_type(generic_context(mutually_bound).specialize_constrained(ConstraintSet.range(Never, U, Sub) & ConstraintSet.range(Never, U, T)))
```
## Nested typevars
A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U`
must be specialized to `list[T]`, but it cannot affect what `T` is specialized to.
```py
from typing import Never
from ty_extensions import ConstraintSet, generic_context
def mentions[T, U]():
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T])
# revealed: ty_extensions.ConstraintSet[((T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]))]
reveal_type(constraints)
# revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]]
reveal_type(generic_context(mentions).specialize_constrained(constraints))
```
If the constraint set contains mutually recursive bounds, specialization inference will not
converge. This test ensures that our cycle detection prevents an endless loop or stack overflow in
this case.
```py
def divergent[T, U]():
constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T])
# revealed: ty_extensions.ConstraintSet[((T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent]))]
reveal_type(constraints)
# revealed: None
reveal_type(generic_context(divergent).specialize_constrained(constraints))
```

View File

@ -53,6 +53,7 @@
//!
//! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram
use std::cell::RefCell;
use std::cmp::Ordering;
use std::fmt::Display;
use std::ops::Range;
@ -62,9 +63,10 @@ use rustc_hash::{FxHashMap, FxHashSet};
use salsa::plumbing::AsId;
use crate::types::generics::{GenericContext, InferableTypeVars, Specialization};
use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard};
use crate::types::{
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation,
TypeVarBoundOrConstraints, UnionType,
TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type,
};
use crate::{Db, FxOrderSet};
@ -213,6 +215,100 @@ impl<'db> ConstraintSet<'db> {
self.node.is_always_satisfied(db)
}
/// Returns whether this constraint set contains any cycles between typevars. If it does, then
/// we cannot create a specialization from this constraint set.
///
/// We have restrictions in place that ensure that there are no cycles in the _lower and upper
/// bounds_ of each constraint, but it's still possible for a constraint to _mention_ another
/// typevar without _constraining_ it. For instance, `(T ≤ int) ∧ (U ≤ list[T])` is a valid
/// constraint set, which we can create a specialization from (`T = int, U = list[int]`). But
/// `(T ≤ list[U]) ∧ (U ≤ list[T])` does not violate our lower/upper bounds restrictions, since
/// neither bound _is_ a typevar. And it's not something we can create a specialization from,
/// since we would endlessly substitute until we stack overflow.
pub(crate) fn is_cyclic(self, db: &'db dyn Db) -> bool {
#[derive(Default)]
struct CollectReachability<'db> {
reachable_typevars: RefCell<FxHashSet<BoundTypeVarIdentity<'db>>>,
recursion_guard: TypeCollector<'db>,
}
impl<'db> TypeVisitor<'db> for CollectReachability<'db> {
fn should_visit_lazy_type_attributes(&self) -> bool {
true
}
fn visit_bound_type_var_type(
&self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
) {
self.reachable_typevars
.borrow_mut()
.insert(bound_typevar.identity(db));
walk_bound_type_var_type(db, bound_typevar, self);
}
fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard);
}
}
fn visit_dfs<'db>(
reachable_typevars: &mut FxHashMap<
BoundTypeVarIdentity<'db>,
FxHashSet<BoundTypeVarIdentity<'db>>,
>,
discovered: &mut FxHashSet<BoundTypeVarIdentity<'db>>,
bound_typevar: BoundTypeVarIdentity<'db>,
) -> bool {
discovered.insert(bound_typevar);
let outgoing = reachable_typevars
.remove(&bound_typevar)
.expect("should not visit typevar twice in DFS");
for outgoing in outgoing {
if discovered.contains(&outgoing) {
return true;
}
if reachable_typevars.contains_key(&outgoing) {
if visit_dfs(reachable_typevars, discovered, outgoing) {
return true;
}
}
}
discovered.remove(&bound_typevar);
false
}
// First find all of the typevars that each constraint directly mentions.
let mut reachable_typevars: FxHashMap<
BoundTypeVarIdentity<'db>,
FxHashSet<BoundTypeVarIdentity<'db>>,
> = FxHashMap::default();
self.node.for_each_constraint(db, &mut |constraint| {
let visitor = CollectReachability::default();
visitor.visit_type(db, constraint.lower(db));
visitor.visit_type(db, constraint.upper(db));
reachable_typevars
.entry(constraint.typevar(db).identity(db))
.or_default()
.extend(visitor.reachable_typevars.into_inner());
});
// Then perform a depth-first search to see if there are any cycles.
let mut discovered: FxHashSet<BoundTypeVarIdentity<'db>> = FxHashSet::default();
while let Some(bound_typevar) = reachable_typevars.keys().copied().next() {
if !discovered.contains(&bound_typevar) {
let cycle_found =
visit_dfs(&mut reachable_typevars, &mut discovered, bound_typevar);
if cycle_found {
return true;
}
}
}
false
}
/// Returns the constraints under which `lhs` is a subtype of `rhs`, assuming that the
/// constraints in this constraint set hold. Panics if neither of the types being compared are
/// a typevar. (That case is handled by `Type::has_relation_to`.)
@ -2964,6 +3060,12 @@ impl<'db> GenericContext<'db> {
db: &'db dyn Db,
constraints: ConstraintSet<'db>,
) -> Result<Specialization<'db>, ()> {
// If the constraint set is cyclic, don't even try to construct a specialization.
if constraints.is_cyclic(db) {
// TODO: Better error
return Err(());
}
// First we intersect with the valid specializations of all of the typevars. We need all of
// valid specializations to hold simultaneously, so we do this once before abstracting over
// each typevar.
@ -3020,7 +3122,7 @@ impl<'db> GenericContext<'db> {
types[i] = least_upper_bound;
}
Ok(self.specialize(db, types.into_boxed_slice()))
Ok(self.specialize_recursive(db, types.into_boxed_slice()))
}
}

View File

@ -500,9 +500,16 @@ impl<'db> GenericContext<'db> {
}
/// Creates a specialization of this generic context. Panics if the length of `types` does not
/// match the number of typevars in the generic context. You must provide a specific type for
/// each typevar; no defaults are used. (Use [`specialize_partial`](Self::specialize_partial)
/// if you might not have types for every typevar.)
/// match the number of typevars in the generic context.
///
/// You must provide a specific type for each typevar; no defaults are used. (Use
/// [`specialize_partial`](Self::specialize_partial) if you might not have types for every
/// typevar.)
///
/// The types you provide should not mention any of the typevars in this generic context;
/// otherwise, you will be left with a partial specialization. (Use
/// [`specialize_recursive`](Self::specialize_recursive) if your types might mention typevars
/// in this generic context.)
pub(crate) fn specialize(
self,
db: &'db dyn Db,
@ -512,6 +519,41 @@ impl<'db> GenericContext<'db> {
Specialization::new(db, self, types, None, None)
}
/// Creates a specialization of this generic context. Panics if the length of `types` does not
/// match the number of typevars in the generic context.
///
/// You are allowed to provide types that mention the typevars in this generic context.
pub(crate) fn specialize_recursive(
self,
db: &'db dyn Db,
mut types: Box<[Type<'db>]>,
) -> Specialization<'db> {
let len = types.len();
assert!(self.len(db) == len);
loop {
let mut any_changed = false;
for i in 0..len {
let partial = PartialSpecialization {
generic_context: self,
types: &types,
};
let updated = types[i].apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(partial),
TypeContext::default(),
);
if updated != types[i] {
types[i] = updated;
any_changed = true;
}
}
if !any_changed {
return Specialization::new(db, self, types, None, None);
}
}
}
/// Creates a specialization of this generic context for the `tuple` class.
pub(crate) fn specialize_tuple(
self,