This commit is contained in:
Douglas Creager 2025-12-03 16:23:59 -05:00
parent 705e4725ad
commit a59fae85cc
3 changed files with 245 additions and 49 deletions

View File

@ -69,19 +69,21 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::fmt::Display; use std::fmt::Display;
use std::ops::Range; use std::ops::{Deref, DerefMut, Range};
use itertools::Itertools; use itertools::Itertools;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use salsa::plumbing::AsId; use salsa::plumbing::AsId;
use crate::types::generics::{GenericContext, InferableTypeVars, Specialization}; use crate::types::generics::{
GenericContext, InferableTypeVars, PartialSpecialization, Specialization,
};
use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard};
use crate::types::{ use crate::types::{
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation, BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeContext, TypeMapping,
TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type, TypeRelation, TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type,
}; };
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderMap, FxOrderSet};
/// An extension trait for building constraint sets from [`Option`] values. /// An extension trait for building constraint sets from [`Option`] values.
pub(crate) trait OptionConstraintsExtension<T> { pub(crate) trait OptionConstraintsExtension<T> {
@ -1094,6 +1096,14 @@ impl<'db> Node<'db> {
self.implies(db, constraint) self.implies(db, constraint)
} }
fn typevars(self, db: &'db dyn Db) -> FxHashSet<BoundTypeVarInstance<'db>> {
let mut typevars = FxHashSet::default();
self.for_each_constraint(db, &mut |constraint| {
typevars.insert(constraint.typevar(db));
});
typevars
}
fn satisfied_by_all_typevars( fn satisfied_by_all_typevars(
self, self,
db: &'db dyn Db, db: &'db dyn Db,
@ -1105,11 +1115,6 @@ impl<'db> Node<'db> {
Node::Interior(_) => {} Node::Interior(_) => {}
} }
let mut typevars = FxHashSet::default();
self.for_each_constraint(db, &mut |constraint| {
typevars.insert(constraint.typevar(db));
});
// Returns if some specialization satisfies this constraint set. // Returns if some specialization satisfies this constraint set.
let some_specialization_satisfies = move |specializations: Node<'db>| { let some_specialization_satisfies = move |specializations: Node<'db>| {
let when_satisfied = specializations.implies(db, self).and(db, specializations); let when_satisfied = specializations.implies(db, self).and(db, specializations);
@ -1124,7 +1129,7 @@ impl<'db> Node<'db> {
.is_always_satisfied(db) .is_always_satisfied(db)
}; };
for typevar in typevars { for typevar in self.typevars(db) {
if typevar.is_inferable(db, inferable) { if typevar.is_inferable(db, inferable) {
// If the typevar is in inferable position, we need to verify that some valid // If the typevar is in inferable position, we need to verify that some valid
// specialization satisfies the constraint set. // specialization satisfies the constraint set.
@ -3354,6 +3359,219 @@ impl<'db> BoundTypeVarInstance<'db> {
} }
} }
#[derive(Clone, Debug, Default)]
pub(crate) struct ConstraintSolutions<'db>(Vec<ConstraintSolution<'db>>);
impl<'db> Deref for ConstraintSolutions<'db> {
type Target = Vec<ConstraintSolution<'db>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'db> DerefMut for ConstraintSolutions<'db> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<'a, 'db> IntoIterator for &'a ConstraintSolutions<'db> {
type Item = &'a ConstraintSolution<'db>;
type IntoIter = std::slice::Iter<'a, ConstraintSolution<'db>>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct ConstraintSolution<'db> {
assignments: FxOrderMap<BoundTypeVarInstance<'db>, (Type<'db>, Type<'db>)>,
}
impl<'db> ConstraintSolution<'db> {
pub(crate) fn get_upper_bound(
&self,
bound_typevar: BoundTypeVarInstance<'db>,
) -> Option<Type<'db>> {
self.assignments
.get(&bound_typevar)
.map(|(_, upper_bound)| *upper_bound)
}
fn add_mapping(
&mut self,
db: &'db dyn Db,
typevar: BoundTypeVarInstance<'db>,
lower_bound: Type<'db>,
upper_bound: Type<'db>,
) {
eprintln!(
" -> ADD {} ≤ {} ≤ {}",
lower_bound.display(db),
typevar.identity(db).display(db),
upper_bound.display(db),
);
let (existing_lower_bound, existing_upper_bound) = self
.assignments
.entry(typevar)
.or_insert_with(|| (Type::Never, Type::object()));
let new_lower_bound = UnionType::from_elements(db, [*existing_lower_bound, lower_bound]);
let new_upper_bound =
IntersectionType::from_elements(db, [*existing_upper_bound, upper_bound]);
*existing_lower_bound = new_lower_bound;
*existing_upper_bound = new_upper_bound;
}
fn collapse_to_single_types(&mut self) {
for (_, (lower_bound, upper_bound)) in &mut self.assignments {
// Use the lower bound if it's more "interesting", otherwise use the upper bound.
if upper_bound.is_object() && !lower_bound.is_never() {
*upper_bound = *lower_bound;
} else {
*lower_bound = *upper_bound;
}
}
}
fn close_over_typevars(&mut self, db: &'db dyn Db) {
// We have to pull this out into a separate variable to satisfy the borrow checker.
let typevars: Vec<_> = self.assignments.keys().copied().collect();
loop {
let mut any_changed = false;
for bound_typevar in &typevars {
let (existing_lower, existing_upper) = self.assignments[bound_typevar];
let updated_lower = existing_lower.apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(
PartialSpecialization::FromConstraintSolution(self),
),
TypeContext::default(),
);
let updated_upper = existing_upper.apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(
PartialSpecialization::FromConstraintSolution(self),
),
TypeContext::default(),
);
if updated_lower != existing_lower || updated_upper != existing_upper {
self.assignments[bound_typevar] = (updated_lower, updated_upper);
any_changed = true;
}
}
if !any_changed {
return;
}
}
}
pub(crate) fn display(&self, db: &'db dyn Db) -> impl Display {
struct DisplayConstraintSolution<'a, 'db> {
solution: &'a ConstraintSolution<'db>,
db: &'db dyn Db,
}
impl Display for DisplayConstraintSolution<'_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[{}]",
self.solution.assignments.iter().format_with(
", ",
|(bound_typevar, (lower, upper)), f| if lower == upper {
f(&format_args!(
"{} = {}",
bound_typevar.identity(self.db).display(self.db),
lower.display(self.db)
))
} else {
f(&format_args!(
"{} ≤ {} ≤ {}",
lower.display(self.db),
bound_typevar.identity(self.db).display(self.db),
upper.display(self.db),
))
}
)
)
}
}
DisplayConstraintSolution { solution: self, db }
}
}
impl<'a, 'db> IntoIterator for &'a ConstraintSolution<'db> {
type Item = (&'a BoundTypeVarInstance<'db>, &'a (Type<'db>, Type<'db>));
type IntoIter = ordermap::map::Iter<'a, BoundTypeVarInstance<'db>, (Type<'db>, Type<'db>)>;
fn into_iter(self) -> Self::IntoIter {
self.assignments.iter()
}
}
impl<'db> ConstraintSet<'db> {
pub(crate) fn solve_for(
self,
db: &'db dyn Db,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSolutions<'db> {
let mut solutions = ConstraintSolutions::default();
self.for_each_path(db, |path| {
eprintln!(
" -> PATH [{}]",
path.assignments
.iter()
.format_with(", ", |assignment, f| f(&assignment.display(db)))
);
let mut solution = ConstraintSolution::default();
for constraint in path.positive_constraints() {
let typevar = constraint.typevar(db);
let lower = constraint.lower(db);
let upper = constraint.upper(db);
eprintln!(
" ~> constraint {} ≤ {} ≤ {}",
lower.display(db),
typevar.identity(db).display(db),
upper.display(db)
);
if typevar.is_inferable(db, inferable) {
eprintln!(" ~> add1");
solution.add_mapping(db, typevar, lower, upper);
}
if let Type::TypeVar(lower_bound_typevar) = lower
&& lower_bound_typevar.is_inferable(db, inferable)
{
eprintln!(" ~> add2");
solution.add_mapping(
db,
lower_bound_typevar,
Type::Never,
Type::TypeVar(typevar),
);
}
if let Type::TypeVar(upper_bound_typevar) = upper
&& upper_bound_typevar.is_inferable(db, inferable)
{
eprintln!(" ~> add3");
solution.add_mapping(
db,
upper_bound_typevar,
Type::TypeVar(typevar),
Type::object(),
);
}
}
solution.collapse_to_single_types();
solution.close_over_typevars(db);
solutions.push(solution);
});
solutions
}
}
impl<'db> GenericContext<'db> { impl<'db> GenericContext<'db> {
pub(crate) fn specialize_constrained( pub(crate) fn specialize_constrained(
self, self,

View File

@ -11,7 +11,7 @@ use crate::semantic_index::scope::{FileScopeId, NodeWithScopeKind, ScopeId};
use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::class::ClassType; use crate::types::class::ClassType;
use crate::types::class_base::ClassBase; use crate::types::class_base::ClassBase;
use crate::types::constraints::ConstraintSet; use crate::types::constraints::{ConstraintSet, ConstraintSolution};
use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::instance::{Protocol, ProtocolInstanceType};
use crate::types::signatures::Parameters; use crate::types::signatures::Parameters;
use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type};
@ -1304,6 +1304,7 @@ pub enum PartialSpecialization<'a, 'db> {
generic_context: GenericContext<'db>, generic_context: GenericContext<'db>,
types: &'a [Type<'db>], types: &'a [Type<'db>],
}, },
FromConstraintSolution(&'a ConstraintSolution<'db>),
} }
impl<'db> PartialSpecialization<'_, 'db> { impl<'db> PartialSpecialization<'_, 'db> {
@ -1324,6 +1325,9 @@ impl<'db> PartialSpecialization<'_, 'db> {
.get_index_of(&bound_typevar.identity(db))?; .get_index_of(&bound_typevar.identity(db))?;
types.get(index).copied() types.get(index).copied()
} }
PartialSpecialization::FromConstraintSolution(solution) => {
solution.get_upper_bound(bound_typevar)
}
} }
} }
} }
@ -1661,6 +1665,7 @@ impl<'db> SpecializationBuilder<'db> {
(Type::Callable(formal_callable), _) => { (Type::Callable(formal_callable), _) => {
eprintln!("==> {}", formal.display(self.db)); eprintln!("==> {}", formal.display(self.db));
eprintln!(" {}", actual.display(self.db)); eprintln!(" {}", actual.display(self.db));
eprintln!(" {}", self.inferable.display(self.db));
let Some(actual_callables) = actual.try_upcast_to_callable(self.db) else { let Some(actual_callables) = actual.try_upcast_to_callable(self.db) else {
eprintln!(" -> NOPE"); eprintln!(" -> NOPE");
return Ok(()); return Ok(());
@ -1688,42 +1693,13 @@ impl<'db> SpecializationBuilder<'db> {
eprintln!(" {}", when.display(self.db)); eprintln!(" {}", when.display(self.db));
eprintln!(" {}", when.display_graph(self.db, &" ")); eprintln!(" {}", when.display_graph(self.db, &" "));
when.for_each_path(self.db, |path| { let solutions = when.solve_for(self.db, self.inferable);
eprintln!( for solution in &solutions {
"--> path [{}]", eprintln!("--> solution [{}]", solution.display(self.db));
path.positive_constraints() for (bound_typevar, (_, ty)) in solution {
.map(|c| c.display(self.db)) self.add_type_mapping(*bound_typevar, *ty, polarity, &mut f);
.format(", ")
);
for constraint in path.positive_constraints() {
let typevar = constraint.typevar(self.db);
let lower = constraint.lower(self.db);
let upper = constraint.upper(self.db);
let upper_has_noninferable_typevar = any_over_type(
self.db,
upper,
&|ty| {
ty.as_typevar().is_some_and(|bound_typevar| {
!bound_typevar.is_inferable(self.db, self.inferable)
})
},
false,
);
if !upper.is_object() && !upper_has_noninferable_typevar {
self.add_type_mapping(typevar, upper, polarity, &mut f);
}
if typevar.is_inferable(self.db, self.inferable)
&& let Type::TypeVar(lower_bound_typevar) = lower
{
self.add_type_mapping(
lower_bound_typevar,
Type::TypeVar(typevar),
polarity,
&mut f,
);
}
} }
}); }
} }
// TODO: Add more forms that we can structurally induct into: type[C], callables // TODO: Add more forms that we can structurally induct into: type[C], callables

View File

@ -752,7 +752,8 @@ impl<'db> Signature<'db> {
// we produce, we reduce it back down to the inferable set that the caller asked about. // we produce, we reduce it back down to the inferable set that the caller asked about.
// If we introduced new inferable typevars, those will be existentially quantified away // If we introduced new inferable typevars, those will be existentially quantified away
// before returning. // before returning.
when.reduce_inferable(db, self_inferable.iter().chain(other_inferable.iter())) //when.reduce_inferable(db, self_inferable.iter().chain(other_inferable.iter()))
when
} }
fn is_equivalent_to_inner( fn is_equivalent_to_inner(
@ -898,7 +899,8 @@ impl<'db> Signature<'db> {
// we produce, we reduce it back down to the inferable set that the caller asked about. // we produce, we reduce it back down to the inferable set that the caller asked about.
// If we introduced new inferable typevars, those will be existentially quantified away // If we introduced new inferable typevars, those will be existentially quantified away
// before returning. // before returning.
when.reduce_inferable(db, self_inferable.iter().chain(other_inferable.iter())) //when.reduce_inferable(db, self_inferable.iter().chain(other_inferable.iter()))
when
} }
fn has_relation_to_inner( fn has_relation_to_inner(