[ty] Infer type of implicit `cls` parameter in method bodies (#21685)

## Summary

Extends https://github.com/astral-sh/ruff/pull/20922 to infer
unannotated `cls` parameters as `type[Self]` in method bodies.

Part of https://github.com/astral-sh/ty/issues/159.
This commit is contained in:
Ibraheem Ahmed 2025-12-10 04:31:28 -05:00 committed by GitHub
parent d2aabeaaa2
commit ff7086d9ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 167 additions and 154 deletions

View File

@ -194,7 +194,7 @@ static SYMPY: Benchmark = Benchmark::new(
max_dep_date: "2025-06-17",
python_version: PythonVersion::PY312,
},
13000,
13030,
);
static TANJUN: Benchmark = Benchmark::new(

View File

@ -63,6 +63,12 @@ python-version = "3.12"
from typing import Self
class A:
def __init__(self):
reveal_type(self) # revealed: Self@__init__
def __init_subclass__(cls, default_name, **kwargs):
reveal_type(cls) # revealed: type[Self@__init_subclass__]
def implicit_self(self) -> Self:
reveal_type(self) # revealed: Self@implicit_self
@ -91,8 +97,7 @@ class A:
@classmethod
def a_classmethod(cls) -> Self:
# TODO: This should be type[Self@bar]
reveal_type(cls) # revealed: Unknown
reveal_type(cls) # revealed: type[Self@a_classmethod]
return cls()
@staticmethod

View File

@ -174,8 +174,7 @@ class B(A):
@classmethod
def f(cls):
# TODO: Once `cls` is supported, this should be `<super: <class 'B'>, <class 'B'>>`
reveal_type(super()) # revealed: <super: <class 'B'>, Unknown>
reveal_type(super()) # revealed: <super: <class 'B'>, <class 'B'>>
super().f()
super(B, B(42)).__init__(42)

View File

@ -27,135 +27,134 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/class/super.md
13 |
14 | @classmethod
15 | def f(cls):
16 | # TODO: Once `cls` is supported, this should be `<super: <class 'B'>, <class 'B'>>`
17 | reveal_type(super()) # revealed: <super: <class 'B'>, Unknown>
18 | super().f()
19 |
20 | super(B, B(42)).__init__(42)
21 | super(B, B).f()
22 | import enum
23 | from typing import Any, Self, Never, Protocol, Callable
24 | from ty_extensions import Intersection
25 |
26 | class BuilderMeta(type):
27 | def __new__(
28 | cls: type[Any],
29 | name: str,
30 | bases: tuple[type, ...],
31 | dct: dict[str, Any],
32 | ) -> BuilderMeta:
33 | # revealed: <super: <class 'BuilderMeta'>, Any>
34 | s = reveal_type(super())
35 | # revealed: Any
36 | return reveal_type(s.__new__(cls, name, bases, dct))
37 |
38 | class BuilderMeta2(type):
39 | def __new__(
40 | cls: type[BuilderMeta2],
41 | name: str,
42 | bases: tuple[type, ...],
43 | dct: dict[str, Any],
44 | ) -> BuilderMeta2:
45 | # revealed: <super: <class 'BuilderMeta2'>, <class 'BuilderMeta2'>>
46 | s = reveal_type(super())
47 | return reveal_type(s.__new__(cls, name, bases, dct)) # revealed: BuilderMeta2
48 |
49 | class Foo[T]:
50 | x: T
51 |
52 | def method(self: Any):
53 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any>
54 |
55 | if isinstance(self, Foo):
56 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any>
57 |
58 | def method2(self: Foo[T]):
59 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]>
60 | reveal_type(super())
61 |
62 | def method3(self: Foo):
63 | # revealed: <super: <class 'Foo'>, Foo[Unknown]>
64 | reveal_type(super())
65 |
66 | def method4(self: Self):
67 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]>
68 | reveal_type(super())
69 |
70 | def method5[S: Foo[int]](self: S, other: S) -> S:
71 | # revealed: <super: <class 'Foo'>, Foo[int]>
72 | reveal_type(super())
73 | return self
74 |
75 | def method6[S: (Foo[int], Foo[str])](self: S, other: S) -> S:
76 | # revealed: <super: <class 'Foo'>, Foo[int]> | <super: <class 'Foo'>, Foo[str]>
77 | reveal_type(super())
78 | return self
79 |
80 | def method7[S](self: S, other: S) -> S:
81 | # error: [invalid-super-argument]
82 | # revealed: Unknown
83 | reveal_type(super())
84 | return self
85 |
86 | def method8[S: int](self: S, other: S) -> S:
87 | # error: [invalid-super-argument]
88 | # revealed: Unknown
89 | reveal_type(super())
90 | return self
91 |
92 | def method9[S: (int, str)](self: S, other: S) -> S:
93 | # error: [invalid-super-argument]
94 | # revealed: Unknown
95 | reveal_type(super())
96 | return self
97 |
98 | def method10[S: Callable[..., str]](self: S, other: S) -> S:
99 | # error: [invalid-super-argument]
100 | # revealed: Unknown
101 | reveal_type(super())
102 | return self
103 |
104 | type Alias = Bar
105 |
106 | class Bar:
107 | def method(self: Alias):
108 | # revealed: <super: <class 'Bar'>, Bar>
109 | reveal_type(super())
110 |
111 | def pls_dont_call_me(self: Never):
112 | # revealed: <super: <class 'Bar'>, Unknown>
113 | reveal_type(super())
114 |
115 | def only_call_me_on_callable_subclasses(self: Intersection[Bar, Callable[..., object]]):
116 | # revealed: <super: <class 'Bar'>, Bar>
117 | reveal_type(super())
118 |
119 | class P(Protocol):
120 | def method(self: P):
121 | # revealed: <super: <class 'P'>, P>
122 | reveal_type(super())
123 |
124 | class E(enum.Enum):
125 | X = 1
126 |
127 | def method(self: E):
128 | match self:
129 | case E.X:
130 | # revealed: <super: <class 'E'>, E>
131 | reveal_type(super())
16 | reveal_type(super()) # revealed: <super: <class 'B'>, <class 'B'>>
17 | super().f()
18 |
19 | super(B, B(42)).__init__(42)
20 | super(B, B).f()
21 | import enum
22 | from typing import Any, Self, Never, Protocol, Callable
23 | from ty_extensions import Intersection
24 |
25 | class BuilderMeta(type):
26 | def __new__(
27 | cls: type[Any],
28 | name: str,
29 | bases: tuple[type, ...],
30 | dct: dict[str, Any],
31 | ) -> BuilderMeta:
32 | # revealed: <super: <class 'BuilderMeta'>, Any>
33 | s = reveal_type(super())
34 | # revealed: Any
35 | return reveal_type(s.__new__(cls, name, bases, dct))
36 |
37 | class BuilderMeta2(type):
38 | def __new__(
39 | cls: type[BuilderMeta2],
40 | name: str,
41 | bases: tuple[type, ...],
42 | dct: dict[str, Any],
43 | ) -> BuilderMeta2:
44 | # revealed: <super: <class 'BuilderMeta2'>, <class 'BuilderMeta2'>>
45 | s = reveal_type(super())
46 | return reveal_type(s.__new__(cls, name, bases, dct)) # revealed: BuilderMeta2
47 |
48 | class Foo[T]:
49 | x: T
50 |
51 | def method(self: Any):
52 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any>
53 |
54 | if isinstance(self, Foo):
55 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any>
56 |
57 | def method2(self: Foo[T]):
58 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]>
59 | reveal_type(super())
60 |
61 | def method3(self: Foo):
62 | # revealed: <super: <class 'Foo'>, Foo[Unknown]>
63 | reveal_type(super())
64 |
65 | def method4(self: Self):
66 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]>
67 | reveal_type(super())
68 |
69 | def method5[S: Foo[int]](self: S, other: S) -> S:
70 | # revealed: <super: <class 'Foo'>, Foo[int]>
71 | reveal_type(super())
72 | return self
73 |
74 | def method6[S: (Foo[int], Foo[str])](self: S, other: S) -> S:
75 | # revealed: <super: <class 'Foo'>, Foo[int]> | <super: <class 'Foo'>, Foo[str]>
76 | reveal_type(super())
77 | return self
78 |
79 | def method7[S](self: S, other: S) -> S:
80 | # error: [invalid-super-argument]
81 | # revealed: Unknown
82 | reveal_type(super())
83 | return self
84 |
85 | def method8[S: int](self: S, other: S) -> S:
86 | # error: [invalid-super-argument]
87 | # revealed: Unknown
88 | reveal_type(super())
89 | return self
90 |
91 | def method9[S: (int, str)](self: S, other: S) -> S:
92 | # error: [invalid-super-argument]
93 | # revealed: Unknown
94 | reveal_type(super())
95 | return self
96 |
97 | def method10[S: Callable[..., str]](self: S, other: S) -> S:
98 | # error: [invalid-super-argument]
99 | # revealed: Unknown
100 | reveal_type(super())
101 | return self
102 |
103 | type Alias = Bar
104 |
105 | class Bar:
106 | def method(self: Alias):
107 | # revealed: <super: <class 'Bar'>, Bar>
108 | reveal_type(super())
109 |
110 | def pls_dont_call_me(self: Never):
111 | # revealed: <super: <class 'Bar'>, Unknown>
112 | reveal_type(super())
113 |
114 | def only_call_me_on_callable_subclasses(self: Intersection[Bar, Callable[..., object]]):
115 | # revealed: <super: <class 'Bar'>, Bar>
116 | reveal_type(super())
117 |
118 | class P(Protocol):
119 | def method(self: P):
120 | # revealed: <super: <class 'P'>, P>
121 | reveal_type(super())
122 |
123 | class E(enum.Enum):
124 | X = 1
125 |
126 | def method(self: E):
127 | match self:
128 | case E.X:
129 | # revealed: <super: <class 'E'>, E>
130 | reveal_type(super())
```
# Diagnostics
```
error[invalid-super-argument]: `S@method7` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method7)` call
--> src/mdtest_snippet.py:83:21
--> src/mdtest_snippet.py:82:21
|
81 | # error: [invalid-super-argument]
82 | # revealed: Unknown
83 | reveal_type(super())
80 | # error: [invalid-super-argument]
81 | # revealed: Unknown
82 | reveal_type(super())
| ^^^^^^^
84 | return self
83 | return self
|
info: Type variable `S` has `object` as its implicit upper bound
info: `object` is not an instance or subclass of `<class 'Foo'>`
@ -166,13 +165,13 @@ info: rule `invalid-super-argument` is enabled by default
```
error[invalid-super-argument]: `S@method8` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method8)` call
--> src/mdtest_snippet.py:89:21
--> src/mdtest_snippet.py:88:21
|
87 | # error: [invalid-super-argument]
88 | # revealed: Unknown
89 | reveal_type(super())
86 | # error: [invalid-super-argument]
87 | # revealed: Unknown
88 | reveal_type(super())
| ^^^^^^^
90 | return self
89 | return self
|
info: Type variable `S` has upper bound `int`
info: `int` is not an instance or subclass of `<class 'Foo'>`
@ -182,13 +181,13 @@ info: rule `invalid-super-argument` is enabled by default
```
error[invalid-super-argument]: `S@method9` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method9)` call
--> src/mdtest_snippet.py:95:21
--> src/mdtest_snippet.py:94:21
|
93 | # error: [invalid-super-argument]
94 | # revealed: Unknown
95 | reveal_type(super())
92 | # error: [invalid-super-argument]
93 | # revealed: Unknown
94 | reveal_type(super())
| ^^^^^^^
96 | return self
95 | return self
|
info: Type variable `S` has constraints `int, str`
info: `int | str` is not an instance or subclass of `<class 'Foo'>`
@ -198,13 +197,13 @@ info: rule `invalid-super-argument` is enabled by default
```
error[invalid-super-argument]: `S@method10` is a type variable with an abstract/structural type as its bounds or constraints, in `super(<class 'Foo'>, S@method10)` call
--> src/mdtest_snippet.py:101:21
--> src/mdtest_snippet.py:100:21
|
99 | # error: [invalid-super-argument]
100 | # revealed: Unknown
101 | reveal_type(super())
98 | # error: [invalid-super-argument]
99 | # revealed: Unknown
100 | reveal_type(super())
| ^^^^^^^
102 | return self
101 | return self
|
info: Type variable `S` has upper bound `(...) -> str`
info: rule `invalid-super-argument` is enabled by default

