[ty] Support dataclass_transform as a function call (#22378)

## Summary

Instead of just as a decorator.

Closes https://github.com/astral-sh/ty/issues/2319.
This commit is contained in:
Charlie Marsh
2026-01-10 08:45:45 -05:00
committed by GitHub
parent cfed34334c
commit 046c5a46d8
6 changed files with 263 additions and 55 deletions

View File

@@ -1032,4 +1032,125 @@ reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
## Calling decorator function directly with a class argument
When a function decorated with `@dataclass_transform()` is called directly with a class argument
(not used as a decorator), it should return the class with the dataclass transformation applied.
### Basic case
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def my_dataclass[T](cls: type[T]) -> type[T]:
return cls
class A:
x: int
B = my_dataclass(A)
reveal_type(B) # revealed: <class 'A'>
B(1)
```
### Function with additional parameters
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def my_dataclass[T](cls: type[T], *, order: bool = False) -> type[T]:
return cls
class A:
x: int
B = my_dataclass(A, order=True)
reveal_type(B) # revealed: <class 'A'>
reveal_type(B(1) < B(2)) # revealed: bool
```
### Overloaded decorator function
When the decorator function has overloads (one for direct class application, one for returning a
decorator), calling it with a class should return the class type.
```py
from typing_extensions import dataclass_transform, Callable, overload
@overload
@dataclass_transform()
def my_dataclass[T](cls: type[T]) -> type[T]: ...
@overload
def my_dataclass[T]() -> Callable[[type[T]], type[T]]: ...
def my_dataclass[T](cls: type[T] | None = None) -> type[T] | Callable[[type[T]], type[T]]:
raise NotImplementedError
class A:
x: int
B = my_dataclass(A)
reveal_type(B) # revealed: <class 'A'>
B(1)
```
### Passing a specialized generic class
When calling a `@dataclass_transform()` decorated function with a specialized generic class, the
specialization should be preserved.
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def my_dataclass[T](cls: type[T]) -> type[T]:
return cls
class A[T]:
x: T
B = my_dataclass(A[int])
reveal_type(B) # revealed: <class 'A[int]'>
B(1)
```
### Decorator factory with class parameter
When a `@dataclass_transform()` decorated function takes a class as a parameter but is used as a
decorator factory (returns a decorator), the dataclass behavior should be applied to the decorated
class, not to the parameter class.
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def hydrated_dataclass[T](target: type[T], *, frozen: bool = False):
def decorator[U](cls: type[U]) -> type[U]:
return cls
return decorator
class Target:
pass
decorator = hydrated_dataclass(Target)
reveal_type(decorator) # revealed: <decorator produced by dataclass-like function>
@hydrated_dataclass(Target)
class Model:
x: int
# Model should be a dataclass-like class with x as a field
Model(x=1)
reveal_type(Model.__init__) # revealed: (self: Model, x: int) -> None
```
[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform

View File

@@ -1682,9 +1682,7 @@ def sequence4(cls: type) -> type:
class Foo: ...
ordered_foo = dataclass(order=True)(Foo)
reveal_type(ordered_foo) # revealed: type[Foo] & Any
# TODO: should be `Foo & Any`
reveal_type(ordered_foo()) # revealed: @Todo(Type::Intersection.call)
# TODO: should be `Any`
reveal_type(ordered_foo() < ordered_foo()) # revealed: @Todo(Type::Intersection.call)
reveal_type(ordered_foo) # revealed: <class 'Foo'>
reveal_type(ordered_foo()) # revealed: Foo
reveal_type(ordered_foo() < ordered_foo()) # revealed: bool
```

View File

@@ -43,8 +43,8 @@ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Paramete
use crate::types::tuple::{TupleLength, TupleSpec, TupleType};
use crate::types::{
BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType,
CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
CallableTypeKind, DATACLASS_FLAGS, DataclassFlags, DataclassParams, FieldInstance,
GenericAlias, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind,
enums, list_members, todo_type,
@@ -611,6 +611,25 @@ impl<'db> Bindings<'db> {
}
}
Type::DataclassDecorator(params) => match overload.parameter_types() {
[Some(Type::ClassLiteral(class_literal))] => {
overload.set_return_type(Type::from(
class_literal.with_dataclass_params(db, Some(params)),
));
}
[Some(Type::GenericAlias(generic_alias))] => {
let new_origin = generic_alias
.origin(db)
.with_dataclass_params(db, Some(params));
overload.set_return_type(Type::GenericAlias(GenericAlias::new(
db,
new_origin,
generic_alias.specialization(db),
)));
}
_ => {}
},
Type::BoundMethod(bound_method)
if bound_method.self_instance(db).is_property_instance() =>
{
@@ -1119,17 +1138,9 @@ impl<'db> Bindings<'db> {
overload.parameter_types()
{
let params = DataclassParams::default_params(db);
overload.set_return_type(Type::from(ClassLiteral::new(
db,
class_literal.name(db),
class_literal.body_scope(db),
class_literal.known(db),
class_literal.deprecated(db),
class_literal.type_check_only(db),
Some(params),
class_literal.dataclass_transformer_params(db),
class_literal.total_ordering(db),
)));
overload.set_return_type(Type::from(
class_literal.with_dataclass_params(db, Some(params)),
));
}
}
@@ -1185,44 +1196,102 @@ impl<'db> Bindings<'db> {
// Ideally, either the implementation, or exactly one of the overloads
// of the function can have the dataclass_transform decorator applied.
// However, we do not yet enforce this, and in the case of multiple
// applications of the decorator, we will only consider the last one
// for the return value, since the prior ones will be over-written.
let return_type = function_type
// applications of the decorator, we will only consider the last one.
let transformer_params = function_type
.iter_overloads_and_implementation(db)
.filter_map(|function_overload| {
function_overload.dataclass_transformer_params(db).map(
|params| {
// This is a call to a custom function that was decorated with `@dataclass_transformer`.
// If this function was called with a keyword argument like `order=False`, we extract
// the argument type and overwrite the corresponding flag in `dataclass_params` after
// constructing them from the `dataclass_transformer`-parameter defaults.
.rev()
.find_map(|function_overload| {
function_overload.dataclass_transformer_params(db)
});
let dataclass_params =
DataclassParams::from_transformer_params(
db, params,
);
let mut flags = dataclass_params.flags(db);
if let Some(params) = transformer_params {
// If this function was called with a keyword argument like
// `order=False`, we extract the argument type and overwrite
// the corresponding flag in `dataclass_params`.
let dataclass_params =
DataclassParams::from_transformer_params(db, params);
let mut flags = dataclass_params.flags(db);
for (param, flag) in DATACLASS_FLAGS {
if let Ok(Some(Type::BooleanLiteral(value))) =
overload.parameter_type_by_name(param, false)
{
flags.set(*flag, value);
}
}
for (param, flag) in DATACLASS_FLAGS {
if let Ok(Some(Type::BooleanLiteral(value))) =
overload.parameter_type_by_name(param, false)
{
flags.set(*flag, value);
}
}
Type::DataclassDecorator(DataclassParams::new(
db,
flags,
dataclass_params.field_specifiers(db),
))
},
)
})
.last();
let dataclass_params = DataclassParams::new(
db,
flags,
dataclass_params.field_specifiers(db),
);
if let Some(return_type) = return_type {
overload.set_return_type(return_type);
// The dataclass_transform spec doesn't clarify how to tell whether
// a decorated function is a decorator or a decorator factory. We
// use heuristics based on the number and type of positional arguments:
//
// - Zero positional arguments: assume it's a decorator factory.
// - More than one positional argument: assume it's a decorator factory.
// - Exactly one positional argument that's a class: ambiguous, so check
// the return type to disambiguate (class-like means decorate directly).
let mut positional_args = overload
.signature
.parameters()
.iter()
.zip(overload.parameter_types())
.filter(|(param, ty)| ty.is_some() && !param.is_keyword_only())
.map(|(_, ty)| ty);
let first_positional = positional_args.next();
let has_more = positional_args.next().is_some();
// Only attempt direct decoration if exactly one positional argument.
if !has_more {
// Helper to check if return type is class-like.
let returns_class = || {
matches!(
overload.return_type(),
Type::ClassLiteral(_)
| Type::GenericAlias(_)
| Type::SubclassOf(_)
)
};
match first_positional {
Some(Some(Type::ClassLiteral(class_literal)))
if returns_class() =>
{
overload.set_return_type(Type::from(
class_literal.with_dataclass_params(
db,
Some(dataclass_params),
),
));
continue;
}
Some(Some(Type::GenericAlias(generic_alias)))
if returns_class() =>
{
let new_origin = generic_alias
.origin(db)
.with_dataclass_params(db, Some(dataclass_params));
overload.set_return_type(Type::GenericAlias(
GenericAlias::new(
db,
new_origin,
generic_alias.specialization(db),
),
));
continue;
}
_ => {}
}
}
// Zero or more than one positional argument, or the argument is
// not a class: assume it's a decorator factory.
overload
.set_return_type(Type::DataclassDecorator(dataclass_params));
}
}
},

View File

@@ -1555,6 +1555,25 @@ impl<'db> ClassLiteral<'db> {
self.is_known(db, KnownClass::Tuple)
}
/// Returns a new `ClassLiteral` with the given dataclass params, preserving all other fields.
pub(crate) fn with_dataclass_params(
self,
db: &'db dyn Db,
dataclass_params: Option<DataclassParams<'db>>,
) -> Self {
ClassLiteral::new(
db,
self.name(db).clone(),
self.body_scope(db),
self.known(db),
self.deprecated(db),
self.type_check_only(db),
dataclass_params,
self.dataclass_transformer_params(db),
self.total_ordering(db),
)
}
/// Returns `true` if this class defines any ordering method (`__lt__`, `__le__`, `__gt__`,
/// `__ge__`) in its own body (not inherited). Used by `@total_ordering` to determine if
/// synthesis is valid.

View File

@@ -748,9 +748,9 @@ impl<'db> FunctionLiteral<'db> {
fn iter_overloads_and_implementation(
self,
db: &'db dyn Db,
) -> impl Iterator<Item = OverloadLiteral<'db>> + 'db {
let (implementation, overloads) = self.overloads_and_implementation(db);
overloads.into_iter().chain(implementation.iter().copied())
) -> impl DoubleEndedIterator<Item = OverloadLiteral<'db>> + 'db {
let (overloads, implementation) = self.overloads_and_implementation(db);
overloads.iter().copied().chain(implementation)
}
/// Typed externally-visible signature for this function.
@@ -1034,7 +1034,7 @@ impl<'db> FunctionType<'db> {
pub(crate) fn iter_overloads_and_implementation(
self,
db: &'db dyn Db,
) -> impl Iterator<Item = OverloadLiteral<'db>> + 'db {
) -> impl DoubleEndedIterator<Item = OverloadLiteral<'db>> + 'db {
self.literal(db).iter_overloads_and_implementation(db)
}

View File

@@ -2941,6 +2941,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// params of the last seen usage of `@dataclass_transform`
let transformer_params = f
.iter_overloads_and_implementation(self.db())
.rev()
.find_map(|overload| overload.dataclass_transformer_params(self.db()));
if let Some(transformer_params) = transformer_params {
dataclass_params = Some(DataclassParams::from_transformer_params(