diff --git a/crates/ty_python_semantic/resources/mdtest/call/methods.md b/crates/ty_python_semantic/resources/mdtest/call/methods.md index c6adcbbbdc..63341a6a09 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/methods.md +++ b/crates/ty_python_semantic/resources/mdtest/call/methods.md @@ -504,6 +504,11 @@ with Child().create() as child: The [`__init_subclass__`] method is implicitly a classmethod: +```toml +[environment] +python-version = "3.12" +``` + ```py class Base: def __init_subclass__(cls, **kwargs): @@ -516,6 +521,130 @@ class Derived(Base): reveal_type(Derived.custom_attribute) # revealed: int ``` +Subclasses must be constructed with arguments matching the required arguments of the base +`__init_subclass__` method. + +```py +class Empty: ... + +class RequiresArg: + def __init_subclass__(cls, arg: int): ... + +class NoArg: + def __init_subclass__(cls): ... + +# Single-base definitions +class MissingArg(RequiresArg): ... # error: [missing-argument] +class InvalidType(RequiresArg, arg="foo"): ... # error: [invalid-argument-type] +class Valid(RequiresArg, arg=1): ... + +# error: [missing-argument] +# error: [unknown-argument] +class IncorrectArg(RequiresArg, not_arg="foo"): ... +``` + +For multiple inheritance, the first resolved `__init_subclass__` method is used. + +```py +class Empty: ... + +class RequiresArg: + def __init_subclass__(cls, arg: int): ... + +class NoArg: + def __init_subclass__(cls): ... + +class Valid(NoArg, RequiresArg): ... +class MissingArg(RequiresArg, NoArg): ... # error: [missing-argument] +class InvalidType(RequiresArg, NoArg, arg="foo"): ... # error: [invalid-argument-type] +class Valid(RequiresArg, NoArg, arg=1): ... + +# Ensure base class without __init_subclass__ is ignored +class Valid(Empty, NoArg): ... +class Valid(Empty, RequiresArg, NoArg, arg=1): ... +class MissingArg(Empty, RequiresArg): ... # error: [missing-argument] +class MissingArg(Empty, RequiresArg, NoArg): ... # error: [missing-argument] +class InvalidType(Empty, RequiresArg, NoArg, arg="foo"): ... # error: [invalid-argument-type] + +# Multiple inheritance with args +class Base(Empty, RequiresArg, NoArg, arg=1): ... +class Valid(Base, arg=1): ... +class MissingArg(Base): ... # error: [missing-argument] +class InvalidType(Base, arg="foo"): ... # error: [invalid-argument-type] +``` + +Keyword splats are allowed if their type can be determined: + +```py +from typing import TypedDict + +class RequiresKwarg: + def __init_subclass__(cls, arg: int): ... + +class WrongArg(TypedDict): + kwarg: int + +class InvalidType(TypedDict): + arg: str + +wrong_arg: WrongArg = {"kwarg": 5} + +# error: [missing-argument] +# error: [unknown-argument] +class MissingArg(RequiresKwarg, **wrong_arg): ... + +invalid_type: InvalidType = {"arg": "foo"} + +# error: [invalid-argument-type] +class InvalidType(RequiresKwarg, **invalid_type): ... +``` + +So are generics: + +```py +from typing import Generic, TypeVar, Literal, overload + +class Base[T]: + def __init_subclass__(cls, arg: T): ... + +class Valid(Base[int], arg=1): ... +class InvalidType(Base[int], arg="x"): ... # error: [invalid-argument-type] + +# Old generic syntax +T = TypeVar("T") + +class Base(Generic[T]): + def __init_subclass__(cls, arg: T) -> None: ... + +class Valid(Base[int], arg=1): ... +class InvalidType(Base[int], arg="x"): ... # error: [invalid-argument-type] +``` + +So are overloads: + +```py +class Base: + @overload + def __init_subclass__(cls, mode: Literal["a"], arg: int) -> None: ... + @overload + def __init_subclass__(cls, mode: Literal["b"], arg: str) -> None: ... + def __init_subclass__(cls, mode: str, arg: int | str) -> None: ... + +class Valid(Base, mode="a", arg=5): ... +class Valid(Base, mode="b", arg="foo"): ... +class InvalidType(Base, mode="b", arg=5): ... # error: [no-matching-overload] +``` + +The `metaclass` keyword is ignored, as it has special meaning and is not passed to +`__init_subclass__` at runtime. + +```py +class Base: + def __init_subclass__(cls, arg: int): ... + +class Valid(Base, arg=5, metaclass=object): ... +``` + ## `@staticmethod` ### Basic diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index c71697c660..c13c217968 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -53,7 +53,7 @@ use crate::semantic_index::{ }; use crate::subscript::{PyIndex, PySlice}; use crate::types::call::bind::{CallableDescription, MatchingOverloadIndex}; -use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; +use crate::types::call::{Argument, Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{ ClassLiteral, CodeGeneratorKind, DynamicClassLiteral, DynamicMetaclassConflict, FieldKind, MetaclassErrorKind, MethodDecorator, @@ -960,7 +960,47 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - // (6) If the class is generic, verify that its generic context does not violate any of + // (7) Check that the class arguments matches the arguments of the + // base class `__init_subclass__` method. + if let Some(args) = class_node.arguments.as_deref() { + let call_args: CallArguments = args + .keywords + .iter() + .filter_map(|keyword| match keyword.arg.as_ref() { + // We mimic the runtime behaviour and discard the metaclass argument + Some(name) if name.id.as_str() == "metaclass" => None, + Some(name) => { + let ty = self.expression_type(&keyword.value); + Some((Argument::Keyword(name.id.as_str()), Some(ty))) + } + None => { + let ty = self.expression_type(&keyword.value); + Some((Argument::Keywords, Some(ty))) + } + }) + .collect(); + + let init_subclass_type = class + .class_member_from_mro( + self.db(), + "__init_subclass__", + MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, + // skip(1) to skip the current class and only consider base classes. + class.iter_mro(self.db(), None).skip(1), + ) + .ignore_possibly_undefined(); + + if let Some(init_subclass) = init_subclass_type { + let call_args = call_args.with_self(Some(Type::from(class))); + if let Err(CallError(CallErrorKind::BindingError, bindings)) = + init_subclass.try_call(self.db(), &call_args) + { + bindings.report_diagnostics(&self.context, class_node.into()); + } + } + } + + // (8) If the class is generic, verify that its generic context does not violate any of // the typevar scoping rules. if let (Some(legacy), Some(inherited)) = ( class.legacy_generic_context(self.db()), @@ -1039,7 +1079,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - // (7) Check that a dataclass does not have more than one `KW_ONLY`. + // (9) Check that a dataclass does not have more than one `KW_ONLY`. if let Some(field_policy @ CodeGeneratorKind::DataclassLike(_)) = CodeGeneratorKind::from_class(self.db(), class.into(), None) { @@ -1074,7 +1114,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - // (8) Check for violations of the Liskov Substitution Principle, + // (10) Check for violations of the Liskov Substitution Principle, // and for violations of other rules relating to invalid overrides of some sort. overrides::check_class(&self.context, class); @@ -1082,7 +1122,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { protocol.validate_members(&self.context); } - // (9) If it's a `TypedDict` class, check that it doesn't include any invalid + // (11) If it's a `TypedDict` class, check that it doesn't include any invalid // statements: https://typing.python.org/en/latest/spec/typeddict.html#class-based-syntax // // The body of the class definition defines the items of the `TypedDict` type. It