only fold once

This commit is contained in:
Douglas Creager 2025-12-15 09:55:17 -05:00
parent 358185b5e2
commit 63c75d85d0
1 changed files with 105 additions and 107 deletions

View File

@ -1357,7 +1357,7 @@ impl<'db> Node<'db> {
self, self,
db: &'db dyn Db, db: &'db dyn Db,
bound_typevar: BoundTypeVarIdentity<'db>, bound_typevar: BoundTypeVarIdentity<'db>,
mut f: impl FnMut(Option<(Type<'db>, Type<'db>, usize)>), mut f: impl FnMut(Option<&[RepresentativeBounds<'db>]>),
) { ) {
self.retain_one(db, bound_typevar) self.retain_one(db, bound_typevar)
.find_representative_types_inner(db, &mut Vec::default(), &mut f); .find_representative_types_inner(db, &mut Vec::default(), &mut f);
@ -1367,43 +1367,37 @@ impl<'db> Node<'db> {
self, self,
db: &'db dyn Db, db: &'db dyn Db,
current_bounds: &mut Vec<RepresentativeBounds<'db>>, current_bounds: &mut Vec<RepresentativeBounds<'db>>,
f: &mut dyn FnMut(Option<(Type<'db>, Type<'db>, usize)>), f: &mut dyn FnMut(Option<&[RepresentativeBounds<'db>]>),
) { ) {
match self { match self {
Node::AlwaysTrue => { Node::AlwaysTrue => {
// If we reach the `true` terminal, the path we've been following represents one
// representative type.
if current_bounds.is_empty() { if current_bounds.is_empty() {
f(None); f(None);
return; return;
} }
// If we reach the `true` terminal, the path we've been following represents one // If `lower ≰ upper`, then this path somehow represents in invalid specialization.
// representative type. Before constructing the final lower and upper bound, sort // That should have been removed from the BDD domain as part of the simplification
// the constraints by their source order. This should give us a consistently // process. (Here we are just checking assignability, so we don't need to construct
// ordered specialization, regardless of the variable ordering of the original BDD. // the lower and upper bounds in a consistent order.)
current_bounds.sort_unstable_by_key(|bounds| bounds.source_order); debug_assert!({
let greatest_lower_bound = let greatest_lower_bound = UnionType::from_elements(
UnionType::from_elements(db, current_bounds.iter().map(|bounds| bounds.lower)); db,
current_bounds.iter().map(|bounds| bounds.lower),
);
let least_upper_bound = IntersectionType::from_elements( let least_upper_bound = IntersectionType::from_elements(
db, db,
current_bounds.iter().map(|bounds| bounds.upper), current_bounds.iter().map(|bounds| bounds.upper),
); );
greatest_lower_bound.is_assignable_to(db, least_upper_bound)
// If `lower ≰ upper`, then this path somehow represents in invalid specialization. });
// That should have been removed from the BDD domain as part of the simplification
// process.
debug_assert!(greatest_lower_bound.is_assignable_to(db, least_upper_bound));
// SAFETY: Checked that current_bounds is non-empty above.
let minimum_source_order = current_bounds[0].source_order;
// We've been tracking the lower and upper bound that the types for this path must // We've been tracking the lower and upper bound that the types for this path must
// satisfy. Pass those bounds along and let the caller choose a representative type // satisfy. Pass those bounds along and let the caller choose a representative type
// from within that range. // from within that range.
f(Some(( f(Some(&current_bounds));
greatest_lower_bound,
least_upper_bound,
minimum_source_order,
)));
} }
Node::AlwaysFalse => { Node::AlwaysFalse => {
@ -1771,6 +1765,7 @@ impl<'db> Node<'db> {
} }
} }
#[derive(Clone, Copy, Debug)]
struct RepresentativeBounds<'db> { struct RepresentativeBounds<'db> {
lower: Type<'db>, lower: Type<'db>,
upper: Type<'db>, upper: Type<'db>,
@ -3674,8 +3669,9 @@ impl<'db> GenericContext<'db> {
// Then we find all of the "representative types" for each typevar in the constraint set. // Then we find all of the "representative types" for each typevar in the constraint set.
let mut error_occurred = false; let mut error_occurred = false;
let mut constraints = Vec::new(); let mut representatives = Vec::new();
let types = self.variables(db).map(|bound_typevar| { let types =
self.variables(db).map(|bound_typevar| {
// Each representative type represents one of the ways that the typevar can satisfy the // Each representative type represents one of the ways that the typevar can satisfy the
// constraint, expressed as a lower/upper bound on the types that the typevar can // constraint, expressed as a lower/upper bound on the types that the typevar can
// specialize to. // specialize to.
@ -3694,21 +3690,16 @@ impl<'db> GenericContext<'db> {
abstracted = %abstracted.retain_one(db, identity).display(db), abstracted = %abstracted.retain_one(db, identity).display(db),
"find specialization for typevar", "find specialization for typevar",
); );
constraints.clear(); representatives.clear();
abstracted.find_representative_types(db, identity, |bounds| match bounds { abstracted.find_representative_types(db, identity, |representative| {
Some(bounds @ (lower_bound, upper_bound, _)) => { match representative {
tracing::trace!( Some(representative) => {
target: "ty_python_semantic::types::constraints::specialize_constrained", representatives.extend_from_slice(representative);
bound_typevar = %identity.display(db),
lower_bound = %lower_bound.display(db),
upper_bound = %upper_bound.display(db),
"found representative type",
);
constraints.push(bounds);
} }
None => { None => {
unconstrained = true; unconstrained = true;
} }
}
}); });
// The BDD is satisfiable, but the typevar is unconstrained, then we use `None` to tell // The BDD is satisfiable, but the typevar is unconstrained, then we use `None` to tell
@ -3724,7 +3715,7 @@ impl<'db> GenericContext<'db> {
// If there are no satisfiable paths in the BDD, then there is no valid specialization // If there are no satisfiable paths in the BDD, then there is no valid specialization
// for this constraint set. // for this constraint set.
if constraints.is_empty() { if representatives.is_empty() {
// TODO: Construct a useful error here // TODO: Construct a useful error here
tracing::debug!( tracing::debug!(
target: "ty_python_semantic::types::constraints::specialize_constrained", target: "ty_python_semantic::types::constraints::specialize_constrained",
@ -3735,12 +3726,19 @@ impl<'db> GenericContext<'db> {
return None; return None;
} }
// Before constructing the final lower and upper bound, sort the constraints by
// their source order. This should give us a consistently ordered specialization,
// regardless of the variable ordering of the original BDD.
representatives.sort_unstable_by_key(|bounds| bounds.source_order);
let greatest_lower_bound =
UnionType::from_elements(db, representatives.iter().map(|bounds| bounds.lower));
let least_upper_bound = IntersectionType::from_elements(
db,
representatives.iter().map(|bounds| bounds.upper),
);
// If `lower ≰ upper`, then there is no type that satisfies all of the paths in the // If `lower ≰ upper`, then there is no type that satisfies all of the paths in the
// BDD. That's an ambiguous specialization, as described above. // BDD. That's an ambiguous specialization, as described above.
let greatest_lower_bound =
UnionType::from_elements(db, constraints.iter().map(|(lower, _, _)| *lower));
let least_upper_bound =
IntersectionType::from_elements(db, constraints.iter().map(|(_, upper, _)| *upper));
if !greatest_lower_bound.is_assignable_to(db, least_upper_bound) { if !greatest_lower_bound.is_assignable_to(db, least_upper_bound) {
tracing::debug!( tracing::debug!(
target: "ty_python_semantic::types::constraints::specialize_constrained", target: "ty_python_semantic::types::constraints::specialize_constrained",