diff --git a/crates/ty_module_resolver/src/list.rs b/crates/ty_module_resolver/src/list.rs index 219b76683a..1770d2a536 100644 --- a/crates/ty_module_resolver/src/list.rs +++ b/crates/ty_module_resolver/src/list.rs @@ -619,7 +619,7 @@ mod tests { list_snapshot(&db), @r#" [ - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), ] "#, ); @@ -662,7 +662,7 @@ mod tests { @r#" [ Module::File("asyncio", "std-custom", "/typeshed/stdlib/asyncio/__init__.pyi", Package, None), - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), Module::File("random", "std-custom", "/typeshed/stdlib/random.pyi", Module, None), ] "#, @@ -755,7 +755,7 @@ mod tests { [ Module::File("asyncio", "std-custom", "/typeshed/stdlib/asyncio/__init__.pyi", Package, None), Module::File("collections", "std-custom", "/typeshed/stdlib/collections/__init__.pyi", Package, Some(Collections)), - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), ] "#, ); @@ -1091,7 +1091,7 @@ mod tests { list_snapshot(&db), @r#" [ - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), ] "#, ); @@ -1107,7 +1107,7 @@ mod tests { list_snapshot(&db), @r#" [ - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), ] "#, ); @@ -1129,7 +1129,7 @@ mod tests { list_snapshot(&db), @r#" [ - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), ] "#, ); @@ -1191,7 +1191,7 @@ mod tests { list_snapshot(&db), @r#" [ - Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, None), + Module::File("functools", "std-custom", "/typeshed/stdlib/functools.pyi", Module, Some(Functools)), ] "#, ); diff --git a/crates/ty_module_resolver/src/module.rs b/crates/ty_module_resolver/src/module.rs index 6f625686c9..02694404e2 100644 --- a/crates/ty_module_resolver/src/module.rs +++ b/crates/ty_module_resolver/src/module.rs @@ -320,6 +320,7 @@ pub enum KnownModule { Abc, Contextlib, Dataclasses, + Functools, Collections, Inspect, #[strum(serialize = "string.templatelib")] @@ -351,6 +352,7 @@ impl KnownModule { Self::Abc => "abc", Self::Contextlib => "contextlib", Self::Dataclasses => "dataclasses", + Self::Functools => "functools", Self::Collections => "collections", Self::Inspect => "inspect", Self::TypeCheckerInternals => "_typeshed._type_checker_internals", @@ -395,6 +397,10 @@ impl KnownModule { pub const fn is_importlib(self) -> bool { matches!(self, Self::ImportLib) } + + pub const fn is_functools(self) -> bool { + matches!(self, Self::Functools) + } } impl std::fmt::Display for KnownModule { diff --git a/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md b/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md new file mode 100644 index 0000000000..2739b31d61 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/decorators/total_ordering.md @@ -0,0 +1,246 @@ +# `functools.total_ordering` + +The `@functools.total_ordering` decorator allows a class to define a single comparison method (like +`__lt__`), and the decorator automatically generates the remaining comparison methods (`__le__`, +`__gt__`, `__ge__`). Defining `__eq__` is optional, as it can be inherited from `object`. + +## Basic usage + +When a class defines `__eq__` and `__lt__`, the decorator synthesizes `__le__`, `__gt__`, and +`__ge__`: + +```py +from functools import total_ordering + +@total_ordering +class Student: + def __init__(self, grade: int): + self.grade = grade + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Student): + return NotImplemented + return self.grade == other.grade + + def __lt__(self, other: "Student") -> bool: + return self.grade < other.grade + +s1 = Student(85) +s2 = Student(90) + +# User-defined comparison methods work as expected. +reveal_type(s1 == s2) # revealed: bool +reveal_type(s1 < s2) # revealed: bool + +# Synthesized comparison methods are available. +reveal_type(s1 <= s2) # revealed: bool +reveal_type(s1 > s2) # revealed: bool +reveal_type(s1 >= s2) # revealed: bool +``` + +## Using `__gt__` as the root comparison method + +When a class defines `__eq__` and `__gt__`, the decorator synthesizes `__lt__`, `__le__`, and +`__ge__`: + +```py +from functools import total_ordering + +@total_ordering +class Priority: + def __init__(self, level: int): + self.level = level + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Priority): + return NotImplemented + return self.level == other.level + + def __gt__(self, other: "Priority") -> bool: + return self.level > other.level + +p1 = Priority(1) +p2 = Priority(2) + +# User-defined comparison methods work +reveal_type(p1 == p2) # revealed: bool +reveal_type(p1 > p2) # revealed: bool + +# Synthesized comparison methods are available +reveal_type(p1 < p2) # revealed: bool +reveal_type(p1 <= p2) # revealed: bool +reveal_type(p1 >= p2) # revealed: bool +``` + +## Inherited `__eq__` + +A class only needs to define a single comparison method. The `__eq__` method can be inherited from +`object`: + +```py +from functools import total_ordering + +@total_ordering +class Score: + def __init__(self, value: int): + self.value = value + + def __lt__(self, other: "Score") -> bool: + return self.value < other.value + +s1 = Score(85) +s2 = Score(90) + +# `__eq__` is inherited from object. +reveal_type(s1 == s2) # revealed: bool + +# Synthesized comparison methods are available. +reveal_type(s1 <= s2) # revealed: bool +reveal_type(s1 > s2) # revealed: bool +reveal_type(s1 >= s2) # revealed: bool +``` + +## Inherited ordering methods + +The decorator also works when the ordering method is inherited from a superclass: + +```py +from functools import total_ordering + +class Base: + def __lt__(self, other: "Base") -> bool: + return True + +@total_ordering +class Child(Base): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Child): + return NotImplemented + return True + +c1 = Child() +c2 = Child() + +# Synthesized methods work even though `__lt__` is inherited. +reveal_type(c1 <= c2) # revealed: bool +reveal_type(c1 > c2) # revealed: bool +reveal_type(c1 >= c2) # revealed: bool +``` + +## Explicitly-defined methods are not overridden + +When a class explicitly defines multiple comparison methods, the decorator does not override them. +We use a narrower return type (`Literal[True]`) to verify that the explicit methods are preserved: + +```py +from functools import total_ordering +from typing import Literal + +@total_ordering +class Temperature: + def __init__(self, celsius: float): + self.celsius = celsius + + def __lt__(self, other: "Temperature") -> Literal[True]: + return True + + def __gt__(self, other: "Temperature") -> Literal[True]: + return True + +t1 = Temperature(20.0) +t2 = Temperature(25.0) + +# User-defined methods preserve their return type. +reveal_type(t1 < t2) # revealed: Literal[True] +reveal_type(t1 > t2) # revealed: Literal[True] + +# Synthesized methods have `bool` return type. +reveal_type(t1 <= t2) # revealed: bool +reveal_type(t1 >= t2) # revealed: bool +``` + +## Combined with `@dataclass` + +The decorator works with `@dataclass`: + +```py +from dataclasses import dataclass +from functools import total_ordering + +@total_ordering +@dataclass +class Point: + x: int + y: int + + def __lt__(self, other: "Point") -> bool: + return (self.x, self.y) < (other.x, other.y) + +p1 = Point(1, 2) +p2 = Point(3, 4) + +# Dataclass-synthesized `__eq__` is available. +reveal_type(p1 == p2) # revealed: bool + +# User-defined comparison method works. +reveal_type(p1 < p2) # revealed: bool + +# Synthesized comparison methods are available. +reveal_type(p1 <= p2) # revealed: bool +reveal_type(p1 > p2) # revealed: bool +reveal_type(p1 >= p2) # revealed: bool +``` + +## Missing ordering method + +If a class has `@total_ordering` but doesn't define any ordering method (itself or in a superclass), +the decorator would fail at runtime. We don't synthesize methods in this case: + +```py +from functools import total_ordering + +@total_ordering +class NoOrdering: + def __eq__(self, other: object) -> bool: + return True + +n1 = NoOrdering() +n2 = NoOrdering() + +# These should error because no ordering method is defined. +n1 <= n2 # error: [unsupported-operator] +n1 >= n2 # error: [unsupported-operator] +``` + +## Without the decorator + +Without `@total_ordering`, classes that only define `__lt__` will not have `__le__` or `__ge__` +synthesized: + +```py +class NoDecorator: + def __init__(self, value: int): + self.value = value + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NoDecorator): + return NotImplemented + return self.value == other.value + + def __lt__(self, other: "NoDecorator") -> bool: + return self.value < other.value + +n1 = NoDecorator(1) +n2 = NoDecorator(2) + +# User-defined methods work. +reveal_type(n1 == n2) # revealed: bool +reveal_type(n1 < n2) # revealed: bool + +# Note: `n1 > n2` works because Python reflects it to `n2 < n1` +reveal_type(n1 > n2) # revealed: bool + +# These comparison operators are not available. +n1 <= n2 # error: [unsupported-operator] +n1 >= n2 # error: [unsupported-operator] +``` diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 24c98e94dc..78f29de858 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1122,6 +1122,7 @@ impl<'db> Bindings<'db> { class_literal.type_check_only(db), Some(params), class_literal.dataclass_transformer_params(db), + class_literal.total_ordering(db), ))); } } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index c3f61879b4..3c1864decc 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1516,6 +1516,9 @@ pub struct ClassLiteral<'db> { pub(crate) dataclass_params: Option>, pub(crate) dataclass_transformer_params: Option>, + + /// Whether this class is decorated with `@functools.total_ordering` + pub(crate) total_ordering: bool, } // The Salsa heap is tracked separately. @@ -1540,6 +1543,17 @@ impl<'db> ClassLiteral<'db> { self.is_known(db, KnownClass::Tuple) } + /// Returns `true` if this class defines any ordering method (`__lt__`, `__le__`, `__gt__`, + /// `__ge__`) in its own body (not inherited). Used by `@total_ordering` to determine if + /// synthesis is valid. + #[salsa::tracked] + pub(crate) fn has_own_ordering_method(self, db: &'db dyn Db) -> bool { + let body_scope = self.body_scope(db); + ["__lt__", "__le__", "__gt__", "__ge__"] + .iter() + .any(|method| !class_member(db, body_scope, method).is_undefined()) + } + pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option> { // Several typeshed definitions examine `sys.version_info`. To break cycles, we hard-code // the knowledge that this class is not generic. @@ -2384,6 +2398,41 @@ impl<'db> ClassLiteral<'db> { ) -> Option> { let dataclass_params = self.dataclass_params(db); + // Handle `@functools.total_ordering`: synthesize comparison methods + // for classes that have `@total_ordering` and define at least one + // ordering method. The decorator requires at least one of __lt__, + // __le__, __gt__, or __ge__ to be defined (either in this class or + // inherited from a superclass, excluding `object`). + if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { + // Check if any class in the MRO (excluding object) defines at least one + // ordering method in its own body (not synthesized). + let has_ordering_method = self + .iter_mro(db, specialization) + .filter_map(super::class_base::ClassBase::into_class) + .filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object)) + .any(|class| class.class_literal(db).0.has_own_ordering_method(db)); + + if has_ordering_method { + let instance_ty = + Type::instance(db, self.apply_optional_specialization(db, specialization)); + + let signature = Signature::new( + Parameters::new( + db, + [ + Parameter::positional_or_keyword(Name::new_static("self")) + .with_annotated_type(instance_ty), + Parameter::positional_or_keyword(Name::new_static("other")) + .with_annotated_type(instance_ty), + ], + ), + Some(KnownClass::Bool.to_instance(db)), + ); + + return Some(Type::function_like_callable(db, signature)); + } + } + let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?; let mut transformer_params = diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 1ea2f4dc83..1ae86a08df 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1413,6 +1413,9 @@ pub enum KnownFunction { /// `dataclasses.field` Field, + /// `functools.total_ordering` + TotalOrdering, + /// `inspect.getattr_static` GetattrStatic, @@ -1501,6 +1504,7 @@ impl KnownFunction { Self::Dataclass | Self::Field => { matches!(module, KnownModule::Dataclasses) } + Self::TotalOrdering => module.is_functools(), Self::GetattrStatic => module.is_inspect(), Self::IsAssignableTo | Self::IsDisjointFrom @@ -2068,6 +2072,7 @@ pub(crate) mod tests { KnownFunction::ImportModule => KnownModule::ImportLib, KnownFunction::NamedTuple => KnownModule::Collections, + KnownFunction::TotalOrdering => KnownModule::Functools, }; let function_definition = known_module_symbol(&db, module, function_name) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 8b7b28d2cf..7a057380e9 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -2864,6 +2864,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut type_check_only = false; let mut dataclass_params = None; let mut dataclass_transformer_params = None; + let mut total_ordering = false; for decorator in decorator_list { let decorator_ty = self.infer_decorator(decorator); if decorator_ty @@ -2874,6 +2875,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { continue; } + if decorator_ty + .as_function_literal() + .is_some_and(|function| function.is_known(self.db(), KnownFunction::TotalOrdering)) + { + total_ordering = true; + continue; + } + if let Type::DataclassDecorator(params) = decorator_ty { dataclass_params = Some(params); continue; @@ -2961,6 +2970,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { type_check_only, dataclass_params, dataclass_transformer_params, + total_ordering, )), };