diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md index 1fad9290a5..c452e3c71d 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md @@ -29,7 +29,7 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]): assert x is 2 reveal_type(x) # revealed: Literal[2] assert y == 2 - reveal_type(y) # revealed: Literal[1, 2, 3] + reveal_type(y) # revealed: Literal[2] ``` ## `assert` with `isinstance` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md index 376c24f1e9..bfa741428b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md @@ -20,11 +20,9 @@ def _(flag1: bool, flag2: bool): x = 1 if flag1 else 2 if flag2 else 3 if x == 1: - # TODO should be Literal[1] - reveal_type(x) # revealed: Literal[1, 2, 3] + reveal_type(x) # revealed: Literal[1] elif x == 2: - # TODO should be Literal[2] - reveal_type(x) # revealed: Literal[2, 3] + reveal_type(x) # revealed: Literal[2] else: reveal_type(x) # revealed: Literal[3] ``` @@ -38,14 +36,11 @@ def _(flag1: bool, flag2: bool): if x != 1: reveal_type(x) # revealed: Literal[2, 3] elif x != 2: - # TODO should be `Literal[1]` - reveal_type(x) # revealed: Literal[1, 3] + reveal_type(x) # revealed: Literal[1] elif x == 3: - # TODO should be Never - reveal_type(x) # revealed: Literal[1, 2, 3] + reveal_type(x) # revealed: Never else: - # TODO should be Never - reveal_type(x) # revealed: Literal[1, 2] + reveal_type(x) # revealed: Never ``` ## Assignment expressions diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/eq.md similarity index 52% rename from crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md rename to crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/eq.md index 20f25d9ed4..bf3184b048 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/eq.md @@ -9,8 +9,7 @@ def _(flag: bool): if x != None: reveal_type(x) # revealed: Literal[1] else: - # TODO should be None - reveal_type(x) # revealed: None | Literal[1] + reveal_type(x) # revealed: None ``` ## `!=` for other singleton types @@ -22,8 +21,7 @@ def _(flag: bool): if x != False: reveal_type(x) # revealed: Literal[True] else: - # TODO should be Literal[False] - reveal_type(x) # revealed: bool + reveal_type(x) # revealed: Literal[False] ``` ## `x != y` where `y` is of literal type @@ -47,8 +45,7 @@ def _(flag: bool): if C != A: reveal_type(C) # revealed: Literal[B] else: - # TODO should be Literal[A] - reveal_type(C) # revealed: Literal[A, B] + reveal_type(C) # revealed: Literal[A] ``` ## `x != y` where `y` has multiple single-valued options @@ -61,8 +58,7 @@ def _(flag1: bool, flag2: bool): if x != y: reveal_type(x) # revealed: Literal[1, 2] else: - # TODO should be Literal[2] - reveal_type(x) # revealed: Literal[1, 2] + reveal_type(x) # revealed: Literal[2] ``` ## `!=` for non-single-valued types @@ -101,6 +97,61 @@ def f() -> Literal[1, 2, 3]: if (x := f()) != 1: reveal_type(x) # revealed: Literal[2, 3] else: - # TODO should be Literal[1] - reveal_type(x) # revealed: Literal[1, 2, 3] + reveal_type(x) # revealed: Literal[1] +``` + +## Union with `Any` + +```py +from typing import Any + +def _(x: Any | None, y: Any | None): + if x != 1: + reveal_type(x) # revealed: (Any & ~Literal[1]) | None + if y == 1: + reveal_type(y) # revealed: Any & ~None +``` + +## Booleans and integers + +```py +from typing import Literal + +def _(b: bool, i: Literal[1, 2]): + if b == 1: + reveal_type(b) # revealed: Literal[True] + else: + reveal_type(b) # revealed: Literal[False] + + if b == 6: + reveal_type(b) # revealed: Never + else: + reveal_type(b) # revealed: bool + + if b == 0: + reveal_type(b) # revealed: Literal[False] + else: + reveal_type(b) # revealed: Literal[True] + + if i == True: + reveal_type(i) # revealed: Literal[1] + else: + reveal_type(i) # revealed: Literal[2] +``` + +## Narrowing `LiteralString` in union + +```py +from typing_extensions import Literal, LiteralString, Any + +def _(s: LiteralString | None, t: LiteralString | Any): + if s == "foo": + reveal_type(s) # revealed: Literal["foo"] + + if s == 1: + reveal_type(s) # revealed: Never + + if t == "foo": + # TODO could be `Literal["foo"] | Any` + reveal_type(t) # revealed: LiteralString | Any ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/nested.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/nested.md index fa69fe8863..ab026cef67 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/nested.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/nested.md @@ -31,17 +31,14 @@ def _(flag1: bool, flag2: bool): if x != 1: reveal_type(x) # revealed: Literal[2, 3] if x == 2: - # TODO should be `Literal[2]` - reveal_type(x) # revealed: Literal[2, 3] + reveal_type(x) # revealed: Literal[2] elif x == 3: reveal_type(x) # revealed: Literal[3] else: reveal_type(x) # revealed: Never elif x != 2: - # TODO should be Literal[1] - reveal_type(x) # revealed: Literal[1, 3] + reveal_type(x) # revealed: Literal[1] else: - # TODO should be Never - reveal_type(x) # revealed: Literal[1, 2, 3] + reveal_type(x) # revealed: Never ``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 1050446ad2..695353ef7b 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -542,6 +542,11 @@ impl<'db> Type<'db> { .is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType)) } + fn is_bool(&self, db: &'db dyn Db) -> bool { + self.into_instance() + .is_some_and(|instance| instance.class().is_known(db, KnownClass::Bool)) + } + pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool { self.into_instance().is_some_and(|instance| { instance @@ -776,8 +781,13 @@ impl<'db> Type<'db> { } pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool { - self.into_union() - .is_some_and(|union| union.elements(db).iter().all(|ty| ty.is_single_valued(db))) + self.into_union().is_some_and(|union| { + union + .elements(db) + .iter() + .all(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string()) + }) || self.is_bool(db) + || self.is_literal_string() } pub const fn into_int_literal(self) -> Option { diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index b884d1da85..1dfac5b041 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -394,6 +394,142 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } + fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { + // We can only narrow on equality checks against single-valued types. + if rhs_ty.is_single_valued(self.db) || rhs_ty.is_union_of_single_valued(self.db) { + // The fully-general (and more efficient) approach here would be to introduce a + // `NeverEqualTo` type that can wrap a single-valued type, and then simply return + // `~NeverEqualTo(rhs_ty)` here and let union/intersection builder sort it out. This is + // how we handle `AlwaysTruthy` and `AlwaysFalsy`. But this means we have to deal with + // this type everywhere, and possibly have it show up unsimplified in some cases, and + // so we instead prefer to just do the simplification here. (Another hybrid option that + // would be similar to this, but more efficient, would be to allow narrowing to return + // something that is not a type, and handle this not-a-type in `symbol_from_bindings`, + // instead of intersecting with a type.) + + // Return `true` if it is possible for any two inhabitants of the given types to + // compare equal to each other; otherwise return `false`. + fn could_compare_equal<'db>( + db: &'db dyn Db, + left_ty: Type<'db>, + right_ty: Type<'db>, + ) -> bool { + if !left_ty.is_disjoint_from(db, right_ty) { + // If types overlap, they have inhabitants in common; it's definitely possible + // for an object to compare equal to itself. + return true; + } + match (left_ty, right_ty) { + // In order to be sure a union type cannot compare equal to another type, it + // must be true that no element of the union can compare equal to that type. + (Type::Union(union), _) => union + .elements(db) + .iter() + .any(|ty| could_compare_equal(db, *ty, right_ty)), + (_, Type::Union(union)) => union + .elements(db) + .iter() + .any(|ty| could_compare_equal(db, left_ty, *ty)), + // Boolean literals and int literals are disjoint, and single valued, and yet + // `True == 1` and `False == 0`. + (Type::BooleanLiteral(b), Type::IntLiteral(i)) + | (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i, + // Other than the above cases, two single-valued disjoint types cannot compare + // equal. + _ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)), + } + } + + // Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot + // compare equal to `rhs_ty`. + fn can_narrow_to_rhs<'db>( + db: &'db dyn Db, + lhs_ty: Type<'db>, + rhs_ty: Type<'db>, + ) -> bool { + match lhs_ty { + Type::Union(union) => union + .elements(db) + .iter() + .all(|ty| can_narrow_to_rhs(db, *ty, rhs_ty)), + // Either `rhs_ty` is a string literal, in which case we can narrow to it (no + // other string literal could compare equal to it), or it is not a string + // literal, in which case (given that it is single-valued), LiteralString + // cannot compare equal to it. + Type::LiteralString => true, + _ => !could_compare_equal(db, lhs_ty, rhs_ty), + } + } + + // Filter `ty` to just the types that cannot be equal to `rhs_ty`. + fn filter_to_cannot_be_equal<'db>( + db: &'db dyn Db, + ty: Type<'db>, + rhs_ty: Type<'db>, + ) -> Type<'db> { + match ty { + Type::Union(union) => { + union.map(db, |ty| filter_to_cannot_be_equal(db, *ty, rhs_ty)) + } + // Treat `bool` as `Literal[True, False]`. + Type::Instance(instance) if instance.class().is_known(db, KnownClass::Bool) => { + UnionType::from_elements( + db, + [Type::BooleanLiteral(true), Type::BooleanLiteral(false)] + .into_iter() + .map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)), + ) + } + _ => { + if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) { + ty + } else { + Type::Never + } + } + } + } + Some(if can_narrow_to_rhs(self.db, lhs_ty, rhs_ty) { + rhs_ty + } else { + filter_to_cannot_be_equal(self.db, lhs_ty, rhs_ty).negate(self.db) + }) + } else { + None + } + } + + fn evaluate_expr_ne(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { + match (lhs_ty, rhs_ty) { + (Type::Instance(instance), Type::IntLiteral(i)) + if instance.class().is_known(self.db, KnownClass::Bool) => + { + if i == 0 { + Some(Type::BooleanLiteral(false).negate(self.db)) + } else if i == 1 { + Some(Type::BooleanLiteral(true).negate(self.db)) + } else { + None + } + } + (_, Type::BooleanLiteral(b)) => { + if b { + Some( + UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(1)]) + .negate(self.db), + ) + } else { + Some( + UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(0)]) + .negate(self.db), + ) + } + } + _ if rhs_ty.is_single_valued(self.db) => Some(rhs_ty.negate(self.db)), + _ => None, + } + } + fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) { match rhs_ty { @@ -435,17 +571,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } ast::CmpOp::Is => Some(rhs_ty), - ast::CmpOp::NotEq => { - if rhs_ty.is_single_valued(self.db) { - let ty = IntersectionBuilder::new(self.db) - .add_negative(rhs_ty) - .build(); - Some(ty) - } else { - None - } - } - ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty), + ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty), + ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty), ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty), ast::CmpOp::NotIn => self .evaluate_expr_in(lhs_ty, rhs_ty)