diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md index 6456b7764b..f9de673a17 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md @@ -238,3 +238,20 @@ def _(s: LiteralString | None, t: LiteralString | Any): # TODO could be `Literal["foo"] | Any` reveal_type(t) # revealed: LiteralString | Any ``` + +## Narrowing with tuple types + +We assume that tuple subclasses don't override `tuple.__eq__`, which only returns True for other +tuples. So they are excluded from the narrowed type when comparing to non-tuple values. + +```py +from typing import Literal + +def _(x: Literal["a", "b"] | tuple[int, int]): + if x == "a": + # tuple type is excluded because it's disjoint from the string literal + reveal_type(x) # revealed: Literal["a"] + else: + # tuple type remains in the else branch + reveal_type(x) # revealed: Literal["b"] | tuple[int, int] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md index 14e2f6eab4..79bb4c1d92 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md @@ -191,3 +191,20 @@ def test(x: Status | int): else: reveal_type(x) # revealed: Literal[Status.REJECTED] | int ``` + +## Union with tuple and `Literal` + +We assume that tuple subclasses don't override `tuple.__eq__`, which only returns True for other +tuples. So they are excluded from the narrowed type when disjoint from the RHS values. + +```py +from typing import Literal + +def test(x: Literal["none", "auto", "required"] | tuple[list[str], Literal["auto", "required"]]): + if x in ("auto", "required"): + # tuple type is excluded because it's disjoint from the string literals + reveal_type(x) # revealed: Literal["auto", "required"] + else: + # tuple type remains in the else branch + reveal_type(x) # revealed: Literal["none"] | tuple[list[str], Literal["auto", "required"]] +``` diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index cc9c0ca0f6..ae36ea47ed 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -324,6 +324,39 @@ fn place_expr(expr: &ast::Expr) -> Option { } } +/// Return `true` if it is possible for any two inhabitants of the given types to +/// compare equal to each other; otherwise return `false`. +fn could_compare_equal<'db>(db: &'db dyn Db, left_ty: Type<'db>, right_ty: Type<'db>) -> bool { + if !left_ty.is_disjoint_from(db, right_ty) { + // If types overlap, they have inhabitants in common; it's definitely possible + // for an object to compare equal to itself. + return true; + } + match (left_ty, right_ty) { + // In order to be sure a union type cannot compare equal to another type, it + // must be true that no element of the union can compare equal to that type. + (Type::Union(union), _) => union + .elements(db) + .iter() + .any(|ty| could_compare_equal(db, *ty, right_ty)), + (_, Type::Union(union)) => union + .elements(db) + .iter() + .any(|ty| could_compare_equal(db, left_ty, *ty)), + // Boolean literals and int literals are disjoint, and single valued, and yet + // `True == 1` and `False == 0`. + (Type::BooleanLiteral(b), Type::IntLiteral(i)) + | (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i, + // We assume that tuples use `tuple.__eq__` which only returns True + // for other tuples, so they cannot compare equal to non-tuple types. + (Type::NominalInstance(instance), _) if instance.tuple_spec(db).is_some() => false, + (_, Type::NominalInstance(instance)) if instance.tuple_spec(db).is_some() => false, + // Other than the above cases, two single-valued disjoint types cannot compare + // equal. + _ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)), + } +} + struct NarrowingConstraintsBuilder<'db, 'ast> { db: &'db dyn Db, module: &'ast ParsedModuleRef, @@ -573,39 +606,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // something that is not a type, and handle this not-a-type in `symbol_from_bindings`, // instead of intersecting with a type.) - // Return `true` if it is possible for any two inhabitants of the given types to - // compare equal to each other; otherwise return `false`. - fn could_compare_equal<'db>( - db: &'db dyn Db, - left_ty: Type<'db>, - right_ty: Type<'db>, - ) -> bool { - if !left_ty.is_disjoint_from(db, right_ty) { - // If types overlap, they have inhabitants in common; it's definitely possible - // for an object to compare equal to itself. - return true; - } - match (left_ty, right_ty) { - // In order to be sure a union type cannot compare equal to another type, it - // must be true that no element of the union can compare equal to that type. - (Type::Union(union), _) => union - .elements(db) - .iter() - .any(|ty| could_compare_equal(db, *ty, right_ty)), - (_, Type::Union(union)) => union - .elements(db) - .iter() - .any(|ty| could_compare_equal(db, left_ty, *ty)), - // Boolean literals and int literals are disjoint, and single valued, and yet - // `True == 1` and `False == 0`. - (Type::BooleanLiteral(b), Type::IntLiteral(i)) - | (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i, - // Other than the above cases, two single-valued disjoint types cannot compare - // equal. - _ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)), - } - } - // Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot // compare equal to `rhs_ty`. fn can_narrow_to_rhs<'db>( @@ -660,7 +660,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) } _ => { - if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) { + if !could_compare_equal(db, ty, rhs_ty) { + // Cannot compare equal to rhs, so keep this type ty } else { Type::Never @@ -721,14 +722,22 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if let Some(lhs_union) = lhs_ty.as_union() { for element in lhs_union.elements(self.db) { - // Keep only the non-single-valued portion of the original type. - if !element.is_single_valued(self.db) - && !element.is_literal_string() - && !element.is_bool(self.db) - && (!element.is_enum(self.db) || element.overrides_equality(self.db)) - { - builder = builder.add(*element); + // Skip single-valued types (handled via RHS matching). + if element.is_single_valued(self.db) { + continue; } + // Skip types that are handled specially (LiteralString, bool, enum). + if element.is_literal_string() + || element.is_bool(self.db) + || (element.is_enum(self.db) && !element.overrides_equality(self.db)) + { + continue; + } + // Skip types that cannot compare equal to any RHS value. + if !could_compare_equal(self.db, *element, rhs_values) { + continue; + } + builder = builder.add(*element); } } Some(builder.build())