[ty] Infer type of implicit `cls` parameter in method bodies (#21685)

## Summary

Extends https://github.com/astral-sh/ruff/pull/20922 to infer
unannotated `cls` parameters as `type[Self]` in method bodies.

Part of https://github.com/astral-sh/ty/issues/159.
This commit is contained in:
Ibraheem Ahmed 2025-12-10 04:31:28 -05:00 committed by GitHub
parent d2aabeaaa2
commit ff7086d9ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 167 additions and 154 deletions

View File

@ -194,7 +194,7 @@ static SYMPY: Benchmark = Benchmark::new(
max_dep_date: "2025-06-17", max_dep_date: "2025-06-17",
python_version: PythonVersion::PY312, python_version: PythonVersion::PY312,
}, },
13000, 13030,
); );
static TANJUN: Benchmark = Benchmark::new( static TANJUN: Benchmark = Benchmark::new(

View File

@ -63,6 +63,12 @@ python-version = "3.12"
from typing import Self from typing import Self
class A: class A:
def __init__(self):
reveal_type(self) # revealed: Self@__init__
def __init_subclass__(cls, default_name, **kwargs):
reveal_type(cls) # revealed: type[Self@__init_subclass__]
def implicit_self(self) -> Self: def implicit_self(self) -> Self:
reveal_type(self) # revealed: Self@implicit_self reveal_type(self) # revealed: Self@implicit_self
@ -91,8 +97,7 @@ class A:
@classmethod @classmethod
def a_classmethod(cls) -> Self: def a_classmethod(cls) -> Self:
# TODO: This should be type[Self@bar] reveal_type(cls) # revealed: type[Self@a_classmethod]
reveal_type(cls) # revealed: Unknown
return cls() return cls()
@staticmethod @staticmethod

View File

@ -174,8 +174,7 @@ class B(A):
@classmethod @classmethod
def f(cls): def f(cls):
# TODO: Once `cls` is supported, this should be `<super: <class 'B'>, <class 'B'>>` reveal_type(super()) # revealed: <super: <class 'B'>, <class 'B'>>
reveal_type(super()) # revealed: <super: <class 'B'>, Unknown>
super().f() super().f()
super(B, B(42)).__init__(42) super(B, B(42)).__init__(42)

View File

@ -27,135 +27,134 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/class/super.md
13 | 13 |
14 | @classmethod 14 | @classmethod
15 | def f(cls): 15 | def f(cls):
16 | # TODO: Once `cls` is supported, this should be `<super: <class 'B'>, <class 'B'>>` 16 | reveal_type(super()) # revealed: <super: <class 'B'>, <class 'B'>>
17 | reveal_type(super()) # revealed: <super: <class 'B'>, Unknown> 17 | super().f()
18 | super().f() 18 |
19 | 19 | super(B, B(42)).__init__(42)
20 | super(B, B(42)).__init__(42) 20 | super(B, B).f()
21 | super(B, B).f() 21 | import enum
22 | import enum 22 | from typing import Any, Self, Never, Protocol, Callable
23 | from typing import Any, Self, Never, Protocol, Callable 23 | from ty_extensions import Intersection
24 | from ty_extensions import Intersection 24 |
25 | 25 | class BuilderMeta(type):
26 | class BuilderMeta(type): 26 | def __new__(
27 | def __new__( 27 | cls: type[Any],
28 | cls: type[Any], 28 | name: str,
29 | name: str, 29 | bases: tuple[type, ...],
30 | bases: tuple[type, ...], 30 | dct: dict[str, Any],
31 | dct: dict[str, Any], 31 | ) -> BuilderMeta:
32 | ) -> BuilderMeta: 32 | # revealed: <super: <class 'BuilderMeta'>, Any>
33 | # revealed: <super: <class 'BuilderMeta'>, Any> 33 | s = reveal_type(super())
34 | s = reveal_type(super()) 34 | # revealed: Any
35 | # revealed: Any 35 | return reveal_type(s.__new__(cls, name, bases, dct))
36 | return reveal_type(s.__new__(cls, name, bases, dct)) 36 |
37 | 37 | class BuilderMeta2(type):
38 | class BuilderMeta2(type): 38 | def __new__(
39 | def __new__( 39 | cls: type[BuilderMeta2],
40 | cls: type[BuilderMeta2], 40 | name: str,
41 | name: str, 41 | bases: tuple[type, ...],
42 | bases: tuple[type, ...], 42 | dct: dict[str, Any],
43 | dct: dict[str, Any], 43 | ) -> BuilderMeta2:
44 | ) -> BuilderMeta2: 44 | # revealed: <super: <class 'BuilderMeta2'>, <class 'BuilderMeta2'>>
45 | # revealed: <super: <class 'BuilderMeta2'>, <class 'BuilderMeta2'>> 45 | s = reveal_type(super())
46 | s = reveal_type(super()) 46 | return reveal_type(s.__new__(cls, name, bases, dct)) # revealed: BuilderMeta2
47 | return reveal_type(s.__new__(cls, name, bases, dct)) # revealed: BuilderMeta2 47 |
48 | 48 | class Foo[T]:
49 | class Foo[T]: 49 | x: T
50 | x: T 50 |
51 | 51 | def method(self: Any):
52 | def method(self: Any): 52 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any>
53 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any> 53 |
54 | 54 | if isinstance(self, Foo):
55 | if isinstance(self, Foo): 55 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any>
56 | reveal_type(super()) # revealed: <super: <class 'Foo'>, Any> 56 |
57 | 57 | def method2(self: Foo[T]):
58 | def method2(self: Foo[T]): 58 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]>
59 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]> 59 | reveal_type(super())
60 | reveal_type(super()) 60 |
61 | 61 | def method3(self: Foo):
62 | def method3(self: Foo): 62 | # revealed: <super: <class 'Foo'>, Foo[Unknown]>
63 | # revealed: <super: <class 'Foo'>, Foo[Unknown]> 63 | reveal_type(super())
64 | reveal_type(super()) 64 |
65 | 65 | def method4(self: Self):
66 | def method4(self: Self): 66 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]>
67 | # revealed: <super: <class 'Foo'>, Foo[T@Foo]> 67 | reveal_type(super())
68 | reveal_type(super()) 68 |
69 | 69 | def method5[S: Foo[int]](self: S, other: S) -> S:
70 | def method5[S: Foo[int]](self: S, other: S) -> S: 70 | # revealed: <super: <class 'Foo'>, Foo[int]>
71 | # revealed: <super: <class 'Foo'>, Foo[int]> 71 | reveal_type(super())
72 | reveal_type(super()) 72 | return self
73 | return self 73 |
74 | 74 | def method6[S: (Foo[int], Foo[str])](self: S, other: S) -> S:
75 | def method6[S: (Foo[int], Foo[str])](self: S, other: S) -> S: 75 | # revealed: <super: <class 'Foo'>, Foo[int]> | <super: <class 'Foo'>, Foo[str]>
76 | # revealed: <super: <class 'Foo'>, Foo[int]> | <super: <class 'Foo'>, Foo[str]> 76 | reveal_type(super())
77 | reveal_type(super()) 77 | return self
78 | return self 78 |
79 | 79 | def method7[S](self: S, other: S) -> S:
80 | def method7[S](self: S, other: S) -> S: 80 | # error: [invalid-super-argument]
81 | # error: [invalid-super-argument] 81 | # revealed: Unknown
82 | # revealed: Unknown 82 | reveal_type(super())
83 | reveal_type(super()) 83 | return self
84 | return self 84 |
85 | 85 | def method8[S: int](self: S, other: S) -> S:
86 | def method8[S: int](self: S, other: S) -> S: 86 | # error: [invalid-super-argument]
87 | # error: [invalid-super-argument] 87 | # revealed: Unknown
88 | # revealed: Unknown 88 | reveal_type(super())
89 | reveal_type(super()) 89 | return self
90 | return self 90 |
91 | 91 | def method9[S: (int, str)](self: S, other: S) -> S:
92 | def method9[S: (int, str)](self: S, other: S) -> S: 92 | # error: [invalid-super-argument]
93 | # error: [invalid-super-argument] 93 | # revealed: Unknown
94 | # revealed: Unknown 94 | reveal_type(super())
95 | reveal_type(super()) 95 | return self
96 | return self 96 |
97 | 97 | def method10[S: Callable[..., str]](self: S, other: S) -> S:
98 | def method10[S: Callable[..., str]](self: S, other: S) -> S: 98 | # error: [invalid-super-argument]
99 | # error: [invalid-super-argument] 99 | # revealed: Unknown
100 | # revealed: Unknown 100 | reveal_type(super())
101 | reveal_type(super()) 101 | return self
102 | return self 102 |
103 | 103 | type Alias = Bar
104 | type Alias = Bar 104 |
105 | 105 | class Bar:
106 | class Bar: 106 | def method(self: Alias):
107 | def method(self: Alias): 107 | # revealed: <super: <class 'Bar'>, Bar>
108 | # revealed: <super: <class 'Bar'>, Bar> 108 | reveal_type(super())
109 | reveal_type(super()) 109 |
110 | 110 | def pls_dont_call_me(self: Never):
111 | def pls_dont_call_me(self: Never): 111 | # revealed: <super: <class 'Bar'>, Unknown>
112 | # revealed: <super: <class 'Bar'>, Unknown> 112 | reveal_type(super())
113 | reveal_type(super()) 113 |
114 | 114 | def only_call_me_on_callable_subclasses(self: Intersection[Bar, Callable[..., object]]):
115 | def only_call_me_on_callable_subclasses(self: Intersection[Bar, Callable[..., object]]): 115 | # revealed: <super: <class 'Bar'>, Bar>
116 | # revealed: <super: <class 'Bar'>, Bar> 116 | reveal_type(super())
117 | reveal_type(super()) 117 |
118 | 118 | class P(Protocol):
119 | class P(Protocol): 119 | def method(self: P):
120 | def method(self: P): 120 | # revealed: <super: <class 'P'>, P>
121 | # revealed: <super: <class 'P'>, P> 121 | reveal_type(super())
122 | reveal_type(super()) 122 |
123 | 123 | class E(enum.Enum):
124 | class E(enum.Enum): 124 | X = 1
125 | X = 1 125 |
126 | 126 | def method(self: E):
127 | def method(self: E): 127 | match self:
128 | match self: 128 | case E.X:
129 | case E.X: 129 | # revealed: <super: <class 'E'>, E>
130 | # revealed: <super: <class 'E'>, E> 130 | reveal_type(super())
131 | reveal_type(super())
``` ```
# Diagnostics # Diagnostics
``` ```
error[invalid-super-argument]: `S@method7` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method7)` call error[invalid-super-argument]: `S@method7` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method7)` call
--> src/mdtest_snippet.py:83:21 --> src/mdtest_snippet.py:82:21
| |
81 | # error: [invalid-super-argument] 80 | # error: [invalid-super-argument]
82 | # revealed: Unknown 81 | # revealed: Unknown
83 | reveal_type(super()) 82 | reveal_type(super())
| ^^^^^^^ | ^^^^^^^
84 | return self 83 | return self
| |
info: Type variable `S` has `object` as its implicit upper bound info: Type variable `S` has `object` as its implicit upper bound
info: `object` is not an instance or subclass of `<class 'Foo'>` info: `object` is not an instance or subclass of `<class 'Foo'>`
@ -166,13 +165,13 @@ info: rule `invalid-super-argument` is enabled by default
``` ```
error[invalid-super-argument]: `S@method8` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method8)` call error[invalid-super-argument]: `S@method8` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method8)` call
--> src/mdtest_snippet.py:89:21 --> src/mdtest_snippet.py:88:21
| |
87 | # error: [invalid-super-argument] 86 | # error: [invalid-super-argument]
88 | # revealed: Unknown 87 | # revealed: Unknown
89 | reveal_type(super()) 88 | reveal_type(super())
| ^^^^^^^ | ^^^^^^^
90 | return self 89 | return self
| |
info: Type variable `S` has upper bound `int` info: Type variable `S` has upper bound `int`
info: `int` is not an instance or subclass of `<class 'Foo'>` info: `int` is not an instance or subclass of `<class 'Foo'>`
@ -182,13 +181,13 @@ info: rule `invalid-super-argument` is enabled by default
``` ```
error[invalid-super-argument]: `S@method9` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method9)` call error[invalid-super-argument]: `S@method9` is not an instance or subclass of `<class 'Foo'>` in `super(<class 'Foo'>, S@method9)` call
--> src/mdtest_snippet.py:95:21 --> src/mdtest_snippet.py:94:21
| |
93 | # error: [invalid-super-argument] 92 | # error: [invalid-super-argument]
94 | # revealed: Unknown 93 | # revealed: Unknown
95 | reveal_type(super()) 94 | reveal_type(super())
| ^^^^^^^ | ^^^^^^^
96 | return self 95 | return self
| |
info: Type variable `S` has constraints `int, str` info: Type variable `S` has constraints `int, str`
info: `int | str` is not an instance or subclass of `<class 'Foo'>` info: `int | str` is not an instance or subclass of `<class 'Foo'>`
@ -198,13 +197,13 @@ info: rule `invalid-super-argument` is enabled by default
``` ```
error[invalid-super-argument]: `S@method10` is a type variable with an abstract/structural type as its bounds or constraints, in `super(<class 'Foo'>, S@method10)` call error[invalid-super-argument]: `S@method10` is a type variable with an abstract/structural type as its bounds or constraints, in `super(<class 'Foo'>, S@method10)` call
--> src/mdtest_snippet.py:101:21 --> src/mdtest_snippet.py:100:21
| |
99 | # error: [invalid-super-argument] 98 | # error: [invalid-super-argument]
100 | # revealed: Unknown 99 | # revealed: Unknown
101 | reveal_type(super()) 100 | reveal_type(super())
| ^^^^^^^ | ^^^^^^^
102 | return self 101 | return self
| |
info: Type variable `S` has upper bound `(...) -> str` info: Type variable `S` has upper bound `(...) -> str`
info: rule `invalid-super-argument` is enabled by default info: rule `invalid-super-argument` is enabled by default

View File

@ -7324,7 +7324,9 @@ impl<'db> Type<'db> {
}); });
}; };
Ok(typing_self(db, scope_id, typevar_binding_context, class).unwrap_or(*self)) Ok(typing_self(db, scope_id, typevar_binding_context, class)
.map(Type::TypeVar)
.unwrap_or(*self))
} }
// We ensure that `typing.TypeAlias` used in the expected position (annotating an // We ensure that `typing.TypeAlias` used in the expected position (annotating an
// annotated assignment statement) doesn't reach here. Using it in any other type // annotated assignment statement) doesn't reach here. Using it in any other type

View File

@ -86,7 +86,7 @@ pub(crate) fn typing_self<'db>(
function_scope_id: ScopeId, function_scope_id: ScopeId,
typevar_binding_context: Option<Definition<'db>>, typevar_binding_context: Option<Definition<'db>>,
class: ClassLiteral<'db>, class: ClassLiteral<'db>,
) -> Option<Type<'db>> { ) -> Option<BoundTypeVarInstance<'db>> {
let index = semantic_index(db, function_scope_id.file(db)); let index = semantic_index(db, function_scope_id.file(db));
let identity = TypeVarIdentity::new( let identity = TypeVarIdentity::new(
@ -117,7 +117,6 @@ pub(crate) fn typing_self<'db>(
typevar_binding_context, typevar_binding_context,
typevar, typevar,
) )
.map(Type::TypeVar)
} }
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]

View File

@ -2702,22 +2702,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let function_node = function_definition.node(self.module()); let function_node = function_definition.node(self.module());
let function_name = &function_node.name; let function_name = &function_node.name;
// TODO: handle implicit type of `cls` for classmethods if is_implicit_staticmethod(function_name) {
if is_implicit_classmethod(function_name) || is_implicit_staticmethod(function_name) {
return None; return None;
} }
let mut is_classmethod = is_implicit_classmethod(function_name);
let inference = infer_definition_types(db, method_definition); let inference = infer_definition_types(db, method_definition);
for decorator in &function_node.decorator_list { for decorator in &function_node.decorator_list {
let decorator_ty = inference.expression_type(&decorator.expression); let decorator_ty = inference.expression_type(&decorator.expression);
if decorator_ty.as_class_literal().is_some_and(|class| { if let Some(known_class) = decorator_ty
matches!( .as_class_literal()
class.known(db), .and_then(|class| class.known(db))
Some(KnownClass::Classmethod | KnownClass::Staticmethod) {
) if known_class == KnownClass::Staticmethod {
}) {
return None; return None;
} }
is_classmethod |= known_class == KnownClass::Classmethod;
}
} }
let class_definition = self.index.expect_single_definition(class); let class_definition = self.index.expect_single_definition(class);
@ -2726,7 +2728,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.inner_type() .inner_type()
.as_class_literal()?; .as_class_literal()?;
typing_self(db, self.scope(), Some(method_definition), class_literal) let typing_self = typing_self(db, self.scope(), Some(method_definition), class_literal);
if is_classmethod {
typing_self
.map(|typing_self| SubclassOfType::from(db, SubclassOfInner::TypeVar(typing_self)))
} else {
typing_self.map(Type::TypeVar)
}
} }
/// Set initial declared/inferred types for a `**kwargs` keyword-variadic parameter. /// Set initial declared/inferred types for a `**kwargs` keyword-variadic parameter.

View File

@ -1702,6 +1702,7 @@ impl<'db> Parameters<'db> {
Some( Some(
typing_self(db, scope_id, typevar_binding_context, class) typing_self(db, scope_id, typevar_binding_context, class)
.map(Type::TypeVar)
.expect("We should always find the surrounding class for an implicit self: Self annotation"), .expect("We should always find the surrounding class for an implicit self: Self annotation"),
) )
} else { } else {