mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 13:30:49 -05:00
[ty] Add support for @total_ordering (#22181)
## Summary We have some suppressions in the pyx codebase related to this, so wanted to resolve. Closes https://github.com/astral-sh/ty/issues/1202.
This commit is contained in:
@@ -0,0 +1,246 @@
|
||||
# `functools.total_ordering`
|
||||
|
||||
The `@functools.total_ordering` decorator allows a class to define a single comparison method (like
|
||||
`__lt__`), and the decorator automatically generates the remaining comparison methods (`__le__`,
|
||||
`__gt__`, `__ge__`). Defining `__eq__` is optional, as it can be inherited from `object`.
|
||||
|
||||
## Basic usage
|
||||
|
||||
When a class defines `__eq__` and `__lt__`, the decorator synthesizes `__le__`, `__gt__`, and
|
||||
`__ge__`:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class Student:
|
||||
def __init__(self, grade: int):
|
||||
self.grade = grade
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Student):
|
||||
return NotImplemented
|
||||
return self.grade == other.grade
|
||||
|
||||
def __lt__(self, other: "Student") -> bool:
|
||||
return self.grade < other.grade
|
||||
|
||||
s1 = Student(85)
|
||||
s2 = Student(90)
|
||||
|
||||
# User-defined comparison methods work as expected.
|
||||
reveal_type(s1 == s2) # revealed: bool
|
||||
reveal_type(s1 < s2) # revealed: bool
|
||||
|
||||
# Synthesized comparison methods are available.
|
||||
reveal_type(s1 <= s2) # revealed: bool
|
||||
reveal_type(s1 > s2) # revealed: bool
|
||||
reveal_type(s1 >= s2) # revealed: bool
|
||||
```
|
||||
|
||||
## Using `__gt__` as the root comparison method
|
||||
|
||||
When a class defines `__eq__` and `__gt__`, the decorator synthesizes `__lt__`, `__le__`, and
|
||||
`__ge__`:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class Priority:
|
||||
def __init__(self, level: int):
|
||||
self.level = level
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Priority):
|
||||
return NotImplemented
|
||||
return self.level == other.level
|
||||
|
||||
def __gt__(self, other: "Priority") -> bool:
|
||||
return self.level > other.level
|
||||
|
||||
p1 = Priority(1)
|
||||
p2 = Priority(2)
|
||||
|
||||
# User-defined comparison methods work
|
||||
reveal_type(p1 == p2) # revealed: bool
|
||||
reveal_type(p1 > p2) # revealed: bool
|
||||
|
||||
# Synthesized comparison methods are available
|
||||
reveal_type(p1 < p2) # revealed: bool
|
||||
reveal_type(p1 <= p2) # revealed: bool
|
||||
reveal_type(p1 >= p2) # revealed: bool
|
||||
```
|
||||
|
||||
## Inherited `__eq__`
|
||||
|
||||
A class only needs to define a single comparison method. The `__eq__` method can be inherited from
|
||||
`object`:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class Score:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __lt__(self, other: "Score") -> bool:
|
||||
return self.value < other.value
|
||||
|
||||
s1 = Score(85)
|
||||
s2 = Score(90)
|
||||
|
||||
# `__eq__` is inherited from object.
|
||||
reveal_type(s1 == s2) # revealed: bool
|
||||
|
||||
# Synthesized comparison methods are available.
|
||||
reveal_type(s1 <= s2) # revealed: bool
|
||||
reveal_type(s1 > s2) # revealed: bool
|
||||
reveal_type(s1 >= s2) # revealed: bool
|
||||
```
|
||||
|
||||
## Inherited ordering methods
|
||||
|
||||
The decorator also works when the ordering method is inherited from a superclass:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
class Base:
|
||||
def __lt__(self, other: "Base") -> bool:
|
||||
return True
|
||||
|
||||
@total_ordering
|
||||
class Child(Base):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Child):
|
||||
return NotImplemented
|
||||
return True
|
||||
|
||||
c1 = Child()
|
||||
c2 = Child()
|
||||
|
||||
# Synthesized methods work even though `__lt__` is inherited.
|
||||
reveal_type(c1 <= c2) # revealed: bool
|
||||
reveal_type(c1 > c2) # revealed: bool
|
||||
reveal_type(c1 >= c2) # revealed: bool
|
||||
```
|
||||
|
||||
## Explicitly-defined methods are not overridden
|
||||
|
||||
When a class explicitly defines multiple comparison methods, the decorator does not override them.
|
||||
We use a narrower return type (`Literal[True]`) to verify that the explicit methods are preserved:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
from typing import Literal
|
||||
|
||||
@total_ordering
|
||||
class Temperature:
|
||||
def __init__(self, celsius: float):
|
||||
self.celsius = celsius
|
||||
|
||||
def __lt__(self, other: "Temperature") -> Literal[True]:
|
||||
return True
|
||||
|
||||
def __gt__(self, other: "Temperature") -> Literal[True]:
|
||||
return True
|
||||
|
||||
t1 = Temperature(20.0)
|
||||
t2 = Temperature(25.0)
|
||||
|
||||
# User-defined methods preserve their return type.
|
||||
reveal_type(t1 < t2) # revealed: Literal[True]
|
||||
reveal_type(t1 > t2) # revealed: Literal[True]
|
||||
|
||||
# Synthesized methods have `bool` return type.
|
||||
reveal_type(t1 <= t2) # revealed: bool
|
||||
reveal_type(t1 >= t2) # revealed: bool
|
||||
```
|
||||
|
||||
## Combined with `@dataclass`
|
||||
|
||||
The decorator works with `@dataclass`:
|
||||
|
||||
```py
|
||||
from dataclasses import dataclass
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
@dataclass
|
||||
class Point:
|
||||
x: int
|
||||
y: int
|
||||
|
||||
def __lt__(self, other: "Point") -> bool:
|
||||
return (self.x, self.y) < (other.x, other.y)
|
||||
|
||||
p1 = Point(1, 2)
|
||||
p2 = Point(3, 4)
|
||||
|
||||
# Dataclass-synthesized `__eq__` is available.
|
||||
reveal_type(p1 == p2) # revealed: bool
|
||||
|
||||
# User-defined comparison method works.
|
||||
reveal_type(p1 < p2) # revealed: bool
|
||||
|
||||
# Synthesized comparison methods are available.
|
||||
reveal_type(p1 <= p2) # revealed: bool
|
||||
reveal_type(p1 > p2) # revealed: bool
|
||||
reveal_type(p1 >= p2) # revealed: bool
|
||||
```
|
||||
|
||||
## Missing ordering method
|
||||
|
||||
If a class has `@total_ordering` but doesn't define any ordering method (itself or in a superclass),
|
||||
the decorator would fail at runtime. We don't synthesize methods in this case:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class NoOrdering:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
n1 = NoOrdering()
|
||||
n2 = NoOrdering()
|
||||
|
||||
# These should error because no ordering method is defined.
|
||||
n1 <= n2 # error: [unsupported-operator]
|
||||
n1 >= n2 # error: [unsupported-operator]
|
||||
```
|
||||
|
||||
## Without the decorator
|
||||
|
||||
Without `@total_ordering`, classes that only define `__lt__` will not have `__le__` or `__ge__`
|
||||
synthesized:
|
||||
|
||||
```py
|
||||
class NoDecorator:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, NoDecorator):
|
||||
return NotImplemented
|
||||
return self.value == other.value
|
||||
|
||||
def __lt__(self, other: "NoDecorator") -> bool:
|
||||
return self.value < other.value
|
||||
|
||||
n1 = NoDecorator(1)
|
||||
n2 = NoDecorator(2)
|
||||
|
||||
# User-defined methods work.
|
||||
reveal_type(n1 == n2) # revealed: bool
|
||||
reveal_type(n1 < n2) # revealed: bool
|
||||
|
||||
# Note: `n1 > n2` works because Python reflects it to `n2 < n1`
|
||||
reveal_type(n1 > n2) # revealed: bool
|
||||
|
||||
# These comparison operators are not available.
|
||||
n1 <= n2 # error: [unsupported-operator]
|
||||
n1 >= n2 # error: [unsupported-operator]
|
||||
```
|
||||
@@ -1122,6 +1122,7 @@ impl<'db> Bindings<'db> {
|
||||
class_literal.type_check_only(db),
|
||||
Some(params),
|
||||
class_literal.dataclass_transformer_params(db),
|
||||
class_literal.total_ordering(db),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1516,6 +1516,9 @@ pub struct ClassLiteral<'db> {
|
||||
|
||||
pub(crate) dataclass_params: Option<DataclassParams<'db>>,
|
||||
pub(crate) dataclass_transformer_params: Option<DataclassTransformerParams<'db>>,
|
||||
|
||||
/// Whether this class is decorated with `@functools.total_ordering`
|
||||
pub(crate) total_ordering: bool,
|
||||
}
|
||||
|
||||
// The Salsa heap is tracked separately.
|
||||
@@ -1540,6 +1543,17 @@ impl<'db> ClassLiteral<'db> {
|
||||
self.is_known(db, KnownClass::Tuple)
|
||||
}
|
||||
|
||||
/// Returns `true` if this class defines any ordering method (`__lt__`, `__le__`, `__gt__`,
|
||||
/// `__ge__`) in its own body (not inherited). Used by `@total_ordering` to determine if
|
||||
/// synthesis is valid.
|
||||
#[salsa::tracked]
|
||||
pub(crate) fn has_own_ordering_method(self, db: &'db dyn Db) -> bool {
|
||||
let body_scope = self.body_scope(db);
|
||||
["__lt__", "__le__", "__gt__", "__ge__"]
|
||||
.iter()
|
||||
.any(|method| !class_member(db, body_scope, method).is_undefined())
|
||||
}
|
||||
|
||||
pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
|
||||
// Several typeshed definitions examine `sys.version_info`. To break cycles, we hard-code
|
||||
// the knowledge that this class is not generic.
|
||||
@@ -2384,6 +2398,41 @@ impl<'db> ClassLiteral<'db> {
|
||||
) -> Option<Type<'db>> {
|
||||
let dataclass_params = self.dataclass_params(db);
|
||||
|
||||
// Handle `@functools.total_ordering`: synthesize comparison methods
|
||||
// for classes that have `@total_ordering` and define at least one
|
||||
// ordering method. The decorator requires at least one of __lt__,
|
||||
// __le__, __gt__, or __ge__ to be defined (either in this class or
|
||||
// inherited from a superclass, excluding `object`).
|
||||
if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
|
||||
// Check if any class in the MRO (excluding object) defines at least one
|
||||
// ordering method in its own body (not synthesized).
|
||||
let has_ordering_method = self
|
||||
.iter_mro(db, specialization)
|
||||
.filter_map(super::class_base::ClassBase::into_class)
|
||||
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
|
||||
.any(|class| class.class_literal(db).0.has_own_ordering_method(db));
|
||||
|
||||
if has_ordering_method {
|
||||
let instance_ty =
|
||||
Type::instance(db, self.apply_optional_specialization(db, specialization));
|
||||
|
||||
let signature = Signature::new(
|
||||
Parameters::new(
|
||||
db,
|
||||
[
|
||||
Parameter::positional_or_keyword(Name::new_static("self"))
|
||||
.with_annotated_type(instance_ty),
|
||||
Parameter::positional_or_keyword(Name::new_static("other"))
|
||||
.with_annotated_type(instance_ty),
|
||||
],
|
||||
),
|
||||
Some(KnownClass::Bool.to_instance(db)),
|
||||
);
|
||||
|
||||
return Some(Type::function_like_callable(db, signature));
|
||||
}
|
||||
}
|
||||
|
||||
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
|
||||
|
||||
let mut transformer_params =
|
||||
|
||||
@@ -1413,6 +1413,9 @@ pub enum KnownFunction {
|
||||
/// `dataclasses.field`
|
||||
Field,
|
||||
|
||||
/// `functools.total_ordering`
|
||||
TotalOrdering,
|
||||
|
||||
/// `inspect.getattr_static`
|
||||
GetattrStatic,
|
||||
|
||||
@@ -1501,6 +1504,7 @@ impl KnownFunction {
|
||||
Self::Dataclass | Self::Field => {
|
||||
matches!(module, KnownModule::Dataclasses)
|
||||
}
|
||||
Self::TotalOrdering => module.is_functools(),
|
||||
Self::GetattrStatic => module.is_inspect(),
|
||||
Self::IsAssignableTo
|
||||
| Self::IsDisjointFrom
|
||||
@@ -2068,6 +2072,7 @@ pub(crate) mod tests {
|
||||
|
||||
KnownFunction::ImportModule => KnownModule::ImportLib,
|
||||
KnownFunction::NamedTuple => KnownModule::Collections,
|
||||
KnownFunction::TotalOrdering => KnownModule::Functools,
|
||||
};
|
||||
|
||||
let function_definition = known_module_symbol(&db, module, function_name)
|
||||
|
||||
@@ -2864,6 +2864,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
let mut type_check_only = false;
|
||||
let mut dataclass_params = None;
|
||||
let mut dataclass_transformer_params = None;
|
||||
let mut total_ordering = false;
|
||||
for decorator in decorator_list {
|
||||
let decorator_ty = self.infer_decorator(decorator);
|
||||
if decorator_ty
|
||||
@@ -2874,6 +2875,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
continue;
|
||||
}
|
||||
|
||||
if decorator_ty
|
||||
.as_function_literal()
|
||||
.is_some_and(|function| function.is_known(self.db(), KnownFunction::TotalOrdering))
|
||||
{
|
||||
total_ordering = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Type::DataclassDecorator(params) = decorator_ty {
|
||||
dataclass_params = Some(params);
|
||||
continue;
|
||||
@@ -2961,6 +2970,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
type_check_only,
|
||||
dataclass_params,
|
||||
dataclass_transformer_params,
|
||||
total_ordering,
|
||||
)),
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user