diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index 9792849666..78db11169c 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -1032,4 +1032,125 @@ reveal_type(asdict(p)) # revealed: dict[str, Any] reveal_type(replace(p, name="Bob")) # revealed: Person ``` +## Calling decorator function directly with a class argument + +When a function decorated with `@dataclass_transform()` is called directly with a class argument +(not used as a decorator), it should return the class with the dataclass transformation applied. + +### Basic case + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def my_dataclass[T](cls: type[T]) -> type[T]: + return cls + +class A: + x: int + +B = my_dataclass(A) + +reveal_type(B) # revealed: + +B(1) +``` + +### Function with additional parameters + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def my_dataclass[T](cls: type[T], *, order: bool = False) -> type[T]: + return cls + +class A: + x: int + +B = my_dataclass(A, order=True) + +reveal_type(B) # revealed: + +reveal_type(B(1) < B(2)) # revealed: bool +``` + +### Overloaded decorator function + +When the decorator function has overloads (one for direct class application, one for returning a +decorator), calling it with a class should return the class type. + +```py +from typing_extensions import dataclass_transform, Callable, overload + +@overload +@dataclass_transform() +def my_dataclass[T](cls: type[T]) -> type[T]: ... +@overload +def my_dataclass[T]() -> Callable[[type[T]], type[T]]: ... +def my_dataclass[T](cls: type[T] | None = None) -> type[T] | Callable[[type[T]], type[T]]: + raise NotImplementedError + +class A: + x: int + +B = my_dataclass(A) + +reveal_type(B) # revealed: + +B(1) +``` + +### Passing a specialized generic class + +When calling a `@dataclass_transform()` decorated function with a specialized generic class, the +specialization should be preserved. + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def my_dataclass[T](cls: type[T]) -> type[T]: + return cls + +class A[T]: + x: T + +B = my_dataclass(A[int]) + +reveal_type(B) # revealed: + +B(1) +``` + +### Decorator factory with class parameter + +When a `@dataclass_transform()` decorated function takes a class as a parameter but is used as a +decorator factory (returns a decorator), the dataclass behavior should be applied to the decorated +class, not to the parameter class. + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def hydrated_dataclass[T](target: type[T], *, frozen: bool = False): + def decorator[U](cls: type[U]) -> type[U]: + return cls + return decorator + +class Target: + pass + +decorator = hydrated_dataclass(Target) +reveal_type(decorator) # revealed: + +@hydrated_dataclass(Target) +class Model: + x: int + +# Model should be a dataclass-like class with x as a field +Model(x=1) +reveal_type(Model.__init__) # revealed: (self: Model, x: int) -> None +``` + [`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md index e9e661c3a0..5ad4ce1ddd 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md @@ -1682,9 +1682,7 @@ def sequence4(cls: type) -> type: class Foo: ... ordered_foo = dataclass(order=True)(Foo) -reveal_type(ordered_foo) # revealed: type[Foo] & Any -# TODO: should be `Foo & Any` -reveal_type(ordered_foo()) # revealed: @Todo(Type::Intersection.call) -# TODO: should be `Any` -reveal_type(ordered_foo() < ordered_foo()) # revealed: @Todo(Type::Intersection.call) +reveal_type(ordered_foo) # revealed: +reveal_type(ordered_foo()) # revealed: Foo +reveal_type(ordered_foo() < ordered_foo()) # revealed: bool ``` diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 9801fde7b5..bbfc311dc2 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -43,8 +43,8 @@ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Paramete use crate::types::tuple::{TupleLength, TupleSpec, TupleType}; use crate::types::{ BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType, - CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams, - FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, + CallableTypeKind, DATACLASS_FLAGS, DataclassFlags, DataclassParams, FieldInstance, + GenericAlias, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind, enums, list_members, todo_type, @@ -611,6 +611,25 @@ impl<'db> Bindings<'db> { } } + Type::DataclassDecorator(params) => match overload.parameter_types() { + [Some(Type::ClassLiteral(class_literal))] => { + overload.set_return_type(Type::from( + class_literal.with_dataclass_params(db, Some(params)), + )); + } + [Some(Type::GenericAlias(generic_alias))] => { + let new_origin = generic_alias + .origin(db) + .with_dataclass_params(db, Some(params)); + overload.set_return_type(Type::GenericAlias(GenericAlias::new( + db, + new_origin, + generic_alias.specialization(db), + ))); + } + _ => {} + }, + Type::BoundMethod(bound_method) if bound_method.self_instance(db).is_property_instance() => { @@ -1119,17 +1138,9 @@ impl<'db> Bindings<'db> { overload.parameter_types() { let params = DataclassParams::default_params(db); - overload.set_return_type(Type::from(ClassLiteral::new( - db, - class_literal.name(db), - class_literal.body_scope(db), - class_literal.known(db), - class_literal.deprecated(db), - class_literal.type_check_only(db), - Some(params), - class_literal.dataclass_transformer_params(db), - class_literal.total_ordering(db), - ))); + overload.set_return_type(Type::from( + class_literal.with_dataclass_params(db, Some(params)), + )); } } @@ -1185,44 +1196,102 @@ impl<'db> Bindings<'db> { // Ideally, either the implementation, or exactly one of the overloads // of the function can have the dataclass_transform decorator applied. // However, we do not yet enforce this, and in the case of multiple - // applications of the decorator, we will only consider the last one - // for the return value, since the prior ones will be over-written. - let return_type = function_type + // applications of the decorator, we will only consider the last one. + let transformer_params = function_type .iter_overloads_and_implementation(db) - .filter_map(|function_overload| { - function_overload.dataclass_transformer_params(db).map( - |params| { - // This is a call to a custom function that was decorated with `@dataclass_transformer`. - // If this function was called with a keyword argument like `order=False`, we extract - // the argument type and overwrite the corresponding flag in `dataclass_params` after - // constructing them from the `dataclass_transformer`-parameter defaults. + .rev() + .find_map(|function_overload| { + function_overload.dataclass_transformer_params(db) + }); - let dataclass_params = - DataclassParams::from_transformer_params( - db, params, - ); - let mut flags = dataclass_params.flags(db); + if let Some(params) = transformer_params { + // If this function was called with a keyword argument like + // `order=False`, we extract the argument type and overwrite + // the corresponding flag in `dataclass_params`. + let dataclass_params = + DataclassParams::from_transformer_params(db, params); + let mut flags = dataclass_params.flags(db); - for (param, flag) in DATACLASS_FLAGS { - if let Ok(Some(Type::BooleanLiteral(value))) = - overload.parameter_type_by_name(param, false) - { - flags.set(*flag, value); - } - } + for (param, flag) in DATACLASS_FLAGS { + if let Ok(Some(Type::BooleanLiteral(value))) = + overload.parameter_type_by_name(param, false) + { + flags.set(*flag, value); + } + } - Type::DataclassDecorator(DataclassParams::new( - db, - flags, - dataclass_params.field_specifiers(db), - )) - }, - ) - }) - .last(); + let dataclass_params = DataclassParams::new( + db, + flags, + dataclass_params.field_specifiers(db), + ); - if let Some(return_type) = return_type { - overload.set_return_type(return_type); + // The dataclass_transform spec doesn't clarify how to tell whether + // a decorated function is a decorator or a decorator factory. We + // use heuristics based on the number and type of positional arguments: + // + // - Zero positional arguments: assume it's a decorator factory. + // - More than one positional argument: assume it's a decorator factory. + // - Exactly one positional argument that's a class: ambiguous, so check + // the return type to disambiguate (class-like means decorate directly). + let mut positional_args = overload + .signature + .parameters() + .iter() + .zip(overload.parameter_types()) + .filter(|(param, ty)| ty.is_some() && !param.is_keyword_only()) + .map(|(_, ty)| ty); + + let first_positional = positional_args.next(); + let has_more = positional_args.next().is_some(); + + // Only attempt direct decoration if exactly one positional argument. + if !has_more { + // Helper to check if return type is class-like. + let returns_class = || { + matches!( + overload.return_type(), + Type::ClassLiteral(_) + | Type::GenericAlias(_) + | Type::SubclassOf(_) + ) + }; + + match first_positional { + Some(Some(Type::ClassLiteral(class_literal))) + if returns_class() => + { + overload.set_return_type(Type::from( + class_literal.with_dataclass_params( + db, + Some(dataclass_params), + ), + )); + continue; + } + Some(Some(Type::GenericAlias(generic_alias))) + if returns_class() => + { + let new_origin = generic_alias + .origin(db) + .with_dataclass_params(db, Some(dataclass_params)); + overload.set_return_type(Type::GenericAlias( + GenericAlias::new( + db, + new_origin, + generic_alias.specialization(db), + ), + )); + continue; + } + _ => {} + } + } + + // Zero or more than one positional argument, or the argument is + // not a class: assume it's a decorator factory. + overload + .set_return_type(Type::DataclassDecorator(dataclass_params)); } } }, diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 95ecc881e6..aad064adca 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1555,6 +1555,25 @@ impl<'db> ClassLiteral<'db> { self.is_known(db, KnownClass::Tuple) } + /// Returns a new `ClassLiteral` with the given dataclass params, preserving all other fields. + pub(crate) fn with_dataclass_params( + self, + db: &'db dyn Db, + dataclass_params: Option>, + ) -> Self { + ClassLiteral::new( + db, + self.name(db).clone(), + self.body_scope(db), + self.known(db), + self.deprecated(db), + self.type_check_only(db), + dataclass_params, + self.dataclass_transformer_params(db), + self.total_ordering(db), + ) + } + /// Returns `true` if this class defines any ordering method (`__lt__`, `__le__`, `__gt__`, /// `__ge__`) in its own body (not inherited). Used by `@total_ordering` to determine if /// synthesis is valid. diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 7094b0795a..e532dbebee 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -748,9 +748,9 @@ impl<'db> FunctionLiteral<'db> { fn iter_overloads_and_implementation( self, db: &'db dyn Db, - ) -> impl Iterator> + 'db { - let (implementation, overloads) = self.overloads_and_implementation(db); - overloads.into_iter().chain(implementation.iter().copied()) + ) -> impl DoubleEndedIterator> + 'db { + let (overloads, implementation) = self.overloads_and_implementation(db); + overloads.iter().copied().chain(implementation) } /// Typed externally-visible signature for this function. @@ -1034,7 +1034,7 @@ impl<'db> FunctionType<'db> { pub(crate) fn iter_overloads_and_implementation( self, db: &'db dyn Db, - ) -> impl Iterator> + 'db { + ) -> impl DoubleEndedIterator> + 'db { self.literal(db).iter_overloads_and_implementation(db) } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index dbcd323d66..91605695b8 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -2941,6 +2941,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // params of the last seen usage of `@dataclass_transform` let transformer_params = f .iter_overloads_and_implementation(self.db()) + .rev() .find_map(|overload| overload.dataclass_transformer_params(self.db())); if let Some(transformer_params) = transformer_params { dataclass_params = Some(DataclassParams::from_transformer_params(