[ty] Validate constructor arguments when a class is used as a decorator (#22377)

## Summary

If a class is used as a decorator, we now use the class constructor.

Closes https://github.com/astral-sh/ty/issues/2232.
This commit is contained in:
Charlie Marsh
2026-01-17 10:48:59 -05:00
committed by GitHub
parent 6b16931169
commit df58d67974
3 changed files with 208 additions and 88 deletions

View File

@@ -235,6 +235,89 @@ def takes_no_argument() -> str:
def g(x): ...
```
### Class, with wrong signature, used as a decorator
When a class is used as a decorator, its constructor (`__init__` or `__new__`) must accept the
decorated function as an argument. If the class's constructor doesn't accept the right arguments, we
emit an error:
```py
class NoInit: ...
# error: [too-many-positional-arguments] "Too many positional arguments to bound method `__init__`: expected 1, got 2"
@NoInit
def foo(): ...
reveal_type(foo) # revealed: NoInit
# error: [invalid-argument-type]
@int
def bar(): ...
reveal_type(bar) # revealed: int
```
### Class, with correct signature, used as a decorator
When a class's constructor accepts the decorated function/class, no error is emitted:
```py
from typing import Callable
class Wrapper:
def __init__(self, func: Callable[..., object]) -> None:
self.func = func
@Wrapper
def my_func() -> int:
return 42
reveal_type(my_func) # revealed: Wrapper
class AcceptsType:
def __init__(self, cls: type) -> None:
self.cls = cls
# Decorator call is validated, but the type transformation isn't applied yet.
# TODO: Class decorator return types should transform the class binding type.
@AcceptsType
class MyClass: ...
reveal_type(MyClass) # revealed: <class 'MyClass'>
```
### Generic class, used as a decorator
Generic class decorators are validated through constructor calls:
```py
from typing import Generic, TypeVar, Callable
T = TypeVar("T")
class Box(Generic[T]):
def __init__(self, value: T) -> None:
self.value = value
# error: [invalid-argument-type]
@Box[int]
def returns_str() -> str:
return "hello"
```
### `type[SomeClass]` used as a decorator
Using `type[SomeClass]` as a decorator validates against the class's constructor:
```py
class Base: ...
def apply_decorator(cls: type[Base]) -> None:
# error: [too-many-positional-arguments] "Too many positional arguments to bound method `__init__`: expected 1, got 2"
@cls
def inner() -> None: ...
```
## Class decorators
Class decorator calls are validated, emitting diagnostics for invalid arguments:

View File

@@ -1899,6 +1899,49 @@ impl<'db> ClassType<'db> {
pub(super) fn definition_span(self, db: &'db dyn Db) -> Span {
self.class_literal(db).header_span(db)
}
/// Returns `true` if calls to this class type should use constructor call handling
/// (via `try_call_constructor`) rather than the regular `try_call` path.
///
/// Some known classes have manual signatures defined in `bindings()` and should use
/// the `try_call` path. For all other class types, we use `try_call_constructor`
/// to properly validate `__new__`/`__init__` signatures.
pub(super) fn should_use_constructor_call(self, db: &'db dyn Db) -> bool {
// For some known classes we have manual signatures defined and use the regular
// `try_call` path instead of constructor call handling.
let has_special_cased_constructor = matches!(
self.known(db),
Some(
KnownClass::Bool
| KnownClass::Str
| KnownClass::Type
| KnownClass::Object
| KnownClass::Property
| KnownClass::Super
| KnownClass::TypeAliasType
| KnownClass::Deprecated
)
) || (
// Constructor calls to `tuple` and subclasses of `tuple` are handled in
// `Type::bindings`, but constructor calls to `tuple[int]`, `tuple[int, ...]`,
// `tuple[int, *tuple[str, ...]]` (etc.) are handled by the default constructor-call
// logic (we synthesize a `__new__` method for them in `ClassType::own_class_member`).
self.is_known(db, KnownClass::Tuple) && !self.is_generic()
) || self.static_class_literal(db).is_some_and(
|(class_literal, specialization)| {
CodeGeneratorKind::TypedDict.matches(db, class_literal.into(), specialization)
},
);
// Use regular `try_call` for all subclasses of `enum.Enum`. This is a temporary
// special-casing until we support the functional syntax for creating enum classes.
let is_enum_subclass = KnownClass::Enum
.to_class_literal(db)
.to_class_type(db)
.is_some_and(|enum_class| self.is_subclass_of(db, enum_class));
!has_special_cased_constructor && !is_enum_subclass
}
}
fn into_callable_cycle_initial<'db>(

View File

@@ -2536,58 +2536,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.undecorated_type = Some(inferred_ty);
for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() {
inferred_ty = match decorator_ty
.try_call(self.db(), &CallArguments::positional([inferred_ty]))
.map(|bindings| bindings.return_type(self.db()))
{
Ok(return_ty) => {
fn propagate_callable_kind<'d>(
db: &'d dyn Db,
ty: Type<'d>,
kind: CallableTypeKind,
) -> Option<Type<'d>> {
match ty {
Type::Callable(callable) => Some(Type::Callable(CallableType::new(
db,
callable.signatures(db),
kind,
))),
Type::Union(union) => union
.try_map(db, |element| propagate_callable_kind(db, *element, kind)),
// Intersections are currently not handled here because that would require
// the decorator to be explicitly annotated as returning an intersection.
_ => None,
}
}
let propagatable_kind = inferred_ty
.try_upcast_to_callable(self.db())
.and_then(CallableTypes::exactly_one)
.and_then(|callable| match callable.kind(self.db()) {
kind @ (CallableTypeKind::FunctionLike
| CallableTypeKind::StaticMethodLike
| CallableTypeKind::ClassMethodLike) => Some(kind),
_ => None,
});
if let Some(return_ty_modified) = propagatable_kind
.and_then(|kind| propagate_callable_kind(self.db(), return_ty, kind))
{
// When a method on a class is decorated with a function that returns a
// `Callable`, assume that the returned callable is also function-like (or
// classmethod-like or staticmethod-like). See "Decorating a method with
// a `Callable`-typed decorator" in `callables_as_descriptors.md` for the
// extended explanation.
return_ty_modified
} else {
return_ty
}
}
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, (*decorator_node).into());
bindings.return_type(self.db())
}
};
inferred_ty = self.apply_decorator(*decorator_ty, inferred_ty, decorator_node);
}
self.add_declaration_with_binding(
@@ -8707,6 +8656,86 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_expression(expression, TypeContext::default())
}
/// Apply a decorator to a function or class type and return the resulting type.
///
/// When the decorator is a class (or generic alias, or `type[]`), this uses the constructor
/// call logic to properly validate `__new__` and `__init__` signatures. For other decorator
/// types, it uses the regular call logic.
fn apply_decorator(
&mut self,
decorator_ty: Type<'db>,
decorated_ty: Type<'db>,
decorator_node: &ast::Decorator,
) -> Type<'db> {
fn propagate_callable_kind<'d>(
db: &'d dyn Db,
ty: Type<'d>,
kind: CallableTypeKind,
) -> Option<Type<'d>> {
match ty {
Type::Callable(callable) => Some(Type::Callable(CallableType::new(
db,
callable.signatures(db),
kind,
))),
Type::Union(union) => {
union.try_map(db, |element| propagate_callable_kind(db, *element, kind))
}
// Intersections are currently not handled here because that would require
// the decorator to be explicitly annotated as returning an intersection.
_ => None,
}
}
let propagatable_kind = decorated_ty
.try_upcast_to_callable(self.db())
.and_then(CallableTypes::exactly_one)
.and_then(|callable| match callable.kind(self.db()) {
kind @ (CallableTypeKind::FunctionLike
| CallableTypeKind::StaticMethodLike
| CallableTypeKind::ClassMethodLike) => Some(kind),
_ => None,
});
// Check if this is a class-like type that should use constructor call handling.
let class = match decorator_ty {
Type::ClassLiteral(class) => Some(ClassType::NonGeneric(class)),
Type::GenericAlias(generic) => Some(ClassType::Generic(generic)),
Type::SubclassOf(subclass) => subclass.subclass_of().into_class(self.db()),
_ => None,
};
let use_constructor_call =
class.is_some_and(|class| class.should_use_constructor_call(self.db()));
let call_arguments = CallArguments::positional([decorated_ty]);
let return_ty = if use_constructor_call {
decorator_ty
.try_call_constructor(self.db(), |_| call_arguments, TypeContext::default())
.unwrap_or_else(|err| {
err.report_diagnostic(&self.context, decorator_ty, decorator_node.into());
err.return_type()
})
} else {
decorator_ty
.try_call(self.db(), &call_arguments)
.map(|bindings| bindings.return_type(self.db()))
.unwrap_or_else(|CallError(_, bindings)| {
bindings.report_diagnostics(&self.context, decorator_node.into());
bindings.return_type(self.db())
})
};
// When a method on a class is decorated with a function that returns a
// `Callable`, assume that the returned callable is also function-like (or
// classmethod-like or staticmethod-like). See "Decorating a method with
// a `Callable`-typed decorator" in `callables_as_descriptors.md` for the
// extended explanation.
propagatable_kind
.and_then(|kind| propagate_callable_kind(self.db(), return_ty, kind))
.unwrap_or(return_ty)
}
/// Infer the argument types for a single binding.
fn infer_argument_types<'a>(
&mut self,
@@ -10644,42 +10673,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// the `try_call` path below.
// TODO: it should be possible to move these special cases into the `try_call_constructor`
// path instead, or even remove some entirely once we support overloads fully.
let has_special_cased_constructor = matches!(
class.known(self.db()),
Some(
KnownClass::Bool
| KnownClass::Str
| KnownClass::Type
| KnownClass::Object
| KnownClass::Property
| KnownClass::Super
| KnownClass::TypeAliasType
| KnownClass::Deprecated
)
) || (
// Constructor calls to `tuple` and subclasses of `tuple` are handled in `Type::Bindings`,
// but constructor calls to `tuple[int]`, `tuple[int, ...]`, `tuple[int, *tuple[str, ...]]` (etc.)
// are handled by the default constructor-call logic (we synthesize a `__new__` method for them
// in `ClassType::own_class_member()`).
class.is_known(self.db(), KnownClass::Tuple) && !class.is_generic()
) || class
.static_class_literal(self.db())
.is_some_and(|(class_literal, specialization)| {
CodeGeneratorKind::TypedDict.matches(
self.db(),
class_literal.into(),
specialization,
)
});
// temporary special-casing for all subclasses of `enum.Enum`
// until we support the functional syntax for creating enum classes
if !has_special_cased_constructor
&& KnownClass::Enum
.to_class_literal(self.db())
.to_class_type(self.db())
.is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class))
{
if class.should_use_constructor_call(self.db()) {
// Inference of correctly-placed `TypeVar`, `ParamSpec`, and `NewType` definitions
// is done in `infer_legacy_typevar`, `infer_paramspec`, and
// `infer_newtype_expression`, and doesn't use the full call-binding machinery. If