mirror of https://github.com/astral-sh/ruff
[red-knot] Update `==` and `!=` narrowing (#17567)
## Summary Historically we have avoided narrowing on `==` tests because in many cases it's unsound, since subclasses of a type could compare equal to who-knows-what. But there are a lot of types (literals and unions of them, as well as some known instances like `None` -- single-valued types) whose `__eq__` behavior we know, and which we can safely narrow away based on equality comparisons. This PR implements equality narrowing in the cases where it is sound. The most elegant way to do this (and the way that is most in-line with our approach up until now) would be to introduce new Type variants `NeverEqualTo[...]` and `AlwaysEqualTo[...]`, and then implement all type relations for those variants, narrow by intersection, and let union and intersection simplification sort it all out. This is analogous to our existing handling for `AlwaysFalse` and `AlwaysTrue`. But I'm reluctant to add new `Type` variants for this, mostly because they could end up un-simplified in some types and make types even more complex. So let's try this approach, where we handle more of the narrowing logic as a special case. ## Test Plan Updated and added tests. --------- Co-authored-by: Carl Meyer <carl@astral.sh> Co-authored-by: Carl Meyer <carl@oddbird.net> Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
parent
ac6219ec38
commit
e71f3ed2c5
|
|
@ -29,7 +29,7 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
|
|||
assert x is 2
|
||||
reveal_type(x) # revealed: Literal[2]
|
||||
assert y == 2
|
||||
reveal_type(y) # revealed: Literal[1, 2, 3]
|
||||
reveal_type(y) # revealed: Literal[2]
|
||||
```
|
||||
|
||||
## `assert` with `isinstance`
|
||||
|
|
|
|||
|
|
@ -20,11 +20,9 @@ def _(flag1: bool, flag2: bool):
|
|||
x = 1 if flag1 else 2 if flag2 else 3
|
||||
|
||||
if x == 1:
|
||||
# TODO should be Literal[1]
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
elif x == 2:
|
||||
# TODO should be Literal[2]
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
reveal_type(x) # revealed: Literal[2]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[3]
|
||||
```
|
||||
|
|
@ -38,14 +36,11 @@ def _(flag1: bool, flag2: bool):
|
|||
if x != 1:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
elif x != 2:
|
||||
# TODO should be `Literal[1]`
|
||||
reveal_type(x) # revealed: Literal[1, 3]
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
elif x == 3:
|
||||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
reveal_type(x) # revealed: Never
|
||||
else:
|
||||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
reveal_type(x) # revealed: Never
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ def _(flag: bool):
|
|||
if x != None:
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
else:
|
||||
# TODO should be None
|
||||
reveal_type(x) # revealed: None | Literal[1]
|
||||
reveal_type(x) # revealed: None
|
||||
```
|
||||
|
||||
## `!=` for other singleton types
|
||||
|
|
@ -22,8 +21,7 @@ def _(flag: bool):
|
|||
if x != False:
|
||||
reveal_type(x) # revealed: Literal[True]
|
||||
else:
|
||||
# TODO should be Literal[False]
|
||||
reveal_type(x) # revealed: bool
|
||||
reveal_type(x) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
## `x != y` where `y` is of literal type
|
||||
|
|
@ -47,8 +45,7 @@ def _(flag: bool):
|
|||
if C != A:
|
||||
reveal_type(C) # revealed: Literal[B]
|
||||
else:
|
||||
# TODO should be Literal[A]
|
||||
reveal_type(C) # revealed: Literal[A, B]
|
||||
reveal_type(C) # revealed: Literal[A]
|
||||
```
|
||||
|
||||
## `x != y` where `y` has multiple single-valued options
|
||||
|
|
@ -61,8 +58,7 @@ def _(flag1: bool, flag2: bool):
|
|||
if x != y:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
else:
|
||||
# TODO should be Literal[2]
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
reveal_type(x) # revealed: Literal[2]
|
||||
```
|
||||
|
||||
## `!=` for non-single-valued types
|
||||
|
|
@ -101,6 +97,61 @@ def f() -> Literal[1, 2, 3]:
|
|||
if (x := f()) != 1:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
else:
|
||||
# TODO should be Literal[1]
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
```
|
||||
|
||||
## Union with `Any`
|
||||
|
||||
```py
|
||||
from typing import Any
|
||||
|
||||
def _(x: Any | None, y: Any | None):
|
||||
if x != 1:
|
||||
reveal_type(x) # revealed: (Any & ~Literal[1]) | None
|
||||
if y == 1:
|
||||
reveal_type(y) # revealed: Any & ~None
|
||||
```
|
||||
|
||||
## Booleans and integers
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def _(b: bool, i: Literal[1, 2]):
|
||||
if b == 1:
|
||||
reveal_type(b) # revealed: Literal[True]
|
||||
else:
|
||||
reveal_type(b) # revealed: Literal[False]
|
||||
|
||||
if b == 6:
|
||||
reveal_type(b) # revealed: Never
|
||||
else:
|
||||
reveal_type(b) # revealed: bool
|
||||
|
||||
if b == 0:
|
||||
reveal_type(b) # revealed: Literal[False]
|
||||
else:
|
||||
reveal_type(b) # revealed: Literal[True]
|
||||
|
||||
if i == True:
|
||||
reveal_type(i) # revealed: Literal[1]
|
||||
else:
|
||||
reveal_type(i) # revealed: Literal[2]
|
||||
```
|
||||
|
||||
## Narrowing `LiteralString` in union
|
||||
|
||||
```py
|
||||
from typing_extensions import Literal, LiteralString, Any
|
||||
|
||||
def _(s: LiteralString | None, t: LiteralString | Any):
|
||||
if s == "foo":
|
||||
reveal_type(s) # revealed: Literal["foo"]
|
||||
|
||||
if s == 1:
|
||||
reveal_type(s) # revealed: Never
|
||||
|
||||
if t == "foo":
|
||||
# TODO could be `Literal["foo"] | Any`
|
||||
reveal_type(t) # revealed: LiteralString | Any
|
||||
```
|
||||
|
|
@ -31,17 +31,14 @@ def _(flag1: bool, flag2: bool):
|
|||
if x != 1:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
if x == 2:
|
||||
# TODO should be `Literal[2]`
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
reveal_type(x) # revealed: Literal[2]
|
||||
elif x == 3:
|
||||
reveal_type(x) # revealed: Literal[3]
|
||||
else:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
elif x != 2:
|
||||
# TODO should be Literal[1]
|
||||
reveal_type(x) # revealed: Literal[1, 3]
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
else:
|
||||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
reveal_type(x) # revealed: Never
|
||||
```
|
||||
|
|
|
|||
|
|
@ -542,6 +542,11 @@ impl<'db> Type<'db> {
|
|||
.is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType))
|
||||
}
|
||||
|
||||
fn is_bool(&self, db: &'db dyn Db) -> bool {
|
||||
self.into_instance()
|
||||
.is_some_and(|instance| instance.class().is_known(db, KnownClass::Bool))
|
||||
}
|
||||
|
||||
pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool {
|
||||
self.into_instance().is_some_and(|instance| {
|
||||
instance
|
||||
|
|
@ -776,8 +781,13 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
|
||||
pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool {
|
||||
self.into_union()
|
||||
.is_some_and(|union| union.elements(db).iter().all(|ty| ty.is_single_valued(db)))
|
||||
self.into_union().is_some_and(|union| {
|
||||
union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.all(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
|
||||
}) || self.is_bool(db)
|
||||
|| self.is_literal_string()
|
||||
}
|
||||
|
||||
pub const fn into_int_literal(self) -> Option<i64> {
|
||||
|
|
|
|||
|
|
@ -394,6 +394,142 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
|
||||
// We can only narrow on equality checks against single-valued types.
|
||||
if rhs_ty.is_single_valued(self.db) || rhs_ty.is_union_of_single_valued(self.db) {
|
||||
// The fully-general (and more efficient) approach here would be to introduce a
|
||||
// `NeverEqualTo` type that can wrap a single-valued type, and then simply return
|
||||
// `~NeverEqualTo(rhs_ty)` here and let union/intersection builder sort it out. This is
|
||||
// how we handle `AlwaysTruthy` and `AlwaysFalsy`. But this means we have to deal with
|
||||
// this type everywhere, and possibly have it show up unsimplified in some cases, and
|
||||
// so we instead prefer to just do the simplification here. (Another hybrid option that
|
||||
// would be similar to this, but more efficient, would be to allow narrowing to return
|
||||
// something that is not a type, and handle this not-a-type in `symbol_from_bindings`,
|
||||
// instead of intersecting with a type.)
|
||||
|
||||
// Return `true` if it is possible for any two inhabitants of the given types to
|
||||
// compare equal to each other; otherwise return `false`.
|
||||
fn could_compare_equal<'db>(
|
||||
db: &'db dyn Db,
|
||||
left_ty: Type<'db>,
|
||||
right_ty: Type<'db>,
|
||||
) -> bool {
|
||||
if !left_ty.is_disjoint_from(db, right_ty) {
|
||||
// If types overlap, they have inhabitants in common; it's definitely possible
|
||||
// for an object to compare equal to itself.
|
||||
return true;
|
||||
}
|
||||
match (left_ty, right_ty) {
|
||||
// In order to be sure a union type cannot compare equal to another type, it
|
||||
// must be true that no element of the union can compare equal to that type.
|
||||
(Type::Union(union), _) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|ty| could_compare_equal(db, *ty, right_ty)),
|
||||
(_, Type::Union(union)) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|ty| could_compare_equal(db, left_ty, *ty)),
|
||||
// Boolean literals and int literals are disjoint, and single valued, and yet
|
||||
// `True == 1` and `False == 0`.
|
||||
(Type::BooleanLiteral(b), Type::IntLiteral(i))
|
||||
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
|
||||
// Other than the above cases, two single-valued disjoint types cannot compare
|
||||
// equal.
|
||||
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
|
||||
}
|
||||
}
|
||||
|
||||
// Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot
|
||||
// compare equal to `rhs_ty`.
|
||||
fn can_narrow_to_rhs<'db>(
|
||||
db: &'db dyn Db,
|
||||
lhs_ty: Type<'db>,
|
||||
rhs_ty: Type<'db>,
|
||||
) -> bool {
|
||||
match lhs_ty {
|
||||
Type::Union(union) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.all(|ty| can_narrow_to_rhs(db, *ty, rhs_ty)),
|
||||
// Either `rhs_ty` is a string literal, in which case we can narrow to it (no
|
||||
// other string literal could compare equal to it), or it is not a string
|
||||
// literal, in which case (given that it is single-valued), LiteralString
|
||||
// cannot compare equal to it.
|
||||
Type::LiteralString => true,
|
||||
_ => !could_compare_equal(db, lhs_ty, rhs_ty),
|
||||
}
|
||||
}
|
||||
|
||||
// Filter `ty` to just the types that cannot be equal to `rhs_ty`.
|
||||
fn filter_to_cannot_be_equal<'db>(
|
||||
db: &'db dyn Db,
|
||||
ty: Type<'db>,
|
||||
rhs_ty: Type<'db>,
|
||||
) -> Type<'db> {
|
||||
match ty {
|
||||
Type::Union(union) => {
|
||||
union.map(db, |ty| filter_to_cannot_be_equal(db, *ty, rhs_ty))
|
||||
}
|
||||
// Treat `bool` as `Literal[True, False]`.
|
||||
Type::Instance(instance) if instance.class().is_known(db, KnownClass::Bool) => {
|
||||
UnionType::from_elements(
|
||||
db,
|
||||
[Type::BooleanLiteral(true), Type::BooleanLiteral(false)]
|
||||
.into_iter()
|
||||
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
|
||||
ty
|
||||
} else {
|
||||
Type::Never
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(if can_narrow_to_rhs(self.db, lhs_ty, rhs_ty) {
|
||||
rhs_ty
|
||||
} else {
|
||||
filter_to_cannot_be_equal(self.db, lhs_ty, rhs_ty).negate(self.db)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_expr_ne(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
|
||||
match (lhs_ty, rhs_ty) {
|
||||
(Type::Instance(instance), Type::IntLiteral(i))
|
||||
if instance.class().is_known(self.db, KnownClass::Bool) =>
|
||||
{
|
||||
if i == 0 {
|
||||
Some(Type::BooleanLiteral(false).negate(self.db))
|
||||
} else if i == 1 {
|
||||
Some(Type::BooleanLiteral(true).negate(self.db))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
(_, Type::BooleanLiteral(b)) => {
|
||||
if b {
|
||||
Some(
|
||||
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(1)])
|
||||
.negate(self.db),
|
||||
)
|
||||
} else {
|
||||
Some(
|
||||
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(0)])
|
||||
.negate(self.db),
|
||||
)
|
||||
}
|
||||
}
|
||||
_ if rhs_ty.is_single_valued(self.db) => Some(rhs_ty.negate(self.db)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
|
||||
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
|
||||
match rhs_ty {
|
||||
|
|
@ -435,17 +571,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
ast::CmpOp::Is => Some(rhs_ty),
|
||||
ast::CmpOp::NotEq => {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty),
|
||||
ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty),
|
||||
ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty),
|
||||
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
|
||||
ast::CmpOp::NotIn => self
|
||||
.evaluate_expr_in(lhs_ty, rhs_ty)
|
||||
|
|
|
|||
Loading…
Reference in New Issue