diff --git a/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md b/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md index 0215ead080..08dba7069c 100644 --- a/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md +++ b/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md @@ -38,6 +38,125 @@ reveal_type(s1 > s2) # revealed: bool reveal_type(s1 >= s2) # revealed: bool ``` +## Signature derived from source ordering method + +When the source ordering method accepts a broader type (like `object`) for its `other` parameter, +the synthesized comparison methods should use the same signature. This allows comparisons with types +other than the class itself: + +```py +from functools import total_ordering + +@total_ordering +class Comparable: + def __init__(self, value: int): + self.value = value + + def __eq__(self, other: object) -> bool: + if isinstance(other, Comparable): + return self.value == other.value + if isinstance(other, int): + return self.value == other + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, Comparable): + return self.value < other.value + if isinstance(other, int): + return self.value < other + return NotImplemented + +a = Comparable(10) +b = Comparable(20) + +# Comparisons with the same type work. +reveal_type(a <= b) # revealed: bool +reveal_type(a >= b) # revealed: bool + +# Comparisons with `int` also work because `__lt__` accepts `object`. +reveal_type(a <= 15) # revealed: bool +reveal_type(a >= 5) # revealed: bool +``` + +## Multiple ordering methods with different signatures + +When multiple ordering methods are defined with different signatures, the decorator selects a "root" +method using the priority order: `__lt__` > `__le__` > `__gt__` > `__ge__`. Synthesized methods use +the signature from the highest-priority method. Methods that are explicitly defined are not +overridden. + +```py +from functools import total_ordering + +@total_ordering +class MultiSig: + def __init__(self, value: int): + self.value = value + + def __eq__(self, other: object) -> bool: + return True + # __lt__ accepts `object` (highest priority, used as root) + def __lt__(self, other: object) -> bool: + return True + # __gt__ only accepts `MultiSig` (not overridden by decorator) + def __gt__(self, other: "MultiSig") -> bool: + return True + +a = MultiSig(10) +b = MultiSig(20) + +# __le__ and __ge__ are synthesized with __lt__'s signature (accepts `object`) +reveal_type(a <= b) # revealed: bool +reveal_type(a <= 15) # revealed: bool +reveal_type(a >= b) # revealed: bool +reveal_type(a >= 15) # revealed: bool + +# __gt__ keeps its original signature (only accepts MultiSig) +reveal_type(a > b) # revealed: bool +a > 15 # error: [unsupported-operator] +``` + +## Overloaded ordering method + +When the source ordering method is overloaded, the synthesized comparison methods should preserve +all overloads: + +```py +from functools import total_ordering +from typing import overload + +@total_ordering +class Flexible: + def __init__(self, value: int): + self.value = value + + def __eq__(self, other: object) -> bool: + return True + + @overload + def __lt__(self, other: "Flexible") -> bool: ... + @overload + def __lt__(self, other: int) -> bool: ... + def __lt__(self, other: "Flexible | int") -> bool: + if isinstance(other, Flexible): + return self.value < other.value + return self.value < other + +a = Flexible(10) +b = Flexible(20) + +# Synthesized __le__ preserves overloads from __lt__ +reveal_type(a <= b) # revealed: bool +reveal_type(a <= 15) # revealed: bool + +# Synthesized __ge__ also preserves overloads +reveal_type(a >= b) # revealed: bool +reveal_type(a >= 15) # revealed: bool + +# But comparison with an unsupported type should still error +a <= "string" # error: [unsupported-operator] +``` + ## Using `__gt__` as the root comparison method When a class defines `__eq__` and `__gt__`, the decorator synthesizes `__lt__`, `__le__`, and @@ -127,6 +246,41 @@ reveal_type(c1 > c2) # revealed: bool reveal_type(c1 >= c2) # revealed: bool ``` +## Method precedence with inheritance + +The decorator always prefers `__lt__` > `__le__` > `__gt__` > `__ge__`, regardless of whether the +method is defined locally or inherited. In this example, the inherited `__lt__` takes precedence +over the locally-defined `__gt__`: + +```py +from functools import total_ordering +from typing import Literal + +class Base: + def __lt__(self, other: "Base") -> Literal[True]: + return True + +@total_ordering +class Child(Base): + # __gt__ is defined locally, but __lt__ (inherited) takes precedence + def __gt__(self, other: "Child") -> Literal[False]: + return False + +c1 = Child() +c2 = Child() + +# __lt__ is inherited from Base +reveal_type(c1 < c2) # revealed: Literal[True] + +# __gt__ is defined locally on Child +reveal_type(c1 > c2) # revealed: Literal[False] + +# __le__ and __ge__ are synthesized from __lt__ (the highest-priority method), +# even though __gt__ is defined locally on the class itself +reveal_type(c1 <= c2) # revealed: bool +reveal_type(c1 >= c2) # revealed: bool +``` + ## Explicitly-defined methods are not overridden When a class explicitly defines multiple comparison methods, the decorator does not override them. @@ -245,6 +399,79 @@ n1 <= n2 # error: [unsupported-operator] n1 >= n2 # error: [unsupported-operator] ``` +## Non-bool return type + +When the root ordering method returns a non-bool type (like `int`), the synthesized methods return a +union of that type and `bool`. This is because `@total_ordering` generates methods like: + +```python +def __le__(self, other): + return self < other or self == other +``` + +If `__lt__` returns `int`, then the synthesized `__le__` could return either `int` (from +`self < other`) or `bool` (from `self == other`). Since `bool` is a subtype of `int`, the union +simplifies to `int`: + +```py +from functools import total_ordering + +@total_ordering +class IntReturn: + def __init__(self, value: int): + self.value = value + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IntReturn): + return NotImplemented + return self.value == other.value + + def __lt__(self, other: "IntReturn") -> int: + return self.value - other.value + +a = IntReturn(10) +b = IntReturn(20) + +# User-defined __lt__ returns int. +reveal_type(a < b) # revealed: int + +# Synthesized methods return int (the union int | bool simplifies to int +# because bool is a subtype of int in Python). +reveal_type(a <= b) # revealed: int +reveal_type(a > b) # revealed: int +reveal_type(a >= b) # revealed: int +``` + +When the root method returns a type that is not a supertype of `bool`, the union is preserved: + +```py +from functools import total_ordering + +@total_ordering +class StrReturn: + def __init__(self, value: str): + self.value = value + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StrReturn): + return NotImplemented + return self.value == other.value + + def __lt__(self, other: "StrReturn") -> str: + return self.value + +a = StrReturn("a") +b = StrReturn("b") + +# User-defined __lt__ returns str. +reveal_type(a < b) # revealed: str + +# Synthesized methods return str | bool. +reveal_type(a <= b) # revealed: str | bool +reveal_type(a > b) # revealed: str | bool +reveal_type(a >= b) # revealed: str | bool +``` + ## Function call form When `total_ordering` is called as a function (not as a decorator), the same validation is diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index aad064adca..77732c05df 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1586,17 +1586,44 @@ impl<'db> ClassLiteral<'db> { } /// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering - /// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation and - /// for synthesizing comparison methods. + /// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation. pub(super) fn has_ordering_method_in_mro( self, db: &'db dyn Db, specialization: Option>, ) -> bool { - self.iter_mro(db, specialization) - .filter_map(ClassBase::into_class) - .filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object)) - .any(|class| class.class_literal(db).0.has_own_ordering_method(db)) + self.total_ordering_root_method(db, specialization) + .is_some() + } + + /// Returns the type of the ordering method used by `@total_ordering`, if any. + /// + /// Following `functools.total_ordering` precedence, we prefer `__lt__` > `__le__` > `__gt__` > + /// `__ge__`, regardless of whether the method is defined locally or inherited. + pub(super) fn total_ordering_root_method( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> Option> { + const ORDERING_METHODS: [&str; 4] = ["__lt__", "__le__", "__gt__", "__ge__"]; + + for name in ORDERING_METHODS { + for base in self.iter_mro(db, specialization) { + let Some(base_class) = base.into_class() else { + continue; + }; + let (base_literal, base_specialization) = base_class.class_literal(db); + if base_literal.is_known(db, KnownClass::Object) { + continue; + } + let member = class_member(db, base_literal.body_scope(db), name); + if let Some(ty) = member.ignore_possibly_undefined() { + return Some(ty.apply_optional_specialization(db, base_specialization)); + } + } + } + + None } pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option> { @@ -2448,26 +2475,44 @@ impl<'db> ClassLiteral<'db> { // ordering method. The decorator requires at least one of __lt__, // __le__, __gt__, or __ge__ to be defined (either in this class or // inherited from a superclass, excluding `object`). - if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { - if self.has_ordering_method_in_mro(db, specialization) { - let instance_ty = - Type::instance(db, self.apply_optional_specialization(db, specialization)); - - let signature = Signature::new( - Parameters::new( - db, - [ - Parameter::positional_or_keyword(Name::new_static("self")) - .with_annotated_type(instance_ty), - Parameter::positional_or_keyword(Name::new_static("other")) - .with_annotated_type(instance_ty), - ], - ), - KnownClass::Bool.to_instance(db), + // + // Only synthesize methods that are not already defined in the MRO. + if self.total_ordering(db) + && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") + && !self + .iter_mro(db, specialization) + .filter_map(ClassBase::into_class) + .filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object)) + .any(|class| { + class_member(db, class.class_literal(db).0.body_scope(db), name) + .ignore_possibly_undefined() + .is_some() + }) + && self.has_ordering_method_in_mro(db, specialization) + && let Some(root_method_ty) = self.total_ordering_root_method(db, specialization) + && let Some(callables) = root_method_ty.try_upcast_to_callable(db) + { + let bool_ty = KnownClass::Bool.to_instance(db); + let synthesized_callables = callables.map(|callable| { + let signatures = CallableSignature::from_overloads( + callable.signatures(db).iter().map(|signature| { + // The generated methods return a union of the root method's return type + // and `bool`. This is because `@total_ordering` synthesizes methods like: + // def __gt__(self, other): return not (self == other or self < other) + // If `__lt__` returns `int`, then `__gt__` could return `int | bool`. + let return_ty = + UnionType::from_elements(db, [signature.return_ty, bool_ty]); + Signature::new_generic( + signature.generic_context, + signature.parameters().clone(), + return_ty, + ) + }), ); + CallableType::new(db, signatures, CallableTypeKind::FunctionLike) + }); - return Some(Type::function_like_callable(db, signature)); - } + return Some(synthesized_callables.into_type(db)); } let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;