From 12dd27da527bf26f4f43e3f0421d8fa0c4597cc1 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 30 Dec 2025 13:45:07 -0500 Subject: [PATCH] [ty] Support narrowing for tuple matches with literal elements (#22303) ## Summary See: https://github.com/astral-sh/ruff/pull/22299#issuecomment-3699913849. --- .../resources/mdtest/narrow/complex_target.md | 95 +++++++++++ .../resources/mdtest/narrow/match.md | 67 ++++++++ crates/ty_python_semantic/src/types/narrow.rs | 153 +++++++++++++++++- 3 files changed, 312 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md b/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md index 42017e1e94..b1a324e557 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md @@ -279,6 +279,101 @@ def _(t9: tuple[int | None, str] | tuple[str, int]): reveal_type(t9) # revealed: tuple[int | None, str] | tuple[str, int] ``` +### Tagged unions of tuples (equality narrowing) + +Narrow unions of tuples based on literal tag elements using `==` comparison: + +```py +from typing import Literal + +class A: ... +class B: ... +class C: ... + +def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]): + if x[0] == "tag1": + reveal_type(x) # revealed: tuple[Literal["tag1"], A] + reveal_type(x[1]) # revealed: A + else: + reveal_type(x) # revealed: tuple[Literal["tag2"], B, C] + reveal_type(x[1]) # revealed: B + reveal_type(x[2]) # revealed: C + +def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]): + if x[0] != "tag1": + reveal_type(x) # revealed: tuple[Literal["tag2"], B, C] + else: + reveal_type(x) # revealed: tuple[Literal["tag1"], A] + +# With int literals +def _(x: tuple[Literal[1], A] | tuple[Literal[2], B]): + if x[0] == 1: + reveal_type(x) # revealed: tuple[Literal[1], A] + else: + reveal_type(x) # revealed: tuple[Literal[2], B] + +# With bytes literals +def _(x: tuple[Literal[b"a"], A] | tuple[Literal[b"b"], B]): + if x[0] == b"a": + reveal_type(x) # revealed: tuple[Literal[b"a"], A] + else: + reveal_type(x) # revealed: tuple[Literal[b"b"], B] + +# Multiple tuple variants +def _(x: tuple[Literal["a"], A] | tuple[Literal["b"], B] | tuple[Literal["c"], C]): + if x[0] == "a": + reveal_type(x) # revealed: tuple[Literal["a"], A] + elif x[0] == "b": + reveal_type(x) # revealed: tuple[Literal["b"], B] + else: + reveal_type(x) # revealed: tuple[Literal["c"], C] + +# Using index 1 instead of 0 +def _(x: tuple[A, Literal["tag1"]] | tuple[B, Literal["tag2"]]): + if x[1] == "tag1": + reveal_type(x) # revealed: tuple[A, Literal["tag1"]] + else: + reveal_type(x) # revealed: tuple[B, Literal["tag2"]] +``` + +Narrowing is restricted to `Literal` tag elements. If any tuple has a non-literal type at the +discriminating index, we can't safely narrow with equality: + +```py +def _(x: tuple[Literal["tag1"], A] | tuple[str, B]): + # Can't narrow because second tuple has `str` (not literal) at index 0 + if x[0] == "tag1": + reveal_type(x) # revealed: tuple[Literal["tag1"], A] | tuple[str, B] + else: + # But we *can* narrow with inequality + reveal_type(x) # revealed: tuple[str, B] +``` + +If the index is out of bounds for any tuple in the union, we also skip narrowing (a diagnostic will +be emitted elsewhere for the out-of-bounds access): + +```py +def _(x: tuple[A, Literal["a"]] | tuple[B]): + # error: [index-out-of-bounds] + if x[1] == "a": + # Can't narrow because index 1 is out of bounds for second tuple + reveal_type(x) # revealed: tuple[A, Literal["a"]] | tuple[B] + else: + reveal_type(x) # revealed: tuple[A, Literal["a"]] | tuple[B] +``` + +We can still narrow tuples when non-tuple types are present in the union: + +```py +def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B] | list[int]): + if x[0] == "tag1": + # A list of ints could have int subclasses in it, + # and int subclasses could have custom `__eq__` methods such that they + # compare equal to `"tag1"`, so `list[int]` cannot be narrowed out of this + # union. + reveal_type(x) # revealed: tuple[Literal["tag1"], A] | list[int] +``` + ### String subscript ```py diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index f0c107851b..357c20155e 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -375,3 +375,70 @@ try: except ValueError: pass ``` + +## Narrowing tagged unions of tuples + +Narrow unions of tuples based on literal tag elements in `match` statements: + +```py +from typing import Literal + +class A: ... +class B: ... +class C: ... + +def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]): + match x[0]: + case "tag1": + reveal_type(x) # revealed: tuple[Literal["tag1"], A] + reveal_type(x[1]) # revealed: A + case "tag2": + reveal_type(x) # revealed: tuple[Literal["tag2"], B, C] + reveal_type(x[1]) # revealed: B + reveal_type(x[2]) # revealed: C + case _: + reveal_type(x) # revealed: Never + +# With int literals +def _(x: tuple[Literal[1], A] | tuple[Literal[2], B]): + match x[0]: + case 1: + reveal_type(x) # revealed: tuple[Literal[1], A] + case 2: + reveal_type(x) # revealed: tuple[Literal[2], B] + case _: + reveal_type(x) # revealed: Never + +# With bytes literals +def _(x: tuple[Literal[b"a"], A] | tuple[Literal[b"b"], B]): + match x[0]: + case b"a": + reveal_type(x) # revealed: tuple[Literal[b"a"], A] + case b"b": + reveal_type(x) # revealed: tuple[Literal[b"b"], B] + case _: + reveal_type(x) # revealed: Never + +# Using index 1 instead of 0 +def _(x: tuple[A, Literal["tag1"]] | tuple[B, Literal["tag2"]]): + match x[1]: + case "tag1": + reveal_type(x) # revealed: tuple[A, Literal["tag1"]] + case "tag2": + reveal_type(x) # revealed: tuple[B, Literal["tag2"]] + case _: + reveal_type(x) # revealed: Never +``` + +Narrowing is restricted to `Literal` tag elements: + +```py +def _(x: tuple[Literal["tag1"], A] | tuple[str, B]): + match x[0]: + case "tag1": + # Can't narrow because second tuple has `str` (not literal) at index 0 + reveal_type(x) # revealed: tuple[Literal["tag1"], A] | tuple[str, B] + case _: + # But we *can* narrow with inequality + reveal_type(x) # revealed: tuple[str, B] +``` diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index e1c840a008..19a0087509 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1091,6 +1091,21 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) { constraints.insert(place, constraint); } + + // Narrow tagged unions of tuples with `Literal` elements, for example: + // + // def _(t: tuple[Literal["a"], A] | tuple[Literal["b"], B]): + // if t[0] == "a": + // reveal_type(t) # tuple[Literal["a"], A] + if let Some((place, constraint)) = self.narrow_tuple_subscript( + inference.expression_type(&*subscript.value), + &subscript.value, + inference.expression_type(&*subscript.slice), + inference.expression_type(&comparators[0]), + constrain_with_equality, + ) { + constraints.insert(place, constraint); + } } let mut last_rhs_ty: Option = None; @@ -1416,6 +1431,16 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) { constraints.insert(place, constraint); } + // Narrow tagged unions of tuples with `Literal` elements, just like `if` statements. + else if let Some((place, constraint)) = self.narrow_tuple_subscript( + inference.expression_type(&*subscript.value), + &subscript.value, + inference.expression_type(&*subscript.slice), + value_ty, + is_positive, + ) { + constraints.insert(place, constraint); + } } Some(constraints) @@ -1524,7 +1549,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let Type::StringLiteral(key_literal) = subscript_key_type else { return None; }; - if !is_supported_typeddict_tag_literal(rhs_type) { + if !is_supported_tag_literal(rhs_type) { return None; } @@ -1567,6 +1592,93 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let place = self.expect_place(&subscript_place_expr); Some((place, NarrowingConstraint::regular(intersection))) } + + /// Narrow tagged unions of tuples with `Literal` elements. + /// + /// Given a subscript expression like `t[0]` where `t` is a union of tuple types, and a + /// comparison value like `"foo"`, this method creates a constraint on `t` that narrows it + /// based on the element value at that index. + /// + /// For example: + /// ```python + /// def _(t: tuple[Literal["a"], A] | tuple[Literal["b"], B]): + /// if t[0] == "a": + /// reveal_type(t) # tuple[Literal["a"], A] + /// ``` + /// + /// Returns `Some((place, constraint))` if narrowing is possible, `None` otherwise. + fn narrow_tuple_subscript( + &self, + subscript_value_type: Type<'db>, + subscript_value_expr: &ast::Expr, + subscript_index_type: Type<'db>, + rhs_type: Type<'db>, + constrain_with_equality: bool, + ) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> { + // We need a union type for narrowing to be useful. + let Type::Union(union) = subscript_value_type else { + return None; + }; + + // The subscript index must be an integer literal. + let Type::IntLiteral(index) = subscript_index_type else { + return None; + }; + let index = i32::try_from(index).ok()?; + + // The comparison value must be a supported literal type. + if !is_supported_tag_literal(rhs_type) { + return None; + } + + let subscript_place_expr = place_expr(subscript_value_expr)?; + + // Skip narrowing if any tuple in the union has an out-of-bounds index. + // A diagnostic will be emitted elsewhere for the out-of-bounds access. + if any_tuple_has_out_of_bounds_index(self.db, union, index) { + return None; + } + + // For equality constraints, all matching elements must have literal types to safely narrow. + // For inequality constraints, we can narrow even with non-literal element types. + if constrain_with_equality + && !all_matching_tuple_elements_have_literal_types(self.db, union, index) + { + return None; + } + + // Filter the union based on whether each tuple element at the index could match the rhs. + let filtered: Vec<_> = union + .elements(self.db) + .iter() + .filter(|elem| { + elem.as_nominal_instance() + .and_then(|inst| inst.tuple_spec(self.db)) + .and_then(|spec| spec.py_index(self.db, index).ok()) + .is_none_or(|el_ty| { + if constrain_with_equality { + // Keep tuples where element could be equal to rhs. + !el_ty.is_disjoint_from(self.db, rhs_type) + } else { + // Keep tuples where element is not always equal to rhs. + !el_ty.is_subtype_of(self.db, rhs_type) + } + }) + }) + .copied() + .collect(); + + // Only create a constraint if we actually narrowed something. + if filtered.len() < union.elements(self.db).len() { + let place = self.expect_place(&subscript_place_expr); + Some(( + place, + NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)), + )) + } else { + None + } + } } // Return true if the given type is a `TypedDict`, or if it's a union that includes at least one @@ -1590,7 +1702,7 @@ fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> } } -fn is_supported_typeddict_tag_literal(ty: Type) -> bool { +fn is_supported_tag_literal(ty: Type) -> bool { matches!( ty, // TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like @@ -1610,7 +1722,7 @@ fn all_matching_typeddict_fields_have_literal_types<'db>( typeddict .items(db) .get(field_name) - .is_none_or(|field| is_supported_typeddict_tag_literal(field.declared_ty)) + .is_none_or(|field| is_supported_tag_literal(field.declared_ty)) }; match ty { @@ -1636,3 +1748,38 @@ fn all_matching_typeddict_fields_have_literal_types<'db>( _ => true, } } + +/// Check if any tuple in the union has an out-of-bounds index. +/// +/// If the index is out of bounds for any tuple, we should skip narrowing entirely +/// since a diagnostic will be emitted elsewhere for the out-of-bounds access. +fn any_tuple_has_out_of_bounds_index<'db>( + db: &'db dyn Db, + union: UnionType<'db>, + index: i32, +) -> bool { + union.elements(db).iter().any(|elem| { + elem.as_nominal_instance() + .and_then(|inst| inst.tuple_spec(db)) + .is_some_and(|spec| spec.py_index(db, index).is_err()) + }) +} + +/// Check that all tuple elements at the given index have literal types. +/// +/// For equality narrowing to be safe, we need to ensure that the element types +/// at the discriminating index are literals (which have well-defined equality). +/// Non-literal types (like `str` or `int`) could have subclasses that override +/// `__eq__` in unexpected ways. +fn all_matching_tuple_elements_have_literal_types<'db>( + db: &'db dyn Db, + union: UnionType<'db>, + index: i32, +) -> bool { + union.elements(db).iter().all(|elem| { + elem.as_nominal_instance() + .and_then(|inst| inst.tuple_spec(db)) + .and_then(|spec| spec.py_index(db, index).ok()) + .is_none_or(is_supported_tag_literal) + }) +}