mirror of
https://github.com/astral-sh/ruff
synced 2026-01-22 05:51:03 -05:00
[ty] Support narrowing for tuple matches with literal elements (#22303)
## Summary See: https://github.com/astral-sh/ruff/pull/22299#issuecomment-3699913849.
This commit is contained in:
@@ -279,6 +279,101 @@ def _(t9: tuple[int | None, str] | tuple[str, int]):
|
||||
reveal_type(t9) # revealed: tuple[int | None, str] | tuple[str, int]
|
||||
```
|
||||
|
||||
### Tagged unions of tuples (equality narrowing)
|
||||
|
||||
Narrow unions of tuples based on literal tag elements using `==` comparison:
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
|
||||
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]):
|
||||
if x[0] == "tag1":
|
||||
reveal_type(x) # revealed: tuple[Literal["tag1"], A]
|
||||
reveal_type(x[1]) # revealed: A
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[Literal["tag2"], B, C]
|
||||
reveal_type(x[1]) # revealed: B
|
||||
reveal_type(x[2]) # revealed: C
|
||||
|
||||
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]):
|
||||
if x[0] != "tag1":
|
||||
reveal_type(x) # revealed: tuple[Literal["tag2"], B, C]
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[Literal["tag1"], A]
|
||||
|
||||
# With int literals
|
||||
def _(x: tuple[Literal[1], A] | tuple[Literal[2], B]):
|
||||
if x[0] == 1:
|
||||
reveal_type(x) # revealed: tuple[Literal[1], A]
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[Literal[2], B]
|
||||
|
||||
# With bytes literals
|
||||
def _(x: tuple[Literal[b"a"], A] | tuple[Literal[b"b"], B]):
|
||||
if x[0] == b"a":
|
||||
reveal_type(x) # revealed: tuple[Literal[b"a"], A]
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[Literal[b"b"], B]
|
||||
|
||||
# Multiple tuple variants
|
||||
def _(x: tuple[Literal["a"], A] | tuple[Literal["b"], B] | tuple[Literal["c"], C]):
|
||||
if x[0] == "a":
|
||||
reveal_type(x) # revealed: tuple[Literal["a"], A]
|
||||
elif x[0] == "b":
|
||||
reveal_type(x) # revealed: tuple[Literal["b"], B]
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[Literal["c"], C]
|
||||
|
||||
# Using index 1 instead of 0
|
||||
def _(x: tuple[A, Literal["tag1"]] | tuple[B, Literal["tag2"]]):
|
||||
if x[1] == "tag1":
|
||||
reveal_type(x) # revealed: tuple[A, Literal["tag1"]]
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[B, Literal["tag2"]]
|
||||
```
|
||||
|
||||
Narrowing is restricted to `Literal` tag elements. If any tuple has a non-literal type at the
|
||||
discriminating index, we can't safely narrow with equality:
|
||||
|
||||
```py
|
||||
def _(x: tuple[Literal["tag1"], A] | tuple[str, B]):
|
||||
# Can't narrow because second tuple has `str` (not literal) at index 0
|
||||
if x[0] == "tag1":
|
||||
reveal_type(x) # revealed: tuple[Literal["tag1"], A] | tuple[str, B]
|
||||
else:
|
||||
# But we *can* narrow with inequality
|
||||
reveal_type(x) # revealed: tuple[str, B]
|
||||
```
|
||||
|
||||
If the index is out of bounds for any tuple in the union, we also skip narrowing (a diagnostic will
|
||||
be emitted elsewhere for the out-of-bounds access):
|
||||
|
||||
```py
|
||||
def _(x: tuple[A, Literal["a"]] | tuple[B]):
|
||||
# error: [index-out-of-bounds]
|
||||
if x[1] == "a":
|
||||
# Can't narrow because index 1 is out of bounds for second tuple
|
||||
reveal_type(x) # revealed: tuple[A, Literal["a"]] | tuple[B]
|
||||
else:
|
||||
reveal_type(x) # revealed: tuple[A, Literal["a"]] | tuple[B]
|
||||
```
|
||||
|
||||
We can still narrow tuples when non-tuple types are present in the union:
|
||||
|
||||
```py
|
||||
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B] | list[int]):
|
||||
if x[0] == "tag1":
|
||||
# A list of ints could have int subclasses in it,
|
||||
# and int subclasses could have custom `__eq__` methods such that they
|
||||
# compare equal to `"tag1"`, so `list[int]` cannot be narrowed out of this
|
||||
# union.
|
||||
reveal_type(x) # revealed: tuple[Literal["tag1"], A] | list[int]
|
||||
```
|
||||
|
||||
### String subscript
|
||||
|
||||
```py
|
||||
|
||||
@@ -375,3 +375,70 @@ try:
|
||||
except ValueError:
|
||||
pass
|
||||
```
|
||||
|
||||
## Narrowing tagged unions of tuples
|
||||
|
||||
Narrow unions of tuples based on literal tag elements in `match` statements:
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
|
||||
def _(x: tuple[Literal["tag1"], A] | tuple[Literal["tag2"], B, C]):
|
||||
match x[0]:
|
||||
case "tag1":
|
||||
reveal_type(x) # revealed: tuple[Literal["tag1"], A]
|
||||
reveal_type(x[1]) # revealed: A
|
||||
case "tag2":
|
||||
reveal_type(x) # revealed: tuple[Literal["tag2"], B, C]
|
||||
reveal_type(x[1]) # revealed: B
|
||||
reveal_type(x[2]) # revealed: C
|
||||
case _:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
# With int literals
|
||||
def _(x: tuple[Literal[1], A] | tuple[Literal[2], B]):
|
||||
match x[0]:
|
||||
case 1:
|
||||
reveal_type(x) # revealed: tuple[Literal[1], A]
|
||||
case 2:
|
||||
reveal_type(x) # revealed: tuple[Literal[2], B]
|
||||
case _:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
# With bytes literals
|
||||
def _(x: tuple[Literal[b"a"], A] | tuple[Literal[b"b"], B]):
|
||||
match x[0]:
|
||||
case b"a":
|
||||
reveal_type(x) # revealed: tuple[Literal[b"a"], A]
|
||||
case b"b":
|
||||
reveal_type(x) # revealed: tuple[Literal[b"b"], B]
|
||||
case _:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
# Using index 1 instead of 0
|
||||
def _(x: tuple[A, Literal["tag1"]] | tuple[B, Literal["tag2"]]):
|
||||
match x[1]:
|
||||
case "tag1":
|
||||
reveal_type(x) # revealed: tuple[A, Literal["tag1"]]
|
||||
case "tag2":
|
||||
reveal_type(x) # revealed: tuple[B, Literal["tag2"]]
|
||||
case _:
|
||||
reveal_type(x) # revealed: Never
|
||||
```
|
||||
|
||||
Narrowing is restricted to `Literal` tag elements:
|
||||
|
||||
```py
|
||||
def _(x: tuple[Literal["tag1"], A] | tuple[str, B]):
|
||||
match x[0]:
|
||||
case "tag1":
|
||||
# Can't narrow because second tuple has `str` (not literal) at index 0
|
||||
reveal_type(x) # revealed: tuple[Literal["tag1"], A] | tuple[str, B]
|
||||
case _:
|
||||
# But we *can* narrow with inequality
|
||||
reveal_type(x) # revealed: tuple[str, B]
|
||||
```
|
||||
|
||||
@@ -1091,6 +1091,21 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
) {
|
||||
constraints.insert(place, constraint);
|
||||
}
|
||||
|
||||
// Narrow tagged unions of tuples with `Literal` elements, for example:
|
||||
//
|
||||
// def _(t: tuple[Literal["a"], A] | tuple[Literal["b"], B]):
|
||||
// if t[0] == "a":
|
||||
// reveal_type(t) # tuple[Literal["a"], A]
|
||||
if let Some((place, constraint)) = self.narrow_tuple_subscript(
|
||||
inference.expression_type(&*subscript.value),
|
||||
&subscript.value,
|
||||
inference.expression_type(&*subscript.slice),
|
||||
inference.expression_type(&comparators[0]),
|
||||
constrain_with_equality,
|
||||
) {
|
||||
constraints.insert(place, constraint);
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_rhs_ty: Option<Type> = None;
|
||||
@@ -1416,6 +1431,16 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
) {
|
||||
constraints.insert(place, constraint);
|
||||
}
|
||||
// Narrow tagged unions of tuples with `Literal` elements, just like `if` statements.
|
||||
else if let Some((place, constraint)) = self.narrow_tuple_subscript(
|
||||
inference.expression_type(&*subscript.value),
|
||||
&subscript.value,
|
||||
inference.expression_type(&*subscript.slice),
|
||||
value_ty,
|
||||
is_positive,
|
||||
) {
|
||||
constraints.insert(place, constraint);
|
||||
}
|
||||
}
|
||||
|
||||
Some(constraints)
|
||||
@@ -1524,7 +1549,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
let Type::StringLiteral(key_literal) = subscript_key_type else {
|
||||
return None;
|
||||
};
|
||||
if !is_supported_typeddict_tag_literal(rhs_type) {
|
||||
if !is_supported_tag_literal(rhs_type) {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -1567,6 +1592,93 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
Some((place, NarrowingConstraint::regular(intersection)))
|
||||
}
|
||||
|
||||
/// Narrow tagged unions of tuples with `Literal` elements.
|
||||
///
|
||||
/// Given a subscript expression like `t[0]` where `t` is a union of tuple types, and a
|
||||
/// comparison value like `"foo"`, this method creates a constraint on `t` that narrows it
|
||||
/// based on the element value at that index.
|
||||
///
|
||||
/// For example:
|
||||
/// ```python
|
||||
/// def _(t: tuple[Literal["a"], A] | tuple[Literal["b"], B]):
|
||||
/// if t[0] == "a":
|
||||
/// reveal_type(t) # tuple[Literal["a"], A]
|
||||
/// ```
|
||||
///
|
||||
/// Returns `Some((place, constraint))` if narrowing is possible, `None` otherwise.
|
||||
fn narrow_tuple_subscript(
|
||||
&self,
|
||||
subscript_value_type: Type<'db>,
|
||||
subscript_value_expr: &ast::Expr,
|
||||
subscript_index_type: Type<'db>,
|
||||
rhs_type: Type<'db>,
|
||||
constrain_with_equality: bool,
|
||||
) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> {
|
||||
// We need a union type for narrowing to be useful.
|
||||
let Type::Union(union) = subscript_value_type else {
|
||||
return None;
|
||||
};
|
||||
|
||||
// The subscript index must be an integer literal.
|
||||
let Type::IntLiteral(index) = subscript_index_type else {
|
||||
return None;
|
||||
};
|
||||
let index = i32::try_from(index).ok()?;
|
||||
|
||||
// The comparison value must be a supported literal type.
|
||||
if !is_supported_tag_literal(rhs_type) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let subscript_place_expr = place_expr(subscript_value_expr)?;
|
||||
|
||||
// Skip narrowing if any tuple in the union has an out-of-bounds index.
|
||||
// A diagnostic will be emitted elsewhere for the out-of-bounds access.
|
||||
if any_tuple_has_out_of_bounds_index(self.db, union, index) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// For equality constraints, all matching elements must have literal types to safely narrow.
|
||||
// For inequality constraints, we can narrow even with non-literal element types.
|
||||
if constrain_with_equality
|
||||
&& !all_matching_tuple_elements_have_literal_types(self.db, union, index)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Filter the union based on whether each tuple element at the index could match the rhs.
|
||||
let filtered: Vec<_> = union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.filter(|elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(self.db))
|
||||
.and_then(|spec| spec.py_index(self.db, index).ok())
|
||||
.is_none_or(|el_ty| {
|
||||
if constrain_with_equality {
|
||||
// Keep tuples where element could be equal to rhs.
|
||||
!el_ty.is_disjoint_from(self.db, rhs_type)
|
||||
} else {
|
||||
// Keep tuples where element is not always equal to rhs.
|
||||
!el_ty.is_subtype_of(self.db, rhs_type)
|
||||
}
|
||||
})
|
||||
})
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// Only create a constraint if we actually narrowed something.
|
||||
if filtered.len() < union.elements(self.db).len() {
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
Some((
|
||||
place,
|
||||
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return true if the given type is a `TypedDict`, or if it's a union that includes at least one
|
||||
@@ -1590,7 +1702,7 @@ fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) ->
|
||||
}
|
||||
}
|
||||
|
||||
fn is_supported_typeddict_tag_literal(ty: Type) -> bool {
|
||||
fn is_supported_tag_literal(ty: Type) -> bool {
|
||||
matches!(
|
||||
ty,
|
||||
// TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like
|
||||
@@ -1610,7 +1722,7 @@ fn all_matching_typeddict_fields_have_literal_types<'db>(
|
||||
typeddict
|
||||
.items(db)
|
||||
.get(field_name)
|
||||
.is_none_or(|field| is_supported_typeddict_tag_literal(field.declared_ty))
|
||||
.is_none_or(|field| is_supported_tag_literal(field.declared_ty))
|
||||
};
|
||||
|
||||
match ty {
|
||||
@@ -1636,3 +1748,38 @@ fn all_matching_typeddict_fields_have_literal_types<'db>(
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if any tuple in the union has an out-of-bounds index.
|
||||
///
|
||||
/// If the index is out of bounds for any tuple, we should skip narrowing entirely
|
||||
/// since a diagnostic will be emitted elsewhere for the out-of-bounds access.
|
||||
fn any_tuple_has_out_of_bounds_index<'db>(
|
||||
db: &'db dyn Db,
|
||||
union: UnionType<'db>,
|
||||
index: i32,
|
||||
) -> bool {
|
||||
union.elements(db).iter().any(|elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(db))
|
||||
.is_some_and(|spec| spec.py_index(db, index).is_err())
|
||||
})
|
||||
}
|
||||
|
||||
/// Check that all tuple elements at the given index have literal types.
|
||||
///
|
||||
/// For equality narrowing to be safe, we need to ensure that the element types
|
||||
/// at the discriminating index are literals (which have well-defined equality).
|
||||
/// Non-literal types (like `str` or `int`) could have subclasses that override
|
||||
/// `__eq__` in unexpected ways.
|
||||
fn all_matching_tuple_elements_have_literal_types<'db>(
|
||||
db: &'db dyn Db,
|
||||
union: UnionType<'db>,
|
||||
index: i32,
|
||||
) -> bool {
|
||||
union.elements(db).iter().all(|elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(db))
|
||||
.and_then(|spec| spec.py_index(db, index).ok())
|
||||
.is_none_or(is_supported_tag_literal)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user