diff --git a/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md b/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md index 2adfd228c5..eb45a6697e 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md +++ b/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md @@ -204,3 +204,37 @@ class Calculator: reveal_type(Calculator().square_then_round(3.14)) # revealed: Unknown | int ``` + +## Use case: Treating dunder methods as bound-method descriptors + +pytorch defines a `__pow__` dunder attribute on [`TensorBase`] in a similar way to the following +example. We generally treat dunder attributes as bound-method descriptors since they all take a +`self` argument. This allows us to type-check the following code correctly: + +```py +from typing import Callable + +def pow_impl(tensor: Tensor, exponent: int) -> Tensor: + raise NotImplementedError + +class Tensor: + __pow__: Callable[[Tensor, int], Tensor] = pow_impl + +Tensor() ** 2 +``` + +The following example is also taken from a real world project. Here, the `__lt__` dunder attribute +is not declared. The attribute type is therefore inferred as `Unknown | Callable[…]`, but we still +treat it as a bound-method descriptor: + +```py +def make_comparison_operator(name: str) -> Callable[[Matrix, Matrix], bool]: + raise NotImplementedError + +class Matrix: + __lt__ = make_comparison_operator("lt") + +Matrix() < Matrix() +``` + +[`tensorbase`]: https://github.com/pytorch/pytorch/blob/f3913ea641d871f04fa2b6588a77f63efeeb9f10/torch/_tensor.py#L1084-L1092 diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index 17da33f1d3..a39b6a6f16 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -1181,18 +1181,17 @@ static_assert(not is_assignable_to(EggsLegacy, Callable[..., Any])) # error: [s An instance type is assignable to a compatible callable type if the instance type's class has a callable `__call__` attribute. -TODO: for the moment, we don't consider the callable type as a bound-method descriptor, but this may -change for better compatibility with mypy/pyright. - ```py +from __future__ import annotations + from typing import Callable from ty_extensions import static_assert, is_assignable_to -def call_impl(a: int) -> str: +def call_impl(a: A, x: int) -> str: return "" class A: - __call__: Callable[[int], str] = call_impl + __call__: Callable[[A, int], str] = call_impl static_assert(is_assignable_to(A, Callable[[int], str])) static_assert(not is_assignable_to(A, Callable[[int], int])) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index 131bd5563b..6034a52529 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -1635,18 +1635,17 @@ f(a) An instance type can be a subtype of a compatible callable type if the instance type's class has a callable `__call__` attribute. -TODO: for the moment, we don't consider the callable type as a bound-method descriptor, but this may -change for better compatibility with mypy/pyright. - ```py +from __future__ import annotations + from typing import Callable from ty_extensions import static_assert, is_subtype_of -def call_impl(a: int) -> str: +def call_impl(a: A, x: int) -> str: return "" class A: - __call__: Callable[[int], str] = call_impl + __call__: Callable[[A, int], str] = call_impl static_assert(is_subtype_of(A, Callable[[int], str])) static_assert(not is_subtype_of(A, Callable[[int], int])) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 32bcc7aa63..410192de9b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -11116,6 +11116,23 @@ impl<'db> IntersectionType<'db> { } } + /// Map a type transformation over all positive elements of the intersection. Leave the + /// negative elements unchanged. + pub(crate) fn map_positive( + self, + db: &'db dyn Db, + mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>, + ) -> Type<'db> { + let mut builder = IntersectionBuilder::new(db); + for ty in self.positive(db) { + builder = builder.add_positive(transform_fn(ty)); + } + for ty in self.negative(db) { + builder = builder.add_negative(*ty); + } + builder.build() + } + pub(crate) fn map_with_boundness( self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 01647b1c01..b9b63a50a1 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -2014,7 +2014,29 @@ impl<'db> ClassLiteral<'db> { name: &str, policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - self.class_member_inner(db, None, name, policy) + fn into_function_like_callable<'d>(db: &'d dyn Db, ty: Type<'d>) -> Type<'d> { + match ty { + Type::Callable(callable_ty) => { + Type::Callable(CallableType::new(db, callable_ty.signatures(db), true)) + } + Type::Union(union) => { + union.map(db, |element| into_function_like_callable(db, *element)) + } + Type::Intersection(intersection) => intersection + .map_positive(db, |element| into_function_like_callable(db, *element)), + _ => ty, + } + } + + let mut member = self.class_member_inner(db, None, name, policy); + + // We generally treat dunder attributes with `Callable` types as function-like callables. + // See `callables_as_descriptors.md` for more details. + if name.starts_with("__") && name.ends_with("__") { + member = member.map_type(|ty| into_function_like_callable(db, ty)); + } + + member } fn class_member_inner(