[ty] Fix negation upper bounds in constraint sets (#21897)

This fixes the logic error that @sharkdp
[found](https://github.com/astral-sh/ruff/pull/21871#discussion_r2605755588)
in the constraint set upper bound normalization logic I introduced in
#21871.

I had originally claimed that `(T ≤ α & ~β)` should simplify into `(T ≤
α) ∧ ¬(T ≤ β)`. But that also suggests that `T ≤ ~β` should simplify to
`¬(T ≤ β)` on its own, and that's not correct.

The correct simplification is that `~α` is an "atomic" type, not an
"intersection" for the purposes of our upper bound simplifcation. So `(T
≤ α & ~β)` should simplify to `(T ≤ α) ∧ (T ≤ ~β)`. That is, break apart
the elements of a (proper) intersection, regardless of whether each
element is negated or not.

This PR fixes the logic, adds a test case, and updates the comments to
be hopefully more clear and accurate.
This commit is contained in:
Douglas Creager 2025-12-10 15:07:50 -05:00 committed by GitHub
parent 5dc0079e78
commit 3e00221a6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 24 deletions

View File

@ -126,7 +126,7 @@ strict subtype of the lower bound, a strict supertype of the upper bound, or inc
```py ```py
from typing import Any, final, Never, Sequence from typing import Any, final, Never, Sequence
from ty_extensions import ConstraintSet, static_assert from ty_extensions import ConstraintSet, Not, static_assert
class Super: ... class Super: ...
class Base(Super): ... class Base(Super): ...
@ -207,6 +207,15 @@ def _[T]() -> None:
static_assert(constraints == expected) static_assert(constraints == expected)
``` ```
A negated _type_ is not the same thing as a negated _range_.
```py
def _[T]() -> None:
negated_type = ConstraintSet.range(Never, T, Not[int])
negated_constraint = ~ConstraintSet.range(Never, T, int)
static_assert(negated_type != negated_constraint)
```
## Intersection ## Intersection
The intersection of two constraint sets requires that the constraints in both sets hold. In many The intersection of two constraint sets requires that the constraints in both sets hold. In many

View File

@ -1268,8 +1268,14 @@ impl<'db> Type<'db> {
self.as_union().expect("Expected a Type::Union variant") self.as_union().expect("Expected a Type::Union variant")
} }
pub(crate) const fn is_intersection(self) -> bool { /// Returns whether this is a "real" intersection type. (Negated types are represented by an
matches!(self, Type::Intersection(_)) /// intersection containing a single negative branch, which this method does _not_ consider a
/// "real" intersection.)
pub(crate) fn is_nontrivial_intersection(self, db: &'db dyn Db) -> bool {
match self {
Type::Intersection(intersection) => !intersection.is_simple_negation(db),
_ => false,
}
} }
pub(crate) const fn as_function_literal(self) -> Option<FunctionType<'db>> { pub(crate) const fn as_function_literal(self) -> Option<FunctionType<'db>> {
@ -14151,6 +14157,10 @@ impl<'db> IntersectionType<'db> {
(self.positive(db).len() + self.negative(db).len()) == 1 (self.positive(db).len() + self.negative(db).len()) == 1
} }
pub(crate) fn is_simple_negation(self, db: &'db dyn Db) -> bool {
self.positive(db).is_empty() && self.negative(db).len() == 1
}
fn heap_size((positive, negative): &(FxOrderSet<Type<'db>>, FxOrderSet<Type<'db>>)) -> usize { fn heap_size((positive, negative): &(FxOrderSet<Type<'db>>, FxOrderSet<Type<'db>>)) -> usize {
ruff_memory_usage::order_set_heap_size(positive) ruff_memory_usage::order_set_heap_size(positive)
+ ruff_memory_usage::order_set_heap_size(negative) + ruff_memory_usage::order_set_heap_size(negative)

View File

@ -435,6 +435,11 @@ impl<'db> ConstraintSet<'db> {
pub(crate) fn display(self, db: &'db dyn Db) -> impl Display { pub(crate) fn display(self, db: &'db dyn Db) -> impl Display {
self.node.simplify_for_display(db).display(db) self.node.simplify_for_display(db).display(db)
} }
#[expect(dead_code)] // Keep this around for debugging purposes
pub(crate) fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
self.node.display_graph(db, prefix)
}
} }
impl From<bool> for ConstraintSet<'_> { impl From<bool> for ConstraintSet<'_> {
@ -498,11 +503,13 @@ impl<'db> ConstrainedTypeVar<'db> {
debug_assert_eq!(upper, upper.top_materialization(db)); debug_assert_eq!(upper, upper.top_materialization(db));
// It's not useful for an upper bound to be an intersection type, or for a lower bound to // It's not useful for an upper bound to be an intersection type, or for a lower bound to
// be a union type. Both of those can be rewritten as simpler BDDs: // be a union type. Because the following equivalences hold, we can break these bounds
// apart and create an equivalent BDD with more nodes but simpler constraints. (Fewer,
// simpler constraints mean that our sequent maps won't grow pathologically large.)
// //
// T ≤ α & β ⇒ (T ≤ α) ∧ (T ≤ β) // T ≤ (α & β) ⇔ (T ≤ α) ∧ (T ≤ β)
// T ≤ α & ¬β ⇒ (T ≤ α) ∧ ¬(T ≤ β) // T ≤ (¬α & ¬β) ⇔ (T ≤ ¬α) ∧ (T ≤ ¬β)
// α | β ≤ T ⇒ (α ≤ T) ∧ (β ≤ T) // (α | β) ≤ T ⇔ (α ≤ T) ∧ (β ≤ T)
if let Type::Union(lower_union) = lower { if let Type::Union(lower_union) = lower {
let mut result = Node::AlwaysTrue; let mut result = Node::AlwaysTrue;
for lower_element in lower_union.elements(db) { for lower_element in lower_union.elements(db) {
@ -513,7 +520,12 @@ impl<'db> ConstrainedTypeVar<'db> {
} }
return result; return result;
} }
if let Type::Intersection(upper_intersection) = upper { // A negated type ¬α is represented as an intersection with no positive elements, and a
// single negative element. We _don't_ want to treat that an "intersection" for the
// purposes of simplifying upper bounds.
if let Type::Intersection(upper_intersection) = upper
&& !upper_intersection.is_simple_negation(db)
{
let mut result = Node::AlwaysTrue; let mut result = Node::AlwaysTrue;
for upper_element in upper_intersection.iter_positive(db) { for upper_element in upper_intersection.iter_positive(db) {
result = result.and( result = result.and(
@ -524,7 +536,7 @@ impl<'db> ConstrainedTypeVar<'db> {
for upper_element in upper_intersection.iter_negative(db) { for upper_element in upper_intersection.iter_negative(db) {
result = result.and( result = result.and(
db, db,
ConstrainedTypeVar::new_node(db, typevar, lower, upper_element).negate(db), ConstrainedTypeVar::new_node(db, typevar, lower, upper_element.negate(db)),
); );
} }
return result; return result;
@ -716,7 +728,7 @@ impl<'db> ConstrainedTypeVar<'db> {
return IntersectionResult::Disjoint; return IntersectionResult::Disjoint;
} }
if lower.is_union() || upper.is_intersection() { if lower.is_union() || upper.is_nontrivial_intersection(db) {
return IntersectionResult::CannotSimplify; return IntersectionResult::CannotSimplify;
} }
@ -1579,7 +1591,6 @@ impl<'db> Node<'db> {
/// │ └─₀ never /// │ └─₀ never
/// └─₀ never /// └─₀ never
/// ``` /// ```
#[cfg_attr(not(test), expect(dead_code))] // Keep this around for debugging purposes
fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display { fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
struct DisplayNode<'a, 'db> { struct DisplayNode<'a, 'db> {
db: &'db dyn Db, db: &'db dyn Db,

View File

@ -10844,19 +10844,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
( (
Type::KnownInstance(KnownInstanceType::ConstraintSet(left)), Type::KnownInstance(KnownInstanceType::ConstraintSet(left)),
Type::KnownInstance(KnownInstanceType::ConstraintSet(right)), Type::KnownInstance(KnownInstanceType::ConstraintSet(right)),
) => { ) => match op {
let result = match op { ast::CmpOp::Eq => Some(Ok(Type::BooleanLiteral(
ast::CmpOp::Eq => Some( left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).is_always_satisfied(self.db()),
left.constraints(self.db()).iff(self.db(), right.constraints(self.db())) ))),
), ast::CmpOp::NotEq => Some(Ok(Type::BooleanLiteral(
ast::CmpOp::NotEq => Some( !left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).is_always_satisfied(self.db()),
left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).negate(self.db()) ))),
),
_ => None, _ => None,
};
result.map(|constraints| Ok(Type::KnownInstance(KnownInstanceType::ConstraintSet(
TrackedConstraintSet::new(self.db(), constraints)
))))
} }
( (