[ty] Exclude parameterized tuple types from narrowing when disjoint from comparison values (#22129)

## Summary

IIUC, tuples with a known structure (`tuple_spec`) use the standard
tuple `__eq__` which only returns `True` for other tuples, so they can
be safely excluded when disjoint from string literals or other non-tuple
types.

Closes https://github.com/astral-sh/ty/issues/2140.
This commit is contained in:
Charlie Marsh
2025-12-22 15:44:49 -05:00
committed by GitHub
parent 4a937543b9
commit 664686bdbc
3 changed files with 84 additions and 41 deletions

View File

@@ -238,3 +238,20 @@ def _(s: LiteralString | None, t: LiteralString | Any):
# TODO could be `Literal["foo"] | Any`
reveal_type(t) # revealed: LiteralString | Any
```
## Narrowing with tuple types
We assume that tuple subclasses don't override `tuple.__eq__`, which only returns True for other
tuples. So they are excluded from the narrowed type when comparing to non-tuple values.
```py
from typing import Literal
def _(x: Literal["a", "b"] | tuple[int, int]):
if x == "a":
# tuple type is excluded because it's disjoint from the string literal
reveal_type(x) # revealed: Literal["a"]
else:
# tuple type remains in the else branch
reveal_type(x) # revealed: Literal["b"] | tuple[int, int]
```

View File

@@ -191,3 +191,20 @@ def test(x: Status | int):
else:
reveal_type(x) # revealed: Literal[Status.REJECTED] | int
```
## Union with tuple and `Literal`
We assume that tuple subclasses don't override `tuple.__eq__`, which only returns True for other
tuples. So they are excluded from the narrowed type when disjoint from the RHS values.
```py
from typing import Literal
def test(x: Literal["none", "auto", "required"] | tuple[list[str], Literal["auto", "required"]]):
if x in ("auto", "required"):
# tuple type is excluded because it's disjoint from the string literals
reveal_type(x) # revealed: Literal["auto", "required"]
else:
# tuple type remains in the else branch
reveal_type(x) # revealed: Literal["none"] | tuple[list[str], Literal["auto", "required"]]
```

View File

@@ -324,6 +324,39 @@ fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
}
}
/// Return `true` if it is possible for any two inhabitants of the given types to
/// compare equal to each other; otherwise return `false`.
fn could_compare_equal<'db>(db: &'db dyn Db, left_ty: Type<'db>, right_ty: Type<'db>) -> bool {
if !left_ty.is_disjoint_from(db, right_ty) {
// If types overlap, they have inhabitants in common; it's definitely possible
// for an object to compare equal to itself.
return true;
}
match (left_ty, right_ty) {
// In order to be sure a union type cannot compare equal to another type, it
// must be true that no element of the union can compare equal to that type.
(Type::Union(union), _) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, *ty, right_ty)),
(_, Type::Union(union)) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, left_ty, *ty)),
// Boolean literals and int literals are disjoint, and single valued, and yet
// `True == 1` and `False == 0`.
(Type::BooleanLiteral(b), Type::IntLiteral(i))
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
// We assume that tuples use `tuple.__eq__` which only returns True
// for other tuples, so they cannot compare equal to non-tuple types.
(Type::NominalInstance(instance), _) if instance.tuple_spec(db).is_some() => false,
(_, Type::NominalInstance(instance)) if instance.tuple_spec(db).is_some() => false,
// Other than the above cases, two single-valued disjoint types cannot compare
// equal.
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
}
}
struct NarrowingConstraintsBuilder<'db, 'ast> {
db: &'db dyn Db,
module: &'ast ParsedModuleRef,
@@ -573,39 +606,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
// something that is not a type, and handle this not-a-type in `symbol_from_bindings`,
// instead of intersecting with a type.)
// Return `true` if it is possible for any two inhabitants of the given types to
// compare equal to each other; otherwise return `false`.
fn could_compare_equal<'db>(
db: &'db dyn Db,
left_ty: Type<'db>,
right_ty: Type<'db>,
) -> bool {
if !left_ty.is_disjoint_from(db, right_ty) {
// If types overlap, they have inhabitants in common; it's definitely possible
// for an object to compare equal to itself.
return true;
}
match (left_ty, right_ty) {
// In order to be sure a union type cannot compare equal to another type, it
// must be true that no element of the union can compare equal to that type.
(Type::Union(union), _) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, *ty, right_ty)),
(_, Type::Union(union)) => union
.elements(db)
.iter()
.any(|ty| could_compare_equal(db, left_ty, *ty)),
// Boolean literals and int literals are disjoint, and single valued, and yet
// `True == 1` and `False == 0`.
(Type::BooleanLiteral(b), Type::IntLiteral(i))
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
// Other than the above cases, two single-valued disjoint types cannot compare
// equal.
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
}
}
// Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot
// compare equal to `rhs_ty`.
fn can_narrow_to_rhs<'db>(
@@ -660,7 +660,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
)
}
_ => {
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
if !could_compare_equal(db, ty, rhs_ty) {
// Cannot compare equal to rhs, so keep this type
ty
} else {
Type::Never
@@ -721,14 +722,22 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
if let Some(lhs_union) = lhs_ty.as_union() {
for element in lhs_union.elements(self.db) {
// Keep only the non-single-valued portion of the original type.
if !element.is_single_valued(self.db)
&& !element.is_literal_string()
&& !element.is_bool(self.db)
&& (!element.is_enum(self.db) || element.overrides_equality(self.db))
{
builder = builder.add(*element);
// Skip single-valued types (handled via RHS matching).
if element.is_single_valued(self.db) {
continue;
}
// Skip types that are handled specially (LiteralString, bool, enum).
if element.is_literal_string()
|| element.is_bool(self.db)
|| (element.is_enum(self.db) && !element.overrides_equality(self.db))
{
continue;
}
// Skip types that cannot compare equal to any RHS value.
if !could_compare_equal(self.db, *element, rhs_values) {
continue;
}
builder = builder.add(*element);
}
}
Some(builder.build())