[red-knot] `typing.dataclass_transform` (#17445)

## Summary

* Add initial support for `typing.dataclass_transform`
* Support decorating a function decorator with `@dataclass_transform(…)`
(used by `attrs`, `strawberry`)
* Support decorating a metaclass with `@dataclass_transform(…)` (used by
`pydantic`, but doesn't work yet, because we don't seem to model
`__new__` calls correctly?)
* *No* support yet for decorating base classes with
`@dataclass_transform(…)`. I haven't figured out how this even supposed
to work. And haven't seen it being used.
* Add `strawberry` as an ecosystem project, as it makes heavy use of
`@dataclass_transform`

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-04-22 10:33:02 +02:00 committed by GitHub
parent f83295fe51
commit 37a0836bd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 634 additions and 117 deletions

View File

@ -45,7 +45,7 @@ jobs:
- name: Install mypy_primer
run: |
uv tool install "git+https://github.com/astral-sh/mypy_primer.git@add-red-knot-support-v5"
uv tool install "git+https://github.com/astral-sh/mypy_primer.git@add-red-knot-support-v6"
- name: Run mypy_primer
shell: bash

View File

@ -231,6 +231,10 @@ unused_peekable = "warn"
# Diagnostics are not actionable: Enable once https://github.com/rust-lang/rust-clippy/issues/13774 is resolved.
large_stack_arrays = "allow"
# Salsa generates functions with parameters for each field of a `salsa::interned` struct.
# If we don't allow this, we get warnings for structs with too many fields.
too_many_arguments = "allow"
[profile.release]
# Note that we set these explicitly, and these values
# were chosen based on a trade-off between compile times

View File

@ -0,0 +1,293 @@
# `typing.dataclass_transform`
```toml
[environment]
python-version = "3.12"
```
`dataclass_transform` is a decorator that can be used to let type checkers know that a function,
class, or metaclass is a `dataclass`-like construct.
## Basic example
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def my_dataclass[T](cls: type[T]) -> type[T]:
# modify cls
return cls
@my_dataclass
class Person:
name: str
age: int | None = None
Person("Alice", 20)
Person("Bob", None)
Person("Bob")
# error: [missing-argument]
Person()
```
## Decorating decorators that take parameters themselves
If we want our `dataclass`-like decorator to also take parameters, that is also possible:
```py
from typing_extensions import dataclass_transform, Callable
@dataclass_transform()
def versioned_class[T](*, version: int = 1):
def decorator(cls):
# modify cls
return cls
return decorator
@versioned_class(version=2)
class Person:
name: str
age: int | None = None
Person("Alice", 20)
# error: [missing-argument]
Person()
```
We properly type-check the arguments to the decorator:
```py
from typing_extensions import dataclass_transform, Callable
# error: [invalid-argument-type]
@versioned_class(version="a string")
class C:
name: str
```
## Types of decorators
The examples from this section are straight from the Python documentation on
[`typing.dataclass_transform`].
### Decorating a decorator function
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
...
return cls
@create_model
class CustomerModel:
id: int
name: str
CustomerModel(id=1, name="Test")
```
### Decorating a metaclass
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class CustomerModel(ModelBase):
id: int
name: str
CustomerModel(id=1, name="Test")
# error: [missing-argument]
CustomerModel()
```
### Decorating a base class
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
class ModelBase: ...
class CustomerModel(ModelBase):
id: int
name: str
# TODO: this is not supported yet
# error: [unknown-argument]
# error: [unknown-argument]
CustomerModel(id=1, name="Test")
```
## Arguments to `dataclass_transform`
### `eq_default`
`eq=True/False` does not have a observable effect (apart from a minor change regarding whether
`other` is positional-only or not, which is not modelled at the moment).
### `order_default`
The `order_default` argument controls whether methods such as `__lt__` are generated by default.
This can be overwritten using the `order` argument to the custom decorator:
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
def normal(*, order: bool = False):
raise NotImplementedError
@dataclass_transform(order_default=False)
def order_default_false(*, order: bool = False):
raise NotImplementedError
@dataclass_transform(order_default=True)
def order_default_true(*, order: bool = True):
raise NotImplementedError
@normal
class Normal:
inner: int
Normal(1) < Normal(2) # error: [unsupported-operator]
@normal(order=True)
class NormalOverwritten:
inner: int
NormalOverwritten(1) < NormalOverwritten(2)
@order_default_false
class OrderFalse:
inner: int
OrderFalse(1) < OrderFalse(2) # error: [unsupported-operator]
@order_default_false(order=True)
class OrderFalseOverwritten:
inner: int
OrderFalseOverwritten(1) < OrderFalseOverwritten(2)
@order_default_true
class OrderTrue:
inner: int
OrderTrue(1) < OrderTrue(2)
@order_default_true(order=False)
class OrderTrueOverwritten:
inner: int
# error: [unsupported-operator]
OrderTrueOverwritten(1) < OrderTrueOverwritten(2)
```
### `kw_only_default`
To do
### `field_specifiers`
To do
## Overloaded dataclass-like decorators
In the case of an overloaded decorator, the `dataclass_transform` decorator can be applied to the
implementation, or to *one* of the overloads.
### Applying `dataclass_transform` to the implementation
```py
from typing_extensions import dataclass_transform, TypeVar, Callable, overload
T = TypeVar("T", bound=type)
@overload
def versioned_class(
cls: T,
*,
version: int = 1,
) -> T: ...
@overload
def versioned_class(
*,
version: int = 1,
) -> Callable[[T], T]: ...
@dataclass_transform()
def versioned_class(
cls: T | None = None,
*,
version: int = 1,
) -> T | Callable[[T], T]:
raise NotImplementedError
@versioned_class
class D1:
x: str
@versioned_class(version=2)
class D2:
x: str
D1("a")
D2("a")
D1(1.2) # error: [invalid-argument-type]
D2(1.2) # error: [invalid-argument-type]
```
### Applying `dataclass_transform` to an overload
```py
from typing_extensions import dataclass_transform, TypeVar, Callable, overload
T = TypeVar("T", bound=type)
@overload
@dataclass_transform()
def versioned_class(
cls: T,
*,
version: int = 1,
) -> T: ...
@overload
def versioned_class(
*,
version: int = 1,
) -> Callable[[T], T]: ...
def versioned_class(
cls: T | None = None,
*,
version: int = 1,
) -> T | Callable[[T], T]:
raise NotImplementedError
@versioned_class
class D1:
x: str
@versioned_class(version=2)
class D2:
x: str
# TODO: these should not be errors
D1("a") # error: [too-many-positional-arguments]
D2("a") # error: [too-many-positional-arguments]
# TODO: these should be invalid-argument-type errors
D1(1.2) # error: [too-many-positional-arguments]
D2(1.2) # error: [too-many-positional-arguments]
```
[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform

View File

@ -689,7 +689,7 @@ from dataclasses import dataclass
dataclass_with_order = dataclass(order=True)
reveal_type(dataclass_with_order) # revealed: <decorator produced by dataclasses.dataclass>
reveal_type(dataclass_with_order) # revealed: <decorator produced by dataclass-like function>
@dataclass_with_order
class C:

View File

@ -18,6 +18,7 @@ python-chess
python-htmlgen
rich
scrapy
strawberry
typeshed-stats
werkzeug
zipp

View File

@ -339,12 +339,12 @@ impl<'db> PropertyInstanceType<'db> {
}
bitflags! {
/// Used as the return type of `dataclass(…)` calls. Keeps track of the arguments
/// Used for the return type of `dataclass(…)` calls. Keeps track of the arguments
/// that were passed in. For the precise meaning of the fields, see [1].
///
/// [1]: https://docs.python.org/3/library/dataclasses.html
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DataclassMetadata: u16 {
pub struct DataclassParams: u16 {
const INIT = 0b0000_0000_0001;
const REPR = 0b0000_0000_0010;
const EQ = 0b0000_0000_0100;
@ -358,12 +358,57 @@ bitflags! {
}
}
impl Default for DataclassMetadata {
impl Default for DataclassParams {
fn default() -> Self {
Self::INIT | Self::REPR | Self::EQ | Self::MATCH_ARGS
}
}
impl From<DataclassTransformerParams> for DataclassParams {
fn from(params: DataclassTransformerParams) -> Self {
let mut result = Self::default();
result.set(
Self::EQ,
params.contains(DataclassTransformerParams::EQ_DEFAULT),
);
result.set(
Self::ORDER,
params.contains(DataclassTransformerParams::ORDER_DEFAULT),
);
result.set(
Self::KW_ONLY,
params.contains(DataclassTransformerParams::KW_ONLY_DEFAULT),
);
result.set(
Self::FROZEN,
params.contains(DataclassTransformerParams::FROZEN_DEFAULT),
);
result
}
}
bitflags! {
/// Used for the return type of `dataclass_transform(…)` calls. Keeps track of the
/// arguments that were passed in. For the precise meaning of the fields, see [1].
///
/// [1]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub struct DataclassTransformerParams: u8 {
const EQ_DEFAULT = 0b0000_0001;
const ORDER_DEFAULT = 0b0000_0010;
const KW_ONLY_DEFAULT = 0b0000_0100;
const FROZEN_DEFAULT = 0b0000_1000;
}
}
impl Default for DataclassTransformerParams {
fn default() -> Self {
Self::EQ_DEFAULT
}
}
/// Representation of a type: a set of possible values at runtime.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
pub enum Type<'db> {
@ -404,7 +449,9 @@ pub enum Type<'db> {
/// A special callable that is returned by a `dataclass(…)` call. It is usually
/// used as a decorator. Note that this is only used as a return type for actual
/// `dataclass` calls, not for the argumentless `@dataclass` decorator.
DataclassDecorator(DataclassMetadata),
DataclassDecorator(DataclassParams),
/// A special callable that is returned by a `dataclass_transform(…)` call.
DataclassTransformer(DataclassTransformerParams),
/// The type of an arbitrary callable object with a certain specified signature.
Callable(CallableType<'db>),
/// A specific module object
@ -524,7 +571,8 @@ impl<'db> Type<'db> {
| Self::BoundMethod(_)
| Self::WrapperDescriptor(_)
| Self::MethodWrapper(_)
| Self::DataclassDecorator(_) => false,
| Self::DataclassDecorator(_)
| Self::DataclassTransformer(_) => false,
Self::GenericAlias(generic) => generic
.specialization(db)
@ -837,7 +885,8 @@ impl<'db> Type<'db> {
| Type::MethodWrapper(_)
| Type::BoundMethod(_)
| Type::WrapperDescriptor(_)
| Self::DataclassDecorator(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::KnownInstance(_)
@ -1073,7 +1122,7 @@ impl<'db> Type<'db> {
self_callable.is_subtype_of(db, other_callable)
}
(Type::DataclassDecorator(_), _) => {
(Type::DataclassDecorator(_) | Type::DataclassTransformer(_), _) => {
// TODO: Implement subtyping using an equivalent `Callable` type.
false
}
@ -1628,6 +1677,7 @@ impl<'db> Type<'db> {
| Type::MethodWrapper(..)
| Type::WrapperDescriptor(..)
| Type::DataclassDecorator(..)
| Type::DataclassTransformer(..)
| Type::IntLiteral(..)
| Type::SliceLiteral(..)
| Type::StringLiteral(..)
@ -1644,6 +1694,7 @@ impl<'db> Type<'db> {
| Type::MethodWrapper(..)
| Type::WrapperDescriptor(..)
| Type::DataclassDecorator(..)
| Type::DataclassTransformer(..)
| Type::IntLiteral(..)
| Type::SliceLiteral(..)
| Type::StringLiteral(..)
@ -1838,8 +1889,14 @@ impl<'db> Type<'db> {
true
}
(Type::Callable(_) | Type::DataclassDecorator(_), _)
| (_, Type::Callable(_) | Type::DataclassDecorator(_)) => {
(
Type::Callable(_) | Type::DataclassDecorator(_) | Type::DataclassTransformer(_),
_,
)
| (
_,
Type::Callable(_) | Type::DataclassDecorator(_) | Type::DataclassTransformer(_),
) => {
// TODO: Implement disjointness for general callable type with other types
false
}
@ -1902,6 +1959,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(..)
| Type::IntLiteral(_)
| Type::BooleanLiteral(_)
@ -2033,7 +2091,7 @@ impl<'db> Type<'db> {
// (this variant represents `f.__get__`, where `f` is any function)
false
}
Type::DataclassDecorator(_) => false,
Type::DataclassDecorator(_) | Type::DataclassTransformer(_) => false,
Type::Instance(InstanceType { class }) => {
class.known(db).is_some_and(KnownClass::is_singleton)
}
@ -2126,7 +2184,8 @@ impl<'db> Type<'db> {
| Type::AlwaysFalsy
| Type::Callable(_)
| Type::PropertyInstance(_)
| Type::DataclassDecorator(_) => false,
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_) => false,
}
}
@ -2262,6 +2321,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::KnownInstance(_)
| Type::AlwaysTruthy
@ -2357,7 +2417,9 @@ impl<'db> Type<'db> {
Type::DataclassDecorator(_) => KnownClass::FunctionType
.to_instance(db)
.instance_member(db, name),
Type::Callable(_) => KnownClass::Object.to_instance(db).instance_member(db, name),
Type::Callable(_) | Type::DataclassTransformer(_) => {
KnownClass::Object.to_instance(db).instance_member(db, name)
}
Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) {
None => KnownClass::Object.to_instance(db).instance_member(db, name),
@ -2774,7 +2836,7 @@ impl<'db> Type<'db> {
Type::DataclassDecorator(_) => KnownClass::FunctionType
.to_instance(db)
.member_lookup_with_policy(db, name, policy),
Type::Callable(_) => KnownClass::Object
Type::Callable(_) | Type::DataclassTransformer(_) => KnownClass::Object
.to_instance(db)
.member_lookup_with_policy(db, name, policy),
@ -3080,6 +3142,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::SliceLiteral(_)
| Type::AlwaysTruthy => Truthiness::AlwaysTrue,
@ -3387,6 +3450,18 @@ impl<'db> Type<'db> {
))
}
// TODO: We should probably also check the original return type of the function
// that was decorated with `@dataclass_transform`, to see if it is consistent with
// with what we configure here.
Type::DataclassTransformer(_) => Signatures::single(CallableSignature::single(
self,
Signature::new(
Parameters::new([Parameter::positional_only(Some(Name::new_static("func")))
.with_annotated_type(Type::object(db))]),
None,
),
)),
Type::FunctionLiteral(function_type) => match function_type.known(db) {
Some(
KnownFunction::IsEquivalentTo
@ -3500,8 +3575,7 @@ impl<'db> Type<'db> {
Parameters::new([Parameter::positional_only(Some(
Name::new_static("cls"),
))
// TODO: type[_T]
.with_annotated_type(Type::any())]),
.with_annotated_type(KnownClass::Type.to_instance(db))]),
None,
),
// TODO: make this overload Python-version-dependent
@ -4289,6 +4363,7 @@ impl<'db> Type<'db> {
| Type::BoundMethod(_)
| Type::WrapperDescriptor(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::Instance(_)
| Type::KnownInstance(_)
| Type::PropertyInstance(_)
@ -4359,6 +4434,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::Never
| Type::FunctionLiteral(_)
| Type::BoundSuper(_)
@ -4574,7 +4650,7 @@ impl<'db> Type<'db> {
Type::MethodWrapper(_) => KnownClass::MethodWrapperType.to_class_literal(db),
Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType.to_class_literal(db),
Type::DataclassDecorator(_) => KnownClass::FunctionType.to_class_literal(db),
Type::Callable(_) => KnownClass::Type.to_instance(db),
Type::Callable(_) | Type::DataclassTransformer(_) => KnownClass::Type.to_instance(db),
Type::ModuleLiteral(_) => KnownClass::ModuleType.to_class_literal(db),
Type::Tuple(_) => KnownClass::Tuple.to_class_literal(db),
@ -4714,6 +4790,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(MethodWrapperKind::StrStartswith(_))
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
// A non-generic class never needs to be specialized. A generic class is specialized
// explicitly (via a subscript expression) or implicitly (via a call), and not because
@ -4820,6 +4897,7 @@ impl<'db> Type<'db> {
| Self::MethodWrapper(_)
| Self::WrapperDescriptor(_)
| Self::DataclassDecorator(_)
| Self::DataclassTransformer(_)
| Self::PropertyInstance(_)
| Self::BoundSuper(_)
| Self::Tuple(_) => self.to_meta_type(db).definition(db),
@ -5883,6 +5961,10 @@ pub struct FunctionType<'db> {
/// A set of special decorators that were applied to this function
decorators: FunctionDecorators,
/// The arguments to `dataclass_transformer`, if this function was annotated
/// with `@dataclass_transformer(...)`.
dataclass_transformer_params: Option<DataclassTransformerParams>,
/// The generic context of a generic function.
generic_context: Option<GenericContext<'db>>,
@ -6019,6 +6101,7 @@ impl<'db> FunctionType<'db> {
self.known(db),
self.body_scope(db),
self.decorators(db),
self.dataclass_transformer_params(db),
Some(generic_context),
self.specialization(db),
)
@ -6035,6 +6118,7 @@ impl<'db> FunctionType<'db> {
self.known(db),
self.body_scope(db),
self.decorators(db),
self.dataclass_transformer_params(db),
self.generic_context(db),
Some(specialization),
)
@ -6079,6 +6163,8 @@ pub enum KnownFunction {
GetProtocolMembers,
/// `typing(_extensions).runtime_checkable`
RuntimeCheckable,
/// `typing(_extensions).dataclass_transform`
DataclassTransform,
/// `abc.abstractmethod`
#[strum(serialize = "abstractmethod")]
@ -6143,6 +6229,7 @@ impl KnownFunction {
| Self::IsProtocol
| Self::GetProtocolMembers
| Self::RuntimeCheckable
| Self::DataclassTransform
| Self::NoTypeCheck => {
matches!(module, KnownModule::Typing | KnownModule::TypingExtensions)
}
@ -7516,6 +7603,7 @@ pub(crate) mod tests {
| KnownFunction::IsProtocol
| KnownFunction::GetProtocolMembers
| KnownFunction::RuntimeCheckable
| KnownFunction::DataclassTransform
| KnownFunction::NoTypeCheck => KnownModule::TypingExtensions,
KnownFunction::IsSingleton

View File

@ -19,8 +19,9 @@ use crate::types::diagnostic::{
use crate::types::generics::{Specialization, SpecializationBuilder};
use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{
BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction,
KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind,
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType,
KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType,
UnionType, WrapperDescriptorKind,
};
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
use ruff_python_ast as ast;
@ -210,8 +211,17 @@ impl<'db> Bindings<'db> {
/// Evaluates the return type of certain known callables, where we have special-case logic to
/// determine the return type in a way that isn't directly expressible in the type system.
fn evaluate_known_cases(&mut self, db: &'db dyn Db) {
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
if let Some(Type::BooleanLiteral(value)) = ty {
*value
} else {
// TODO: emit a diagnostic if we receive `bool`
default
}
};
// Each special case listed here should have a corresponding clause in `Type::signatures`.
for binding in &mut self.elements {
for (binding, callable_signature) in self.elements.iter_mut().zip(self.signatures.iter()) {
let binding_type = binding.callable_type;
let Some((overload_index, overload)) = binding.matching_overload_mut() else {
continue;
@ -413,6 +423,21 @@ impl<'db> Bindings<'db> {
}
}
Type::DataclassTransformer(params) => {
if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() {
overload.set_return_type(Type::FunctionLiteral(FunctionType::new(
db,
function.name(db),
function.known(db),
function.body_scope(db),
function.decorators(db),
Some(params),
function.generic_context(db),
function.specialization(db),
)));
}
}
Type::BoundMethod(bound_method)
if bound_method.self_instance(db).is_property_instance() =>
{
@ -598,53 +623,90 @@ impl<'db> Bindings<'db> {
if let [init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot] =
overload.parameter_types()
{
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
if let Some(Type::BooleanLiteral(value)) = ty {
*value
} else {
// TODO: emit a diagnostic if we receive `bool`
default
}
};
let mut metadata = DataclassMetadata::empty();
let mut params = DataclassParams::empty();
if to_bool(init, true) {
metadata |= DataclassMetadata::INIT;
params |= DataclassParams::INIT;
}
if to_bool(repr, true) {
metadata |= DataclassMetadata::REPR;
params |= DataclassParams::REPR;
}
if to_bool(eq, true) {
metadata |= DataclassMetadata::EQ;
params |= DataclassParams::EQ;
}
if to_bool(order, false) {
metadata |= DataclassMetadata::ORDER;
params |= DataclassParams::ORDER;
}
if to_bool(unsafe_hash, false) {
metadata |= DataclassMetadata::UNSAFE_HASH;
params |= DataclassParams::UNSAFE_HASH;
}
if to_bool(frozen, false) {
metadata |= DataclassMetadata::FROZEN;
params |= DataclassParams::FROZEN;
}
if to_bool(match_args, true) {
metadata |= DataclassMetadata::MATCH_ARGS;
params |= DataclassParams::MATCH_ARGS;
}
if to_bool(kw_only, false) {
metadata |= DataclassMetadata::KW_ONLY;
params |= DataclassParams::KW_ONLY;
}
if to_bool(slots, false) {
metadata |= DataclassMetadata::SLOTS;
params |= DataclassParams::SLOTS;
}
if to_bool(weakref_slot, false) {
metadata |= DataclassMetadata::WEAKREF_SLOT;
params |= DataclassParams::WEAKREF_SLOT;
}
overload.set_return_type(Type::DataclassDecorator(metadata));
overload.set_return_type(Type::DataclassDecorator(params));
}
}
_ => {}
Some(KnownFunction::DataclassTransform) => {
if let [eq_default, order_default, kw_only_default, frozen_default, _field_specifiers, _kwargs] =
overload.parameter_types()
{
let mut params = DataclassTransformerParams::empty();
if to_bool(eq_default, true) {
params |= DataclassTransformerParams::EQ_DEFAULT;
}
if to_bool(order_default, false) {
params |= DataclassTransformerParams::ORDER_DEFAULT;
}
if to_bool(kw_only_default, false) {
params |= DataclassTransformerParams::KW_ONLY_DEFAULT;
}
if to_bool(frozen_default, false) {
params |= DataclassTransformerParams::FROZEN_DEFAULT;
}
overload.set_return_type(Type::DataclassTransformer(params));
}
}
_ => {
if let Some(params) = function_type.dataclass_transformer_params(db) {
// This is a call to a custom function that was decorated with `@dataclass_transformer`.
// If this function was called with a keyword argument like `order=False`, we extract
// the argument type and overwrite the corresponding flag in `dataclass_params` after
// constructing them from the `dataclass_transformer`-parameter defaults.
let mut dataclass_params = DataclassParams::from(params);
if let Some(Some(Type::BooleanLiteral(order))) = callable_signature
.iter()
.nth(overload_index)
.and_then(|signature| {
let (idx, _) =
signature.parameters().keyword_by_name("order")?;
overload.parameter_types().get(idx)
})
{
dataclass_params.set(DataclassParams::ORDER, *order);
}
overload.set_return_type(Type::DataclassDecorator(dataclass_params));
}
}
},
Type::ClassLiteral(class) => match class.known(db) {

View File

@ -10,7 +10,7 @@ use crate::semantic_index::definition::Definition;
use crate::semantic_index::DeclarationWithConstraint;
use crate::types::generics::{GenericContext, Specialization};
use crate::types::signatures::{Parameter, Parameters};
use crate::types::{CallableType, DataclassMetadata, Signature};
use crate::types::{CallableType, DataclassParams, DataclassTransformerParams, Signature};
use crate::{
module_resolver::file_to_module,
semantic_index::{
@ -106,7 +106,8 @@ pub struct Class<'db> {
pub(crate) known: Option<KnownClass>,
pub(crate) dataclass_metadata: Option<DataclassMetadata>,
pub(crate) dataclass_params: Option<DataclassParams>,
pub(crate) dataclass_transformer_params: Option<DataclassTransformerParams>,
}
impl<'db> Class<'db> {
@ -469,8 +470,8 @@ impl<'db> ClassLiteralType<'db> {
self.class(db).known
}
pub(crate) fn dataclass_metadata(self, db: &'db dyn Db) -> Option<DataclassMetadata> {
self.class(db).dataclass_metadata
pub(crate) fn dataclass_params(self, db: &'db dyn Db) -> Option<DataclassParams> {
self.class(db).dataclass_params
}
/// Return `true` if this class represents `known_class`
@ -699,6 +700,7 @@ impl<'db> ClassLiteralType<'db> {
/// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred.
pub(super) fn metaclass(self, db: &'db dyn Db) -> Type<'db> {
self.try_metaclass(db)
.map(|(ty, _)| ty)
.unwrap_or_else(|_| SubclassOfType::subclass_of_unknown())
}
@ -712,7 +714,10 @@ impl<'db> ClassLiteralType<'db> {
/// Return the metaclass of this class, or an error if the metaclass cannot be inferred.
#[salsa::tracked]
pub(super) fn try_metaclass(self, db: &'db dyn Db) -> Result<Type<'db>, MetaclassError<'db>> {
pub(super) fn try_metaclass(
self,
db: &'db dyn Db,
) -> Result<(Type<'db>, Option<DataclassTransformerParams>), MetaclassError<'db>> {
let class = self.class(db);
tracing::trace!("ClassLiteralType::try_metaclass: {}", class.name);
@ -723,7 +728,7 @@ impl<'db> ClassLiteralType<'db> {
// We emit diagnostics for cyclic class definitions elsewhere.
// Avoid attempting to infer the metaclass if the class is cyclically defined:
// it would be easy to enter an infinite loop.
return Ok(SubclassOfType::subclass_of_unknown());
return Ok((SubclassOfType::subclass_of_unknown(), None));
}
let explicit_metaclass = self.explicit_metaclass(db);
@ -768,7 +773,7 @@ impl<'db> ClassLiteralType<'db> {
}),
};
return return_ty_result.map(|ty| ty.to_meta_type(db));
return return_ty_result.map(|ty| (ty.to_meta_type(db), None));
};
// Reconcile all base classes' metaclasses with the candidate metaclass.
@ -805,7 +810,10 @@ impl<'db> ClassLiteralType<'db> {
});
}
Ok(candidate.metaclass.into())
Ok((
candidate.metaclass.into(),
candidate.metaclass.class(db).dataclass_transformer_params,
))
}
/// Returns the class member of this class named `name`.
@ -969,12 +977,8 @@ impl<'db> ClassLiteralType<'db> {
});
if symbol.symbol.is_unbound() {
if let Some(metadata) = self.dataclass_metadata(db) {
if let Some(dataclass_member) =
self.own_dataclass_member(db, specialization, metadata, name)
{
return Symbol::bound(dataclass_member).into();
}
if let Some(dataclass_member) = self.own_dataclass_member(db, specialization, name) {
return Symbol::bound(dataclass_member).into();
}
}
@ -986,70 +990,97 @@ impl<'db> ClassLiteralType<'db> {
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
metadata: DataclassMetadata,
name: &str,
) -> Option<Type<'db>> {
if name == "__init__" && metadata.contains(DataclassMetadata::INIT) {
let mut parameters = vec![];
let params = self.dataclass_params(db);
let has_dataclass_param = |param| params.is_some_and(|params| params.contains(param));
for (name, (mut attr_ty, mut default_ty)) in self.dataclass_fields(db, specialization) {
// The descriptor handling below is guarded by this fully-static check, because dynamic
// types like `Any` are valid (data) descriptors: since they have all possible attributes,
// they also have a (callable) `__set__` method. The problem is that we can't determine
// the type of the value parameter this way. Instead, we want to use the dynamic type
// itself in this case, so we skip the special descriptor handling.
if attr_ty.is_fully_static(db) {
let dunder_set = attr_ty.class_member(db, "__set__".into());
if let Some(dunder_set) = dunder_set.symbol.ignore_possibly_unbound() {
// This type of this attribute is a data descriptor. Instead of overwriting the
// descriptor attribute, data-classes will (implicitly) call the `__set__` method
// of the descriptor. This means that the synthesized `__init__` parameter for
// this attribute is determined by possible `value` parameter types with which
// the `__set__` method can be called. We build a union of all possible options
// to account for possible overloads.
let mut value_types = UnionBuilder::new(db);
for signature in &dunder_set.signatures(db) {
for overload in signature {
if let Some(value_param) = overload.parameters().get_positional(2) {
value_types = value_types.add(
value_param.annotated_type().unwrap_or_else(Type::unknown),
);
} else if overload.parameters().is_gradual() {
value_types = value_types.add(Type::unknown());
match name {
"__init__" => {
let has_synthesized_dunder_init = has_dataclass_param(DataclassParams::INIT)
|| self
.try_metaclass(db)
.is_ok_and(|(_, transformer_params)| transformer_params.is_some());
if !has_synthesized_dunder_init {
return None;
}
let mut parameters = vec![];
for (name, (mut attr_ty, mut default_ty)) in
self.dataclass_fields(db, specialization)
{
// The descriptor handling below is guarded by this fully-static check, because dynamic
// types like `Any` are valid (data) descriptors: since they have all possible attributes,
// they also have a (callable) `__set__` method. The problem is that we can't determine
// the type of the value parameter this way. Instead, we want to use the dynamic type
// itself in this case, so we skip the special descriptor handling.
if attr_ty.is_fully_static(db) {
let dunder_set = attr_ty.class_member(db, "__set__".into());
if let Some(dunder_set) = dunder_set.symbol.ignore_possibly_unbound() {
// This type of this attribute is a data descriptor. Instead of overwriting the
// descriptor attribute, data-classes will (implicitly) call the `__set__` method
// of the descriptor. This means that the synthesized `__init__` parameter for
// this attribute is determined by possible `value` parameter types with which
// the `__set__` method can be called. We build a union of all possible options
// to account for possible overloads.
let mut value_types = UnionBuilder::new(db);
for signature in &dunder_set.signatures(db) {
for overload in signature {
if let Some(value_param) =
overload.parameters().get_positional(2)
{
value_types = value_types.add(
value_param
.annotated_type()
.unwrap_or_else(Type::unknown),
);
} else if overload.parameters().is_gradual() {
value_types = value_types.add(Type::unknown());
}
}
}
}
attr_ty = value_types.build();
attr_ty = value_types.build();
// The default value of the attribute is *not* determined by the right hand side
// of the class-body assignment. Instead, the runtime invokes `__get__` on the
// descriptor, as if it had been called on the class itself, i.e. it passes `None`
// for the `instance` argument.
// The default value of the attribute is *not* determined by the right hand side
// of the class-body assignment. Instead, the runtime invokes `__get__` on the
// descriptor, as if it had been called on the class itself, i.e. it passes `None`
// for the `instance` argument.
if let Some(ref mut default_ty) = default_ty {
*default_ty = default_ty
.try_call_dunder_get(db, Type::none(db), Type::ClassLiteral(self))
.map(|(return_ty, _)| return_ty)
.unwrap_or_else(Type::unknown);
if let Some(ref mut default_ty) = default_ty {
*default_ty = default_ty
.try_call_dunder_get(
db,
Type::none(db),
Type::ClassLiteral(self),
)
.map(|(return_ty, _)| return_ty)
.unwrap_or_else(Type::unknown);
}
}
}
let mut parameter =
Parameter::positional_or_keyword(name).with_annotated_type(attr_ty);
if let Some(default_ty) = default_ty {
parameter = parameter.with_default_type(default_ty);
}
parameters.push(parameter);
}
let mut parameter =
Parameter::positional_or_keyword(name).with_annotated_type(attr_ty);
let init_signature =
Signature::new(Parameters::new(parameters), Some(Type::none(db)));
if let Some(default_ty) = default_ty {
parameter = parameter.with_default_type(default_ty);
}
parameters.push(parameter);
Some(Type::Callable(CallableType::single(db, init_signature)))
}
"__lt__" | "__le__" | "__gt__" | "__ge__" => {
if !has_dataclass_param(DataclassParams::ORDER) {
return None;
}
let init_signature = Signature::new(Parameters::new(parameters), Some(Type::none(db)));
return Some(Type::Callable(CallableType::single(db, init_signature)));
} else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
if metadata.contains(DataclassMetadata::ORDER) {
let signature = Signature::new(
Parameters::new([Parameter::positional_or_keyword(Name::new_static("other"))
// TODO: could be `Self`.
@ -1059,11 +1090,17 @@ impl<'db> ClassLiteralType<'db> {
Some(KnownClass::Bool.to_instance(db)),
);
return Some(Type::Callable(CallableType::single(db, signature)));
Some(Type::Callable(CallableType::single(db, signature)))
}
_ => None,
}
}
None
fn is_dataclass(self, db: &'db dyn Db) -> bool {
self.dataclass_params(db).is_some()
|| self
.try_metaclass(db)
.is_ok_and(|(_, transformer_params)| transformer_params.is_some())
}
/// Returns a list of all annotated attributes defined in this class, or any of its superclasses.
@ -1079,7 +1116,7 @@ impl<'db> ClassLiteralType<'db> {
.filter_map(|superclass| {
if let Some(class) = superclass.into_class() {
let class_literal = class.class_literal(db).0;
if class_literal.dataclass_metadata(db).is_some() {
if class_literal.is_dataclass(db) {
Some(class_literal)
} else {
None

View File

@ -90,6 +90,7 @@ impl<'db> ClassBase<'db> {
| Type::MethodWrapper(_)
| Type::WrapperDescriptor(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::BytesLiteral(_)
| Type::IntLiteral(_)
| Type::StringLiteral(_)

View File

@ -195,7 +195,10 @@ impl Display for DisplayRepresentation<'_> {
write!(f, "<wrapper-descriptor `{method}` of `{object}` objects>")
}
Type::DataclassDecorator(_) => {
f.write_str("<decorator produced by dataclasses.dataclass>")
f.write_str("<decorator produced by dataclass-like function>")
}
Type::DataclassTransformer(_) => {
f.write_str("<decorator produced by typing.dataclass_transform>")
}
Type::Union(union) => union.display(self.db).fmt(f),
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),

View File

@ -82,7 +82,7 @@ use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class,
ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType,
ClassLiteralType, ClassType, DataclassParams, DynamicType, FunctionDecorators, FunctionType,
GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
@ -1457,6 +1457,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let mut decorator_types_and_nodes = Vec::with_capacity(decorator_list.len());
let mut function_decorators = FunctionDecorators::empty();
let mut dataclass_transformer_params = None;
for decorator in decorator_list {
let decorator_ty = self.infer_decorator(decorator);
@ -1477,6 +1478,8 @@ impl<'db> TypeInferenceBuilder<'db> {
function_decorators |= FunctionDecorators::CLASSMETHOD;
continue;
}
} else if let Type::DataclassTransformer(params) = decorator_ty {
dataclass_transformer_params = Some(params);
}
decorator_types_and_nodes.push((decorator_ty, decorator));
@ -1523,6 +1526,7 @@ impl<'db> TypeInferenceBuilder<'db> {
function_kind,
body_scope,
function_decorators,
dataclass_transformer_params,
generic_context,
specialization,
));
@ -1757,19 +1761,32 @@ impl<'db> TypeInferenceBuilder<'db> {
body: _,
} = class_node;
let mut dataclass_metadata = None;
let mut dataclass_params = None;
let mut dataclass_transformer_params = None;
for decorator in decorator_list {
let decorator_ty = self.infer_decorator(decorator);
if decorator_ty
.into_function_literal()
.is_some_and(|function| function.is_known(self.db(), KnownFunction::Dataclass))
{
dataclass_metadata = Some(DataclassMetadata::default());
dataclass_params = Some(DataclassParams::default());
continue;
}
if let Type::DataclassDecorator(metadata) = decorator_ty {
dataclass_metadata = Some(metadata);
if let Type::DataclassDecorator(params) = decorator_ty {
dataclass_params = Some(params);
continue;
}
if let Type::FunctionLiteral(f) = decorator_ty {
if let Some(params) = f.dataclass_transformer_params(self.db()) {
dataclass_params = Some(params.into());
continue;
}
}
if let Type::DataclassTransformer(params) = decorator_ty {
dataclass_transformer_params = Some(params);
continue;
}
}
@ -1789,7 +1806,8 @@ impl<'db> TypeInferenceBuilder<'db> {
name: name.id.clone(),
body_scope,
known: maybe_known_class,
dataclass_metadata,
dataclass_params,
dataclass_transformer_params,
};
let class_literal = match generic_context {
Some(generic_context) => {
@ -2502,6 +2520,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::MethodWrapper(_)
| Type::WrapperDescriptor(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::TypeVar(..)
| Type::AlwaysTruthy
| Type::AlwaysFalsy => {
@ -4882,6 +4901,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::BoundMethod(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
@ -5164,6 +5184,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)
@ -5188,6 +5209,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)

View File

@ -79,6 +79,12 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
(Type::DataclassDecorator(_), _) => Ordering::Less,
(_, Type::DataclassDecorator(_)) => Ordering::Greater,
(Type::DataclassTransformer(left), Type::DataclassTransformer(right)) => {
left.bits().cmp(&right.bits())
}
(Type::DataclassTransformer(_), _) => Ordering::Less,
(_, Type::DataclassTransformer(_)) => Ordering::Greater,
(Type::Callable(left), Type::Callable(right)) => {
debug_assert_eq!(*left, left.normalized(db));
debug_assert_eq!(*right, right.normalized(db));