View File

@ -7324,7 +7324,9 @@ impl<'db> Type<'db> {
});
};
Ok(typing_self(db, scope_id, typevar_binding_context, class).unwrap_or(*self))
Ok(typing_self(db, scope_id, typevar_binding_context, class)
.map(Type::TypeVar)
.unwrap_or(*self))
}
// We ensure that `typing.TypeAlias` used in the expected position (annotating an
// annotated assignment statement) doesn't reach here. Using it in any other type

View File

@ -86,7 +86,7 @@ pub(crate) fn typing_self<'db>(
function_scope_id: ScopeId,
typevar_binding_context: Option<Definition<'db>>,
class: ClassLiteral<'db>,
) -> Option<Type<'db>> {
) -> Option<BoundTypeVarInstance<'db>> {
let index = semantic_index(db, function_scope_id.file(db));
let identity = TypeVarIdentity::new(
@ -117,7 +117,6 @@ pub(crate) fn typing_self<'db>(
typevar_binding_context,
typevar,
)
.map(Type::TypeVar)
}
#[derive(Clone, Copy, Debug)]

View File

@ -2702,22 +2702,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let function_node = function_definition.node(self.module());
let function_name = &function_node.name;
// TODO: handle implicit type of `cls` for classmethods
if is_implicit_classmethod(function_name) || is_implicit_staticmethod(function_name) {
if is_implicit_staticmethod(function_name) {
return None;
}
let mut is_classmethod = is_implicit_classmethod(function_name);
let inference = infer_definition_types(db, method_definition);
for decorator in &function_node.decorator_list {
let decorator_ty = inference.expression_type(&decorator.expression);
if decorator_ty.as_class_literal().is_some_and(|class| {
matches!(
class.known(db),
Some(KnownClass::Classmethod | KnownClass::Staticmethod)
)
}) {
if let Some(known_class) = decorator_ty
.as_class_literal()
.and_then(|class| class.known(db))
{
if known_class == KnownClass::Staticmethod {
return None;
}
is_classmethod |= known_class == KnownClass::Classmethod;
}
}
let class_definition = self.index.expect_single_definition(class);
@ -2726,7 +2728,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.inner_type()
.as_class_literal()?;
typing_self(db, self.scope(), Some(method_definition), class_literal)
let typing_self = typing_self(db, self.scope(), Some(method_definition), class_literal);
if is_classmethod {
typing_self
.map(|typing_self| SubclassOfType::from(db, SubclassOfInner::TypeVar(typing_self)))
} else {
typing_self.map(Type::TypeVar)
}
}
/// Set initial declared/inferred types for a `**kwargs` keyword-variadic parameter.

View File

@ -1702,6 +1702,7 @@ impl<'db> Parameters<'db> {
Some(
typing_self(db, scope_id, typevar_binding_context, class)
.map(Type::TypeVar)
.expect("We should always find the surrounding class for an implicit self: Self annotation"),
)
} else {