diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index 6eb506e602..413edf54b1 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -1166,6 +1166,56 @@ static_assert(is_subtype_of(TypeOf[C], Callable[[int], int])) static_assert(is_subtype_of(TypeOf[C], Callable[[], str])) ``` +#### Classes with `__new__` + +```py +from typing import Callable +from knot_extensions import TypeOf, static_assert, is_subtype_of + +class A: + def __new__(cls, a: int) -> int: + return a + +static_assert(is_subtype_of(TypeOf[A], Callable[[int], int])) +static_assert(not is_subtype_of(TypeOf[A], Callable[[], int])) + +class B: ... +class C(B): ... + +class D: + def __new__(cls) -> B: + return B() + +class E(D): + def __new__(cls) -> C: + return C() + +static_assert(is_subtype_of(TypeOf[E], Callable[[], C])) +static_assert(is_subtype_of(TypeOf[E], Callable[[], B])) +static_assert(not is_subtype_of(TypeOf[D], Callable[[], C])) +static_assert(is_subtype_of(TypeOf[D], Callable[[], B])) +``` + +#### Classes with `__call__` and `__new__` + +If `__call__` and `__new__` are both present, `__call__` takes precedence. + +```py +from typing import Callable +from knot_extensions import TypeOf, static_assert, is_subtype_of + +class MetaWithIntReturn(type): + def __call__(cls) -> int: + return super().__call__() + +class F(metaclass=MetaWithIntReturn): + def __new__(cls) -> str: + return super().__new__(cls) + +static_assert(is_subtype_of(TypeOf[F], Callable[[], int])) +static_assert(not is_subtype_of(TypeOf[F], Callable[[], str])) +``` + ### Bound methods ```py diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index f65054c182..ec35873e32 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1175,25 +1175,9 @@ impl<'db> Type<'db> { self_subclass_ty.is_subtype_of(db, target_subclass_ty) } - (Type::ClassLiteral(_), Type::Callable(_)) => { - let metaclass_call_function_symbol = self - .member_lookup_with_policy( - db, - "__call__".into(), - MemberLookupPolicy::NO_INSTANCE_FALLBACK - | MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK, - ) - .symbol; - - if let Symbol::Type(Type::BoundMethod(metaclass_call_function), _) = - metaclass_call_function_symbol - { - // TODO: this intentionally diverges from step 1 in - // https://typing.python.org/en/latest/spec/constructors.html#converting-a-constructor-to-callable - // by always respecting the signature of the metaclass `__call__`, rather than - // using a heuristic which makes unwarranted assumptions to sometimes ignore it. - let metaclass_call_function = metaclass_call_function.into_callable_type(db); - return metaclass_call_function.is_subtype_of(db, target); + (Type::ClassLiteral(class_literal), Type::Callable(_)) => { + if let Some(callable) = class_literal.into_callable(db) { + return callable.is_subtype_of(db, target); } false } @@ -5961,6 +5945,15 @@ impl<'db> FunctionType<'db> { )) } + /// Convert the `FunctionType` into a [`Type::BoundMethod`]. + pub(crate) fn into_bound_method_type( + self, + db: &'db dyn Db, + self_instance: Type<'db>, + ) -> Type<'db> { + Type::BoundMethod(BoundMethodType::new(db, self, self_instance)) + } + /// Returns the [`FileRange`] of the function's name. pub fn focus_range(self, db: &dyn Db) -> FileRange { FileRange::new( diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index e90ef080d4..8eaa0c7988 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -818,6 +818,43 @@ impl<'db> ClassLiteralType<'db> { )) } + pub(super) fn into_callable(self, db: &'db dyn Db) -> Option> { + let self_ty = Type::from(self); + let metaclass_call_function_symbol = self_ty + .member_lookup_with_policy( + db, + "__call__".into(), + MemberLookupPolicy::NO_INSTANCE_FALLBACK + | MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK, + ) + .symbol; + + if let Symbol::Type(Type::BoundMethod(metaclass_call_function), _) = + metaclass_call_function_symbol + { + // TODO: this intentionally diverges from step 1 in + // https://typing.python.org/en/latest/spec/constructors.html#converting-a-constructor-to-callable + // by always respecting the signature of the metaclass `__call__`, rather than + // using a heuristic which makes unwarranted assumptions to sometimes ignore it. + return Some(metaclass_call_function.into_callable_type(db)); + } + + let new_function_symbol = self_ty + .member_lookup_with_policy( + db, + "__new__".into(), + MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK + | MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK, + ) + .symbol; + + if let Symbol::Type(Type::FunctionLiteral(new_function), _) = new_function_symbol { + return Some(new_function.into_bound_method_type(db, self.into())); + } + // TODO handle `__init__` also + None + } + /// Returns the class member of this class named `name`. /// /// The member resolves to a member on the class itself or any of its proper superclasses.