[ty] Apply narrowing to walrus targets (#22369)

## Summary

Closes https://github.com/astral-sh/ty/issues/2300.
This commit is contained in:
Charlie Marsh
2026-01-13 19:56:47 -05:00
committed by GitHub
parent ddd2fc7a90
commit ea46426157
6 changed files with 104 additions and 20 deletions

View File

@@ -69,6 +69,30 @@ def call_with_args(y: object, a: int, b: str) -> object:
return None
```
## Narrowing with named expressions (walrus operator)
When `callable()` is used with a named expression, the target of the named expression should be
narrowed.
```py
from typing import Any
class Foo:
func: Any | None
def f(foo: Foo):
first = getattr(foo, "func", None)
if callable(first):
reveal_type(first) # revealed: Any & Top[(...) -> object]
else:
reveal_type(first) # revealed: (Any & ~Top[(...) -> object]) | None
if callable(second := getattr(foo, "func", None)):
reveal_type(second) # revealed: Any & Top[(...) -> object]
else:
reveal_type(second) # revealed: (Any & ~Top[(...) -> object]) | None
```
## Assignability of narrowed callables
A narrowed callable `Top[Callable[..., object]]` should be assignable to `Callable[..., Any]`. This

View File

@@ -580,3 +580,19 @@ def test(a: Any, items: list[T]) -> None:
if isinstance(v, dict):
cast(T, v) # no panic
```
## Narrowing with named expressions (walrus operator)
When `isinstance()` is used with a named expression, the target of the named expression should be
narrowed.
```py
def get_value() -> int | str:
return 1
def f():
if isinstance(x := get_value(), int):
reveal_type(x) # revealed: int
else:
reveal_type(x) # revealed: str
```

View File

@@ -347,3 +347,19 @@ def _(x: LiteralString):
else:
reveal_type(x) # revealed: LiteralString & ~Literal[""]
```
## Narrowing with named expressions (walrus operator)
When a truthiness check is used with a named expression, the target of the named expression should
be narrowed.
```py
def get_value() -> str | None:
return "hello"
def f():
if x := get_value():
reveal_type(x) # revealed: str & ~AlwaysFalsy
else:
reveal_type(x) # revealed: (str & ~AlwaysTruthy) | None
```

View File

@@ -499,3 +499,32 @@ def _(x: object):
if f(x) and (g(x) or h(x)):
reveal_type(x) # revealed: B | (A & C)
```
## Narrowing with named expressions (walrus operator)
When a type guard is used with a named expression, the target of the named expression should be
narrowed.
```py
from typing_extensions import TypeGuard, TypeIs
def is_str(x: object) -> TypeIs[str]:
return isinstance(x, str)
def guard_str(x: object) -> TypeGuard[str]:
return isinstance(x, str)
def get_value() -> int | str:
return 1
def f():
if is_str(x := get_value()):
reveal_type(x) # revealed: str
else:
reveal_type(x) # revealed: int
if guard_str(y := get_value()):
reveal_type(y) # revealed: str
else:
reveal_type(y) # revealed: int | str
```

View File

@@ -38,6 +38,12 @@ impl PlaceExpr {
pub(crate) fn try_from_expr<'e>(expr: impl Into<ast::ExprRef<'e>>) -> Option<Self> {
let expr = expr.into();
// For named expressions (walrus operator), extract the target.
let expr = match expr {
ast::ExprRef::Named(named) => named.target.as_ref().into(),
_ => expr,
};
if let ast::ExprRef::Name(name) = expr {
return Some(PlaceExpr::Symbol(Symbol::new(name.id.clone())));
}

View File

@@ -454,13 +454,6 @@ fn merge_constraints_or<'db>(
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
match expr {
ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()),
_ => PlaceExpr::try_from_expr(expr),
}
}
/// Return `true` if it is possible for any two inhabitants of the given types to
/// compare equal to each other; otherwise return `false`.
fn could_compare_equal<'db>(db: &'db dyn Db, left_ty: Type<'db>, right_ty: Type<'db>) -> bool {
@@ -721,7 +714,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expr: &ast::Expr,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let target = place_expr(expr)?;
let target = PlaceExpr::try_from_expr(expr)?;
let place = self.expect_place(&target);
let ty = if is_positive {
@@ -1030,7 +1023,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
if matches!(&**ops, [ast::CmpOp::Is | ast::CmpOp::IsNot])
&& let ast::Expr::Subscript(subscript) = &**left
&& let Type::Union(union) = inference.expression_type(&*subscript.value)
&& let Some(subscript_place_expr) = place_expr(&subscript.value)
&& let Some(subscript_place_expr) = PlaceExpr::try_from_expr(&subscript.value)
&& let Type::IntLiteral(index) = inference.expression_type(&*subscript.slice)
&& let Ok(index) = i32::try_from(index)
&& let rhs_ty = inference.expression_type(&comparators[0])
@@ -1122,7 +1115,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
// reveal_type(u) # revealed: Bar
if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn])
&& let Type::StringLiteral(key) = inference.expression_type(&**left)
&& let Some(rhs_place_expr) = place_expr(&comparators[0])
&& let Some(rhs_place_expr) = PlaceExpr::try_from_expr(&comparators[0])
&& let rhs_type = inference.expression_type(&comparators[0])
&& is_typeddict_or_union_with_typeddicts(self.db, rhs_type)
{
@@ -1190,7 +1183,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_) => {
if let Some(left) = place_expr(left)
if let Some(left) = PlaceExpr::try_from_expr(left)
&& let Some(ty) =
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
@@ -1215,7 +1208,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
};
let target = match &**args {
[first] => match place_expr(first) {
[first] => match PlaceExpr::try_from_expr(first) {
Some(target) => target,
None => continue,
},
@@ -1264,7 +1257,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
| ast::Expr::Named(_)
) =>
{
if let Some(right_place) = place_expr(right)
if let Some(right_place) = PlaceExpr::try_from_expr(right)
// Swap lhs_ty and rhs_ty since we're narrowing the right operand
&& let Some(ty) =
self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive)
@@ -1315,7 +1308,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
// Narrow only the parts of the type that are safe to narrow based on len().
if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) {
let target = place_expr(arg)?;
let target = PlaceExpr::try_from_expr(arg)?;
let place = self.expect_place(&target);
Some(NarrowingConstraints::from_iter([(
place,
@@ -1329,7 +1322,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let [first_arg, second_arg] = &*expr_call.arguments.args else {
return None;
};
let first_arg = place_expr(first_arg)?;
let first_arg = PlaceExpr::try_from_expr(first_arg)?;
let function = function_type.known(self.db)?;
let place = self.expect_place(&first_arg);
@@ -1427,7 +1420,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
singleton: ast::Singleton,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let subject = PlaceExpr::try_from_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let ty = match singleton {
@@ -1456,7 +1449,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return None;
}
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let subject = PlaceExpr::try_from_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let class_type =
@@ -1486,7 +1479,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
let subject_node = subject.node_ref(self.db, self.module);
let place = {
let subject = place_expr(subject_node)?;
let subject = PlaceExpr::try_from_expr(subject_node)?;
self.expect_place(&subject)
};
let subject_ty =
@@ -1638,7 +1631,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
if !is_typeddict_or_union_with_typeddicts(self.db, subscript_value_type) {
return None;
}
let subscript_place_expr = place_expr(subscript_value_expr)?;
let subscript_place_expr = PlaceExpr::try_from_expr(subscript_value_expr)?;
let Type::StringLiteral(key_literal) = subscript_key_type else {
return None;
};
@@ -1724,7 +1717,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return None;
}
let subscript_place_expr = place_expr(subscript_value_expr)?;
let subscript_place_expr = PlaceExpr::try_from_expr(subscript_value_expr)?;
// Skip narrowing if any tuple in the union has an out-of-bounds index.
// A diagnostic will be emitted elsewhere for the out-of-bounds access.