mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 05:20:49 -05:00
[ty] Support dataclass_transform as a function call (#22378)
## Summary Instead of just as a decorator. Closes https://github.com/astral-sh/ty/issues/2319.
This commit is contained in:
@@ -1032,4 +1032,125 @@ reveal_type(asdict(p)) # revealed: dict[str, Any]
|
||||
reveal_type(replace(p, name="Bob")) # revealed: Person
|
||||
```
|
||||
|
||||
## Calling decorator function directly with a class argument
|
||||
|
||||
When a function decorated with `@dataclass_transform()` is called directly with a class argument
|
||||
(not used as a decorator), it should return the class with the dataclass transformation applied.
|
||||
|
||||
### Basic case
|
||||
|
||||
```py
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
@dataclass_transform()
|
||||
def my_dataclass[T](cls: type[T]) -> type[T]:
|
||||
return cls
|
||||
|
||||
class A:
|
||||
x: int
|
||||
|
||||
B = my_dataclass(A)
|
||||
|
||||
reveal_type(B) # revealed: <class 'A'>
|
||||
|
||||
B(1)
|
||||
```
|
||||
|
||||
### Function with additional parameters
|
||||
|
||||
```py
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
@dataclass_transform()
|
||||
def my_dataclass[T](cls: type[T], *, order: bool = False) -> type[T]:
|
||||
return cls
|
||||
|
||||
class A:
|
||||
x: int
|
||||
|
||||
B = my_dataclass(A, order=True)
|
||||
|
||||
reveal_type(B) # revealed: <class 'A'>
|
||||
|
||||
reveal_type(B(1) < B(2)) # revealed: bool
|
||||
```
|
||||
|
||||
### Overloaded decorator function
|
||||
|
||||
When the decorator function has overloads (one for direct class application, one for returning a
|
||||
decorator), calling it with a class should return the class type.
|
||||
|
||||
```py
|
||||
from typing_extensions import dataclass_transform, Callable, overload
|
||||
|
||||
@overload
|
||||
@dataclass_transform()
|
||||
def my_dataclass[T](cls: type[T]) -> type[T]: ...
|
||||
@overload
|
||||
def my_dataclass[T]() -> Callable[[type[T]], type[T]]: ...
|
||||
def my_dataclass[T](cls: type[T] | None = None) -> type[T] | Callable[[type[T]], type[T]]:
|
||||
raise NotImplementedError
|
||||
|
||||
class A:
|
||||
x: int
|
||||
|
||||
B = my_dataclass(A)
|
||||
|
||||
reveal_type(B) # revealed: <class 'A'>
|
||||
|
||||
B(1)
|
||||
```
|
||||
|
||||
### Passing a specialized generic class
|
||||
|
||||
When calling a `@dataclass_transform()` decorated function with a specialized generic class, the
|
||||
specialization should be preserved.
|
||||
|
||||
```py
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
@dataclass_transform()
|
||||
def my_dataclass[T](cls: type[T]) -> type[T]:
|
||||
return cls
|
||||
|
||||
class A[T]:
|
||||
x: T
|
||||
|
||||
B = my_dataclass(A[int])
|
||||
|
||||
reveal_type(B) # revealed: <class 'A[int]'>
|
||||
|
||||
B(1)
|
||||
```
|
||||
|
||||
### Decorator factory with class parameter
|
||||
|
||||
When a `@dataclass_transform()` decorated function takes a class as a parameter but is used as a
|
||||
decorator factory (returns a decorator), the dataclass behavior should be applied to the decorated
|
||||
class, not to the parameter class.
|
||||
|
||||
```py
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
@dataclass_transform()
|
||||
def hydrated_dataclass[T](target: type[T], *, frozen: bool = False):
|
||||
def decorator[U](cls: type[U]) -> type[U]:
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
class Target:
|
||||
pass
|
||||
|
||||
decorator = hydrated_dataclass(Target)
|
||||
reveal_type(decorator) # revealed: <decorator produced by dataclass-like function>
|
||||
|
||||
@hydrated_dataclass(Target)
|
||||
class Model:
|
||||
x: int
|
||||
|
||||
# Model should be a dataclass-like class with x as a field
|
||||
Model(x=1)
|
||||
reveal_type(Model.__init__) # revealed: (self: Model, x: int) -> None
|
||||
```
|
||||
|
||||
[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform
|
||||
|
||||
@@ -1682,9 +1682,7 @@ def sequence4(cls: type) -> type:
|
||||
class Foo: ...
|
||||
|
||||
ordered_foo = dataclass(order=True)(Foo)
|
||||
reveal_type(ordered_foo) # revealed: type[Foo] & Any
|
||||
# TODO: should be `Foo & Any`
|
||||
reveal_type(ordered_foo()) # revealed: @Todo(Type::Intersection.call)
|
||||
# TODO: should be `Any`
|
||||
reveal_type(ordered_foo() < ordered_foo()) # revealed: @Todo(Type::Intersection.call)
|
||||
reveal_type(ordered_foo) # revealed: <class 'Foo'>
|
||||
reveal_type(ordered_foo()) # revealed: Foo
|
||||
reveal_type(ordered_foo() < ordered_foo()) # revealed: bool
|
||||
```
|
||||
|
||||
@@ -43,8 +43,8 @@ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Paramete
|
||||
use crate::types::tuple::{TupleLength, TupleSpec, TupleType};
|
||||
use crate::types::{
|
||||
BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType,
|
||||
CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams,
|
||||
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
|
||||
CallableTypeKind, DATACLASS_FLAGS, DataclassFlags, DataclassParams, FieldInstance,
|
||||
GenericAlias, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
|
||||
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
||||
TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind,
|
||||
enums, list_members, todo_type,
|
||||
@@ -611,6 +611,25 @@ impl<'db> Bindings<'db> {
|
||||
}
|
||||
}
|
||||
|
||||
Type::DataclassDecorator(params) => match overload.parameter_types() {
|
||||
[Some(Type::ClassLiteral(class_literal))] => {
|
||||
overload.set_return_type(Type::from(
|
||||
class_literal.with_dataclass_params(db, Some(params)),
|
||||
));
|
||||
}
|
||||
[Some(Type::GenericAlias(generic_alias))] => {
|
||||
let new_origin = generic_alias
|
||||
.origin(db)
|
||||
.with_dataclass_params(db, Some(params));
|
||||
overload.set_return_type(Type::GenericAlias(GenericAlias::new(
|
||||
db,
|
||||
new_origin,
|
||||
generic_alias.specialization(db),
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
|
||||
Type::BoundMethod(bound_method)
|
||||
if bound_method.self_instance(db).is_property_instance() =>
|
||||
{
|
||||
@@ -1119,17 +1138,9 @@ impl<'db> Bindings<'db> {
|
||||
overload.parameter_types()
|
||||
{
|
||||
let params = DataclassParams::default_params(db);
|
||||
overload.set_return_type(Type::from(ClassLiteral::new(
|
||||
db,
|
||||
class_literal.name(db),
|
||||
class_literal.body_scope(db),
|
||||
class_literal.known(db),
|
||||
class_literal.deprecated(db),
|
||||
class_literal.type_check_only(db),
|
||||
Some(params),
|
||||
class_literal.dataclass_transformer_params(db),
|
||||
class_literal.total_ordering(db),
|
||||
)));
|
||||
overload.set_return_type(Type::from(
|
||||
class_literal.with_dataclass_params(db, Some(params)),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1185,44 +1196,102 @@ impl<'db> Bindings<'db> {
|
||||
// Ideally, either the implementation, or exactly one of the overloads
|
||||
// of the function can have the dataclass_transform decorator applied.
|
||||
// However, we do not yet enforce this, and in the case of multiple
|
||||
// applications of the decorator, we will only consider the last one
|
||||
// for the return value, since the prior ones will be over-written.
|
||||
let return_type = function_type
|
||||
// applications of the decorator, we will only consider the last one.
|
||||
let transformer_params = function_type
|
||||
.iter_overloads_and_implementation(db)
|
||||
.filter_map(|function_overload| {
|
||||
function_overload.dataclass_transformer_params(db).map(
|
||||
|params| {
|
||||
// This is a call to a custom function that was decorated with `@dataclass_transformer`.
|
||||
// If this function was called with a keyword argument like `order=False`, we extract
|
||||
// the argument type and overwrite the corresponding flag in `dataclass_params` after
|
||||
// constructing them from the `dataclass_transformer`-parameter defaults.
|
||||
.rev()
|
||||
.find_map(|function_overload| {
|
||||
function_overload.dataclass_transformer_params(db)
|
||||
});
|
||||
|
||||
let dataclass_params =
|
||||
DataclassParams::from_transformer_params(
|
||||
db, params,
|
||||
);
|
||||
let mut flags = dataclass_params.flags(db);
|
||||
if let Some(params) = transformer_params {
|
||||
// If this function was called with a keyword argument like
|
||||
// `order=False`, we extract the argument type and overwrite
|
||||
// the corresponding flag in `dataclass_params`.
|
||||
let dataclass_params =
|
||||
DataclassParams::from_transformer_params(db, params);
|
||||
let mut flags = dataclass_params.flags(db);
|
||||
|
||||
for (param, flag) in DATACLASS_FLAGS {
|
||||
if let Ok(Some(Type::BooleanLiteral(value))) =
|
||||
overload.parameter_type_by_name(param, false)
|
||||
{
|
||||
flags.set(*flag, value);
|
||||
}
|
||||
}
|
||||
for (param, flag) in DATACLASS_FLAGS {
|
||||
if let Ok(Some(Type::BooleanLiteral(value))) =
|
||||
overload.parameter_type_by_name(param, false)
|
||||
{
|
||||
flags.set(*flag, value);
|
||||
}
|
||||
}
|
||||
|
||||
Type::DataclassDecorator(DataclassParams::new(
|
||||
db,
|
||||
flags,
|
||||
dataclass_params.field_specifiers(db),
|
||||
))
|
||||
},
|
||||
)
|
||||
})
|
||||
.last();
|
||||
let dataclass_params = DataclassParams::new(
|
||||
db,
|
||||
flags,
|
||||
dataclass_params.field_specifiers(db),
|
||||
);
|
||||
|
||||
if let Some(return_type) = return_type {
|
||||
overload.set_return_type(return_type);
|
||||
// The dataclass_transform spec doesn't clarify how to tell whether
|
||||
// a decorated function is a decorator or a decorator factory. We
|
||||
// use heuristics based on the number and type of positional arguments:
|
||||
//
|
||||
// - Zero positional arguments: assume it's a decorator factory.
|
||||
// - More than one positional argument: assume it's a decorator factory.
|
||||
// - Exactly one positional argument that's a class: ambiguous, so check
|
||||
// the return type to disambiguate (class-like means decorate directly).
|
||||
let mut positional_args = overload
|
||||
.signature
|
||||
.parameters()
|
||||
.iter()
|
||||
.zip(overload.parameter_types())
|
||||
.filter(|(param, ty)| ty.is_some() && !param.is_keyword_only())
|
||||
.map(|(_, ty)| ty);
|
||||
|
||||
let first_positional = positional_args.next();
|
||||
let has_more = positional_args.next().is_some();
|
||||
|
||||
// Only attempt direct decoration if exactly one positional argument.
|
||||
if !has_more {
|
||||
// Helper to check if return type is class-like.
|
||||
let returns_class = || {
|
||||
matches!(
|
||||
overload.return_type(),
|
||||
Type::ClassLiteral(_)
|
||||
| Type::GenericAlias(_)
|
||||
| Type::SubclassOf(_)
|
||||
)
|
||||
};
|
||||
|
||||
match first_positional {
|
||||
Some(Some(Type::ClassLiteral(class_literal)))
|
||||
if returns_class() =>
|
||||
{
|
||||
overload.set_return_type(Type::from(
|
||||
class_literal.with_dataclass_params(
|
||||
db,
|
||||
Some(dataclass_params),
|
||||
),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
Some(Some(Type::GenericAlias(generic_alias)))
|
||||
if returns_class() =>
|
||||
{
|
||||
let new_origin = generic_alias
|
||||
.origin(db)
|
||||
.with_dataclass_params(db, Some(dataclass_params));
|
||||
overload.set_return_type(Type::GenericAlias(
|
||||
GenericAlias::new(
|
||||
db,
|
||||
new_origin,
|
||||
generic_alias.specialization(db),
|
||||
),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Zero or more than one positional argument, or the argument is
|
||||
// not a class: assume it's a decorator factory.
|
||||
overload
|
||||
.set_return_type(Type::DataclassDecorator(dataclass_params));
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1555,6 +1555,25 @@ impl<'db> ClassLiteral<'db> {
|
||||
self.is_known(db, KnownClass::Tuple)
|
||||
}
|
||||
|
||||
/// Returns a new `ClassLiteral` with the given dataclass params, preserving all other fields.
|
||||
pub(crate) fn with_dataclass_params(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
dataclass_params: Option<DataclassParams<'db>>,
|
||||
) -> Self {
|
||||
ClassLiteral::new(
|
||||
db,
|
||||
self.name(db).clone(),
|
||||
self.body_scope(db),
|
||||
self.known(db),
|
||||
self.deprecated(db),
|
||||
self.type_check_only(db),
|
||||
dataclass_params,
|
||||
self.dataclass_transformer_params(db),
|
||||
self.total_ordering(db),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns `true` if this class defines any ordering method (`__lt__`, `__le__`, `__gt__`,
|
||||
/// `__ge__`) in its own body (not inherited). Used by `@total_ordering` to determine if
|
||||
/// synthesis is valid.
|
||||
|
||||
@@ -748,9 +748,9 @@ impl<'db> FunctionLiteral<'db> {
|
||||
fn iter_overloads_and_implementation(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
) -> impl Iterator<Item = OverloadLiteral<'db>> + 'db {
|
||||
let (implementation, overloads) = self.overloads_and_implementation(db);
|
||||
overloads.into_iter().chain(implementation.iter().copied())
|
||||
) -> impl DoubleEndedIterator<Item = OverloadLiteral<'db>> + 'db {
|
||||
let (overloads, implementation) = self.overloads_and_implementation(db);
|
||||
overloads.iter().copied().chain(implementation)
|
||||
}
|
||||
|
||||
/// Typed externally-visible signature for this function.
|
||||
@@ -1034,7 +1034,7 @@ impl<'db> FunctionType<'db> {
|
||||
pub(crate) fn iter_overloads_and_implementation(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
) -> impl Iterator<Item = OverloadLiteral<'db>> + 'db {
|
||||
) -> impl DoubleEndedIterator<Item = OverloadLiteral<'db>> + 'db {
|
||||
self.literal(db).iter_overloads_and_implementation(db)
|
||||
}
|
||||
|
||||
|
||||
@@ -2941,6 +2941,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
// params of the last seen usage of `@dataclass_transform`
|
||||
let transformer_params = f
|
||||
.iter_overloads_and_implementation(self.db())
|
||||
.rev()
|
||||
.find_map(|overload| overload.dataclass_transformer_params(self.db()));
|
||||
if let Some(transformer_params) = transformer_params {
|
||||
dataclass_params = Some(DataclassParams::from_transformer_params(
|
||||
|
||||
Reference in New Issue
Block a user