From df58d6797421e6c8f7060f42517e102bed33cdf8 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sat, 17 Jan 2026 10:48:59 -0500 Subject: [PATCH] [ty] Validate constructor arguments when a class is used as a decorator (#22377) ## Summary If a class is used as a decorator, we now use the class constructor. Closes https://github.com/astral-sh/ty/issues/2232. --- .../resources/mdtest/decorators.md | 83 +++++++++ crates/ty_python_semantic/src/types/class.rs | 43 +++++ .../src/types/infer/builder.rs | 170 +++++++++--------- 3 files changed, 208 insertions(+), 88 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/decorators.md b/crates/ty_python_semantic/resources/mdtest/decorators.md index a8e353c5e6..4a46f265a9 100644 --- a/crates/ty_python_semantic/resources/mdtest/decorators.md +++ b/crates/ty_python_semantic/resources/mdtest/decorators.md @@ -235,6 +235,89 @@ def takes_no_argument() -> str: def g(x): ... ``` +### Class, with wrong signature, used as a decorator + +When a class is used as a decorator, its constructor (`__init__` or `__new__`) must accept the +decorated function as an argument. If the class's constructor doesn't accept the right arguments, we +emit an error: + +```py +class NoInit: ... + +# error: [too-many-positional-arguments] "Too many positional arguments to bound method `__init__`: expected 1, got 2" +@NoInit +def foo(): ... + +reveal_type(foo) # revealed: NoInit + +# error: [invalid-argument-type] +@int +def bar(): ... + +reveal_type(bar) # revealed: int +``` + +### Class, with correct signature, used as a decorator + +When a class's constructor accepts the decorated function/class, no error is emitted: + +```py +from typing import Callable + +class Wrapper: + def __init__(self, func: Callable[..., object]) -> None: + self.func = func + +@Wrapper +def my_func() -> int: + return 42 + +reveal_type(my_func) # revealed: Wrapper + +class AcceptsType: + def __init__(self, cls: type) -> None: + self.cls = cls + +# Decorator call is validated, but the type transformation isn't applied yet. +# TODO: Class decorator return types should transform the class binding type. +@AcceptsType +class MyClass: ... + +reveal_type(MyClass) # revealed: +``` + +### Generic class, used as a decorator + +Generic class decorators are validated through constructor calls: + +```py +from typing import Generic, TypeVar, Callable + +T = TypeVar("T") + +class Box(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + +# error: [invalid-argument-type] +@Box[int] +def returns_str() -> str: + return "hello" +``` + +### `type[SomeClass]` used as a decorator + +Using `type[SomeClass]` as a decorator validates against the class's constructor: + +```py +class Base: ... + +def apply_decorator(cls: type[Base]) -> None: + # error: [too-many-positional-arguments] "Too many positional arguments to bound method `__init__`: expected 1, got 2" + @cls + def inner() -> None: ... +``` + ## Class decorators Class decorator calls are validated, emitting diagnostics for invalid arguments: diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index f9107a1405..e7bfd67c00 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1899,6 +1899,49 @@ impl<'db> ClassType<'db> { pub(super) fn definition_span(self, db: &'db dyn Db) -> Span { self.class_literal(db).header_span(db) } + + /// Returns `true` if calls to this class type should use constructor call handling + /// (via `try_call_constructor`) rather than the regular `try_call` path. + /// + /// Some known classes have manual signatures defined in `bindings()` and should use + /// the `try_call` path. For all other class types, we use `try_call_constructor` + /// to properly validate `__new__`/`__init__` signatures. + pub(super) fn should_use_constructor_call(self, db: &'db dyn Db) -> bool { + // For some known classes we have manual signatures defined and use the regular + // `try_call` path instead of constructor call handling. + let has_special_cased_constructor = matches!( + self.known(db), + Some( + KnownClass::Bool + | KnownClass::Str + | KnownClass::Type + | KnownClass::Object + | KnownClass::Property + | KnownClass::Super + | KnownClass::TypeAliasType + | KnownClass::Deprecated + ) + ) || ( + // Constructor calls to `tuple` and subclasses of `tuple` are handled in + // `Type::bindings`, but constructor calls to `tuple[int]`, `tuple[int, ...]`, + // `tuple[int, *tuple[str, ...]]` (etc.) are handled by the default constructor-call + // logic (we synthesize a `__new__` method for them in `ClassType::own_class_member`). + self.is_known(db, KnownClass::Tuple) && !self.is_generic() + ) || self.static_class_literal(db).is_some_and( + |(class_literal, specialization)| { + CodeGeneratorKind::TypedDict.matches(db, class_literal.into(), specialization) + }, + ); + + // Use regular `try_call` for all subclasses of `enum.Enum`. This is a temporary + // special-casing until we support the functional syntax for creating enum classes. + let is_enum_subclass = KnownClass::Enum + .to_class_literal(db) + .to_class_type(db) + .is_some_and(|enum_class| self.is_subclass_of(db, enum_class)); + + !has_special_cased_constructor && !is_enum_subclass + } } fn into_callable_cycle_initial<'db>( diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 85e2bb2a7e..a5f4f96856 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -2536,58 +2536,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.undecorated_type = Some(inferred_ty); for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { - inferred_ty = match decorator_ty - .try_call(self.db(), &CallArguments::positional([inferred_ty])) - .map(|bindings| bindings.return_type(self.db())) - { - Ok(return_ty) => { - fn propagate_callable_kind<'d>( - db: &'d dyn Db, - ty: Type<'d>, - kind: CallableTypeKind, - ) -> Option> { - match ty { - Type::Callable(callable) => Some(Type::Callable(CallableType::new( - db, - callable.signatures(db), - kind, - ))), - Type::Union(union) => union - .try_map(db, |element| propagate_callable_kind(db, *element, kind)), - // Intersections are currently not handled here because that would require - // the decorator to be explicitly annotated as returning an intersection. - _ => None, - } - } - - let propagatable_kind = inferred_ty - .try_upcast_to_callable(self.db()) - .and_then(CallableTypes::exactly_one) - .and_then(|callable| match callable.kind(self.db()) { - kind @ (CallableTypeKind::FunctionLike - | CallableTypeKind::StaticMethodLike - | CallableTypeKind::ClassMethodLike) => Some(kind), - _ => None, - }); - - if let Some(return_ty_modified) = propagatable_kind - .and_then(|kind| propagate_callable_kind(self.db(), return_ty, kind)) - { - // When a method on a class is decorated with a function that returns a - // `Callable`, assume that the returned callable is also function-like (or - // classmethod-like or staticmethod-like). See "Decorating a method with - // a `Callable`-typed decorator" in `callables_as_descriptors.md` for the - // extended explanation. - return_ty_modified - } else { - return_ty - } - } - Err(CallError(_, bindings)) => { - bindings.report_diagnostics(&self.context, (*decorator_node).into()); - bindings.return_type(self.db()) - } - }; + inferred_ty = self.apply_decorator(*decorator_ty, inferred_ty, decorator_node); } self.add_declaration_with_binding( @@ -8707,6 +8656,86 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_expression(expression, TypeContext::default()) } + /// Apply a decorator to a function or class type and return the resulting type. + /// + /// When the decorator is a class (or generic alias, or `type[]`), this uses the constructor + /// call logic to properly validate `__new__` and `__init__` signatures. For other decorator + /// types, it uses the regular call logic. + fn apply_decorator( + &mut self, + decorator_ty: Type<'db>, + decorated_ty: Type<'db>, + decorator_node: &ast::Decorator, + ) -> Type<'db> { + fn propagate_callable_kind<'d>( + db: &'d dyn Db, + ty: Type<'d>, + kind: CallableTypeKind, + ) -> Option> { + match ty { + Type::Callable(callable) => Some(Type::Callable(CallableType::new( + db, + callable.signatures(db), + kind, + ))), + Type::Union(union) => { + union.try_map(db, |element| propagate_callable_kind(db, *element, kind)) + } + // Intersections are currently not handled here because that would require + // the decorator to be explicitly annotated as returning an intersection. + _ => None, + } + } + + let propagatable_kind = decorated_ty + .try_upcast_to_callable(self.db()) + .and_then(CallableTypes::exactly_one) + .and_then(|callable| match callable.kind(self.db()) { + kind @ (CallableTypeKind::FunctionLike + | CallableTypeKind::StaticMethodLike + | CallableTypeKind::ClassMethodLike) => Some(kind), + _ => None, + }); + + // Check if this is a class-like type that should use constructor call handling. + let class = match decorator_ty { + Type::ClassLiteral(class) => Some(ClassType::NonGeneric(class)), + Type::GenericAlias(generic) => Some(ClassType::Generic(generic)), + Type::SubclassOf(subclass) => subclass.subclass_of().into_class(self.db()), + _ => None, + }; + + let use_constructor_call = + class.is_some_and(|class| class.should_use_constructor_call(self.db())); + + let call_arguments = CallArguments::positional([decorated_ty]); + let return_ty = if use_constructor_call { + decorator_ty + .try_call_constructor(self.db(), |_| call_arguments, TypeContext::default()) + .unwrap_or_else(|err| { + err.report_diagnostic(&self.context, decorator_ty, decorator_node.into()); + err.return_type() + }) + } else { + decorator_ty + .try_call(self.db(), &call_arguments) + .map(|bindings| bindings.return_type(self.db())) + .unwrap_or_else(|CallError(_, bindings)| { + bindings.report_diagnostics(&self.context, decorator_node.into()); + bindings.return_type(self.db()) + }) + }; + + // When a method on a class is decorated with a function that returns a + // `Callable`, assume that the returned callable is also function-like (or + // classmethod-like or staticmethod-like). See "Decorating a method with + // a `Callable`-typed decorator" in `callables_as_descriptors.md` for the + // extended explanation. + propagatable_kind + .and_then(|kind| propagate_callable_kind(self.db(), return_ty, kind)) + .unwrap_or(return_ty) + } + /// Infer the argument types for a single binding. fn infer_argument_types<'a>( &mut self, @@ -10644,42 +10673,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // the `try_call` path below. // TODO: it should be possible to move these special cases into the `try_call_constructor` // path instead, or even remove some entirely once we support overloads fully. - let has_special_cased_constructor = matches!( - class.known(self.db()), - Some( - KnownClass::Bool - | KnownClass::Str - | KnownClass::Type - | KnownClass::Object - | KnownClass::Property - | KnownClass::Super - | KnownClass::TypeAliasType - | KnownClass::Deprecated - ) - ) || ( - // Constructor calls to `tuple` and subclasses of `tuple` are handled in `Type::Bindings`, - // but constructor calls to `tuple[int]`, `tuple[int, ...]`, `tuple[int, *tuple[str, ...]]` (etc.) - // are handled by the default constructor-call logic (we synthesize a `__new__` method for them - // in `ClassType::own_class_member()`). - class.is_known(self.db(), KnownClass::Tuple) && !class.is_generic() - ) || class - .static_class_literal(self.db()) - .is_some_and(|(class_literal, specialization)| { - CodeGeneratorKind::TypedDict.matches( - self.db(), - class_literal.into(), - specialization, - ) - }); - - // temporary special-casing for all subclasses of `enum.Enum` - // until we support the functional syntax for creating enum classes - if !has_special_cased_constructor - && KnownClass::Enum - .to_class_literal(self.db()) - .to_class_type(self.db()) - .is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class)) - { + if class.should_use_constructor_call(self.db()) { // Inference of correctly-placed `TypeVar`, `ParamSpec`, and `NewType` definitions // is done in `infer_legacy_typevar`, `infer_paramspec`, and // `infer_newtype_expression`, and doesn't use the full call-binding machinery. If