[ty] Support narrowing for tuple matches with literal elements (#22303)

## Summary

See:
https://github.com/astral-sh/ruff/pull/22299#issuecomment-3699913849.
This commit is contained in:
Charlie Marsh
2025-12-30 13:45:07 -05:00
committed by GitHub
parent e0e1e9535e
commit 12dd27da52
3 changed files with 312 additions and 3 deletions

View File

@@ -279,6 +279,101 @@ def _(t9: tuple[int | None, str] | tuple[str, int]):
reveal_type(t9) # revealed: tuple[int | None, str] | tuple[str, int]
```
### Tagged unions of tuples (equality narrowing)
Narrow unions of tuples based on literal tag elements using `==` comparison:
```py
from typing import Literal
class A: ...
class B: ...
class C: ...
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]):
if x[0] == "tag1":
reveal_type(x) # revealed: tuple[Literal["tag1"], A]
reveal_type(x[1]) # revealed: A
else:
reveal_type(x) # revealed: tuple[Literal["tag2"], B, C]
reveal_type(x[1]) # revealed: B
reveal_type(x[2]) # revealed: C
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]):
if x[0] != "tag1":
reveal_type(x) # revealed: tuple[Literal["tag2"], B, C]
else:
reveal_type(x) # revealed: tuple[Literal["tag1"], A]
# With int literals
def _(x: tuple[Literal[1], A] | tuple[Literal[2], B]):
if x[0] == 1:
reveal_type(x) # revealed: tuple[Literal[1], A]
else:
reveal_type(x) # revealed: tuple[Literal[2], B]
# With bytes literals
def _(x: tuple[Literal[b"a"], A] | tuple[Literal[b"b"], B]):
if x[0] == b"a":
reveal_type(x) # revealed: tuple[Literal[b"a"], A]
else:
reveal_type(x) # revealed: tuple[Literal[b"b"], B]
# Multiple tuple variants
def _(x: tuple[Literal["a"], A] | tuple[Literal["b"], B] | tuple[Literal["c"], C]):
if x[0] == "a":
reveal_type(x) # revealed: tuple[Literal["a"], A]
elif x[0] == "b":
reveal_type(x) # revealed: tuple[Literal["b"], B]
else:
reveal_type(x) # revealed: tuple[Literal["c"], C]
# Using index 1 instead of 0
def _(x: tuple[A, Literal["tag1"]] | tuple[B, Literal["tag2"]]):
if x[1] == "tag1":
reveal_type(x) # revealed: tuple[A, Literal["tag1"]]
else:
reveal_type(x) # revealed: tuple[B, Literal["tag2"]]
```
Narrowing is restricted to `Literal` tag elements. If any tuple has a non-literal type at the
discriminating index, we can't safely narrow with equality:
```py
def _(x: tuple[Literal["tag1"], A] | tuple[str, B]):
# Can't narrow because second tuple has `str` (not literal) at index 0
if x[0] == "tag1":
reveal_type(x) # revealed: tuple[Literal["tag1"], A] | tuple[str, B]
else:
# But we *can* narrow with inequality
reveal_type(x) # revealed: tuple[str, B]
```
If the index is out of bounds for any tuple in the union, we also skip narrowing (a diagnostic will
be emitted elsewhere for the out-of-bounds access):
```py
def _(x: tuple[A, Literal["a"]] | tuple[B]):
# error: [index-out-of-bounds]
if x[1] == "a":
# Can't narrow because index 1 is out of bounds for second tuple
reveal_type(x) # revealed: tuple[A, Literal["a"]] | tuple[B]
else:
reveal_type(x) # revealed: tuple[A, Literal["a"]] | tuple[B]
```
We can still narrow tuples when non-tuple types are present in the union:
```py
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B] | list[int]):
if x[0] == "tag1":
# A list of ints could have int subclasses in it,
# and int subclasses could have custom `__eq__` methods such that they
# compare equal to `"tag1"`, so `list[int]` cannot be narrowed out of this
# union.
reveal_type(x) # revealed: tuple[Literal["tag1"], A] | list[int]
```
### String subscript
```py

View File

