mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 05:20:49 -05:00
[ty] Right-hand side narrowing for if Foo is type(x) expressions (#22608)
This commit is contained in:
@@ -138,3 +138,79 @@ if (x := f()) is None:
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
||||
|
||||
## `is` where the other operand is a call expression
|
||||
|
||||
```py
|
||||
from typing import Literal, final
|
||||
|
||||
def foo() -> Literal[42]:
|
||||
return 42
|
||||
|
||||
def f(x: object):
|
||||
if x is foo():
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
else:
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
if x is not foo():
|
||||
reveal_type(x) # revealed: object
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
|
||||
if foo() is x:
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
else:
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
if foo() is not x:
|
||||
reveal_type(x) # revealed: object
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
|
||||
def bar() -> int:
|
||||
return 42
|
||||
|
||||
def g(x: object):
|
||||
if x is bar():
|
||||
reveal_type(x) # revealed: int
|
||||
else:
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
if x is not bar():
|
||||
reveal_type(x) # revealed: object
|
||||
else:
|
||||
reveal_type(x) # revealed: int
|
||||
|
||||
@final
|
||||
class FinalClass: ...
|
||||
|
||||
def baz() -> FinalClass:
|
||||
return FinalClass()
|
||||
|
||||
def h(x: object):
|
||||
if x is baz():
|
||||
reveal_type(x) # revealed: FinalClass
|
||||
else:
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
if x is not baz():
|
||||
reveal_type(x) # revealed: object
|
||||
else:
|
||||
reveal_type(x) # revealed: FinalClass
|
||||
|
||||
def spam() -> None:
|
||||
return None
|
||||
|
||||
def h(x: object):
|
||||
if x is spam():
|
||||
reveal_type(x) # revealed: None
|
||||
else:
|
||||
# `else` narrowing can occur because `spam()` returns a singleton type
|
||||
reveal_type(x) # revealed: ~None
|
||||
|
||||
if x is not spam():
|
||||
reveal_type(x) # revealed: ~None
|
||||
else:
|
||||
reveal_type(x) # revealed: None
|
||||
```
|
||||
|
||||
@@ -20,6 +20,11 @@ def _(x: A | B, y: A | C):
|
||||
# to infer the full union type:
|
||||
reveal_type(x) # revealed: A | B
|
||||
|
||||
if A is type(x):
|
||||
reveal_type(x) # revealed: A
|
||||
else:
|
||||
reveal_type(x) # revealed: A | B
|
||||
|
||||
if type(y) is C:
|
||||
reveal_type(y) # revealed: C
|
||||
else:
|
||||
@@ -28,6 +33,11 @@ def _(x: A | B, y: A | C):
|
||||
# and `C` could exist
|
||||
reveal_type(y) # revealed: A
|
||||
|
||||
if C is type(y):
|
||||
reveal_type(y) # revealed: C
|
||||
else:
|
||||
reveal_type(y) # revealed: A
|
||||
|
||||
if type(y) is A:
|
||||
reveal_type(y) # revealed: A
|
||||
else:
|
||||
@@ -36,6 +46,11 @@ def _(x: A | B, y: A | C):
|
||||
# to `False` even if `y` was an instance of `A`,
|
||||
# so narrowing cannot occur
|
||||
reveal_type(y) # revealed: A | C
|
||||
|
||||
if A is type(y):
|
||||
reveal_type(y) # revealed: A
|
||||
else:
|
||||
reveal_type(y) # revealed: A | C
|
||||
```
|
||||
|
||||
## `type(x) is not C`
|
||||
|
||||
@@ -130,6 +130,12 @@ pub(crate) mod node_key {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Box<ast::Expr>> for ExpressionNodeKey {
|
||||
fn from(value: &Box<ast::Expr>) -> Self {
|
||||
Self(NodeKey::from_node(&**value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ast::ExprCall> for ExpressionNodeKey {
|
||||
fn from(value: &ast::ExprCall) -> Self {
|
||||
Self(NodeKey::from_node(value))
|
||||
|
||||
@@ -1191,23 +1191,19 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
let rhs_ty = inference.expression_type(right);
|
||||
last_rhs_ty = Some(rhs_ty);
|
||||
|
||||
match left {
|
||||
ast::Expr::Name(_)
|
||||
| ast::Expr::Attribute(_)
|
||||
| ast::Expr::Subscript(_)
|
||||
| ast::Expr::Named(_) => {
|
||||
if let Some(left) = PlaceExpr::try_from_expr(left)
|
||||
&& let Some(ty) =
|
||||
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
|
||||
{
|
||||
let place = self.expect_place(&left);
|
||||
constraints.insert(place, NarrowingConstraint::intersection(ty));
|
||||
}
|
||||
}
|
||||
ast::Expr::Call(ast::ExprCall {
|
||||
// Narrowing for:
|
||||
// - `if type(x) is Y`
|
||||
// - `if type(x) is not Y`
|
||||
// - `if Y is type(x)`
|
||||
// - `if Y is not type(x)`
|
||||
if let (ast::Expr::Call(call), _, _, Type::ClassLiteral(class))
|
||||
| (_, Type::ClassLiteral(class), ast::Expr::Call(call), _) =
|
||||
(left, lhs_ty, right, rhs_ty)
|
||||
{
|
||||
let ast::ExprCall {
|
||||
range: _,
|
||||
node_index: _,
|
||||
func: callable,
|
||||
func,
|
||||
arguments:
|
||||
ast::Arguments {
|
||||
args,
|
||||
@@ -1215,71 +1211,82 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
range: _,
|
||||
node_index: _,
|
||||
},
|
||||
}) if keywords.is_empty() => {
|
||||
let Type::ClassLiteral(rhs_class) = rhs_ty else {
|
||||
continue;
|
||||
};
|
||||
} = call;
|
||||
|
||||
let target = match &**args {
|
||||
[first] => match PlaceExpr::try_from_expr(first) {
|
||||
Some(target) => target,
|
||||
None => continue,
|
||||
},
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let is_positive = match op {
|
||||
ast::CmpOp::Is => is_positive,
|
||||
ast::CmpOp::IsNot => !is_positive,
|
||||
_ => continue,
|
||||
};
|
||||
// If this is `None`, it indicates that we cannot do `if type(x) is Y`
|
||||
// narrowing: we can only do narrowing for `if type(x) is Y` and
|
||||
// `if type(x) is not Y`, not for `if type(x) == Y` or `if type(x) != Y`.
|
||||
let is_positive = match op {
|
||||
ast::CmpOp::Is => Some(is_positive),
|
||||
ast::CmpOp::IsNot => Some(!is_positive),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(is_positive) = is_positive
|
||||
&& keywords.is_empty()
|
||||
&& let [single_argument] = &**args
|
||||
&& let Some(target) = PlaceExpr::try_from_expr(single_argument)
|
||||
// `else`-branch narrowing for `if type(x) is Y` can only be done
|
||||
// if `Y` is a final class
|
||||
if !rhs_class.is_final(self.db) && !is_positive {
|
||||
continue;
|
||||
}
|
||||
|
||||
let callable_type = inference.expression_type(&**callable);
|
||||
|
||||
if callable_type
|
||||
.as_class_literal()
|
||||
.is_some_and(|c| c.is_known(self.db, KnownClass::Type))
|
||||
{
|
||||
let place = self.expect_place(&target);
|
||||
constraints.insert(
|
||||
place,
|
||||
NarrowingConstraint::intersection(
|
||||
Type::instance(self.db, rhs_class.top_materialization(self.db))
|
||||
.negate_if(self.db, !is_positive),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
// For symmetric operators (==, !=, is, is not), if left is not a narrowable target,
|
||||
// try to narrow the right operand instead by swapping the operands.
|
||||
// E.g., `None != x` should narrow `x` the same way as `x != None`.
|
||||
_ if matches!(
|
||||
op,
|
||||
ast::CmpOp::Eq | ast::CmpOp::NotEq | ast::CmpOp::Is | ast::CmpOp::IsNot
|
||||
) && matches!(
|
||||
right,
|
||||
ast::Expr::Name(_)
|
||||
| ast::Expr::Attribute(_)
|
||||
| ast::Expr::Subscript(_)
|
||||
| ast::Expr::Named(_)
|
||||
) =>
|
||||
&& (is_positive || class.is_final(self.db))
|
||||
&& let Type::ClassLiteral(called_class) = inference.expression_type(func)
|
||||
&& called_class.is_known(self.db, KnownClass::Type)
|
||||
{
|
||||
if let Some(right_place) = PlaceExpr::try_from_expr(right)
|
||||
// Swap lhs_ty and rhs_ty since we're narrowing the right operand
|
||||
&& let Some(ty) =
|
||||
self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive)
|
||||
{
|
||||
let place = self.expect_place(&right_place);
|
||||
constraints.insert(place, NarrowingConstraint::intersection(ty));
|
||||
}
|
||||
let place = self.expect_place(&target);
|
||||
constraints.insert(
|
||||
place,
|
||||
NarrowingConstraint::intersection(
|
||||
Type::instance(self.db, class.top_materialization(self.db))
|
||||
.negate_if(self.db, !is_positive),
|
||||
),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Left-hand-side narrowing for:
|
||||
// - `if x == y`
|
||||
// - `if x != y`
|
||||
// - `if x is y`
|
||||
// - `if x is not y`
|
||||
// - `if x in y`
|
||||
// - `if x not in y`
|
||||
//
|
||||
// Right-hand side narrowing for:
|
||||
// - `if y == x`
|
||||
// - `if y != x`
|
||||
// - `if y is x`
|
||||
// - `if y is not x`
|
||||
if let (
|
||||
narrowable @ (ast::Expr::Name(_)
|
||||
| ast::Expr::Attribute(_)
|
||||
| ast::Expr::Subscript(_)
|
||||
| ast::Expr::Named(_)),
|
||||
narrowable_type,
|
||||
_,
|
||||
other_type,
|
||||
)
|
||||
| (
|
||||
_,
|
||||
other_type,
|
||||
narrowable @ (ast::Expr::Name(_)
|
||||
| ast::Expr::Attribute(_)
|
||||
| ast::Expr::Subscript(_)
|
||||
| ast::Expr::Named(_)),
|
||||
narrowable_type,
|
||||
) = (left, lhs_ty, right, rhs_ty)
|
||||
{
|
||||
// The right-hand side can only be narrowed for a symmetric operator.
|
||||
// `in` and `not in` are not symmetric.
|
||||
if (narrowable == left || !matches!(op, ast::CmpOp::In | ast::CmpOp::NotIn))
|
||||
&& let Some(narrowable) = PlaceExpr::try_from_expr(narrowable)
|
||||
&& let Some(ty) =
|
||||
self.evaluate_expr_compare_op(narrowable_type, other_type, *op, is_positive)
|
||||
{
|
||||
let place = self.expect_place(&narrowable);
|
||||
constraints.insert(place, NarrowingConstraint::intersection(ty));
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Some(constraints)
|
||||
|
||||
Reference in New Issue
Block a user