diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md index e3c40104f1..41e4bf56ac 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md @@ -138,3 +138,79 @@ if (x := f()) is None: else: reveal_type(x) # revealed: Literal[1, 2] ``` + +## `is` where the other operand is a call expression + +```py +from typing import Literal, final + +def foo() -> Literal[42]: + return 42 + +def f(x: object): + if x is foo(): + reveal_type(x) # revealed: Literal[42] + else: + reveal_type(x) # revealed: object + + if x is not foo(): + reveal_type(x) # revealed: object + else: + reveal_type(x) # revealed: Literal[42] + + if foo() is x: + reveal_type(x) # revealed: Literal[42] + else: + reveal_type(x) # revealed: object + + if foo() is not x: + reveal_type(x) # revealed: object + else: + reveal_type(x) # revealed: Literal[42] + +def bar() -> int: + return 42 + +def g(x: object): + if x is bar(): + reveal_type(x) # revealed: int + else: + reveal_type(x) # revealed: object + + if x is not bar(): + reveal_type(x) # revealed: object + else: + reveal_type(x) # revealed: int + +@final +class FinalClass: ... + +def baz() -> FinalClass: + return FinalClass() + +def h(x: object): + if x is baz(): + reveal_type(x) # revealed: FinalClass + else: + reveal_type(x) # revealed: object + + if x is not baz(): + reveal_type(x) # revealed: object + else: + reveal_type(x) # revealed: FinalClass + +def spam() -> None: + return None + +def h(x: object): + if x is spam(): + reveal_type(x) # revealed: None + else: + # `else` narrowing can occur because `spam()` returns a singleton type + reveal_type(x) # revealed: ~None + + if x is not spam(): + reveal_type(x) # revealed: ~None + else: + reveal_type(x) # revealed: None +``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index 2a68069abe..14b90a20c8 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -20,6 +20,11 @@ def _(x: A | B, y: A | C): # to infer the full union type: reveal_type(x) # revealed: A | B + if A is type(x): + reveal_type(x) # revealed: A + else: + reveal_type(x) # revealed: A | B + if type(y) is C: reveal_type(y) # revealed: C else: @@ -28,6 +33,11 @@ def _(x: A | B, y: A | C): # and `C` could exist reveal_type(y) # revealed: A + if C is type(y): + reveal_type(y) # revealed: C + else: + reveal_type(y) # revealed: A + if type(y) is A: reveal_type(y) # revealed: A else: @@ -36,6 +46,11 @@ def _(x: A | B, y: A | C): # to `False` even if `y` was an instance of `A`, # so narrowing cannot occur reveal_type(y) # revealed: A | C + + if A is type(y): + reveal_type(y) # revealed: A + else: + reveal_type(y) # revealed: A | C ``` ## `type(x) is not C` diff --git a/crates/ty_python_semantic/src/semantic_index/ast_ids.rs b/crates/ty_python_semantic/src/semantic_index/ast_ids.rs index cc2c65526e..5b8e83a2b0 100644 --- a/crates/ty_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/ty_python_semantic/src/semantic_index/ast_ids.rs @@ -130,6 +130,12 @@ pub(crate) mod node_key { } } + impl From<&Box> for ExpressionNodeKey { + fn from(value: &Box) -> Self { + Self(NodeKey::from_node(&**value)) + } + } + impl From<&ast::ExprCall> for ExpressionNodeKey { fn from(value: &ast::ExprCall) -> Self { Self(NodeKey::from_node(value)) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index d56e2c022e..e8f7fd2c06 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1191,23 +1191,19 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let rhs_ty = inference.expression_type(right); last_rhs_ty = Some(rhs_ty); - match left { - ast::Expr::Name(_) - | ast::Expr::Attribute(_) - | ast::Expr::Subscript(_) - | ast::Expr::Named(_) => { - if let Some(left) = PlaceExpr::try_from_expr(left) - && let Some(ty) = - self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) - { - let place = self.expect_place(&left); - constraints.insert(place, NarrowingConstraint::intersection(ty)); - } - } - ast::Expr::Call(ast::ExprCall { + // Narrowing for: + // - `if type(x) is Y` + // - `if type(x) is not Y` + // - `if Y is type(x)` + // - `if Y is not type(x)` + if let (ast::Expr::Call(call), _, _, Type::ClassLiteral(class)) + | (_, Type::ClassLiteral(class), ast::Expr::Call(call), _) = + (left, lhs_ty, right, rhs_ty) + { + let ast::ExprCall { range: _, node_index: _, - func: callable, + func, arguments: ast::Arguments { args, @@ -1215,71 +1211,82 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { range: _, node_index: _, }, - }) if keywords.is_empty() => { - let Type::ClassLiteral(rhs_class) = rhs_ty else { - continue; - }; + } = call; - let target = match &**args { - [first] => match PlaceExpr::try_from_expr(first) { - Some(target) => target, - None => continue, - }, - _ => continue, - }; - - let is_positive = match op { - ast::CmpOp::Is => is_positive, - ast::CmpOp::IsNot => !is_positive, - _ => continue, - }; + // If this is `None`, it indicates that we cannot do `if type(x) is Y` + // narrowing: we can only do narrowing for `if type(x) is Y` and + // `if type(x) is not Y`, not for `if type(x) == Y` or `if type(x) != Y`. + let is_positive = match op { + ast::CmpOp::Is => Some(is_positive), + ast::CmpOp::IsNot => Some(!is_positive), + _ => None, + }; + if let Some(is_positive) = is_positive + && keywords.is_empty() + && let [single_argument] = &**args + && let Some(target) = PlaceExpr::try_from_expr(single_argument) // `else`-branch narrowing for `if type(x) is Y` can only be done // if `Y` is a final class - if !rhs_class.is_final(self.db) && !is_positive { - continue; - } - - let callable_type = inference.expression_type(&**callable); - - if callable_type - .as_class_literal() - .is_some_and(|c| c.is_known(self.db, KnownClass::Type)) - { - let place = self.expect_place(&target); - constraints.insert( - place, - NarrowingConstraint::intersection( - Type::instance(self.db, rhs_class.top_materialization(self.db)) - .negate_if(self.db, !is_positive), - ), - ); - } - } - // For symmetric operators (==, !=, is, is not), if left is not a narrowable target, - // try to narrow the right operand instead by swapping the operands. - // E.g., `None != x` should narrow `x` the same way as `x != None`. - _ if matches!( - op, - ast::CmpOp::Eq | ast::CmpOp::NotEq | ast::CmpOp::Is | ast::CmpOp::IsNot - ) && matches!( - right, - ast::Expr::Name(_) - | ast::Expr::Attribute(_) - | ast::Expr::Subscript(_) - | ast::Expr::Named(_) - ) => + && (is_positive || class.is_final(self.db)) + && let Type::ClassLiteral(called_class) = inference.expression_type(func) + && called_class.is_known(self.db, KnownClass::Type) { - if let Some(right_place) = PlaceExpr::try_from_expr(right) - // Swap lhs_ty and rhs_ty since we're narrowing the right operand - && let Some(ty) = - self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive) - { - let place = self.expect_place(&right_place); - constraints.insert(place, NarrowingConstraint::intersection(ty)); - } + let place = self.expect_place(&target); + constraints.insert( + place, + NarrowingConstraint::intersection( + Type::instance(self.db, class.top_materialization(self.db)) + .negate_if(self.db, !is_positive), + ), + ); + continue; + } + } + + // Left-hand-side narrowing for: + // - `if x == y` + // - `if x != y` + // - `if x is y` + // - `if x is not y` + // - `if x in y` + // - `if x not in y` + // + // Right-hand side narrowing for: + // - `if y == x` + // - `if y != x` + // - `if y is x` + // - `if y is not x` + if let ( + narrowable @ (ast::Expr::Name(_) + | ast::Expr::Attribute(_) + | ast::Expr::Subscript(_) + | ast::Expr::Named(_)), + narrowable_type, + _, + other_type, + ) + | ( + _, + other_type, + narrowable @ (ast::Expr::Name(_) + | ast::Expr::Attribute(_) + | ast::Expr::Subscript(_) + | ast::Expr::Named(_)), + narrowable_type, + ) = (left, lhs_ty, right, rhs_ty) + { + // The right-hand side can only be narrowed for a symmetric operator. + // `in` and `not in` are not symmetric. + if (narrowable == left || !matches!(op, ast::CmpOp::In | ast::CmpOp::NotIn)) + && let Some(narrowable) = PlaceExpr::try_from_expr(narrowable) + && let Some(ty) = + self.evaluate_expr_compare_op(narrowable_type, other_type, *op, is_positive) + { + let place = self.expect_place(&narrowable); + constraints.insert(place, NarrowingConstraint::intersection(ty)); + continue; } - _ => {} } } Some(constraints)