From ea4642615794edca896538e4756206c26111ba85 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 13 Jan 2026 19:56:47 -0500 Subject: [PATCH] [ty] Apply narrowing to walrus targets (#22369) ## Summary Closes https://github.com/astral-sh/ty/issues/2300. --- .../resources/mdtest/narrow/callable.md | 24 ++++++++++++++ .../resources/mdtest/narrow/isinstance.md | 16 +++++++++ .../resources/mdtest/narrow/truthiness.md | 16 +++++++++ .../resources/mdtest/narrow/type_guards.md | 29 ++++++++++++++++ .../src/semantic_index/place.rs | 6 ++++ crates/ty_python_semantic/src/types/narrow.rs | 33 ++++++++----------- 6 files changed, 104 insertions(+), 20 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/callable.md b/crates/ty_python_semantic/resources/mdtest/narrow/callable.md index 07f225b051..9326539e0c 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/callable.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/callable.md @@ -69,6 +69,30 @@ def call_with_args(y: object, a: int, b: str) -> object: return None ``` +## Narrowing with named expressions (walrus operator) + +When `callable()` is used with a named expression, the target of the named expression should be +narrowed. + +```py +from typing import Any + +class Foo: + func: Any | None + +def f(foo: Foo): + first = getattr(foo, "func", None) + if callable(first): + reveal_type(first) # revealed: Any & Top[(...) -> object] + else: + reveal_type(first) # revealed: (Any & ~Top[(...) -> object]) | None + + if callable(second := getattr(foo, "func", None)): + reveal_type(second) # revealed: Any & Top[(...) -> object] + else: + reveal_type(second) # revealed: (Any & ~Top[(...) -> object]) | None +``` + ## Assignability of narrowed callables A narrowed callable `Top[Callable[..., object]]` should be assignable to `Callable[..., Any]`. This diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md b/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md index 7121c4c1dd..4e7efb7a10 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md @@ -580,3 +580,19 @@ def test(a: Any, items: list[T]) -> None: if isinstance(v, dict): cast(T, v) # no panic ``` + +## Narrowing with named expressions (walrus operator) + +When `isinstance()` is used with a named expression, the target of the named expression should be +narrowed. + +```py +def get_value() -> int | str: + return 1 + +def f(): + if isinstance(x := get_value(), int): + reveal_type(x) # revealed: int + else: + reveal_type(x) # revealed: str +``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md index cb9f0c4545..a7666dcf53 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md @@ -347,3 +347,19 @@ def _(x: LiteralString): else: reveal_type(x) # revealed: LiteralString & ~Literal[""] ``` + +## Narrowing with named expressions (walrus operator) + +When a truthiness check is used with a named expression, the target of the named expression should +be narrowed. + +```py +def get_value() -> str | None: + return "hello" + +def f(): + if x := get_value(): + reveal_type(x) # revealed: str & ~AlwaysFalsy + else: + reveal_type(x) # revealed: (str & ~AlwaysTruthy) | None +``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 8e492f4ba2..170dab1162 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -499,3 +499,32 @@ def _(x: object): if f(x) and (g(x) or h(x)): reveal_type(x) # revealed: B | (A & C) ``` + +## Narrowing with named expressions (walrus operator) + +When a type guard is used with a named expression, the target of the named expression should be +narrowed. + +```py +from typing_extensions import TypeGuard, TypeIs + +def is_str(x: object) -> TypeIs[str]: + return isinstance(x, str) + +def guard_str(x: object) -> TypeGuard[str]: + return isinstance(x, str) + +def get_value() -> int | str: + return 1 + +def f(): + if is_str(x := get_value()): + reveal_type(x) # revealed: str + else: + reveal_type(x) # revealed: int + + if guard_str(y := get_value()): + reveal_type(y) # revealed: str + else: + reveal_type(y) # revealed: int | str +``` diff --git a/crates/ty_python_semantic/src/semantic_index/place.rs b/crates/ty_python_semantic/src/semantic_index/place.rs index 04bea97626..850a43bf80 100644 --- a/crates/ty_python_semantic/src/semantic_index/place.rs +++ b/crates/ty_python_semantic/src/semantic_index/place.rs @@ -38,6 +38,12 @@ impl PlaceExpr { pub(crate) fn try_from_expr<'e>(expr: impl Into>) -> Option { let expr = expr.into(); + // For named expressions (walrus operator), extract the target. + let expr = match expr { + ast::ExprRef::Named(named) => named.target.as_ref().into(), + _ => expr, + }; + if let ast::ExprRef::Name(name) = expr { return Some(PlaceExpr::Symbol(Symbol::new(name.id.clone()))); } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 6a044e3e13..0bf9d539f0 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -454,13 +454,6 @@ fn merge_constraints_or<'db>( } } -fn place_expr(expr: &ast::Expr) -> Option { - match expr { - ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()), - _ => PlaceExpr::try_from_expr(expr), - } -} - /// Return `true` if it is possible for any two inhabitants of the given types to /// compare equal to each other; otherwise return `false`. fn could_compare_equal<'db>(db: &'db dyn Db, left_ty: Type<'db>, right_ty: Type<'db>) -> bool { @@ -721,7 +714,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr: &ast::Expr, is_positive: bool, ) -> Option> { - let target = place_expr(expr)?; + let target = PlaceExpr::try_from_expr(expr)?; let place = self.expect_place(&target); let ty = if is_positive { @@ -1030,7 +1023,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if matches!(&**ops, [ast::CmpOp::Is | ast::CmpOp::IsNot]) && let ast::Expr::Subscript(subscript) = &**left && let Type::Union(union) = inference.expression_type(&*subscript.value) - && let Some(subscript_place_expr) = place_expr(&subscript.value) + && let Some(subscript_place_expr) = PlaceExpr::try_from_expr(&subscript.value) && let Type::IntLiteral(index) = inference.expression_type(&*subscript.slice) && let Ok(index) = i32::try_from(index) && let rhs_ty = inference.expression_type(&comparators[0]) @@ -1122,7 +1115,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // reveal_type(u) # revealed: Bar if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn]) && let Type::StringLiteral(key) = inference.expression_type(&**left) - && let Some(rhs_place_expr) = place_expr(&comparators[0]) + && let Some(rhs_place_expr) = PlaceExpr::try_from_expr(&comparators[0]) && let rhs_type = inference.expression_type(&comparators[0]) && is_typeddict_or_union_with_typeddicts(self.db, rhs_type) { @@ -1190,7 +1183,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) | ast::Expr::Named(_) => { - if let Some(left) = place_expr(left) + if let Some(left) = PlaceExpr::try_from_expr(left) && let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) { @@ -1215,7 +1208,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }; let target = match &**args { - [first] => match place_expr(first) { + [first] => match PlaceExpr::try_from_expr(first) { Some(target) => target, None => continue, }, @@ -1264,7 +1257,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | ast::Expr::Named(_) ) => { - if let Some(right_place) = place_expr(right) + if let Some(right_place) = PlaceExpr::try_from_expr(right) // Swap lhs_ty and rhs_ty since we're narrowing the right operand && let Some(ty) = self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive) @@ -1315,7 +1308,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // Narrow only the parts of the type that are safe to narrow based on len(). if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) { - let target = place_expr(arg)?; + let target = PlaceExpr::try_from_expr(arg)?; let place = self.expect_place(&target); Some(NarrowingConstraints::from_iter([( place, @@ -1329,7 +1322,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let [first_arg, second_arg] = &*expr_call.arguments.args else { return None; }; - let first_arg = place_expr(first_arg)?; + let first_arg = PlaceExpr::try_from_expr(first_arg)?; let function = function_type.known(self.db)?; let place = self.expect_place(&first_arg); @@ -1427,7 +1420,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { singleton: ast::Singleton, is_positive: bool, ) -> Option> { - let subject = place_expr(subject.node_ref(self.db, self.module))?; + let subject = PlaceExpr::try_from_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); let ty = match singleton { @@ -1456,7 +1449,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let subject = place_expr(subject.node_ref(self.db, self.module))?; + let subject = PlaceExpr::try_from_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); let class_type = @@ -1486,7 +1479,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) -> Option> { let subject_node = subject.node_ref(self.db, self.module); let place = { - let subject = place_expr(subject_node)?; + let subject = PlaceExpr::try_from_expr(subject_node)?; self.expect_place(&subject) }; let subject_ty = @@ -1638,7 +1631,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if !is_typeddict_or_union_with_typeddicts(self.db, subscript_value_type) { return None; } - let subscript_place_expr = place_expr(subscript_value_expr)?; + let subscript_place_expr = PlaceExpr::try_from_expr(subscript_value_expr)?; let Type::StringLiteral(key_literal) = subscript_key_type else { return None; }; @@ -1724,7 +1717,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let subscript_place_expr = place_expr(subscript_value_expr)?; + let subscript_place_expr = PlaceExpr::try_from_expr(subscript_value_expr)?; // Skip narrowing if any tuple in the union has an out-of-bounds index. // A diagnostic will be emitted elsewhere for the out-of-bounds access.