[red-knot] Add support for `@classmethod`s (#16305)

## Summary

Add support for `@classmethod`s.

```py
class C:
    @classmethod
    def f(cls, x: int) -> str:
        return "a"

reveal_type(C.f(1))  # revealed: str
```

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-02-24 09:55:34 +01:00 committed by GitHub
parent 81a57656d8
commit 141ba253da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 272 additions and 60 deletions

View File

@ -73,12 +73,12 @@ qux = (foo, bar)
reveal_type(qux) # revealed: tuple[Literal["foo"], Literal["bar"]]
# TODO: Infer "LiteralString"
reveal_type(foo.join(qux)) # revealed: @Todo(decorated method)
reveal_type(foo.join(qux)) # revealed: @Todo(overloaded method)
template: LiteralString = "{}, {}"
reveal_type(template) # revealed: Literal["{}, {}"]
# TODO: Infer `LiteralString`
reveal_type(template.format(foo, bar)) # revealed: @Todo(decorated method)
reveal_type(template.format(foo, bar)) # revealed: @Todo(overloaded method)
```
### Assignability

View File

@ -1042,8 +1042,8 @@ Most attribute accesses on bool-literal types are delegated to `builtins.bool`,
bools are instances of that class:
```py
reveal_type(True.__and__) # revealed: @Todo(decorated method)
reveal_type(False.__or__) # revealed: @Todo(decorated method)
reveal_type(True.__and__) # revealed: @Todo(overloaded method)
reveal_type(False.__or__) # revealed: @Todo(overloaded method)
```
Some attributes are special-cased, however:

View File

@ -306,7 +306,7 @@ reveal_type(1 + A()) # revealed: A
reveal_type(A() + "foo") # revealed: A
# TODO should be `A` since `str.__add__` doesn't support `A` instances
# TODO overloads
reveal_type("foo" + A()) # revealed: @Todo(return type)
reveal_type("foo" + A()) # revealed: @Todo(return type of decorated function)
reveal_type(A() + b"foo") # revealed: A
# TODO should be `A` since `bytes.__add__` doesn't support `A` instances
@ -314,7 +314,7 @@ reveal_type(b"foo" + A()) # revealed: bytes
reveal_type(A() + ()) # revealed: A
# TODO this should be `A`, since `tuple.__add__` doesn't support `A` instances
reveal_type(() + A()) # revealed: @Todo(return type)
reveal_type(() + A()) # revealed: @Todo(return type of decorated function)
literal_string_instance = "foo" * 1_000_000_000
# the test is not testing what it's meant to be testing if this isn't a `LiteralString`:
@ -323,7 +323,7 @@ reveal_type(literal_string_instance) # revealed: LiteralString
reveal_type(A() + literal_string_instance) # revealed: A
# TODO should be `A` since `str.__add__` doesn't support `A` instances
# TODO overloads
reveal_type(literal_string_instance + A()) # revealed: @Todo(return type)
reveal_type(literal_string_instance + A()) # revealed: @Todo(return type of decorated function)
```
## Operations involving instances of classes inheriting from `Any`

View File

@ -51,9 +51,9 @@ reveal_type(1 ** (largest_u32 + 1)) # revealed: int
reveal_type(2**largest_u32) # revealed: int
def variable(x: int):
reveal_type(x**2) # revealed: @Todo(return type)
reveal_type(2**x) # revealed: @Todo(return type)
reveal_type(x**x) # revealed: @Todo(return type)
reveal_type(x**2) # revealed: @Todo(return type of decorated function)
reveal_type(2**x) # revealed: @Todo(return type of decorated function)
reveal_type(x**x) # revealed: @Todo(return type of decorated function)
```
## Division by Zero

View File

@ -44,7 +44,7 @@ def bar() -> str:
return "bar"
# TODO: should reveal `int`, as the decorator replaces `bar` with `foo`
reveal_type(bar()) # revealed: @Todo(return type)
reveal_type(bar()) # revealed: @Todo(return type of decorated function)
```
## Invalid callable

View File

@ -255,4 +255,126 @@ method_wrapper()
method_wrapper(C(), C, "one too many")
```
## `@classmethod`
### Basic
When a `@classmethod` attribute is accessed, it returns a bound method object, even when accessed on
the class object itself:
```py
from __future__ import annotations
class C:
@classmethod
def f(cls: type[C], x: int) -> str:
return "a"
reveal_type(C.f) # revealed: <bound method `f` of `Literal[C]`>
reveal_type(C().f) # revealed: <bound method `f` of `type[C]`>
```
The `cls` method argument is then implicitly passed as the first argument when calling the method:
```py
reveal_type(C.f(1)) # revealed: str
reveal_type(C().f(1)) # revealed: str
```
When the class method is called incorrectly, we detect it:
```py
C.f("incorrect") # error: [invalid-argument-type]
C.f() # error: [missing-argument]
C.f(1, 2) # error: [too-many-positional-arguments]
```
If the `cls` parameter is wrongly annotated, we emit an error at the call site:
```py
class D:
@classmethod
def f(cls: D):
# This function is wrongly annotated, it should be `type[D]` instead of `D`
pass
# error: [invalid-argument-type] "Object of type `Literal[D]` cannot be assigned to parameter 1 (`cls`); expected type `D`"
D.f()
```
When a class method is accessed on a derived class, it is bound to that derived class:
```py
class Derived(C):
pass
reveal_type(Derived.f) # revealed: <bound method `f` of `Literal[Derived]`>
reveal_type(Derived().f) # revealed: <bound method `f` of `type[Derived]`>
reveal_type(Derived.f(1)) # revealed: str
reveal_type(Derived().f(1)) # revealed: str
```
### Accessing the classmethod as a static member
Accessing a `@classmethod`-decorated function at runtime returns a `classmethod` object. We
currently don't model this explicitly:
```py
from inspect import getattr_static
class C:
@classmethod
def f(cls): ...
reveal_type(getattr_static(C, "f")) # revealed: Literal[f]
reveal_type(getattr_static(C, "f").__get__) # revealed: <method-wrapper `__get__` of `f`>
```
But we correctly model how the `classmethod` descriptor works:
```py
reveal_type(getattr_static(C, "f").__get__(None, C)) # revealed: <bound method `f` of `Literal[C]`>
reveal_type(getattr_static(C, "f").__get__(C(), C)) # revealed: <bound method `f` of `Literal[C]`>
reveal_type(getattr_static(C, "f").__get__(C())) # revealed: <bound method `f` of `type[C]`>
```
The `owner` argument takes precedence over the `instance` argument:
```py
reveal_type(getattr_static(C, "f").__get__("dummy", C)) # revealed: <bound method `f` of `Literal[C]`>
```
### Classmethods mixed with other decorators
When a `@classmethod` is additionally decorated with another decorator, it is still treated as a
class method:
```py
from __future__ import annotations
def does_nothing[T](f: T) -> T:
return f
class C:
@classmethod
@does_nothing
def f1(cls: type[C], x: int) -> str:
return "a"
@does_nothing
@classmethod
def f2(cls: type[C], x: int) -> str:
return "a"
# TODO: We do not support decorators yet (only limited special cases). Eventually,
# these should all return `str`:
reveal_type(C.f1(1)) # revealed: @Todo(return type of decorated function)
reveal_type(C().f1(1)) # revealed: @Todo(decorated method)
reveal_type(C.f2(1)) # revealed: @Todo(return type of decorated function)
reveal_type(C().f2(1)) # revealed: @Todo(decorated method)
```
[functions and methods]: https://docs.python.org/3/howto/descriptor.html#functions-and-methods

View File

@ -201,14 +201,11 @@ class C:
c1 = C.factory("test") # okay
# TODO: should be `C`
reveal_type(c1) # revealed: @Todo(return type)
reveal_type(c1) # revealed: C
# TODO: should be `str`
reveal_type(C.get_name()) # revealed: @Todo(return type)
reveal_type(C.get_name()) # revealed: str
# TODO: should be `str`
reveal_type(C("42").get_name()) # revealed: @Todo(decorated method)
reveal_type(C("42").get_name()) # revealed: str
```
## Descriptors only work when used as class variables

View File

@ -167,7 +167,7 @@ class A:
__slots__ = ()
__slots__ += ("a", "b")
reveal_type(A.__slots__) # revealed: @Todo(return type)
reveal_type(A.__slots__) # revealed: @Todo(return type of decorated function)
class B:
__slots__ = ("c", "d")

View File

@ -25,7 +25,7 @@ reveal_type(y) # revealed: Unknown
def _(n: int):
a = b"abcde"[n]
# TODO: Support overloads... Should be `bytes`
reveal_type(a) # revealed: @Todo(return type)
reveal_type(a) # revealed: @Todo(return type of decorated function)
```
## Slices
@ -44,10 +44,10 @@ b[::0] # error: [zero-stepsize-in-slice]
def _(m: int, n: int):
byte_slice1 = b[m:n]
# TODO: Support overloads... Should be `bytes`
reveal_type(byte_slice1) # revealed: @Todo(return type)
reveal_type(byte_slice1) # revealed: @Todo(return type of decorated function)
def _(s: bytes) -> bytes:
byte_slice2 = s[0:5]
# TODO: Support overloads... Should be `bytes`
reveal_type(byte_slice2) # revealed: @Todo(return type)
reveal_type(byte_slice2) # revealed: @Todo(return type of decorated function)
```

View File

@ -12,13 +12,13 @@ x = [1, 2, 3]
reveal_type(x) # revealed: list
# TODO reveal int
reveal_type(x[0]) # revealed: @Todo(return type)
reveal_type(x[0]) # revealed: @Todo(return type of decorated function)
# TODO reveal list
reveal_type(x[0:1]) # revealed: @Todo(return type)
reveal_type(x[0:1]) # revealed: @Todo(return type of decorated function)
# TODO error
reveal_type(x["a"]) # revealed: @Todo(return type)
reveal_type(x["a"]) # revealed: @Todo(return type of decorated function)
```
## Assignments within list assignment

View File

@ -22,7 +22,7 @@ reveal_type(b) # revealed: Unknown
def _(n: int):
a = "abcde"[n]
# TODO: Support overloads... Should be `str`
reveal_type(a) # revealed: @Todo(return type)
reveal_type(a) # revealed: @Todo(return type of decorated function)
```
## Slices
@ -76,11 +76,11 @@ def _(m: int, n: int, s2: str):
substring1 = s[m:n]
# TODO: Support overloads... Should be `LiteralString`
reveal_type(substring1) # revealed: @Todo(return type)
reveal_type(substring1) # revealed: @Todo(return type of decorated function)
substring2 = s2[0:5]
# TODO: Support overloads... Should be `str`
reveal_type(substring2) # revealed: @Todo(return type)
reveal_type(substring2) # revealed: @Todo(return type of decorated function)
```
## Unsupported slice types

View File

@ -70,7 +70,7 @@ def _(m: int, n: int):
tuple_slice = t[m:n]
# TODO: Support overloads... Should be `tuple[Literal[1, 'a', b"b"] | None, ...]`
reveal_type(tuple_slice) # revealed: @Todo(return type)
reveal_type(tuple_slice) # revealed: @Todo(return type of decorated function)
```
## Inheritance

View File

@ -1948,13 +1948,29 @@ impl<'db> Type<'db> {
)
},
]),
Some(match arguments.first_argument() {
Some(ty) if ty.is_none(db) => Type::FunctionLiteral(function),
Some(instance) => Type::Callable(CallableType::BoundMethod(
BoundMethodType::new(db, function, instance),
)),
_ => Type::unknown(),
}),
if function.has_known_class_decorator(db, KnownClass::Classmethod)
&& function.decorators(db).len() == 1
{
if let Some(owner) = arguments.second_argument() {
Some(Type::Callable(CallableType::BoundMethod(
BoundMethodType::new(db, function, owner),
)))
} else if let Some(instance) = arguments.first_argument() {
Some(Type::Callable(CallableType::BoundMethod(
BoundMethodType::new(db, function, instance.to_meta_type(db)),
)))
} else {
Some(Type::unknown())
}
} else {
Some(match arguments.first_argument() {
Some(ty) if ty.is_none(db) => Type::FunctionLiteral(function),
Some(instance) => Type::Callable(CallableType::BoundMethod(
BoundMethodType::new(db, function, instance),
)),
_ => Type::unknown(),
})
},
);
let binding = bind_call(db, arguments, &signature, self);
@ -2004,18 +2020,40 @@ impl<'db> Type<'db> {
},
]),
Some(
match (arguments.first_argument(), arguments.second_argument()) {
(Some(function @ Type::FunctionLiteral(_)), Some(instance))
if instance.is_none(db) =>
if let Some(function_ty @ Type::FunctionLiteral(function)) =
arguments.first_argument()
{
if function.has_known_class_decorator(db, KnownClass::Classmethod)
&& function.decorators(db).len() == 1
{
function
if let Some(owner) = arguments.third_argument() {
Type::Callable(CallableType::BoundMethod(BoundMethodType::new(
db, function, owner,
)))
} else if let Some(instance) = arguments.second_argument() {
Type::Callable(CallableType::BoundMethod(BoundMethodType::new(
db,
function,
instance.to_meta_type(db),
)))
} else {
Type::unknown()
}
} else {
if let Some(instance) = arguments.second_argument() {
if instance.is_none(db) {
function_ty
} else {
Type::Callable(CallableType::BoundMethod(
BoundMethodType::new(db, function, instance),
))
}
} else {
Type::unknown()
}
}
(Some(Type::FunctionLiteral(function)), Some(instance)) => {
Type::Callable(CallableType::BoundMethod(BoundMethodType::new(
db, function, instance,
)))
}
_ => Type::unknown(),
} else {
Type::unknown()
},
),
);
@ -2108,6 +2146,10 @@ impl<'db> Type<'db> {
};
}
Some(KnownFunction::Overload) => {
binding.set_return_type(todo_type!("overload(..) return type"));
}
Some(KnownFunction::GetattrStatic) => {
let Some((instance_ty, attr_name, default)) =
binding.three_parameter_types()
@ -2853,6 +2895,7 @@ pub enum KnownClass {
Property,
BaseException,
BaseExceptionGroup,
Classmethod,
// Types
GenericAlias,
ModuleType,
@ -2937,7 +2980,8 @@ impl<'db> KnownClass {
| Self::Counter
| Self::DefaultDict
| Self::Deque
| Self::Float => Truthiness::Ambiguous,
| Self::Float
| Self::Classmethod => Truthiness::Ambiguous,
}
}
@ -2961,6 +3005,7 @@ impl<'db> KnownClass {
Self::Property => "property",
Self::BaseException => "BaseException",
Self::BaseExceptionGroup => "BaseExceptionGroup",
Self::Classmethod => "classmethod",
Self::GenericAlias => "GenericAlias",
Self::ModuleType => "ModuleType",
Self::FunctionType => "FunctionType",
@ -3042,6 +3087,7 @@ impl<'db> KnownClass {
| Self::Dict
| Self::BaseException
| Self::BaseExceptionGroup
| Self::Classmethod
| Self::Slice
| Self::Range
| Self::Property => KnownModule::Builtins,
@ -3114,6 +3160,7 @@ impl<'db> KnownClass {
| KnownClass::Property
| KnownClass::BaseException
| KnownClass::BaseExceptionGroup
| KnownClass::Classmethod
| KnownClass::GenericAlias
| KnownClass::ModuleType
| KnownClass::FunctionType
@ -3176,6 +3223,7 @@ impl<'db> KnownClass {
| Self::SupportsIndex
| Self::BaseException
| Self::BaseExceptionGroup
| Self::Classmethod
| Self::TypeVar => false,
}
}
@ -3200,8 +3248,10 @@ impl<'db> KnownClass {
"list" => Self::List,
"slice" => Self::Slice,
"range" => Self::Range,
"property" => Self::Property,
"BaseException" => Self::BaseException,
"BaseExceptionGroup" => Self::BaseExceptionGroup,
"classmethod" => Self::Classmethod,
"GenericAlias" => Self::GenericAlias,
"NoneType" => Self::NoneType,
"ModuleType" => Self::ModuleType,
@ -3266,6 +3316,7 @@ impl<'db> KnownClass {
| Self::BaseException
| Self::EllipsisType
| Self::BaseExceptionGroup
| Self::Classmethod
| Self::FunctionType
| Self::MethodType
| Self::MethodWrapperType
@ -3919,8 +3970,18 @@ pub struct FunctionType<'db> {
#[salsa::tracked]
impl<'db> FunctionType<'db> {
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.decorators(db).contains(&decorator)
pub fn has_known_class_decorator(self, db: &dyn Db, decorator: KnownClass) -> bool {
self.decorators(db).iter().any(|d| {
d.into_class_literal()
.is_some_and(|c| c.class.is_known(db, decorator))
})
}
pub fn has_known_function_decorator(self, db: &dyn Db, decorator: KnownFunction) -> bool {
self.decorators(db).iter().any(|d| {
d.into_function_literal()
.is_some_and(|f| f.is_known(db, decorator))
})
}
/// Typed externally-visible signature for this function.
@ -3937,13 +3998,23 @@ impl<'db> FunctionType<'db> {
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(return_ref)]
pub fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let function_stmt_node = self.body_scope(db).node(db).expect_function();
let internal_signature = self.internal_signature(db);
if function_stmt_node.decorator_list.is_empty() {
return internal_signature;
let decorators = self.decorators(db);
let mut decorators = decorators.iter();
if let Some(d) = decorators.next() {
if d.into_class_literal()
.is_some_and(|c| c.class.is_known(db, KnownClass::Classmethod))
&& decorators.next().is_none()
{
internal_signature
} else {
Signature::todo("return type of decorated function")
}
} else {
internal_signature
}
// TODO process the effect of decorators on the signature
Signature::todo()
}
/// Typed internally-visible signature for this function.
@ -3989,6 +4060,8 @@ pub enum KnownFunction {
AssertType,
/// `typing(_extensions).cast`
Cast,
/// `typing(_extensions).overload`
Overload,
/// `inspect.getattr_static`
GetattrStatic,
@ -4036,6 +4109,7 @@ impl KnownFunction {
"no_type_check" => Self::NoTypeCheck,
"assert_type" => Self::AssertType,
"cast" => Self::Cast,
"overload" => Self::Overload,
"getattr_static" => Self::GetattrStatic,
"static_assert" => Self::StaticAssert,
"is_subtype_of" => Self::IsSubtypeOf,
@ -4063,7 +4137,12 @@ impl KnownFunction {
}
},
Self::Len | Self::Repr => module.is_builtins(),
Self::AssertType | Self::Cast | Self::RevealType | Self::Final | Self::NoTypeCheck => {
Self::AssertType
| Self::Cast
| Self::Overload
| Self::RevealType
| Self::Final
| Self::NoTypeCheck => {
matches!(module, KnownModule::Typing | KnownModule::TypingExtensions)
}
Self::GetattrStatic => {
@ -4100,6 +4179,7 @@ impl KnownFunction {
Self::ConstraintFunction(_)
| Self::Len
| Self::Repr
| Self::Overload
| Self::Final
| Self::NoTypeCheck
| Self::RevealType
@ -4767,10 +4847,17 @@ impl<'db> Class<'db> {
if let Some(function) = declared_ty.into_function_literal() {
// TODO: Eventually, we are going to process all decorators correctly. This is
// just a temporary heuristic to provide a broad categorization into properties
// and non-property methods.
if function.has_decorator(db, KnownClass::Property.to_class_literal(db)) {
// just a temporary heuristic to provide a broad categorization
if function.has_known_class_decorator(db, KnownClass::Classmethod)
&& function.decorators(db).len() == 1
{
SymbolAndQualifiers(Symbol::bound(declared_ty), qualifiers)
} else if function.has_known_class_decorator(db, KnownClass::Property) {
SymbolAndQualifiers::todo("@property")
} else if function.has_known_function_decorator(db, KnownFunction::Overload)
{
SymbolAndQualifiers::todo("overloaded method")
} else if !function.decorators(db).is_empty() {
SymbolAndQualifiers::todo("decorated method")
} else {

View File

@ -47,6 +47,11 @@ impl<'a, 'db> CallArguments<'a, 'db> {
pub(crate) fn second_argument(&self) -> Option<Type<'db>> {
self.0.get(1).map(Argument::ty)
}
// TODO this should be eliminated in favor of [`bind_call`]
pub(crate) fn third_argument(&self) -> Option<Type<'db>> {
self.0.get(2).map(Argument::ty)
}
}
impl<'db, 'a, 'b> IntoIterator for &'b CallArguments<'a, 'db> {

View File

@ -29,10 +29,11 @@ impl<'db> Signature<'db> {
}
/// Return a todo signature: (*args: Todo, **kwargs: Todo) -> Todo
pub(crate) fn todo() -> Self {
#[allow(unused_variables)] // 'reason' only unused in debug builds
pub(crate) fn todo(reason: &'static str) -> Self {
Self {
parameters: Parameters::todo(),
return_ty: Some(todo_type!("return type")),
return_ty: Some(todo_type!(reason)),
}
}
@ -650,7 +651,7 @@ mod tests {
.unwrap();
let func = get_function_f(&db, "/src/a.py");
let expected_sig = Signature::todo();
let expected_sig = Signature::todo("return type of decorated function");
// With no decorators, internal and external signature are the same
assert_eq!(func.signature(&db), &expected_sig);