mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 05:20:49 -05:00
[ty] Preserve argument signature in @total_ordering (#22496)
## Summary Closes https://github.com/astral-sh/ty/issues/2435.
This commit is contained in:
@@ -38,6 +38,125 @@ reveal_type(s1 > s2) # revealed: bool
|
||||
reveal_type(s1 >= s2) # revealed: bool
|
||||
```
|
||||
|
||||
## Signature derived from source ordering method
|
||||
|
||||
When the source ordering method accepts a broader type (like `object`) for its `other` parameter,
|
||||
the synthesized comparison methods should use the same signature. This allows comparisons with types
|
||||
other than the class itself:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class Comparable:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Comparable):
|
||||
return self.value == other.value
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, Comparable):
|
||||
return self.value < other.value
|
||||
if isinstance(other, int):
|
||||
return self.value < other
|
||||
return NotImplemented
|
||||
|
||||
a = Comparable(10)
|
||||
b = Comparable(20)
|
||||
|
||||
# Comparisons with the same type work.
|
||||
reveal_type(a <= b) # revealed: bool
|
||||
reveal_type(a >= b) # revealed: bool
|
||||
|
||||
# Comparisons with `int` also work because `__lt__` accepts `object`.
|
||||
reveal_type(a <= 15) # revealed: bool
|
||||
reveal_type(a >= 5) # revealed: bool
|
||||
```
|
||||
|
||||
## Multiple ordering methods with different signatures
|
||||
|
||||
When multiple ordering methods are defined with different signatures, the decorator selects a "root"
|
||||
method using the priority order: `__lt__` > `__le__` > `__gt__` > `__ge__`. Synthesized methods use
|
||||
the signature from the highest-priority method. Methods that are explicitly defined are not
|
||||
overridden.
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class MultiSig:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return True
|
||||
# __lt__ accepts `object` (highest priority, used as root)
|
||||
def __lt__(self, other: object) -> bool:
|
||||
return True
|
||||
# __gt__ only accepts `MultiSig` (not overridden by decorator)
|
||||
def __gt__(self, other: "MultiSig") -> bool:
|
||||
return True
|
||||
|
||||
a = MultiSig(10)
|
||||
b = MultiSig(20)
|
||||
|
||||
# __le__ and __ge__ are synthesized with __lt__'s signature (accepts `object`)
|
||||
reveal_type(a <= b) # revealed: bool
|
||||
reveal_type(a <= 15) # revealed: bool
|
||||
reveal_type(a >= b) # revealed: bool
|
||||
reveal_type(a >= 15) # revealed: bool
|
||||
|
||||
# __gt__ keeps its original signature (only accepts MultiSig)
|
||||
reveal_type(a > b) # revealed: bool
|
||||
a > 15 # error: [unsupported-operator]
|
||||
```
|
||||
|
||||
## Overloaded ordering method
|
||||
|
||||
When the source ordering method is overloaded, the synthesized comparison methods should preserve
|
||||
all overloads:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
from typing import overload
|
||||
|
||||
@total_ordering
|
||||
class Flexible:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
@overload
|
||||
def __lt__(self, other: "Flexible") -> bool: ...
|
||||
@overload
|
||||
def __lt__(self, other: int) -> bool: ...
|
||||
def __lt__(self, other: "Flexible | int") -> bool:
|
||||
if isinstance(other, Flexible):
|
||||
return self.value < other.value
|
||||
return self.value < other
|
||||
|
||||
a = Flexible(10)
|
||||
b = Flexible(20)
|
||||
|
||||
# Synthesized __le__ preserves overloads from __lt__
|
||||
reveal_type(a <= b) # revealed: bool
|
||||
reveal_type(a <= 15) # revealed: bool
|
||||
|
||||
# Synthesized __ge__ also preserves overloads
|
||||
reveal_type(a >= b) # revealed: bool
|
||||
reveal_type(a >= 15) # revealed: bool
|
||||
|
||||
# But comparison with an unsupported type should still error
|
||||
a <= "string" # error: [unsupported-operator]
|
||||
```
|
||||
|
||||
## Using `__gt__` as the root comparison method
|
||||
|
||||
When a class defines `__eq__` and `__gt__`, the decorator synthesizes `__lt__`, `__le__`, and
|
||||
@@ -127,6 +246,41 @@ reveal_type(c1 > c2) # revealed: bool
|
||||
reveal_type(c1 >= c2) # revealed: bool
|
||||
```
|
||||
|
||||
## Method precedence with inheritance
|
||||
|
||||
The decorator always prefers `__lt__` > `__le__` > `__gt__` > `__ge__`, regardless of whether the
|
||||
method is defined locally or inherited. In this example, the inherited `__lt__` takes precedence
|
||||
over the locally-defined `__gt__`:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
from typing import Literal
|
||||
|
||||
class Base:
|
||||
def __lt__(self, other: "Base") -> Literal[True]:
|
||||
return True
|
||||
|
||||
@total_ordering
|
||||
class Child(Base):
|
||||
# __gt__ is defined locally, but __lt__ (inherited) takes precedence
|
||||
def __gt__(self, other: "Child") -> Literal[False]:
|
||||
return False
|
||||
|
||||
c1 = Child()
|
||||
c2 = Child()
|
||||
|
||||
# __lt__ is inherited from Base
|
||||
reveal_type(c1 < c2) # revealed: Literal[True]
|
||||
|
||||
# __gt__ is defined locally on Child
|
||||
reveal_type(c1 > c2) # revealed: Literal[False]
|
||||
|
||||
# __le__ and __ge__ are synthesized from __lt__ (the highest-priority method),
|
||||
# even though __gt__ is defined locally on the class itself
|
||||
reveal_type(c1 <= c2) # revealed: bool
|
||||
reveal_type(c1 >= c2) # revealed: bool
|
||||
```
|
||||
|
||||
## Explicitly-defined methods are not overridden
|
||||
|
||||
When a class explicitly defines multiple comparison methods, the decorator does not override them.
|
||||
@@ -245,6 +399,79 @@ n1 <= n2 # error: [unsupported-operator]
|
||||
n1 >= n2 # error: [unsupported-operator]
|
||||
```
|
||||
|
||||
## Non-bool return type
|
||||
|
||||
When the root ordering method returns a non-bool type (like `int`), the synthesized methods return a
|
||||
union of that type and `bool`. This is because `@total_ordering` generates methods like:
|
||||
|
||||
```python
|
||||
def __le__(self, other):
|
||||
return self < other or self == other
|
||||
```
|
||||
|
||||
If `__lt__` returns `int`, then the synthesized `__le__` could return either `int` (from
|
||||
`self < other`) or `bool` (from `self == other`). Since `bool` is a subtype of `int`, the union
|
||||
simplifies to `int`:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class IntReturn:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, IntReturn):
|
||||
return NotImplemented
|
||||
return self.value == other.value
|
||||
|
||||
def __lt__(self, other: "IntReturn") -> int:
|
||||
return self.value - other.value
|
||||
|
||||
a = IntReturn(10)
|
||||
b = IntReturn(20)
|
||||
|
||||
# User-defined __lt__ returns int.
|
||||
reveal_type(a < b) # revealed: int
|
||||
|
||||
# Synthesized methods return int (the union int | bool simplifies to int
|
||||
# because bool is a subtype of int in Python).
|
||||
reveal_type(a <= b) # revealed: int
|
||||
reveal_type(a > b) # revealed: int
|
||||
reveal_type(a >= b) # revealed: int
|
||||
```
|
||||
|
||||
When the root method returns a type that is not a supertype of `bool`, the union is preserved:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class StrReturn:
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, StrReturn):
|
||||
return NotImplemented
|
||||
return self.value == other.value
|
||||
|
||||
def __lt__(self, other: "StrReturn") -> str:
|
||||
return self.value
|
||||
|
||||
a = StrReturn("a")
|
||||
b = StrReturn("b")
|
||||
|
||||
# User-defined __lt__ returns str.
|
||||
reveal_type(a < b) # revealed: str
|
||||
|
||||
# Synthesized methods return str | bool.
|
||||
reveal_type(a <= b) # revealed: str | bool
|
||||
reveal_type(a > b) # revealed: str | bool
|
||||
reveal_type(a >= b) # revealed: str | bool
|
||||
```
|
||||
|
||||
## Function call form
|
||||
|
||||
When `total_ordering` is called as a function (not as a decorator), the same validation is
|
||||
|
||||
@@ -1586,17 +1586,44 @@ impl<'db> ClassLiteral<'db> {
|
||||
}
|
||||
|
||||
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
|
||||
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation and
|
||||
/// for synthesizing comparison methods.
|
||||
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation.
|
||||
pub(super) fn has_ordering_method_in_mro(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
specialization: Option<Specialization<'db>>,
|
||||
) -> bool {
|
||||
self.iter_mro(db, specialization)
|
||||
.filter_map(ClassBase::into_class)
|
||||
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
|
||||
.any(|class| class.class_literal(db).0.has_own_ordering_method(db))
|
||||
self.total_ordering_root_method(db, specialization)
|
||||
.is_some()
|
||||
}
|
||||
|
||||
/// Returns the type of the ordering method used by `@total_ordering`, if any.
|
||||
///
|
||||
/// Following `functools.total_ordering` precedence, we prefer `__lt__` > `__le__` > `__gt__` >
|
||||
/// `__ge__`, regardless of whether the method is defined locally or inherited.
|
||||
pub(super) fn total_ordering_root_method(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
specialization: Option<Specialization<'db>>,
|
||||
) -> Option<Type<'db>> {
|
||||
const ORDERING_METHODS: [&str; 4] = ["__lt__", "__le__", "__gt__", "__ge__"];
|
||||
|
||||
for name in ORDERING_METHODS {
|
||||
for base in self.iter_mro(db, specialization) {
|
||||
let Some(base_class) = base.into_class() else {
|
||||
continue;
|
||||
};
|
||||
let (base_literal, base_specialization) = base_class.class_literal(db);
|
||||
if base_literal.is_known(db, KnownClass::Object) {
|
||||
continue;
|
||||
}
|
||||
let member = class_member(db, base_literal.body_scope(db), name);
|
||||
if let Some(ty) = member.ignore_possibly_undefined() {
|
||||
return Some(ty.apply_optional_specialization(db, base_specialization));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
|
||||
@@ -2448,26 +2475,44 @@ impl<'db> ClassLiteral<'db> {
|
||||
// ordering method. The decorator requires at least one of __lt__,
|
||||
// __le__, __gt__, or __ge__ to be defined (either in this class or
|
||||
// inherited from a superclass, excluding `object`).
|
||||
if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
|
||||
if self.has_ordering_method_in_mro(db, specialization) {
|
||||
let instance_ty =
|
||||
Type::instance(db, self.apply_optional_specialization(db, specialization));
|
||||
|
||||
let signature = Signature::new(
|
||||
Parameters::new(
|
||||
db,
|
||||
[
|
||||
Parameter::positional_or_keyword(Name::new_static("self"))
|
||||
.with_annotated_type(instance_ty),
|
||||
Parameter::positional_or_keyword(Name::new_static("other"))
|
||||
.with_annotated_type(instance_ty),
|
||||
],
|
||||
),
|
||||
KnownClass::Bool.to_instance(db),
|
||||
//
|
||||
// Only synthesize methods that are not already defined in the MRO.
|
||||
if self.total_ordering(db)
|
||||
&& matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__")
|
||||
&& !self
|
||||
.iter_mro(db, specialization)
|
||||
.filter_map(ClassBase::into_class)
|
||||
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
|
||||
.any(|class| {
|
||||
class_member(db, class.class_literal(db).0.body_scope(db), name)
|
||||
.ignore_possibly_undefined()
|
||||
.is_some()
|
||||
})
|
||||
&& self.has_ordering_method_in_mro(db, specialization)
|
||||
&& let Some(root_method_ty) = self.total_ordering_root_method(db, specialization)
|
||||
&& let Some(callables) = root_method_ty.try_upcast_to_callable(db)
|
||||
{
|
||||
let bool_ty = KnownClass::Bool.to_instance(db);
|
||||
let synthesized_callables = callables.map(|callable| {
|
||||
let signatures = CallableSignature::from_overloads(
|
||||
callable.signatures(db).iter().map(|signature| {
|
||||
// The generated methods return a union of the root method's return type
|
||||
// and `bool`. This is because `@total_ordering` synthesizes methods like:
|
||||
// def __gt__(self, other): return not (self == other or self < other)
|
||||
// If `__lt__` returns `int`, then `__gt__` could return `int | bool`.
|
||||
let return_ty =
|
||||
UnionType::from_elements(db, [signature.return_ty, bool_ty]);
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
return_ty,
|
||||
)
|
||||
}),
|
||||
);
|
||||
CallableType::new(db, signatures, CallableTypeKind::FunctionLike)
|
||||
});
|
||||
|
||||
return Some(Type::function_like_callable(db, signature));
|
||||
}
|
||||
return Some(synthesized_callables.into_type(db));
|
||||
}
|
||||
|
||||
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
|
||||
|
||||
Reference in New Issue
Block a user