[ty] Preserve argument signature in @total_ordering (#22496)

## Summary

Closes https://github.com/astral-sh/ty/issues/2435.
This commit is contained in:
Charlie Marsh
2026-01-10 14:35:58 -05:00
committed by GitHub
parent 8e29be9c1c
commit 2c68057c4b
2 changed files with 296 additions and 24 deletions

View File

@@ -38,6 +38,125 @@ reveal_type(s1 > s2) # revealed: bool
reveal_type(s1 >= s2) # revealed: bool
```
## Signature derived from source ordering method
When the source ordering method accepts a broader type (like `object`) for its `other` parameter,
the synthesized comparison methods should use the same signature. This allows comparisons with types
other than the class itself:
```py
from functools import total_ordering
@total_ordering
class Comparable:
def __init__(self, value: int):
self.value = value
def __eq__(self, other: object) -> bool:
if isinstance(other, Comparable):
return self.value == other.value
if isinstance(other, int):
return self.value == other
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Comparable):
return self.value < other.value
if isinstance(other, int):
return self.value < other
return NotImplemented
a = Comparable(10)
b = Comparable(20)
# Comparisons with the same type work.
reveal_type(a <= b) # revealed: bool
reveal_type(a >= b) # revealed: bool
# Comparisons with `int` also work because `__lt__` accepts `object`.
reveal_type(a <= 15) # revealed: bool
reveal_type(a >= 5) # revealed: bool
```
## Multiple ordering methods with different signatures
When multiple ordering methods are defined with different signatures, the decorator selects a "root"
method using the priority order: `__lt__` > `__le__` > `__gt__` > `__ge__`. Synthesized methods use
the signature from the highest-priority method. Methods that are explicitly defined are not
overridden.
```py
from functools import total_ordering
@total_ordering
class MultiSig:
def __init__(self, value: int):
self.value = value
def __eq__(self, other: object) -> bool:
return True
# __lt__ accepts `object` (highest priority, used as root)
def __lt__(self, other: object) -> bool:
return True
# __gt__ only accepts `MultiSig` (not overridden by decorator)
def __gt__(self, other: "MultiSig") -> bool:
return True
a = MultiSig(10)
b = MultiSig(20)
# __le__ and __ge__ are synthesized with __lt__'s signature (accepts `object`)
reveal_type(a <= b) # revealed: bool
reveal_type(a <= 15) # revealed: bool
reveal_type(a >= b) # revealed: bool
reveal_type(a >= 15) # revealed: bool
# __gt__ keeps its original signature (only accepts MultiSig)
reveal_type(a > b) # revealed: bool
a > 15 # error: [unsupported-operator]
```
## Overloaded ordering method
When the source ordering method is overloaded, the synthesized comparison methods should preserve
all overloads:
```py
from functools import total_ordering
from typing import overload
@total_ordering
class Flexible:
def __init__(self, value: int):
self.value = value
def __eq__(self, other: object) -> bool:
return True
@overload
def __lt__(self, other: "Flexible") -> bool: ...
@overload
def __lt__(self, other: int) -> bool: ...
def __lt__(self, other: "Flexible | int") -> bool:
if isinstance(other, Flexible):
return self.value < other.value
return self.value < other
a = Flexible(10)
b = Flexible(20)
# Synthesized __le__ preserves overloads from __lt__
reveal_type(a <= b) # revealed: bool
reveal_type(a <= 15) # revealed: bool
# Synthesized __ge__ also preserves overloads
reveal_type(a >= b) # revealed: bool
reveal_type(a >= 15) # revealed: bool
# But comparison with an unsupported type should still error
a <= "string" # error: [unsupported-operator]
```
## Using `__gt__` as the root comparison method
When a class defines `__eq__` and `__gt__`, the decorator synthesizes `__lt__`, `__le__`, and
@@ -127,6 +246,41 @@ reveal_type(c1 > c2) # revealed: bool
reveal_type(c1 >= c2) # revealed: bool
```
## Method precedence with inheritance
The decorator always prefers `__lt__` > `__le__` > `__gt__` > `__ge__`, regardless of whether the
method is defined locally or inherited. In this example, the inherited `__lt__` takes precedence
over the locally-defined `__gt__`:
```py
from functools import total_ordering
from typing import Literal
class Base:
def __lt__(self, other: "Base") -> Literal[True]:
return True
@total_ordering
class Child(Base):
# __gt__ is defined locally, but __lt__ (inherited) takes precedence
def __gt__(self, other: "Child") -> Literal[False]:
return False
c1 = Child()
c2 = Child()
# __lt__ is inherited from Base
reveal_type(c1 < c2) # revealed: Literal[True]
# __gt__ is defined locally on Child
reveal_type(c1 > c2) # revealed: Literal[False]
# __le__ and __ge__ are synthesized from __lt__ (the highest-priority method),
# even though __gt__ is defined locally on the class itself
reveal_type(c1 <= c2) # revealed: bool
reveal_type(c1 >= c2) # revealed: bool
```
## Explicitly-defined methods are not overridden
When a class explicitly defines multiple comparison methods, the decorator does not override them.
@@ -245,6 +399,79 @@ n1 <= n2 # error: [unsupported-operator]
n1 >= n2 # error: [unsupported-operator]
```
## Non-bool return type
When the root ordering method returns a non-bool type (like `int`), the synthesized methods return a
union of that type and `bool`. This is because `@total_ordering` generates methods like:
```python
def __le__(self, other):
return self < other or self == other
```
If `__lt__` returns `int`, then the synthesized `__le__` could return either `int` (from
`self < other`) or `bool` (from `self == other`). Since `bool` is a subtype of `int`, the union
simplifies to `int`:
```py
from functools import total_ordering
@total_ordering
class IntReturn:
def __init__(self, value: int):
self.value = value
def __eq__(self, other: object) -> bool:
if not isinstance(other, IntReturn):
return NotImplemented
return self.value == other.value
def __lt__(self, other: "IntReturn") -> int:
return self.value - other.value
a = IntReturn(10)
b = IntReturn(20)
# User-defined __lt__ returns int.
reveal_type(a < b) # revealed: int
# Synthesized methods return int (the union int | bool simplifies to int
# because bool is a subtype of int in Python).
reveal_type(a <= b) # revealed: int
reveal_type(a > b) # revealed: int
reveal_type(a >= b) # revealed: int
```
When the root method returns a type that is not a supertype of `bool`, the union is preserved:
```py
from functools import total_ordering
@total_ordering
class StrReturn:
def __init__(self, value: str):
self.value = value
def __eq__(self, other: object) -> bool:
if not isinstance(other, StrReturn):
return NotImplemented
return self.value == other.value
def __lt__(self, other: "StrReturn") -> str:
return self.value
a = StrReturn("a")
b = StrReturn("b")
# User-defined __lt__ returns str.
reveal_type(a < b) # revealed: str
# Synthesized methods return str | bool.
reveal_type(a <= b) # revealed: str | bool
reveal_type(a > b) # revealed: str | bool
reveal_type(a >= b) # revealed: str | bool
```
## Function call form
When `total_ordering` is called as a function (not as a decorator), the same validation is

View File

@@ -1586,17 +1586,44 @@ impl<'db> ClassLiteral<'db> {
}
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation and
/// for synthesizing comparison methods.
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation.
pub(super) fn has_ordering_method_in_mro(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> bool {
self.iter_mro(db, specialization)
.filter_map(ClassBase::into_class)
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
.any(|class| class.class_literal(db).0.has_own_ordering_method(db))
self.total_ordering_root_method(db, specialization)
.is_some()
}
/// Returns the type of the ordering method used by `@total_ordering`, if any.
///
/// Following `functools.total_ordering` precedence, we prefer `__lt__` > `__le__` > `__gt__` >
/// `__ge__`, regardless of whether the method is defined locally or inherited.
pub(super) fn total_ordering_root_method(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> Option<Type<'db>> {
const ORDERING_METHODS: [&str; 4] = ["__lt__", "__le__", "__gt__", "__ge__"];
for name in ORDERING_METHODS {
for base in self.iter_mro(db, specialization) {
let Some(base_class) = base.into_class() else {
continue;
};
let (base_literal, base_specialization) = base_class.class_literal(db);
if base_literal.is_known(db, KnownClass::Object) {
continue;
}
let member = class_member(db, base_literal.body_scope(db), name);
if let Some(ty) = member.ignore_possibly_undefined() {
return Some(ty.apply_optional_specialization(db, base_specialization));
}
}
}
None
}
pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
@@ -2448,26 +2475,44 @@ impl<'db> ClassLiteral<'db> {
// ordering method. The decorator requires at least one of __lt__,
// __le__, __gt__, or __ge__ to be defined (either in this class or
// inherited from a superclass, excluding `object`).
if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
if self.has_ordering_method_in_mro(db, specialization) {
let instance_ty =
Type::instance(db, self.apply_optional_specialization(db, specialization));
let signature = Signature::new(
Parameters::new(
db,
[
Parameter::positional_or_keyword(Name::new_static("self"))
.with_annotated_type(instance_ty),
Parameter::positional_or_keyword(Name::new_static("other"))
.with_annotated_type(instance_ty),
],
),
KnownClass::Bool.to_instance(db),
//
// Only synthesize methods that are not already defined in the MRO.
if self.total_ordering(db)
&& matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__")
&& !self
.iter_mro(db, specialization)
.filter_map(ClassBase::into_class)
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
.any(|class| {
class_member(db, class.class_literal(db).0.body_scope(db), name)
.ignore_possibly_undefined()
.is_some()
})
&& self.has_ordering_method_in_mro(db, specialization)
&& let Some(root_method_ty) = self.total_ordering_root_method(db, specialization)
&& let Some(callables) = root_method_ty.try_upcast_to_callable(db)
{
let bool_ty = KnownClass::Bool.to_instance(db);
let synthesized_callables = callables.map(|callable| {
let signatures = CallableSignature::from_overloads(
callable.signatures(db).iter().map(|signature| {
// The generated methods return a union of the root method's return type
// and `bool`. This is because `@total_ordering` synthesizes methods like:
// def __gt__(self, other): return not (self == other or self < other)
// If `__lt__` returns `int`, then `__gt__` could return `int | bool`.
let return_ty =
UnionType::from_elements(db, [signature.return_ty, bool_ty]);
Signature::new_generic(
signature.generic_context,
signature.parameters().clone(),
return_ty,
)
}),
);
CallableType::new(db, signatures, CallableTypeKind::FunctionLike)
});
return Some(Type::function_like_callable(db, signature));
}
return Some(synthesized_callables.into_type(db));
}
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;