diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index a52e0588b7..11f648ebd7 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -1411,7 +1411,7 @@ quux. __getstate__ :: bound method Quux.__getstate__() -> object __hash__ :: bound method Quux.__hash__() -> int __init__ :: bound method Quux.__init__() -> Unknown - __init_subclass__ :: bound method Quux.__init_subclass__() -> None + __init_subclass__ :: bound method type[Quux].__init_subclass__() -> None __module__ :: str __ne__ :: bound method Quux.__ne__(value: object, /) -> bool __new__ :: bound method Quux.__new__() -> Quux @@ -1456,7 +1456,7 @@ quux.b __getstate__ :: bound method Quux.__getstate__() -> object __hash__ :: bound method Quux.__hash__() -> int __init__ :: bound method Quux.__init__() -> Unknown - __init_subclass__ :: bound method Quux.__init_subclass__() -> None + __init_subclass__ :: bound method type[Quux].__init_subclass__() -> None __module__ :: str __ne__ :: bound method Quux.__ne__(value: object, /) -> bool __new__ :: bound method Quux.__new__() -> Quux @@ -1506,7 +1506,7 @@ C. __getstate__ :: def __getstate__(self) -> object __hash__ :: def __hash__(self) -> int __init__ :: def __init__(self) -> None - __init_subclass__ :: def __init_subclass__(cls) -> None + __init_subclass__ :: bound method .__init_subclass__() -> None __instancecheck__ :: bound method .__instancecheck__(instance: Any, /) -> bool __itemsize__ :: int __module__ :: str @@ -1575,7 +1575,7 @@ Meta. __getstate__ :: def __getstate__(self) -> object __hash__ :: def __hash__(self) -> int __init__ :: Overload[(self, o: object, /) -> None, (self, name: str, bases: tuple[type, ...], dict: dict[str, Any], /, **kwds: Any) -> None] - __init_subclass__ :: def __init_subclass__(cls) -> None + __init_subclass__ :: bound method .__init_subclass__() -> None __instancecheck__ :: def __instancecheck__(self, instance: Any, /) -> bool __itemsize__ :: int __module__ :: str @@ -1682,7 +1682,7 @@ Quux. __getstate__ :: def __getstate__(self) -> object __hash__ :: def __hash__(self) -> int __init__ :: def __init__(self) -> Unknown - __init_subclass__ :: def __init_subclass__(cls) -> None + __init_subclass__ :: bound method .__init_subclass__() -> None __instancecheck__ :: bound method .__instancecheck__(instance: Any, /) -> bool __itemsize__ :: int __module__ :: str @@ -1756,7 +1756,7 @@ Answer. __getstate__ :: def __getstate__(self) -> object __hash__ :: def __hash__(self) -> int __init__ :: def __init__(self) -> None - __init_subclass__ :: def __init_subclass__(cls) -> None + __init_subclass__ :: bound method .__init_subclass__() -> None __instancecheck__ :: bound method .__instancecheck__(instance: Any, /) -> bool __itemsize__ :: int __iter__ :: bound method .__iter__[_EnumMemberT]() -> Iterator[_EnumMemberT@__iter__] diff --git a/crates/ty_python_semantic/resources/mdtest/call/methods.md b/crates/ty_python_semantic/resources/mdtest/call/methods.md index 9bac69ea59..9185e1722e 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/methods.md +++ b/crates/ty_python_semantic/resources/mdtest/call/methods.md @@ -462,6 +462,22 @@ reveal_type(C.f2(1)) # revealed: str reveal_type(C().f2(1)) # revealed: str ``` +### `__init_subclass__` + +The [`__init_subclass__`] method is implicitly a classmethod: + +```py +class Base: + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.custom_attribute: int = 0 + +class Derived(Base): + pass + +reveal_type(Derived.custom_attribute) # revealed: int +``` + ## `@staticmethod` ### Basic @@ -571,3 +587,4 @@ reveal_type(C().f2(1)) # revealed: str ``` [functions and methods]: https://docs.python.org/3/howto/descriptor.html#functions-and-methods +[`__init_subclass__`]: https://docs.python.org/3/reference/datamodel.html#object.__init_subclass__ diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 509ff9af45..23f4422867 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -46,7 +46,7 @@ use crate::types::diagnostic::{INVALID_AWAIT, INVALID_TYPE_FORM, UNSUPPORTED_BOO pub use crate::types::display::DisplaySettings; use crate::types::enums::{enum_metadata, is_single_member_enum}; use crate::types::function::{ - DataclassTransformerParams, FunctionDecorators, FunctionSpans, FunctionType, KnownFunction, + DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction, }; use crate::types::generics::{ GenericContext, PartialSpecialization, Specialization, bind_typevar, walk_generic_context, @@ -8819,10 +8819,7 @@ impl<'db> BoundMethodType<'db> { /// a `@classmethod`, then it should be an instance of that bound-instance type. pub(crate) fn typing_self_type(self, db: &'db dyn Db) -> Type<'db> { let mut self_instance = self.self_instance(db); - if self - .function(db) - .has_known_decorator(db, FunctionDecorators::CLASSMETHOD) - { + if self.function(db).is_classmethod(db) { self_instance = self_instance.to_instance(db).unwrap_or_else(Type::unknown); } self_instance diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index ce218ab90e..f9665cfc98 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -272,7 +272,7 @@ impl<'db> Bindings<'db> { for (overload_index, overload) in binding.matching_overloads_mut() { match binding_type { Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { - if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { + if function.is_classmethod(db) { match overload.parameter_types() { [_, Some(owner)] => { overload.set_return_type(Type::BoundMethod( @@ -308,7 +308,7 @@ impl<'db> Bindings<'db> { if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] = overload.parameter_types() { - if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { + if function.is_classmethod(db) { match overload.parameter_types() { [_, _, Some(owner)] => { overload.set_return_type(Type::BoundMethod( diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 8881729751..d678e749a2 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1258,7 +1258,7 @@ pub(super) enum MethodDecorator { impl MethodDecorator { fn try_from_fn_type(db: &dyn Db, fn_type: FunctionType) -> Result { match ( - fn_type.has_known_decorator(db, FunctionDecorators::CLASSMETHOD), + fn_type.is_classmethod(db), fn_type.has_known_decorator(db, FunctionDecorators::STATICMETHOD), ) { (true, true) => Err(()), // A method can't be static and class method at the same time. diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 9cda288267..6b3c2902c8 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -721,6 +721,13 @@ impl<'db> FunctionType<'db> { self.literal(db).has_known_decorator(db, decorator) } + /// Returns true if this method is decorated with `@classmethod`, or if it is implicitly a + /// classmethod. + pub(crate) fn is_classmethod(self, db: &'db dyn Db) -> bool { + self.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) + || self.name(db) == "__init_subclass__" + } + /// If the implementation of this function is deprecated, returns the `@warnings.deprecated`. /// /// Checking if an overload is deprecated requires deeper call analysis.