mirror of
https://github.com/astral-sh/ruff
synced 2026-01-22 22:10:48 -05:00
[ty] Exclude parameterized tuple types from narrowing when disjoint from comparison values (#22129)
## Summary IIUC, tuples with a known structure (`tuple_spec`) use the standard tuple `__eq__` which only returns `True` for other tuples, so they can be safely excluded when disjoint from string literals or other non-tuple types. Closes https://github.com/astral-sh/ty/issues/2140.
This commit is contained in:
@@ -238,3 +238,20 @@ def _(s: LiteralString | None, t: LiteralString | Any):
|
||||
# TODO could be `Literal["foo"] | Any`
|
||||
reveal_type(t) # revealed: LiteralString | Any
|
||||
```
|
||||
|
||||
## Narrowing with tuple types
|
||||
|
||||
We assume that tuple subclasses don't override `tuple.__eq__`, which only returns True for other
|
||||
tuples. So they are excluded from the narrowed type when comparing to non-tuple values.
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def _(x: Literal["a", "b"] | tuple[int, int]):
|
||||
if x == "a":
|
||||
# tuple type is excluded because it's disjoint from the string literal
|
||||
reveal_type(x) # revealed: Literal["a"]
|
||||
else:
|
||||
# tuple type remains in the else branch
|
||||
reveal_type(x) # revealed: Literal["b"] | tuple[int, int]
|
||||
```
|
||||
|
||||
@@ -191,3 +191,20 @@ def test(x: Status | int):
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[Status.REJECTED] | int
|
||||
```
|
||||
|
||||
## Union with tuple and `Literal`
|
||||
|
||||
We assume that tuple subclasses don't override `tuple.__eq__`, which only returns True for other
|
||||
tuples. So they are excluded from the narrowed type when disjoint from the RHS values.
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def test(x: Literal["none", "auto", "required"] | tuple[list[str], Literal["auto", "required"]]):
|
||||
if x in ("auto", "required"):
|
||||
# tuple type is excluded because it's disjoint from the string literals
|
||||
reveal_type(x) # revealed: Literal["auto", "required"]
|
||||
else:
|
||||
# tuple type remains in the else branch
|
||||
reveal_type(x) # revealed: Literal["none"] | tuple[list[str], Literal["auto", "required"]]
|
||||
```
|
||||
|
||||
@@ -324,6 +324,39 @@ fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return `true` if it is possible for any two inhabitants of the given types to
|
||||
/// compare equal to each other; otherwise return `false`.
|
||||
fn could_compare_equal<'db>(db: &'db dyn Db, left_ty: Type<'db>, right_ty: Type<'db>) -> bool {
|
||||
if !left_ty.is_disjoint_from(db, right_ty) {
|
||||
// If types overlap, they have inhabitants in common; it's definitely possible
|
||||
// for an object to compare equal to itself.
|
||||
return true;
|
||||
}
|
||||
match (left_ty, right_ty) {
|
||||
// In order to be sure a union type cannot compare equal to another type, it
|
||||
// must be true that no element of the union can compare equal to that type.
|
||||
(Type::Union(union), _) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|ty| could_compare_equal(db, *ty, right_ty)),
|
||||
(_, Type::Union(union)) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|ty| could_compare_equal(db, left_ty, *ty)),
|
||||
// Boolean literals and int literals are disjoint, and single valued, and yet
|
||||
// `True == 1` and `False == 0`.
|
||||
(Type::BooleanLiteral(b), Type::IntLiteral(i))
|
||||
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
|
||||
// We assume that tuples use `tuple.__eq__` which only returns True
|
||||
// for other tuples, so they cannot compare equal to non-tuple types.
|
||||
(Type::NominalInstance(instance), _) if instance.tuple_spec(db).is_some() => false,
|
||||
(_, Type::NominalInstance(instance)) if instance.tuple_spec(db).is_some() => false,
|
||||
// Other than the above cases, two single-valued disjoint types cannot compare
|
||||
// equal.
|
||||
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
|
||||
}
|
||||
}
|
||||
|
||||
struct NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
db: &'db dyn Db,
|
||||
module: &'ast ParsedModuleRef,
|
||||
@@ -573,39 +606,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
// something that is not a type, and handle this not-a-type in `symbol_from_bindings`,
|
||||
// instead of intersecting with a type.)
|
||||
|
||||
// Return `true` if it is possible for any two inhabitants of the given types to
|
||||
// compare equal to each other; otherwise return `false`.
|
||||
fn could_compare_equal<'db>(
|
||||
db: &'db dyn Db,
|
||||
left_ty: Type<'db>,
|
||||
right_ty: Type<'db>,
|
||||
) -> bool {
|
||||
if !left_ty.is_disjoint_from(db, right_ty) {
|
||||
// If types overlap, they have inhabitants in common; it's definitely possible
|
||||
// for an object to compare equal to itself.
|
||||
return true;
|
||||
}
|
||||
match (left_ty, right_ty) {
|
||||
// In order to be sure a union type cannot compare equal to another type, it
|
||||
// must be true that no element of the union can compare equal to that type.
|
||||
(Type::Union(union), _) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|ty| could_compare_equal(db, *ty, right_ty)),
|
||||
(_, Type::Union(union)) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|ty| could_compare_equal(db, left_ty, *ty)),
|
||||
// Boolean literals and int literals are disjoint, and single valued, and yet
|
||||
// `True == 1` and `False == 0`.
|
||||
(Type::BooleanLiteral(b), Type::IntLiteral(i))
|
||||
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
|
||||
// Other than the above cases, two single-valued disjoint types cannot compare
|
||||
// equal.
|
||||
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
|
||||
}
|
||||
}
|
||||
|
||||
// Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot
|
||||
// compare equal to `rhs_ty`.
|
||||
fn can_narrow_to_rhs<'db>(
|
||||
@@ -660,7 +660,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
|
||||
if !could_compare_equal(db, ty, rhs_ty) {
|
||||
// Cannot compare equal to rhs, so keep this type
|
||||
ty
|
||||
} else {
|
||||
Type::Never
|
||||
@@ -721,14 +722,22 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
|
||||
if let Some(lhs_union) = lhs_ty.as_union() {
|
||||
for element in lhs_union.elements(self.db) {
|
||||
// Keep only the non-single-valued portion of the original type.
|
||||
if !element.is_single_valued(self.db)
|
||||
&& !element.is_literal_string()
|
||||
&& !element.is_bool(self.db)
|
||||
&& (!element.is_enum(self.db) || element.overrides_equality(self.db))
|
||||
{
|
||||
builder = builder.add(*element);
|
||||
// Skip single-valued types (handled via RHS matching).
|
||||
if element.is_single_valued(self.db) {
|
||||
continue;
|
||||
}
|
||||
// Skip types that are handled specially (LiteralString, bool, enum).
|
||||
if element.is_literal_string()
|
||||
|| element.is_bool(self.db)
|
||||
|| (element.is_enum(self.db) && !element.overrides_equality(self.db))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// Skip types that cannot compare equal to any RHS value.
|
||||
if !could_compare_equal(self.db, *element, rhs_values) {
|
||||
continue;
|
||||
}
|
||||
builder = builder.add(*element);
|
||||
}
|
||||
}
|
||||
Some(builder.build())
|
||||
|
||||
Reference in New Issue
Block a user