[ty] Right-hand side narrowing for if Foo is type(x) expressions (#22608)

This commit is contained in:
Alex Waygood
2026-01-17 15:49:51 +00:00
committed by GitHub
parent df58d67974
commit 3608c620ac
4 changed files with 178 additions and 74 deletions

View File

@@ -138,3 +138,79 @@ if (x := f()) is None:
else:
reveal_type(x) # revealed: Literal[1, 2]
```
## `is` where the other operand is a call expression
```py
from typing import Literal, final
def foo() -> Literal[42]:
return 42
def f(x: object):
if x is foo():
reveal_type(x) # revealed: Literal[42]
else:
reveal_type(x) # revealed: object
if x is not foo():
reveal_type(x) # revealed: object
else:
reveal_type(x) # revealed: Literal[42]
if foo() is x:
reveal_type(x) # revealed: Literal[42]
else:
reveal_type(x) # revealed: object
if foo() is not x:
reveal_type(x) # revealed: object
else:
reveal_type(x) # revealed: Literal[42]
def bar() -> int:
return 42
def g(x: object):
if x is bar():
reveal_type(x) # revealed: int
else:
reveal_type(x) # revealed: object
if x is not bar():
reveal_type(x) # revealed: object
else:
reveal_type(x) # revealed: int
@final
class FinalClass: ...
def baz() -> FinalClass:
return FinalClass()
def h(x: object):
if x is baz():
reveal_type(x) # revealed: FinalClass
else:
reveal_type(x) # revealed: object
if x is not baz():
reveal_type(x) # revealed: object
else:
reveal_type(x) # revealed: FinalClass
def spam() -> None:
return None
def h(x: object):
if x is spam():
reveal_type(x) # revealed: None
else:
# `else` narrowing can occur because `spam()` returns a singleton type
reveal_type(x) # revealed: ~None
if x is not spam():
reveal_type(x) # revealed: ~None
else:
reveal_type(x) # revealed: None
```

View File

@@ -20,6 +20,11 @@ def _(x: A | B, y: A | C):
# to infer the full union type:
reveal_type(x) # revealed: A | B
if A is type(x):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B
if type(y) is C:
reveal_type(y) # revealed: C
else:
@@ -28,6 +33,11 @@ def _(x: A | B, y: A | C):
# and `C` could exist
reveal_type(y) # revealed: A
if C is type(y):
reveal_type(y) # revealed: C
else:
reveal_type(y) # revealed: A
if type(y) is A:
reveal_type(y) # revealed: A
else:
@@ -36,6 +46,11 @@ def _(x: A | B, y: A | C):
# to `False` even if `y` was an instance of `A`,
# so narrowing cannot occur
reveal_type(y) # revealed: A | C
if A is type(y):
reveal_type(y) # revealed: A
else:
reveal_type(y) # revealed: A | C
```
## `type(x) is not C`

View File

@@ -130,6 +130,12 @@ pub(crate) mod node_key {
}
}
impl From<&Box<ast::Expr>> for ExpressionNodeKey {
fn from(value: &Box<ast::Expr>) -> Self {
Self(NodeKey::from_node(&**value))
}
}
impl From<&ast::ExprCall> for ExpressionNodeKey {
fn from(value: &ast::ExprCall) -> Self {
Self(NodeKey::from_node(value))

View File

@@ -1191,23 +1191,19 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let rhs_ty = inference.expression_type(right);
last_rhs_ty = Some(rhs_ty);
match left {
ast::Expr::Name(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_) => {
if let Some(left) = PlaceExpr::try_from_expr(left)
&& let Some(ty) =
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert(place, NarrowingConstraint::intersection(ty));
}
}
ast::Expr::Call(ast::ExprCall {
// Narrowing for:
// - `if type(x) is Y`
// - `if type(x) is not Y`
// - `if Y is type(x)`
// - `if Y is not type(x)`
if let (ast::Expr::Call(call), _, _, Type::ClassLiteral(class))
| (_, Type::ClassLiteral(class), ast::Expr::Call(call), _) =
(left, lhs_ty, right, rhs_ty)
{
let ast::ExprCall {
range: _,
node_index: _,
func: callable,
func,
arguments:
ast::Arguments {
args,
@@ -1215,71 +1211,82 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
range: _,
node_index: _,
},
}) if keywords.is_empty() => {
let Type::ClassLiteral(rhs_class) = rhs_ty else {
continue;
};
} = call;
let target = match &**args {
[first] => match PlaceExpr::try_from_expr(first) {
Some(target) => target,
None => continue,
},
_ => continue,
};
let is_positive = match op {
ast::CmpOp::Is => is_positive,
ast::CmpOp::IsNot => !is_positive,
_ => continue,
};
// If this is `None`, it indicates that we cannot do `if type(x) is Y`
// narrowing: we can only do narrowing for `if type(x) is Y` and
// `if type(x) is not Y`, not for `if type(x) == Y` or `if type(x) != Y`.
let is_positive = match op {
ast::CmpOp::Is => Some(is_positive),
ast::CmpOp::IsNot => Some(!is_positive),
_ => None,
};
if let Some(is_positive) = is_positive
&& keywords.is_empty()
&& let [single_argument] = &**args
&& let Some(target) = PlaceExpr::try_from_expr(single_argument)
// `else`-branch narrowing for `if type(x) is Y` can only be done
// if `Y` is a final class
if !rhs_class.is_final(self.db) && !is_positive {
continue;
}
let callable_type = inference.expression_type(&**callable);
if callable_type
.as_class_literal()
.is_some_and(|c| c.is_known(self.db, KnownClass::Type))
{
let place = self.expect_place(&target);
constraints.insert(
place,
NarrowingConstraint::intersection(
Type::instance(self.db, rhs_class.top_materialization(self.db))
.negate_if(self.db, !is_positive),
),
);
}
}
// For symmetric operators (==, !=, is, is not), if left is not a narrowable target,
// try to narrow the right operand instead by swapping the operands.
// E.g., `None != x` should narrow `x` the same way as `x != None`.
_ if matches!(
op,
ast::CmpOp::Eq | ast::CmpOp::NotEq | ast::CmpOp::Is | ast::CmpOp::IsNot
) && matches!(
right,
ast::Expr::Name(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_)
) =>
&& (is_positive || class.is_final(self.db))
&& let Type::ClassLiteral(called_class) = inference.expression_type(func)
&& called_class.is_known(self.db, KnownClass::Type)
{
if let Some(right_place) = PlaceExpr::try_from_expr(right)
// Swap lhs_ty and rhs_ty since we're narrowing the right operand
&& let Some(ty) =
self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive)
{
let place = self.expect_place(&right_place);
constraints.insert(place, NarrowingConstraint::intersection(ty));
}
let place = self.expect_place(&target);
constraints.insert(
place,
NarrowingConstraint::intersection(
Type::instance(self.db, class.top_materialization(self.db))
.negate_if(self.db, !is_positive),
),
);
continue;
}
}
// Left-hand-side narrowing for:
// - `if x == y`
// - `if x != y`
// - `if x is y`
// - `if x is not y`
// - `if x in y`
// - `if x not in y`
//
// Right-hand side narrowing for:
// - `if y == x`
// - `if y != x`
// - `if y is x`
// - `if y is not x`
if let (
narrowable @ (ast::Expr::Name(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_)),
narrowable_type,
_,
other_type,
)
| (
_,
other_type,
narrowable @ (ast::Expr::Name(_)
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_)),
narrowable_type,
) = (left, lhs_ty, right, rhs_ty)
{
// The right-hand side can only be narrowed for a symmetric operator.
// `in` and `not in` are not symmetric.
if (narrowable == left || !matches!(op, ast::CmpOp::In | ast::CmpOp::NotIn))
&& let Some(narrowable) = PlaceExpr::try_from_expr(narrowable)
&& let Some(ty) =
self.evaluate_expr_compare_op(narrowable_type, other_type, *op, is_positive)
{
let place = self.expect_place(&narrowable);
constraints.insert(place, NarrowingConstraint::intersection(ty));
continue;
}
_ => {}
}
}
Some(constraints)