From b72391746345ee4d3695a4f76654f44e9884e890 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 23 Dec 2025 20:15:50 -0500 Subject: [PATCH] [ty] Support tuple narrowing based on member checks (#22167) ## Summary Closes https://github.com/astral-sh/ty/issues/2179. --- .../resources/mdtest/narrow/complex_target.md | 77 ++++++++++++++++++- crates/ty_python_semantic/src/types/narrow.rs | 44 ++++++++++- 2 files changed, 116 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md b/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md index 19d35b73b5..42017e1e94 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md @@ -197,15 +197,86 @@ def _(t1: tuple[int | None, int | None], t2: tuple[int, int] | tuple[None, None] n = 0 if t1[n] is not None: - # Non-literal subscript narrowing are currently not supported, as well as mypy, pyright + # Narrowing the individual element type with a non-literal subscript is not supported reveal_type(t1[0]) # revealed: int | None reveal_type(t1[n]) # revealed: int | None reveal_type(t1[1]) # revealed: int | None + # However, we can still discriminate between tuples in a union using a variable index: + if t2[n] is not None: + reveal_type(t2) # revealed: tuple[int, int] + if t2[0] is not None: + reveal_type(t2) # revealed: tuple[int, int] reveal_type(t2[0]) # revealed: int - # TODO: should be int - reveal_type(t2[1]) # revealed: int | None + reveal_type(t2[1]) # revealed: int + else: + reveal_type(t2) # revealed: tuple[None, None] + reveal_type(t2[0]) # revealed: None + reveal_type(t2[1]) # revealed: None + + if t2[0] is None: + reveal_type(t2) # revealed: tuple[None, None] + else: + reveal_type(t2) # revealed: tuple[int, int] + +def _(t3: tuple[int, str] | tuple[None, None] | tuple[bool, bytes]): + # Narrow to tuples where first element is not None + if t3[0] is not None: + reveal_type(t3) # revealed: tuple[int, str] | tuple[bool, bytes] + + # Narrow to tuples where first element is None + if t3[0] is None: + reveal_type(t3) # revealed: tuple[None, None] + +def _(t4: tuple[bool, int] | tuple[bool, str]): + # Both tuples have bool at index 0, which is not disjoint from True, + # so neither gets filtered out when checking `is True` + if t4[0] is True: + reveal_type(t4) # revealed: tuple[bool, int] | tuple[bool, str] + +def _(t5: tuple[int, None] | tuple[None, int]): + # Narrow on second element (index 1) + if t5[1] is not None: + reveal_type(t5) # revealed: tuple[None, int] + else: + reveal_type(t5) # revealed: tuple[int, None] + + # Negative index + if t5[-1] is None: + reveal_type(t5) # revealed: tuple[int, None] + +def _(t6: tuple[int, ...] | tuple[None, None]): + # Variadic tuple at index 0 has element type `int` (not a union), + # so `tuple[None, None]` gets filtered out + if t6[0] is not None: + reveal_type(t6) # revealed: tuple[int, ...] + +def _(t6b: tuple[int, ...] | tuple[None, ...]): + # Both variadic: `int` is disjoint from None, `None` is not disjoint from None + if t6b[0] is not None: + reveal_type(t6b) # revealed: tuple[int, ...] + else: + reveal_type(t6b) # revealed: tuple[None, ...] + +def _(t7: tuple[int, int] | tuple[None, None]): + # Index out of range for both tuples - no narrowing, but errors are emitted + # error: [index-out-of-bounds] "Index 5 is out of bounds for tuple `tuple[int, int]` with length 2" + # error: [index-out-of-bounds] "Index 5 is out of bounds for tuple `tuple[None, None]` with length 2" + if t7[5] is not None: + reveal_type(t7) # revealed: tuple[int, int] | tuple[None, None] + +def _(t8: tuple[int, int, int] | tuple[None, None]): + # Index in range for first tuple but out of range for second + # error: [index-out-of-bounds] "Index 2 is out of bounds for tuple `tuple[None, None]` with length 2" + if t8[2] is not None: + reveal_type(t8) # revealed: tuple[int, int, int] | tuple[None, None] + +def _(t9: tuple[int | None, str] | tuple[str, int]): + # When the element type is a union (like `int | None`), we can't filter + # out the tuple. + if t9[0] is not None: + reveal_type(t9) # revealed: tuple[int | None, str] | tuple[str, int] ``` ### String subscript diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 1ea3c28590..3a589d81bb 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -7,6 +7,7 @@ use crate::semantic_index::predicate::{ PredicateNode, }; use crate::semantic_index::scope::ScopeId; +use crate::subscript::PyIndex; use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; @@ -23,14 +24,13 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast::name::Name; use ruff_python_stdlib::identifiers::is_identifier; +use super::UnionType; use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; use std::collections::hash_map::Entry; -use super::UnionType; - /// Return the type constraint that `test` (if true) would place on `symbol`, if any. /// /// For example, if we have this code: @@ -881,6 +881,46 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); let mut constraints = NarrowingConstraints::default(); + // Narrow unions of tuples based on element checks. For example: + // + // def _(t: tuple[int, int] | tuple[None, None]): + // if t[0] is not None: + // reveal_type(t) # tuple[int, int] + if matches!(&**ops, [ast::CmpOp::Is | ast::CmpOp::IsNot]) + && let ast::Expr::Subscript(subscript) = &**left + && let Type::Union(union) = inference.expression_type(&*subscript.value) + && let Some(subscript_place_expr) = place_expr(&subscript.value) + && let Type::IntLiteral(index) = inference.expression_type(&*subscript.slice) + && let Ok(index) = i32::try_from(index) + && let rhs_ty = inference.expression_type(&comparators[0]) + && rhs_ty.is_singleton(self.db) + { + let is_positive_check = is_positive == (ops[0] == ast::CmpOp::Is); + let filtered: Vec<_> = union + .elements(self.db) + .iter() + .filter(|elem| { + elem.as_nominal_instance() + .and_then(|inst| inst.tuple_spec(self.db)) + .and_then(|spec| spec.py_index(self.db, index).ok()) + .is_none_or(|el_ty| { + if is_positive_check { + // `is X` context: keep tuples where element could be X + !el_ty.is_disjoint_from(self.db, rhs_ty) + } else { + // `is not X` context: keep tuples where element is not always X + !el_ty.is_subtype_of(self.db, rhs_ty) + } + }) + }) + .copied() + .collect(); + if filtered.len() < union.elements(self.db).len() { + let place = self.expect_place(&subscript_place_expr); + constraints.insert(place, UnionType::from_elements(self.db, filtered)); + } + } + // Narrow tagged unions of `TypedDict`s with `Literal` keys, for example: // // class Foo(TypedDict):