@@ -375,3 +375,70 @@ try:
except ValueError:
pass
```
## Narrowing tagged unions of tuples
Narrow unions of tuples based on literal tag elements in `match` statements:
```py
from typing import Literal
class A: ...
class B: ...
class C: ...
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]):
match x[0]:
case "tag1":
reveal_type(x) # revealed: tuple[Literal["tag1"], A]
reveal_type(x[1]) # revealed: A
case "tag2":
reveal_type(x) # revealed: tuple[Literal["tag2"], B, C]
reveal_type(x[1]) # revealed: B
reveal_type(x[2]) # revealed: C
case _:
reveal_type(x) # revealed: Never
# With int literals
def _(x: tuple[Literal[1], A] | tuple[Literal[2], B]):
match x[0]:
case 1:
reveal_type(x) # revealed: tuple[Literal[1], A]
case 2:
reveal_type(x) # revealed: tuple[Literal[2], B]
case _:
reveal_type(x) # revealed: Never
# With bytes literals
def _(x: tuple[Literal[b"a"], A] | tuple[Literal[b"b"], B]):
match x[0]:
case b"a":
reveal_type(x) # revealed: tuple[Literal[b"a"], A]
case b"b":
reveal_type(x) # revealed: tuple[Literal[b"b"], B]
case _:
reveal_type(x) # revealed: Never
# Using index 1 instead of 0
def _(x: tuple[A, Literal["tag1"]] | tuple[B, Literal["tag2"]]):
match x[1]:
case "tag1":
reveal_type(x) # revealed: tuple[A, Literal["tag1"]]
case "tag2":
reveal_type(x) # revealed: tuple[B, Literal["tag2"]]
case _:
reveal_type(x) # revealed: Never
```
Narrowing is restricted to `Literal` tag elements:
```py
def _(x: tuple[Literal["tag1"], A] | tuple[str, B]):
match x[0]:
case "tag1":
# Can't narrow because second tuple has `str` (not literal) at index 0
reveal_type(x) # revealed: tuple[Literal["tag1"], A] | tuple[str, B]
case _:
# But we *can* narrow with inequality
reveal_type(x) # revealed: tuple[str, B]
```

View File

@@ -1091,6 +1091,21 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) {
constraints.insert(place, constraint);
}
// Narrow tagged unions of tuples with `Literal` elements, for example:
//
// def _(t: tuple[Literal["a"], A] | tuple[Literal["b"], B]):
// if t[0] == "a":
// reveal_type(t) # tuple[Literal["a"], A]
if let Some((place, constraint)) = self.narrow_tuple_subscript(
inference.expression_type(&*subscript.value),
&subscript.value,
inference.expression_type(&*subscript.slice),
inference.expression_type(&comparators[0]),
constrain_with_equality,
) {
constraints.insert(place, constraint);
}
}
let mut last_rhs_ty: Option<Type> = None;
@@ -1416,6 +1431,16 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) {
constraints.insert(place, constraint);
}
// Narrow tagged unions of tuples with `Literal` elements, just like `if` statements.
else if let Some((place, constraint)) = self.narrow_tuple_subscript(
inference.expression_type(&*subscript.value),
&subscript.value,
inference.expression_type(&*subscript.slice),
value_ty,
is_positive,
) {
constraints.insert(place, constraint);
}
}
Some(constraints)
@@ -1524,7 +1549,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let Type::StringLiteral(key_literal) = subscript_key_type else {
return None;
};
if !is_supported_typeddict_tag_literal(rhs_type) {
if !is_supported_tag_literal(rhs_type) {
return None;
}
@@ -1567,6 +1592,93 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let place = self.expect_place(&subscript_place_expr);
Some((place, NarrowingConstraint::regular(intersection)))
}
/// Narrow tagged unions of tuples with `Literal` elements.
///
/// Given a subscript expression like `t[0]` where `t` is a union of tuple types, and a
/// comparison value like `"foo"`, this method creates a constraint on `t` that narrows it
/// based on the element value at that index.
///
/// For example:
/// ```python
/// def _(t: tuple[Literal["a"], A] | tuple[Literal["b"], B]):
/// if t[0] == "a":
/// reveal_type(t) # tuple[Literal["a"], A]
/// ```
///
/// Returns `Some((place, constraint))` if narrowing is possible, `None` otherwise.
fn narrow_tuple_subscript(
&self,
subscript_value_type: Type<'db>,
subscript_value_expr: &ast::Expr,
subscript_index_type: Type<'db>,
rhs_type: Type<'db>,
constrain_with_equality: bool,
) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> {
// We need a union type for narrowing to be useful.
let Type::Union(union) = subscript_value_type else {
return None;
};
// The subscript index must be an integer literal.
let Type::IntLiteral(index) = subscript_index_type else {
return None;
};
let index = i32::try_from(index).ok()?;
// The comparison value must be a supported literal type.
if !is_supported_tag_literal(rhs_type) {
return None;
}
let subscript_place_expr = place_expr(subscript_value_expr)?;
// Skip narrowing if any tuple in the union has an out-of-bounds index.
// A diagnostic will be emitted elsewhere for the out-of-bounds access.
if any_tuple_has_out_of_bounds_index(self.db, union, index) {
return None;
}
// For equality constraints, all matching elements must have literal types to safely narrow.
// For inequality constraints, we can narrow even with non-literal element types.
if constrain_with_equality
&& !all_matching_tuple_elements_have_literal_types(self.db, union, index)
{
return None;
}
// Filter the union based on whether each tuple element at the index could match the rhs.
let filtered: Vec<_> = union
.elements(self.db)
.iter()
.filter(|elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(self.db))
.and_then(|spec| spec.py_index(self.db, index).ok())
.is_none_or(|el_ty| {
if constrain_with_equality {
// Keep tuples where element could be equal to rhs.
!el_ty.is_disjoint_from(self.db, rhs_type)
} else {
// Keep tuples where element is not always equal to rhs.
!el_ty.is_subtype_of(self.db, rhs_type)
}
})
})
.copied()
.collect();
// Only create a constraint if we actually narrowed something.
if filtered.len() < union.elements(self.db).len() {
let place = self.expect_place(&subscript_place_expr);
Some((
place,
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
))
} else {
None
}
}
}
// Return true if the given type is a `TypedDict`, or if it's a union that includes at least one
@@ -1590,7 +1702,7 @@ fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) ->
}
}
fn is_supported_typeddict_tag_literal(ty: Type) -> bool {
fn is_supported_tag_literal(ty: Type) -> bool {
matches!(
ty,
// TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like
@@ -1610,7 +1722,7 @@ fn all_matching_typeddict_fields_have_literal_types<'db>(
typeddict
.items(db)
.get(field_name)
.is_none_or(|field| is_supported_typeddict_tag_literal(field.declared_ty))
.is_none_or(|field| is_supported_tag_literal(field.declared_ty))
};
match ty {
@@ -1636,3 +1748,38 @@ fn all_matching_typeddict_fields_have_literal_types<'db>(
_ => true,
}
}
/// Check if any tuple in the union has an out-of-bounds index.
///
/// If the index is out of bounds for any tuple, we should skip narrowing entirely
/// since a diagnostic will be emitted elsewhere for the out-of-bounds access.
fn any_tuple_has_out_of_bounds_index<'db>(
db: &'db dyn Db,
union: UnionType<'db>,
index: i32,
) -> bool {
union.elements(db).iter().any(|elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(db))
.is_some_and(|spec| spec.py_index(db, index).is_err())
})
}
/// Check that all tuple elements at the given index have literal types.
///
/// For equality narrowing to be safe, we need to ensure that the element types
/// at the discriminating index are literals (which have well-defined equality).
/// Non-literal types (like `str` or `int`) could have subclasses that override
/// `__eq__` in unexpected ways.
fn all_matching_tuple_elements_have_literal_types<'db>(
db: &'db dyn Db,
union: UnionType<'db>,
index: i32,
) -> bool {
union.elements(db).iter().all(|elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(db))
.and_then(|spec| spec.py_index(db, index).ok())
.is_none_or(is_supported_tag_literal)
})
}