fix negation logic

This commit is contained in:
Douglas Creager 2025-12-10 10:24:44 -05:00
parent 0cecb9f4ef
commit 1518127d79
3 changed files with 39 additions and 22 deletions

View File

@ -1268,8 +1268,14 @@ impl<'db> Type<'db> {
self.as_union().expect("Expected a Type::Union variant") self.as_union().expect("Expected a Type::Union variant")
} }
pub(crate) const fn is_intersection(self) -> bool { /// Returns whether this is a "real" intersection type. (Negated types are represented by an
matches!(self, Type::Intersection(_)) /// intersection containing a single negative branch, which this method does _not_ consider a
/// "real" intersection.)
pub(crate) fn is_nontrivial_intersection(self, db: &'db dyn Db) -> bool {
match self {
Type::Intersection(intersection) => !intersection.is_simple_negation(db),
_ => false,
}
} }
pub(crate) const fn as_function_literal(self) -> Option<FunctionType<'db>> { pub(crate) const fn as_function_literal(self) -> Option<FunctionType<'db>> {
@ -14123,6 +14129,10 @@ impl<'db> IntersectionType<'db> {
(self.positive(db).len() + self.negative(db).len()) == 1 (self.positive(db).len() + self.negative(db).len()) == 1
} }
pub(crate) fn is_simple_negation(self, db: &'db dyn Db) -> bool {
self.positive(db).len() == 0 && self.negative(db).len() == 1
}
fn heap_size((positive, negative): &(FxOrderSet<Type<'db>>, FxOrderSet<Type<'db>>)) -> usize { fn heap_size((positive, negative): &(FxOrderSet<Type<'db>>, FxOrderSet<Type<'db>>)) -> usize {
ruff_memory_usage::order_set_heap_size(positive) ruff_memory_usage::order_set_heap_size(positive)
+ ruff_memory_usage::order_set_heap_size(negative) + ruff_memory_usage::order_set_heap_size(negative)

View File

@ -435,6 +435,11 @@ impl<'db> ConstraintSet<'db> {
pub(crate) fn display(self, db: &'db dyn Db) -> impl Display { pub(crate) fn display(self, db: &'db dyn Db) -> impl Display {
self.node.simplify_for_display(db).display(db) self.node.simplify_for_display(db).display(db)
} }
#[expect(dead_code)] // Keep this around for debugging purposes
pub(crate) fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
self.node.display_graph(db, prefix)
}
} }
impl From<bool> for ConstraintSet<'_> { impl From<bool> for ConstraintSet<'_> {
@ -498,11 +503,13 @@ impl<'db> ConstrainedTypeVar<'db> {
debug_assert_eq!(upper, upper.top_materialization(db)); debug_assert_eq!(upper, upper.top_materialization(db));
// It's not useful for an upper bound to be an intersection type, or for a lower bound to // It's not useful for an upper bound to be an intersection type, or for a lower bound to
// be a union type. Both of those can be rewritten as simpler BDDs: // be a union type. Because the following equivalences hold, we can break these bounds
// apart and create an equivalent BDD with more nodes but simpler constraints. (Fewer,
// simpler constraints mean that our sequent maps won't grow pathologically large.)
// //
// T ≤ α & β ⇒ (T ≤ α) ∧ (T ≤ β) // T ≤ (α & β) ⇔ (T ≤ α) ∧ (T ≤ β)
// T ≤ α & ¬β ⇒ (T ≤ α) ∧ ¬(T ≤ β) // T ≤ (¬α & ¬β) ⇔ (T ≤ ¬α) ∧ (T ≤ ¬β)
// α | β ≤ T ⇒ (α ≤ T) ∧ (β ≤ T) // (α | β) ≤ T ⇔ (α ≤ T) ∧ (β ≤ T)
if let Type::Union(lower_union) = lower { if let Type::Union(lower_union) = lower {
let mut result = Node::AlwaysTrue; let mut result = Node::AlwaysTrue;
for lower_element in lower_union.elements(db) { for lower_element in lower_union.elements(db) {
@ -513,7 +520,12 @@ impl<'db> ConstrainedTypeVar<'db> {
} }
return result; return result;
} }
if let Type::Intersection(upper_intersection) = upper { // A negated type ¬α is represented as an intersection with no positive elements, and a
// single negative element. We _don't_ want to treat that an "intersection" for the
// purposes of simplifying upper bounds.
if let Type::Intersection(upper_intersection) = upper
&& !upper_intersection.is_simple_negation(db)
{
let mut result = Node::AlwaysTrue; let mut result = Node::AlwaysTrue;
for upper_element in upper_intersection.iter_positive(db) { for upper_element in upper_intersection.iter_positive(db) {
result = result.and( result = result.and(
@ -524,7 +536,7 @@ impl<'db> ConstrainedTypeVar<'db> {
for upper_element in upper_intersection.iter_negative(db) { for upper_element in upper_intersection.iter_negative(db) {
result = result.and( result = result.and(
db, db,
ConstrainedTypeVar::new_node(db, typevar, lower, upper_element).negate(db), ConstrainedTypeVar::new_node(db, typevar, lower, upper_element.negate(db)),
); );
} }
return result; return result;
@ -716,7 +728,7 @@ impl<'db> ConstrainedTypeVar<'db> {
return IntersectionResult::Disjoint; return IntersectionResult::Disjoint;
} }
if lower.is_union() || upper.is_intersection() { if lower.is_union() || upper.is_nontrivial_intersection(db) {
return IntersectionResult::CannotSimplify; return IntersectionResult::CannotSimplify;
} }

View File

@ -10844,19 +10844,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
( (
Type::KnownInstance(KnownInstanceType::ConstraintSet(left)), Type::KnownInstance(KnownInstanceType::ConstraintSet(left)),
Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), Type::KnownInstance(KnownInstanceType::ConstraintSet(right)),
) => { ) => match op {
let result = match op { ast::CmpOp::Eq => Some(Ok(Type::BooleanLiteral(
ast::CmpOp::Eq => Some( left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).is_always_satisfied(self.db()),
left.constraints(self.db()).iff(self.db(), right.constraints(self.db())) ))),
), ast::CmpOp::NotEq => Some(Ok(Type::BooleanLiteral(
ast::CmpOp::NotEq => Some( !left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).is_always_satisfied(self.db()),
left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).negate(self.db()) ))),
), _ => None,
_ => None,
};
result.map(|constraints| Ok(Type::KnownInstance(KnownInstanceType::ConstraintSet(
TrackedConstraintSet::new(self.db(), constraints)
))))
} }
( (