mirror of https://github.com/astral-sh/ruff
[ty] Fix negation upper bounds in constraint sets (#21897)
This fixes the logic error that @sharkdp [found](https://github.com/astral-sh/ruff/pull/21871#discussion_r2605755588) in the constraint set upper bound normalization logic I introduced in #21871. I had originally claimed that `(T ≤ α & ~β)` should simplify into `(T ≤ α) ∧ ¬(T ≤ β)`. But that also suggests that `T ≤ ~β` should simplify to `¬(T ≤ β)` on its own, and that's not correct. The correct simplification is that `~α` is an "atomic" type, not an "intersection" for the purposes of our upper bound simplifcation. So `(T ≤ α & ~β)` should simplify to `(T ≤ α) ∧ (T ≤ ~β)`. That is, break apart the elements of a (proper) intersection, regardless of whether each element is negated or not. This PR fixes the logic, adds a test case, and updates the comments to be hopefully more clear and accurate.
This commit is contained in:
parent
5dc0079e78
commit
3e00221a6c
|
|
@ -126,7 +126,7 @@ strict subtype of the lower bound, a strict supertype of the upper bound, or inc
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from typing import Any, final, Never, Sequence
|
from typing import Any, final, Never, Sequence
|
||||||
from ty_extensions import ConstraintSet, static_assert
|
from ty_extensions import ConstraintSet, Not, static_assert
|
||||||
|
|
||||||
class Super: ...
|
class Super: ...
|
||||||
class Base(Super): ...
|
class Base(Super): ...
|
||||||
|
|
@ -207,6 +207,15 @@ def _[T]() -> None:
|
||||||
static_assert(constraints == expected)
|
static_assert(constraints == expected)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
A negated _type_ is not the same thing as a negated _range_.
|
||||||
|
|
||||||
|
```py
|
||||||
|
def _[T]() -> None:
|
||||||
|
negated_type = ConstraintSet.range(Never, T, Not[int])
|
||||||
|
negated_constraint = ~ConstraintSet.range(Never, T, int)
|
||||||
|
static_assert(negated_type != negated_constraint)
|
||||||
|
```
|
||||||
|
|
||||||
## Intersection
|
## Intersection
|
||||||
|
|
||||||
The intersection of two constraint sets requires that the constraints in both sets hold. In many
|
The intersection of two constraint sets requires that the constraints in both sets hold. In many
|
||||||
|
|
|
||||||
|
|
@ -1268,8 +1268,14 @@ impl<'db> Type<'db> {
|
||||||
self.as_union().expect("Expected a Type::Union variant")
|
self.as_union().expect("Expected a Type::Union variant")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) const fn is_intersection(self) -> bool {
|
/// Returns whether this is a "real" intersection type. (Negated types are represented by an
|
||||||
matches!(self, Type::Intersection(_))
|
/// intersection containing a single negative branch, which this method does _not_ consider a
|
||||||
|
/// "real" intersection.)
|
||||||
|
pub(crate) fn is_nontrivial_intersection(self, db: &'db dyn Db) -> bool {
|
||||||
|
match self {
|
||||||
|
Type::Intersection(intersection) => !intersection.is_simple_negation(db),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) const fn as_function_literal(self) -> Option<FunctionType<'db>> {
|
pub(crate) const fn as_function_literal(self) -> Option<FunctionType<'db>> {
|
||||||
|
|
@ -14151,6 +14157,10 @@ impl<'db> IntersectionType<'db> {
|
||||||
(self.positive(db).len() + self.negative(db).len()) == 1
|
(self.positive(db).len() + self.negative(db).len()) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_simple_negation(self, db: &'db dyn Db) -> bool {
|
||||||
|
self.positive(db).is_empty() && self.negative(db).len() == 1
|
||||||
|
}
|
||||||
|
|
||||||
fn heap_size((positive, negative): &(FxOrderSet<Type<'db>>, FxOrderSet<Type<'db>>)) -> usize {
|
fn heap_size((positive, negative): &(FxOrderSet<Type<'db>>, FxOrderSet<Type<'db>>)) -> usize {
|
||||||
ruff_memory_usage::order_set_heap_size(positive)
|
ruff_memory_usage::order_set_heap_size(positive)
|
||||||
+ ruff_memory_usage::order_set_heap_size(negative)
|
+ ruff_memory_usage::order_set_heap_size(negative)
|
||||||
|
|
|
||||||
|
|
@ -435,6 +435,11 @@ impl<'db> ConstraintSet<'db> {
|
||||||
pub(crate) fn display(self, db: &'db dyn Db) -> impl Display {
|
pub(crate) fn display(self, db: &'db dyn Db) -> impl Display {
|
||||||
self.node.simplify_for_display(db).display(db)
|
self.node.simplify_for_display(db).display(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[expect(dead_code)] // Keep this around for debugging purposes
|
||||||
|
pub(crate) fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
|
||||||
|
self.node.display_graph(db, prefix)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<bool> for ConstraintSet<'_> {
|
impl From<bool> for ConstraintSet<'_> {
|
||||||
|
|
@ -498,11 +503,13 @@ impl<'db> ConstrainedTypeVar<'db> {
|
||||||
debug_assert_eq!(upper, upper.top_materialization(db));
|
debug_assert_eq!(upper, upper.top_materialization(db));
|
||||||
|
|
||||||
// It's not useful for an upper bound to be an intersection type, or for a lower bound to
|
// It's not useful for an upper bound to be an intersection type, or for a lower bound to
|
||||||
// be a union type. Both of those can be rewritten as simpler BDDs:
|
// be a union type. Because the following equivalences hold, we can break these bounds
|
||||||
|
// apart and create an equivalent BDD with more nodes but simpler constraints. (Fewer,
|
||||||
|
// simpler constraints mean that our sequent maps won't grow pathologically large.)
|
||||||
//
|
//
|
||||||
// T ≤ α & β ⇒ (T ≤ α) ∧ (T ≤ β)
|
// T ≤ (α & β) ⇔ (T ≤ α) ∧ (T ≤ β)
|
||||||
// T ≤ α & ¬β ⇒ (T ≤ α) ∧ ¬(T ≤ β)
|
// T ≤ (¬α & ¬β) ⇔ (T ≤ ¬α) ∧ (T ≤ ¬β)
|
||||||
// α | β ≤ T ⇒ (α ≤ T) ∧ (β ≤ T)
|
// (α | β) ≤ T ⇔ (α ≤ T) ∧ (β ≤ T)
|
||||||
if let Type::Union(lower_union) = lower {
|
if let Type::Union(lower_union) = lower {
|
||||||
let mut result = Node::AlwaysTrue;
|
let mut result = Node::AlwaysTrue;
|
||||||
for lower_element in lower_union.elements(db) {
|
for lower_element in lower_union.elements(db) {
|
||||||
|
|
@ -513,7 +520,12 @@ impl<'db> ConstrainedTypeVar<'db> {
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if let Type::Intersection(upper_intersection) = upper {
|
// A negated type ¬α is represented as an intersection with no positive elements, and a
|
||||||
|
// single negative element. We _don't_ want to treat that an "intersection" for the
|
||||||
|
// purposes of simplifying upper bounds.
|
||||||
|
if let Type::Intersection(upper_intersection) = upper
|
||||||
|
&& !upper_intersection.is_simple_negation(db)
|
||||||
|
{
|
||||||
let mut result = Node::AlwaysTrue;
|
let mut result = Node::AlwaysTrue;
|
||||||
for upper_element in upper_intersection.iter_positive(db) {
|
for upper_element in upper_intersection.iter_positive(db) {
|
||||||
result = result.and(
|
result = result.and(
|
||||||
|
|
@ -524,7 +536,7 @@ impl<'db> ConstrainedTypeVar<'db> {
|
||||||
for upper_element in upper_intersection.iter_negative(db) {
|
for upper_element in upper_intersection.iter_negative(db) {
|
||||||
result = result.and(
|
result = result.and(
|
||||||
db,
|
db,
|
||||||
ConstrainedTypeVar::new_node(db, typevar, lower, upper_element).negate(db),
|
ConstrainedTypeVar::new_node(db, typevar, lower, upper_element.negate(db)),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -716,7 +728,7 @@ impl<'db> ConstrainedTypeVar<'db> {
|
||||||
return IntersectionResult::Disjoint;
|
return IntersectionResult::Disjoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
if lower.is_union() || upper.is_intersection() {
|
if lower.is_union() || upper.is_nontrivial_intersection(db) {
|
||||||
return IntersectionResult::CannotSimplify;
|
return IntersectionResult::CannotSimplify;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1579,7 +1591,6 @@ impl<'db> Node<'db> {
|
||||||
/// │ └─₀ never
|
/// │ └─₀ never
|
||||||
/// └─₀ never
|
/// └─₀ never
|
||||||
/// ```
|
/// ```
|
||||||
#[cfg_attr(not(test), expect(dead_code))] // Keep this around for debugging purposes
|
|
||||||
fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
|
fn display_graph(self, db: &'db dyn Db, prefix: &dyn Display) -> impl Display {
|
||||||
struct DisplayNode<'a, 'db> {
|
struct DisplayNode<'a, 'db> {
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
|
|
|
||||||
|
|
@ -10844,19 +10844,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
(
|
(
|
||||||
Type::KnownInstance(KnownInstanceType::ConstraintSet(left)),
|
Type::KnownInstance(KnownInstanceType::ConstraintSet(left)),
|
||||||
Type::KnownInstance(KnownInstanceType::ConstraintSet(right)),
|
Type::KnownInstance(KnownInstanceType::ConstraintSet(right)),
|
||||||
) => {
|
) => match op {
|
||||||
let result = match op {
|
ast::CmpOp::Eq => Some(Ok(Type::BooleanLiteral(
|
||||||
ast::CmpOp::Eq => Some(
|
left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).is_always_satisfied(self.db()),
|
||||||
left.constraints(self.db()).iff(self.db(), right.constraints(self.db()))
|
))),
|
||||||
),
|
ast::CmpOp::NotEq => Some(Ok(Type::BooleanLiteral(
|
||||||
ast::CmpOp::NotEq => Some(
|
!left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).is_always_satisfied(self.db()),
|
||||||
left.constraints(self.db()).iff(self.db(), right.constraints(self.db())).negate(self.db())
|
))),
|
||||||
),
|
_ => None,
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
result.map(|constraints| Ok(Type::KnownInstance(KnownInstanceType::ConstraintSet(
|
|
||||||
TrackedConstraintSet::new(self.db(), constraints)
|
|
||||||
))))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue