[red-knot] Update `==` and `!=` narrowing (#17567)

## Summary

Historically we have avoided narrowing on `==` tests because in many
cases it's unsound, since subclasses of a type could compare equal to
who-knows-what. But there are a lot of types (literals and unions of
them, as well as some known instances like `None` -- single-valued
types) whose `__eq__` behavior we know, and which we can safely narrow
away based on equality comparisons.

This PR implements equality narrowing in the cases where it is sound.
The most elegant way to do this (and the way that is most in-line with
our approach up until now) would be to introduce new Type variants
`NeverEqualTo[...]` and `AlwaysEqualTo[...]`, and then implement all
type relations for those variants, narrow by intersection, and let union
and intersection simplification sort it all out. This is analogous to
our existing handling for `AlwaysFalse` and `AlwaysTrue`.

But I'm reluctant to add new `Type` variants for this, mostly because
they could end up un-simplified in some types and make types even more
complex. So let's try this approach, where we handle more of the
narrowing logic as a special case.

## Test Plan

Updated and added tests.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
Co-authored-by: Carl Meyer <carl@oddbird.net>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Matthew Mckee 2025-04-24 15:56:39 +01:00 committed by GitHub
parent ac6219ec38
commit e71f3ed2c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 220 additions and 40 deletions

View File

@ -29,7 +29,7 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x is 2 assert x is 2
reveal_type(x) # revealed: Literal[2] reveal_type(x) # revealed: Literal[2]
assert y == 2 assert y == 2
reveal_type(y) # revealed: Literal[1, 2, 3] reveal_type(y) # revealed: Literal[2]
``` ```
## `assert` with `isinstance` ## `assert` with `isinstance`

View File

@ -20,11 +20,9 @@ def _(flag1: bool, flag2: bool):
x = 1 if flag1 else 2 if flag2 else 3 x = 1 if flag1 else 2 if flag2 else 3
if x == 1: if x == 1:
# TODO should be Literal[1] reveal_type(x) # revealed: Literal[1]
reveal_type(x) # revealed: Literal[1, 2, 3]
elif x == 2: elif x == 2:
# TODO should be Literal[2] reveal_type(x) # revealed: Literal[2]
reveal_type(x) # revealed: Literal[2, 3]
else: else:
reveal_type(x) # revealed: Literal[3] reveal_type(x) # revealed: Literal[3]
``` ```
@ -38,14 +36,11 @@ def _(flag1: bool, flag2: bool):
if x != 1: if x != 1:
reveal_type(x) # revealed: Literal[2, 3] reveal_type(x) # revealed: Literal[2, 3]
elif x != 2: elif x != 2:
# TODO should be `Literal[1]` reveal_type(x) # revealed: Literal[1]
reveal_type(x) # revealed: Literal[1, 3]
elif x == 3: elif x == 3:
# TODO should be Never reveal_type(x) # revealed: Never
reveal_type(x) # revealed: Literal[1, 2, 3]
else: else:
# TODO should be Never reveal_type(x) # revealed: Never
reveal_type(x) # revealed: Literal[1, 2]
``` ```
## Assignment expressions ## Assignment expressions

View File

@ -9,8 +9,7 @@ def _(flag: bool):
if x != None: if x != None:
reveal_type(x) # revealed: Literal[1] reveal_type(x) # revealed: Literal[1]
else: else:
# TODO should be None reveal_type(x) # revealed: None
reveal_type(x) # revealed: None | Literal[1]
``` ```
## `!=` for other singleton types ## `!=` for other singleton types
@ -22,8 +21,7 @@ def _(flag: bool):
if x != False: if x != False:
reveal_type(x) # revealed: Literal[True] reveal_type(x) # revealed: Literal[True]
else: else:
# TODO should be Literal[False] reveal_type(x) # revealed: Literal[False]
reveal_type(x) # revealed: bool
``` ```
## `x != y` where `y` is of literal type ## `x != y` where `y` is of literal type
@ -47,8 +45,7 @@ def _(flag: bool):
if C != A: if C != A:
reveal_type(C) # revealed: Literal[B] reveal_type(C) # revealed: Literal[B]
else: else:
# TODO should be Literal[A] reveal_type(C) # revealed: Literal[A]
reveal_type(C) # revealed: Literal[A, B]
``` ```
## `x != y` where `y` has multiple single-valued options ## `x != y` where `y` has multiple single-valued options
@ -61,8 +58,7 @@ def _(flag1: bool, flag2: bool):
if x != y: if x != y:
reveal_type(x) # revealed: Literal[1, 2] reveal_type(x) # revealed: Literal[1, 2]
else: else:
# TODO should be Literal[2] reveal_type(x) # revealed: Literal[2]
reveal_type(x) # revealed: Literal[1, 2]
``` ```
## `!=` for non-single-valued types ## `!=` for non-single-valued types
@ -101,6 +97,61 @@ def f() -> Literal[1, 2, 3]:
if (x := f()) != 1: if (x := f()) != 1:
reveal_type(x) # revealed: Literal[2, 3] reveal_type(x) # revealed: Literal[2, 3]
else: else:
# TODO should be Literal[1] reveal_type(x) # revealed: Literal[1]
reveal_type(x) # revealed: Literal[1, 2, 3] ```
## Union with `Any`
```py
from typing import Any
def _(x: Any | None, y: Any | None):
if x != 1:
reveal_type(x) # revealed: (Any & ~Literal[1]) | None
if y == 1:
reveal_type(y) # revealed: Any & ~None
```
## Booleans and integers
```py
from typing import Literal
def _(b: bool, i: Literal[1, 2]):
if b == 1:
reveal_type(b) # revealed: Literal[True]
else:
reveal_type(b) # revealed: Literal[False]
if b == 6:
reveal_type(b) # revealed: Never
else:
reveal_type(b) # revealed: bool
if b == 0:
reveal_type(b) # revealed: Literal[False]
else:
reveal_type(b) # revealed: Literal[True]
if i == True:
reveal_type(i) # revealed: Literal[1]
else:
reveal_type(i) # revealed: Literal[2]
```
## Narrowing `LiteralString` in union
```py
from typing_extensions import Literal, LiteralString, Any
def _(s: LiteralString | None, t: LiteralString | Any):
if s == "foo":
reveal_type(s) # revealed: Literal["foo"]
if s == 1:
reveal_type(s) # revealed: Never
if t == "foo":
# TODO could be `Literal["foo"] | Any`
reveal_type(t) # revealed: LiteralString | Any
``` ```

View File

@ -31,17 +31,14 @@ def _(flag1: bool, flag2: bool):
if x != 1: if x != 1:
reveal_type(x) # revealed: Literal[2, 3] reveal_type(x) # revealed: Literal[2, 3]
if x == 2: if x == 2:
# TODO should be `Literal[2]` reveal_type(x) # revealed: Literal[2]
reveal_type(x) # revealed: Literal[2, 3]
elif x == 3: elif x == 3:
reveal_type(x) # revealed: Literal[3] reveal_type(x) # revealed: Literal[3]
else: else:
reveal_type(x) # revealed: Never reveal_type(x) # revealed: Never
elif x != 2: elif x != 2:
# TODO should be Literal[1] reveal_type(x) # revealed: Literal[1]
reveal_type(x) # revealed: Literal[1, 3]
else: else:
# TODO should be Never reveal_type(x) # revealed: Never
reveal_type(x) # revealed: Literal[1, 2, 3]
``` ```

View File

@ -542,6 +542,11 @@ impl<'db> Type<'db> {
.is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType)) .is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType))
} }
fn is_bool(&self, db: &'db dyn Db) -> bool {
self.into_instance()
.is_some_and(|instance| instance.class().is_known(db, KnownClass::Bool))
}
pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool { pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool {
self.into_instance().is_some_and(|instance| { self.into_instance().is_some_and(|instance| {
instance instance
@ -776,8 +781,13 @@ impl<'db> Type<'db> {
} }
pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool { pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool {
self.into_union() self.into_union().is_some_and(|union| {
.is_some_and(|union| union.elements(db).iter().all(|ty| ty.is_single_valued(db))) union
.elements(db)
.iter()
.all(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
}) || self.is_bool(db)
|| self.is_literal_string()
} }
pub const fn into_int_literal(self) -> Option<i64> { pub const fn into_int_literal(self) -> Option<i64> {

View File

@ -394,6 +394,142 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
} }
} }
fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
// We can only narrow on equality checks against single-valued types.
if rhs_ty.is_single_valued(self.db) || rhs_ty.is_union_of_single_valued(self.db) {
// The fully-general (and more efficient) approach here would be to introduce a
// `NeverEqualTo` type that can wrap a single-valued type, and then simply return
// `~NeverEqualTo(rhs_ty)` here and let union/intersection builder sort it out. This is
// how we handle `AlwaysTruthy` and `AlwaysFalsy`. But this means we have to deal with
// this type everywhere, and possibly have it show up unsimplified in some cases, and
// so we instead prefer to just do the simplification here. (Another hybrid option that
// would be similar to this, but more efficient, would be to allow narrowing to return
// something that is not a type, and handle this not-a-type in `symbol_from_bindings`,
// instead of intersecting with a type.)
// Return `true` if it is possible for any two inhabitants of the given types to
// compare equal to each other; otherwise return `false`.
fn could_compare_equal<'db>(
db: &'db dyn Db,
left_ty: Type<'db>,
right_ty: Type<'db>,
) -> bool {
if !left_ty.is_disjoint_from(db, right_ty) {
// If types overlap, they have inhabitants in common; it's definitely possible
// for an object to compare equal to itself.
return true;
}
match (left_ty, right_ty) {
// In order to be sure a union type cannot compare equal to another type, it
// must be true that no element of the union can compare equal to that type.
(Type::Union(union), _) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, *ty, right_ty)),
(_, Type::Union(union)) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, left_ty, *ty)),
// Boolean literals and int literals are disjoint, and single valued, and yet
// `True == 1` and `False == 0`.
(Type::BooleanLiteral(b), Type::IntLiteral(i))
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
// Other than the above cases, two single-valued disjoint types cannot compare
// equal.
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
}
}
// Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot
// compare equal to `rhs_ty`.
fn can_narrow_to_rhs<'db>(
db: &'db dyn Db,
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
) -> bool {
match lhs_ty {
Type::Union(union) => union
.elements(db)
.iter()
.all(|ty| can_narrow_to_rhs(db, *ty, rhs_ty)),
// Either `rhs_ty` is a string literal, in which case we can narrow to it (no
// other string literal could compare equal to it), or it is not a string
// literal, in which case (given that it is single-valued), LiteralString
// cannot compare equal to it.
Type::LiteralString => true,
_ => !could_compare_equal(db, lhs_ty, rhs_ty),
}
}
// Filter `ty` to just the types that cannot be equal to `rhs_ty`.
fn filter_to_cannot_be_equal<'db>(
db: &'db dyn Db,
ty: Type<'db>,
rhs_ty: Type<'db>,
) -> Type<'db> {
match ty {
Type::Union(union) => {
union.map(db, |ty| filter_to_cannot_be_equal(db, *ty, rhs_ty))
}
// Treat `bool` as `Literal[True, False]`.
Type::Instance(instance) if instance.class().is_known(db, KnownClass::Bool) => {
UnionType::from_elements(
db,
[Type::BooleanLiteral(true), Type::BooleanLiteral(false)]
.into_iter()
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
)
}
_ => {
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
ty
} else {
Type::Never
}
}
}
}
Some(if can_narrow_to_rhs(self.db, lhs_ty, rhs_ty) {
rhs_ty
} else {
filter_to_cannot_be_equal(self.db, lhs_ty, rhs_ty).negate(self.db)
})
} else {
None
}
}
fn evaluate_expr_ne(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
match (lhs_ty, rhs_ty) {
(Type::Instance(instance), Type::IntLiteral(i))
if instance.class().is_known(self.db, KnownClass::Bool) =>
{
if i == 0 {
Some(Type::BooleanLiteral(false).negate(self.db))
} else if i == 1 {
Some(Type::BooleanLiteral(true).negate(self.db))
} else {
None
}
}
(_, Type::BooleanLiteral(b)) => {
if b {
Some(
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(1)])
.negate(self.db),
)
} else {
Some(
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(0)])
.negate(self.db),
)
}
}
_ if rhs_ty.is_single_valued(self.db) => Some(rhs_ty.negate(self.db)),
_ => None,
}
}
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> { fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) { if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
match rhs_ty { match rhs_ty {
@ -435,17 +571,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
} }
} }
ast::CmpOp::Is => Some(rhs_ty), ast::CmpOp::Is => Some(rhs_ty),
ast::CmpOp::NotEq => { ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty),
if rhs_ty.is_single_valued(self.db) { ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty),
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
Some(ty)
} else {
None
}
}
ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty),
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty), ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
ast::CmpOp::NotIn => self ast::CmpOp::NotIn => self
.evaluate_expr_in(lhs_ty, rhs_ty) .evaluate_expr_in(lhs_ty, rhs_ty)