From 37a0836bd2bb9bb456a4e4bb37ee4cb55d1f0509 Mon Sep 17 00:00:00 2001 From: David Peter Date: Tue, 22 Apr 2025 10:33:02 +0200 Subject: [PATCH] [red-knot] `typing.dataclass_transform` (#17445) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary * Add initial support for `typing.dataclass_transform` * Support decorating a function decorator with `@dataclass_transform(…)` (used by `attrs`, `strawberry`) * Support decorating a metaclass with `@dataclass_transform(…)` (used by `pydantic`, but doesn't work yet, because we don't seem to model `__new__` calls correctly?) * *No* support yet for decorating base classes with `@dataclass_transform(…)`. I haven't figured out how this even supposed to work. And haven't seen it being used. * Add `strawberry` as an ecosystem project, as it makes heavy use of `@dataclass_transform` ## Test Plan New Markdown tests --- .github/workflows/mypy_primer.yaml | 2 +- Cargo.toml | 4 + .../resources/mdtest/dataclass_transform.md | 293 ++++++++++++++++++ .../resources/mdtest/dataclasses.md | 2 +- .../resources/primer/good.txt | 1 + crates/red_knot_python_semantic/src/types.rs | 120 ++++++- .../src/types/call/bind.rs | 112 +++++-- .../src/types/class.rs | 171 ++++++---- .../src/types/class_base.rs | 1 + .../src/types/display.rs | 5 +- .../src/types/infer.rs | 34 +- .../src/types/type_ordering.rs | 6 + 12 files changed, 634 insertions(+), 117 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/dataclass_transform.md diff --git a/.github/workflows/mypy_primer.yaml b/.github/workflows/mypy_primer.yaml index 490ae970fa..ac0dc7a2a1 100644 --- a/.github/workflows/mypy_primer.yaml +++ b/.github/workflows/mypy_primer.yaml @@ -45,7 +45,7 @@ jobs: - name: Install mypy_primer run: | - uv tool install "git+https://github.com/astral-sh/mypy_primer.git@add-red-knot-support-v5" + uv tool install "git+https://github.com/astral-sh/mypy_primer.git@add-red-knot-support-v6" - name: Run mypy_primer shell: bash diff --git a/Cargo.toml b/Cargo.toml index 933ae24c5f..7a12fd6c3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -231,6 +231,10 @@ unused_peekable = "warn" # Diagnostics are not actionable: Enable once https://github.com/rust-lang/rust-clippy/issues/13774 is resolved. large_stack_arrays = "allow" +# Salsa generates functions with parameters for each field of a `salsa::interned` struct. +# If we don't allow this, we get warnings for structs with too many fields. +too_many_arguments = "allow" + [profile.release] # Note that we set these explicitly, and these values # were chosen based on a trade-off between compile times diff --git a/crates/red_knot_python_semantic/resources/mdtest/dataclass_transform.md b/crates/red_knot_python_semantic/resources/mdtest/dataclass_transform.md new file mode 100644 index 0000000000..a57fcd612e --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/dataclass_transform.md @@ -0,0 +1,293 @@ +# `typing.dataclass_transform` + +```toml +[environment] +python-version = "3.12" +``` + +`dataclass_transform` is a decorator that can be used to let type checkers know that a function, +class, or metaclass is a `dataclass`-like construct. + +## Basic example + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def my_dataclass[T](cls: type[T]) -> type[T]: + # modify cls + return cls + +@my_dataclass +class Person: + name: str + age: int | None = None + +Person("Alice", 20) +Person("Bob", None) +Person("Bob") + +# error: [missing-argument] +Person() +``` + +## Decorating decorators that take parameters themselves + +If we want our `dataclass`-like decorator to also take parameters, that is also possible: + +```py +from typing_extensions import dataclass_transform, Callable + +@dataclass_transform() +def versioned_class[T](*, version: int = 1): + def decorator(cls): + # modify cls + return cls + return decorator + +@versioned_class(version=2) +class Person: + name: str + age: int | None = None + +Person("Alice", 20) + +# error: [missing-argument] +Person() +``` + +We properly type-check the arguments to the decorator: + +```py +from typing_extensions import dataclass_transform, Callable + +# error: [invalid-argument-type] +@versioned_class(version="a string") +class C: + name: str +``` + +## Types of decorators + +The examples from this section are straight from the Python documentation on +[`typing.dataclass_transform`]. + +### Decorating a decorator function + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def create_model[T](cls: type[T]) -> type[T]: + ... + return cls + +@create_model +class CustomerModel: + id: int + name: str + +CustomerModel(id=1, name="Test") +``` + +### Decorating a metaclass + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +class ModelMeta(type): ... + +class ModelBase(metaclass=ModelMeta): ... + +class CustomerModel(ModelBase): + id: int + name: str + +CustomerModel(id=1, name="Test") + +# error: [missing-argument] +CustomerModel() +``` + +### Decorating a base class + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +class ModelBase: ... + +class CustomerModel(ModelBase): + id: int + name: str + +# TODO: this is not supported yet +# error: [unknown-argument] +# error: [unknown-argument] +CustomerModel(id=1, name="Test") +``` + +## Arguments to `dataclass_transform` + +### `eq_default` + +`eq=True/False` does not have a observable effect (apart from a minor change regarding whether +`other` is positional-only or not, which is not modelled at the moment). + +### `order_default` + +The `order_default` argument controls whether methods such as `__lt__` are generated by default. +This can be overwritten using the `order` argument to the custom decorator: + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +def normal(*, order: bool = False): + raise NotImplementedError + +@dataclass_transform(order_default=False) +def order_default_false(*, order: bool = False): + raise NotImplementedError + +@dataclass_transform(order_default=True) +def order_default_true(*, order: bool = True): + raise NotImplementedError + +@normal +class Normal: + inner: int + +Normal(1) < Normal(2) # error: [unsupported-operator] + +@normal(order=True) +class NormalOverwritten: + inner: int + +NormalOverwritten(1) < NormalOverwritten(2) + +@order_default_false +class OrderFalse: + inner: int + +OrderFalse(1) < OrderFalse(2) # error: [unsupported-operator] + +@order_default_false(order=True) +class OrderFalseOverwritten: + inner: int + +OrderFalseOverwritten(1) < OrderFalseOverwritten(2) + +@order_default_true +class OrderTrue: + inner: int + +OrderTrue(1) < OrderTrue(2) + +@order_default_true(order=False) +class OrderTrueOverwritten: + inner: int + +# error: [unsupported-operator] +OrderTrueOverwritten(1) < OrderTrueOverwritten(2) +``` + +### `kw_only_default` + +To do + +### `field_specifiers` + +To do + +## Overloaded dataclass-like decorators + +In the case of an overloaded decorator, the `dataclass_transform` decorator can be applied to the +implementation, or to *one* of the overloads. + +### Applying `dataclass_transform` to the implementation + +```py +from typing_extensions import dataclass_transform, TypeVar, Callable, overload + +T = TypeVar("T", bound=type) + +@overload +def versioned_class( + cls: T, + *, + version: int = 1, +) -> T: ... +@overload +def versioned_class( + *, + version: int = 1, +) -> Callable[[T], T]: ... +@dataclass_transform() +def versioned_class( + cls: T | None = None, + *, + version: int = 1, +) -> T | Callable[[T], T]: + raise NotImplementedError + +@versioned_class +class D1: + x: str + +@versioned_class(version=2) +class D2: + x: str + +D1("a") +D2("a") + +D1(1.2) # error: [invalid-argument-type] +D2(1.2) # error: [invalid-argument-type] +``` + +### Applying `dataclass_transform` to an overload + +```py +from typing_extensions import dataclass_transform, TypeVar, Callable, overload + +T = TypeVar("T", bound=type) + +@overload +@dataclass_transform() +def versioned_class( + cls: T, + *, + version: int = 1, +) -> T: ... +@overload +def versioned_class( + *, + version: int = 1, +) -> Callable[[T], T]: ... +def versioned_class( + cls: T | None = None, + *, + version: int = 1, +) -> T | Callable[[T], T]: + raise NotImplementedError + +@versioned_class +class D1: + x: str + +@versioned_class(version=2) +class D2: + x: str + +# TODO: these should not be errors +D1("a") # error: [too-many-positional-arguments] +D2("a") # error: [too-many-positional-arguments] + +# TODO: these should be invalid-argument-type errors +D1(1.2) # error: [too-many-positional-arguments] +D2(1.2) # error: [too-many-positional-arguments] +``` + +[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform diff --git a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md index 94958c883b..0f23feea9f 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md +++ b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md @@ -689,7 +689,7 @@ from dataclasses import dataclass dataclass_with_order = dataclass(order=True) -reveal_type(dataclass_with_order) # revealed: +reveal_type(dataclass_with_order) # revealed: @dataclass_with_order class C: diff --git a/crates/red_knot_python_semantic/resources/primer/good.txt b/crates/red_knot_python_semantic/resources/primer/good.txt index 360e1e72ef..3dca5c49f9 100644 --- a/crates/red_knot_python_semantic/resources/primer/good.txt +++ b/crates/red_knot_python_semantic/resources/primer/good.txt @@ -18,6 +18,7 @@ python-chess python-htmlgen rich scrapy +strawberry typeshed-stats werkzeug zipp diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index c0ccc6e7ee..ca20e3c55e 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -339,12 +339,12 @@ impl<'db> PropertyInstanceType<'db> { } bitflags! { - /// Used as the return type of `dataclass(…)` calls. Keeps track of the arguments + /// Used for the return type of `dataclass(…)` calls. Keeps track of the arguments /// that were passed in. For the precise meaning of the fields, see [1]. /// /// [1]: https://docs.python.org/3/library/dataclasses.html #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] - pub struct DataclassMetadata: u16 { + pub struct DataclassParams: u16 { const INIT = 0b0000_0000_0001; const REPR = 0b0000_0000_0010; const EQ = 0b0000_0000_0100; @@ -358,12 +358,57 @@ bitflags! { } } -impl Default for DataclassMetadata { +impl Default for DataclassParams { fn default() -> Self { Self::INIT | Self::REPR | Self::EQ | Self::MATCH_ARGS } } +impl From for DataclassParams { + fn from(params: DataclassTransformerParams) -> Self { + let mut result = Self::default(); + + result.set( + Self::EQ, + params.contains(DataclassTransformerParams::EQ_DEFAULT), + ); + result.set( + Self::ORDER, + params.contains(DataclassTransformerParams::ORDER_DEFAULT), + ); + result.set( + Self::KW_ONLY, + params.contains(DataclassTransformerParams::KW_ONLY_DEFAULT), + ); + result.set( + Self::FROZEN, + params.contains(DataclassTransformerParams::FROZEN_DEFAULT), + ); + + result + } +} + +bitflags! { + /// Used for the return type of `dataclass_transform(…)` calls. Keeps track of the + /// arguments that were passed in. For the precise meaning of the fields, see [1]. + /// + /// [1]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)] + pub struct DataclassTransformerParams: u8 { + const EQ_DEFAULT = 0b0000_0001; + const ORDER_DEFAULT = 0b0000_0010; + const KW_ONLY_DEFAULT = 0b0000_0100; + const FROZEN_DEFAULT = 0b0000_1000; + } +} + +impl Default for DataclassTransformerParams { + fn default() -> Self { + Self::EQ_DEFAULT + } +} + /// Representation of a type: a set of possible values at runtime. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] pub enum Type<'db> { @@ -404,7 +449,9 @@ pub enum Type<'db> { /// A special callable that is returned by a `dataclass(…)` call. It is usually /// used as a decorator. Note that this is only used as a return type for actual /// `dataclass` calls, not for the argumentless `@dataclass` decorator. - DataclassDecorator(DataclassMetadata), + DataclassDecorator(DataclassParams), + /// A special callable that is returned by a `dataclass_transform(…)` call. + DataclassTransformer(DataclassTransformerParams), /// The type of an arbitrary callable object with a certain specified signature. Callable(CallableType<'db>), /// A specific module object @@ -524,7 +571,8 @@ impl<'db> Type<'db> { | Self::BoundMethod(_) | Self::WrapperDescriptor(_) | Self::MethodWrapper(_) - | Self::DataclassDecorator(_) => false, + | Self::DataclassDecorator(_) + | Self::DataclassTransformer(_) => false, Self::GenericAlias(generic) => generic .specialization(db) @@ -837,7 +885,8 @@ impl<'db> Type<'db> { | Type::MethodWrapper(_) | Type::BoundMethod(_) | Type::WrapperDescriptor(_) - | Self::DataclassDecorator(_) + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::KnownInstance(_) @@ -1073,7 +1122,7 @@ impl<'db> Type<'db> { self_callable.is_subtype_of(db, other_callable) } - (Type::DataclassDecorator(_), _) => { + (Type::DataclassDecorator(_) | Type::DataclassTransformer(_), _) => { // TODO: Implement subtyping using an equivalent `Callable` type. false } @@ -1628,6 +1677,7 @@ impl<'db> Type<'db> { | Type::MethodWrapper(..) | Type::WrapperDescriptor(..) | Type::DataclassDecorator(..) + | Type::DataclassTransformer(..) | Type::IntLiteral(..) | Type::SliceLiteral(..) | Type::StringLiteral(..) @@ -1644,6 +1694,7 @@ impl<'db> Type<'db> { | Type::MethodWrapper(..) | Type::WrapperDescriptor(..) | Type::DataclassDecorator(..) + | Type::DataclassTransformer(..) | Type::IntLiteral(..) | Type::SliceLiteral(..) | Type::StringLiteral(..) @@ -1838,8 +1889,14 @@ impl<'db> Type<'db> { true } - (Type::Callable(_) | Type::DataclassDecorator(_), _) - | (_, Type::Callable(_) | Type::DataclassDecorator(_)) => { + ( + Type::Callable(_) | Type::DataclassDecorator(_) | Type::DataclassTransformer(_), + _, + ) + | ( + _, + Type::Callable(_) | Type::DataclassDecorator(_) | Type::DataclassTransformer(_), + ) => { // TODO: Implement disjointness for general callable type with other types false } @@ -1902,6 +1959,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(..) | Type::IntLiteral(_) | Type::BooleanLiteral(_) @@ -2033,7 +2091,7 @@ impl<'db> Type<'db> { // (this variant represents `f.__get__`, where `f` is any function) false } - Type::DataclassDecorator(_) => false, + Type::DataclassDecorator(_) | Type::DataclassTransformer(_) => false, Type::Instance(InstanceType { class }) => { class.known(db).is_some_and(KnownClass::is_singleton) } @@ -2126,7 +2184,8 @@ impl<'db> Type<'db> { | Type::AlwaysFalsy | Type::Callable(_) | Type::PropertyInstance(_) - | Type::DataclassDecorator(_) => false, + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) => false, } } @@ -2262,6 +2321,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(_) | Type::KnownInstance(_) | Type::AlwaysTruthy @@ -2357,7 +2417,9 @@ impl<'db> Type<'db> { Type::DataclassDecorator(_) => KnownClass::FunctionType .to_instance(db) .instance_member(db, name), - Type::Callable(_) => KnownClass::Object.to_instance(db).instance_member(db, name), + Type::Callable(_) | Type::DataclassTransformer(_) => { + KnownClass::Object.to_instance(db).instance_member(db, name) + } Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) { None => KnownClass::Object.to_instance(db).instance_member(db, name), @@ -2774,7 +2836,7 @@ impl<'db> Type<'db> { Type::DataclassDecorator(_) => KnownClass::FunctionType .to_instance(db) .member_lookup_with_policy(db, name, policy), - Type::Callable(_) => KnownClass::Object + Type::Callable(_) | Type::DataclassTransformer(_) => KnownClass::Object .to_instance(db) .member_lookup_with_policy(db, name, policy), @@ -3080,6 +3142,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(_) | Type::SliceLiteral(_) | Type::AlwaysTruthy => Truthiness::AlwaysTrue, @@ -3387,6 +3450,18 @@ impl<'db> Type<'db> { )) } + // TODO: We should probably also check the original return type of the function + // that was decorated with `@dataclass_transform`, to see if it is consistent with + // with what we configure here. + Type::DataclassTransformer(_) => Signatures::single(CallableSignature::single( + self, + Signature::new( + Parameters::new([Parameter::positional_only(Some(Name::new_static("func"))) + .with_annotated_type(Type::object(db))]), + None, + ), + )), + Type::FunctionLiteral(function_type) => match function_type.known(db) { Some( KnownFunction::IsEquivalentTo @@ -3500,8 +3575,7 @@ impl<'db> Type<'db> { Parameters::new([Parameter::positional_only(Some( Name::new_static("cls"), )) - // TODO: type[_T] - .with_annotated_type(Type::any())]), + .with_annotated_type(KnownClass::Type.to_instance(db))]), None, ), // TODO: make this overload Python-version-dependent @@ -4289,6 +4363,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::Instance(_) | Type::KnownInstance(_) | Type::PropertyInstance(_) @@ -4359,6 +4434,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::Never | Type::FunctionLiteral(_) | Type::BoundSuper(_) @@ -4574,7 +4650,7 @@ impl<'db> Type<'db> { Type::MethodWrapper(_) => KnownClass::MethodWrapperType.to_class_literal(db), Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType.to_class_literal(db), Type::DataclassDecorator(_) => KnownClass::FunctionType.to_class_literal(db), - Type::Callable(_) => KnownClass::Type.to_instance(db), + Type::Callable(_) | Type::DataclassTransformer(_) => KnownClass::Type.to_instance(db), Type::ModuleLiteral(_) => KnownClass::ModuleType.to_class_literal(db), Type::Tuple(_) => KnownClass::Tuple.to_class_literal(db), @@ -4714,6 +4790,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(_) // A non-generic class never needs to be specialized. A generic class is specialized // explicitly (via a subscript expression) or implicitly (via a call), and not because @@ -4820,6 +4897,7 @@ impl<'db> Type<'db> { | Self::MethodWrapper(_) | Self::WrapperDescriptor(_) | Self::DataclassDecorator(_) + | Self::DataclassTransformer(_) | Self::PropertyInstance(_) | Self::BoundSuper(_) | Self::Tuple(_) => self.to_meta_type(db).definition(db), @@ -5883,6 +5961,10 @@ pub struct FunctionType<'db> { /// A set of special decorators that were applied to this function decorators: FunctionDecorators, + /// The arguments to `dataclass_transformer`, if this function was annotated + /// with `@dataclass_transformer(...)`. + dataclass_transformer_params: Option, + /// The generic context of a generic function. generic_context: Option>, @@ -6019,6 +6101,7 @@ impl<'db> FunctionType<'db> { self.known(db), self.body_scope(db), self.decorators(db), + self.dataclass_transformer_params(db), Some(generic_context), self.specialization(db), ) @@ -6035,6 +6118,7 @@ impl<'db> FunctionType<'db> { self.known(db), self.body_scope(db), self.decorators(db), + self.dataclass_transformer_params(db), self.generic_context(db), Some(specialization), ) @@ -6079,6 +6163,8 @@ pub enum KnownFunction { GetProtocolMembers, /// `typing(_extensions).runtime_checkable` RuntimeCheckable, + /// `typing(_extensions).dataclass_transform` + DataclassTransform, /// `abc.abstractmethod` #[strum(serialize = "abstractmethod")] @@ -6143,6 +6229,7 @@ impl KnownFunction { | Self::IsProtocol | Self::GetProtocolMembers | Self::RuntimeCheckable + | Self::DataclassTransform | Self::NoTypeCheck => { matches!(module, KnownModule::Typing | KnownModule::TypingExtensions) } @@ -7516,6 +7603,7 @@ pub(crate) mod tests { | KnownFunction::IsProtocol | KnownFunction::GetProtocolMembers | KnownFunction::RuntimeCheckable + | KnownFunction::DataclassTransform | KnownFunction::NoTypeCheck => KnownModule::TypingExtensions, KnownFunction::IsSingleton diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index a577dba34b..6687bd2846 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -19,8 +19,9 @@ use crate::types::diagnostic::{ use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, - KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, + BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType, + KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, + UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; use ruff_python_ast as ast; @@ -210,8 +211,17 @@ impl<'db> Bindings<'db> { /// Evaluates the return type of certain known callables, where we have special-case logic to /// determine the return type in a way that isn't directly expressible in the type system. fn evaluate_known_cases(&mut self, db: &'db dyn Db) { + let to_bool = |ty: &Option>, default: bool| -> bool { + if let Some(Type::BooleanLiteral(value)) = ty { + *value + } else { + // TODO: emit a diagnostic if we receive `bool` + default + } + }; + // Each special case listed here should have a corresponding clause in `Type::signatures`. - for binding in &mut self.elements { + for (binding, callable_signature) in self.elements.iter_mut().zip(self.signatures.iter()) { let binding_type = binding.callable_type; let Some((overload_index, overload)) = binding.matching_overload_mut() else { continue; @@ -413,6 +423,21 @@ impl<'db> Bindings<'db> { } } + Type::DataclassTransformer(params) => { + if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() { + overload.set_return_type(Type::FunctionLiteral(FunctionType::new( + db, + function.name(db), + function.known(db), + function.body_scope(db), + function.decorators(db), + Some(params), + function.generic_context(db), + function.specialization(db), + ))); + } + } + Type::BoundMethod(bound_method) if bound_method.self_instance(db).is_property_instance() => { @@ -598,53 +623,90 @@ impl<'db> Bindings<'db> { if let [init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot] = overload.parameter_types() { - let to_bool = |ty: &Option>, default: bool| -> bool { - if let Some(Type::BooleanLiteral(value)) = ty { - *value - } else { - // TODO: emit a diagnostic if we receive `bool` - default - } - }; - - let mut metadata = DataclassMetadata::empty(); + let mut params = DataclassParams::empty(); if to_bool(init, true) { - metadata |= DataclassMetadata::INIT; + params |= DataclassParams::INIT; } if to_bool(repr, true) { - metadata |= DataclassMetadata::REPR; + params |= DataclassParams::REPR; } if to_bool(eq, true) { - metadata |= DataclassMetadata::EQ; + params |= DataclassParams::EQ; } if to_bool(order, false) { - metadata |= DataclassMetadata::ORDER; + params |= DataclassParams::ORDER; } if to_bool(unsafe_hash, false) { - metadata |= DataclassMetadata::UNSAFE_HASH; + params |= DataclassParams::UNSAFE_HASH; } if to_bool(frozen, false) { - metadata |= DataclassMetadata::FROZEN; + params |= DataclassParams::FROZEN; } if to_bool(match_args, true) { - metadata |= DataclassMetadata::MATCH_ARGS; + params |= DataclassParams::MATCH_ARGS; } if to_bool(kw_only, false) { - metadata |= DataclassMetadata::KW_ONLY; + params |= DataclassParams::KW_ONLY; } if to_bool(slots, false) { - metadata |= DataclassMetadata::SLOTS; + params |= DataclassParams::SLOTS; } if to_bool(weakref_slot, false) { - metadata |= DataclassMetadata::WEAKREF_SLOT; + params |= DataclassParams::WEAKREF_SLOT; } - overload.set_return_type(Type::DataclassDecorator(metadata)); + overload.set_return_type(Type::DataclassDecorator(params)); } } - _ => {} + Some(KnownFunction::DataclassTransform) => { + if let [eq_default, order_default, kw_only_default, frozen_default, _field_specifiers, _kwargs] = + overload.parameter_types() + { + let mut params = DataclassTransformerParams::empty(); + + if to_bool(eq_default, true) { + params |= DataclassTransformerParams::EQ_DEFAULT; + } + if to_bool(order_default, false) { + params |= DataclassTransformerParams::ORDER_DEFAULT; + } + if to_bool(kw_only_default, false) { + params |= DataclassTransformerParams::KW_ONLY_DEFAULT; + } + if to_bool(frozen_default, false) { + params |= DataclassTransformerParams::FROZEN_DEFAULT; + } + + overload.set_return_type(Type::DataclassTransformer(params)); + } + } + + _ => { + if let Some(params) = function_type.dataclass_transformer_params(db) { + // This is a call to a custom function that was decorated with `@dataclass_transformer`. + // If this function was called with a keyword argument like `order=False`, we extract + // the argument type and overwrite the corresponding flag in `dataclass_params` after + // constructing them from the `dataclass_transformer`-parameter defaults. + + let mut dataclass_params = DataclassParams::from(params); + + if let Some(Some(Type::BooleanLiteral(order))) = callable_signature + .iter() + .nth(overload_index) + .and_then(|signature| { + let (idx, _) = + signature.parameters().keyword_by_name("order")?; + overload.parameter_types().get(idx) + }) + { + dataclass_params.set(DataclassParams::ORDER, *order); + } + + overload.set_return_type(Type::DataclassDecorator(dataclass_params)); + } + } }, Type::ClassLiteral(class) => match class.known(db) { diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 23a4a1fde6..e8b92667ac 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -10,7 +10,7 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::DeclarationWithConstraint; use crate::types::generics::{GenericContext, Specialization}; use crate::types::signatures::{Parameter, Parameters}; -use crate::types::{CallableType, DataclassMetadata, Signature}; +use crate::types::{CallableType, DataclassParams, DataclassTransformerParams, Signature}; use crate::{ module_resolver::file_to_module, semantic_index::{ @@ -106,7 +106,8 @@ pub struct Class<'db> { pub(crate) known: Option, - pub(crate) dataclass_metadata: Option, + pub(crate) dataclass_params: Option, + pub(crate) dataclass_transformer_params: Option, } impl<'db> Class<'db> { @@ -469,8 +470,8 @@ impl<'db> ClassLiteralType<'db> { self.class(db).known } - pub(crate) fn dataclass_metadata(self, db: &'db dyn Db) -> Option { - self.class(db).dataclass_metadata + pub(crate) fn dataclass_params(self, db: &'db dyn Db) -> Option { + self.class(db).dataclass_params } /// Return `true` if this class represents `known_class` @@ -699,6 +700,7 @@ impl<'db> ClassLiteralType<'db> { /// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred. pub(super) fn metaclass(self, db: &'db dyn Db) -> Type<'db> { self.try_metaclass(db) + .map(|(ty, _)| ty) .unwrap_or_else(|_| SubclassOfType::subclass_of_unknown()) } @@ -712,7 +714,10 @@ impl<'db> ClassLiteralType<'db> { /// Return the metaclass of this class, or an error if the metaclass cannot be inferred. #[salsa::tracked] - pub(super) fn try_metaclass(self, db: &'db dyn Db) -> Result, MetaclassError<'db>> { + pub(super) fn try_metaclass( + self, + db: &'db dyn Db, + ) -> Result<(Type<'db>, Option), MetaclassError<'db>> { let class = self.class(db); tracing::trace!("ClassLiteralType::try_metaclass: {}", class.name); @@ -723,7 +728,7 @@ impl<'db> ClassLiteralType<'db> { // We emit diagnostics for cyclic class definitions elsewhere. // Avoid attempting to infer the metaclass if the class is cyclically defined: // it would be easy to enter an infinite loop. - return Ok(SubclassOfType::subclass_of_unknown()); + return Ok((SubclassOfType::subclass_of_unknown(), None)); } let explicit_metaclass = self.explicit_metaclass(db); @@ -768,7 +773,7 @@ impl<'db> ClassLiteralType<'db> { }), }; - return return_ty_result.map(|ty| ty.to_meta_type(db)); + return return_ty_result.map(|ty| (ty.to_meta_type(db), None)); }; // Reconcile all base classes' metaclasses with the candidate metaclass. @@ -805,7 +810,10 @@ impl<'db> ClassLiteralType<'db> { }); } - Ok(candidate.metaclass.into()) + Ok(( + candidate.metaclass.into(), + candidate.metaclass.class(db).dataclass_transformer_params, + )) } /// Returns the class member of this class named `name`. @@ -969,12 +977,8 @@ impl<'db> ClassLiteralType<'db> { }); if symbol.symbol.is_unbound() { - if let Some(metadata) = self.dataclass_metadata(db) { - if let Some(dataclass_member) = - self.own_dataclass_member(db, specialization, metadata, name) - { - return Symbol::bound(dataclass_member).into(); - } + if let Some(dataclass_member) = self.own_dataclass_member(db, specialization, name) { + return Symbol::bound(dataclass_member).into(); } } @@ -986,70 +990,97 @@ impl<'db> ClassLiteralType<'db> { self, db: &'db dyn Db, specialization: Option>, - metadata: DataclassMetadata, name: &str, ) -> Option> { - if name == "__init__" && metadata.contains(DataclassMetadata::INIT) { - let mut parameters = vec![]; + let params = self.dataclass_params(db); + let has_dataclass_param = |param| params.is_some_and(|params| params.contains(param)); - for (name, (mut attr_ty, mut default_ty)) in self.dataclass_fields(db, specialization) { - // The descriptor handling below is guarded by this fully-static check, because dynamic - // types like `Any` are valid (data) descriptors: since they have all possible attributes, - // they also have a (callable) `__set__` method. The problem is that we can't determine - // the type of the value parameter this way. Instead, we want to use the dynamic type - // itself in this case, so we skip the special descriptor handling. - if attr_ty.is_fully_static(db) { - let dunder_set = attr_ty.class_member(db, "__set__".into()); - if let Some(dunder_set) = dunder_set.symbol.ignore_possibly_unbound() { - // This type of this attribute is a data descriptor. Instead of overwriting the - // descriptor attribute, data-classes will (implicitly) call the `__set__` method - // of the descriptor. This means that the synthesized `__init__` parameter for - // this attribute is determined by possible `value` parameter types with which - // the `__set__` method can be called. We build a union of all possible options - // to account for possible overloads. - let mut value_types = UnionBuilder::new(db); - for signature in &dunder_set.signatures(db) { - for overload in signature { - if let Some(value_param) = overload.parameters().get_positional(2) { - value_types = value_types.add( - value_param.annotated_type().unwrap_or_else(Type::unknown), - ); - } else if overload.parameters().is_gradual() { - value_types = value_types.add(Type::unknown()); + match name { + "__init__" => { + let has_synthesized_dunder_init = has_dataclass_param(DataclassParams::INIT) + || self + .try_metaclass(db) + .is_ok_and(|(_, transformer_params)| transformer_params.is_some()); + + if !has_synthesized_dunder_init { + return None; + } + + let mut parameters = vec![]; + + for (name, (mut attr_ty, mut default_ty)) in + self.dataclass_fields(db, specialization) + { + // The descriptor handling below is guarded by this fully-static check, because dynamic + // types like `Any` are valid (data) descriptors: since they have all possible attributes, + // they also have a (callable) `__set__` method. The problem is that we can't determine + // the type of the value parameter this way. Instead, we want to use the dynamic type + // itself in this case, so we skip the special descriptor handling. + if attr_ty.is_fully_static(db) { + let dunder_set = attr_ty.class_member(db, "__set__".into()); + if let Some(dunder_set) = dunder_set.symbol.ignore_possibly_unbound() { + // This type of this attribute is a data descriptor. Instead of overwriting the + // descriptor attribute, data-classes will (implicitly) call the `__set__` method + // of the descriptor. This means that the synthesized `__init__` parameter for + // this attribute is determined by possible `value` parameter types with which + // the `__set__` method can be called. We build a union of all possible options + // to account for possible overloads. + let mut value_types = UnionBuilder::new(db); + for signature in &dunder_set.signatures(db) { + for overload in signature { + if let Some(value_param) = + overload.parameters().get_positional(2) + { + value_types = value_types.add( + value_param + .annotated_type() + .unwrap_or_else(Type::unknown), + ); + } else if overload.parameters().is_gradual() { + value_types = value_types.add(Type::unknown()); + } } } - } - attr_ty = value_types.build(); + attr_ty = value_types.build(); - // The default value of the attribute is *not* determined by the right hand side - // of the class-body assignment. Instead, the runtime invokes `__get__` on the - // descriptor, as if it had been called on the class itself, i.e. it passes `None` - // for the `instance` argument. + // The default value of the attribute is *not* determined by the right hand side + // of the class-body assignment. Instead, the runtime invokes `__get__` on the + // descriptor, as if it had been called on the class itself, i.e. it passes `None` + // for the `instance` argument. - if let Some(ref mut default_ty) = default_ty { - *default_ty = default_ty - .try_call_dunder_get(db, Type::none(db), Type::ClassLiteral(self)) - .map(|(return_ty, _)| return_ty) - .unwrap_or_else(Type::unknown); + if let Some(ref mut default_ty) = default_ty { + *default_ty = default_ty + .try_call_dunder_get( + db, + Type::none(db), + Type::ClassLiteral(self), + ) + .map(|(return_ty, _)| return_ty) + .unwrap_or_else(Type::unknown); + } } } + + let mut parameter = + Parameter::positional_or_keyword(name).with_annotated_type(attr_ty); + + if let Some(default_ty) = default_ty { + parameter = parameter.with_default_type(default_ty); + } + + parameters.push(parameter); } - let mut parameter = - Parameter::positional_or_keyword(name).with_annotated_type(attr_ty); + let init_signature = + Signature::new(Parameters::new(parameters), Some(Type::none(db))); - if let Some(default_ty) = default_ty { - parameter = parameter.with_default_type(default_ty); - } - - parameters.push(parameter); + Some(Type::Callable(CallableType::single(db, init_signature))) } + "__lt__" | "__le__" | "__gt__" | "__ge__" => { + if !has_dataclass_param(DataclassParams::ORDER) { + return None; + } - let init_signature = Signature::new(Parameters::new(parameters), Some(Type::none(db))); - - return Some(Type::Callable(CallableType::single(db, init_signature))); - } else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { - if metadata.contains(DataclassMetadata::ORDER) { let signature = Signature::new( Parameters::new([Parameter::positional_or_keyword(Name::new_static("other")) // TODO: could be `Self`. @@ -1059,11 +1090,17 @@ impl<'db> ClassLiteralType<'db> { Some(KnownClass::Bool.to_instance(db)), ); - return Some(Type::Callable(CallableType::single(db, signature))); + Some(Type::Callable(CallableType::single(db, signature))) } + _ => None, } + } - None + fn is_dataclass(self, db: &'db dyn Db) -> bool { + self.dataclass_params(db).is_some() + || self + .try_metaclass(db) + .is_ok_and(|(_, transformer_params)| transformer_params.is_some()) } /// Returns a list of all annotated attributes defined in this class, or any of its superclasses. @@ -1079,7 +1116,7 @@ impl<'db> ClassLiteralType<'db> { .filter_map(|superclass| { if let Some(class) = superclass.into_class() { let class_literal = class.class_literal(db).0; - if class_literal.dataclass_metadata(db).is_some() { + if class_literal.is_dataclass(db) { Some(class_literal) } else { None diff --git a/crates/red_knot_python_semantic/src/types/class_base.rs b/crates/red_knot_python_semantic/src/types/class_base.rs index 5f342a3f3a..6142cd322f 100644 --- a/crates/red_knot_python_semantic/src/types/class_base.rs +++ b/crates/red_knot_python_semantic/src/types/class_base.rs @@ -90,6 +90,7 @@ impl<'db> ClassBase<'db> { | Type::MethodWrapper(_) | Type::WrapperDescriptor(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::BytesLiteral(_) | Type::IntLiteral(_) | Type::StringLiteral(_) diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index b420d3ef72..20ca22bd98 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -195,7 +195,10 @@ impl Display for DisplayRepresentation<'_> { write!(f, "") } Type::DataclassDecorator(_) => { - f.write_str("") + f.write_str("") + } + Type::DataclassTransformer(_) => { + f.write_str("") } Type::Union(union) => union.display(self.db).fmt(f), Type::Intersection(intersection) => intersection.display(self.db).fmt(f), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 466fe9be5a..13e20cffd4 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -82,7 +82,7 @@ use crate::types::mro::MroErrorKind; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class, - ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, + ClassLiteralType, ClassType, DataclassParams, DynamicType, FunctionDecorators, FunctionType, GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType, @@ -1457,6 +1457,7 @@ impl<'db> TypeInferenceBuilder<'db> { let mut decorator_types_and_nodes = Vec::with_capacity(decorator_list.len()); let mut function_decorators = FunctionDecorators::empty(); + let mut dataclass_transformer_params = None; for decorator in decorator_list { let decorator_ty = self.infer_decorator(decorator); @@ -1477,6 +1478,8 @@ impl<'db> TypeInferenceBuilder<'db> { function_decorators |= FunctionDecorators::CLASSMETHOD; continue; } + } else if let Type::DataclassTransformer(params) = decorator_ty { + dataclass_transformer_params = Some(params); } decorator_types_and_nodes.push((decorator_ty, decorator)); @@ -1523,6 +1526,7 @@ impl<'db> TypeInferenceBuilder<'db> { function_kind, body_scope, function_decorators, + dataclass_transformer_params, generic_context, specialization, )); @@ -1757,19 +1761,32 @@ impl<'db> TypeInferenceBuilder<'db> { body: _, } = class_node; - let mut dataclass_metadata = None; + let mut dataclass_params = None; + let mut dataclass_transformer_params = None; for decorator in decorator_list { let decorator_ty = self.infer_decorator(decorator); if decorator_ty .into_function_literal() .is_some_and(|function| function.is_known(self.db(), KnownFunction::Dataclass)) { - dataclass_metadata = Some(DataclassMetadata::default()); + dataclass_params = Some(DataclassParams::default()); continue; } - if let Type::DataclassDecorator(metadata) = decorator_ty { - dataclass_metadata = Some(metadata); + if let Type::DataclassDecorator(params) = decorator_ty { + dataclass_params = Some(params); + continue; + } + + if let Type::FunctionLiteral(f) = decorator_ty { + if let Some(params) = f.dataclass_transformer_params(self.db()) { + dataclass_params = Some(params.into()); + continue; + } + } + + if let Type::DataclassTransformer(params) = decorator_ty { + dataclass_transformer_params = Some(params); continue; } } @@ -1789,7 +1806,8 @@ impl<'db> TypeInferenceBuilder<'db> { name: name.id.clone(), body_scope, known: maybe_known_class, - dataclass_metadata, + dataclass_params, + dataclass_transformer_params, }; let class_literal = match generic_context { Some(generic_context) => { @@ -2502,6 +2520,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::MethodWrapper(_) | Type::WrapperDescriptor(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::TypeVar(..) | Type::AlwaysTruthy | Type::AlwaysFalsy => { @@ -4882,6 +4901,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::BoundMethod(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) @@ -5164,6 +5184,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::GenericAlias(_) @@ -5188,6 +5209,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::GenericAlias(_) diff --git a/crates/red_knot_python_semantic/src/types/type_ordering.rs b/crates/red_knot_python_semantic/src/types/type_ordering.rs index 8f69687465..111b758bc7 100644 --- a/crates/red_knot_python_semantic/src/types/type_ordering.rs +++ b/crates/red_knot_python_semantic/src/types/type_ordering.rs @@ -79,6 +79,12 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::DataclassDecorator(_), _) => Ordering::Less, (_, Type::DataclassDecorator(_)) => Ordering::Greater, + (Type::DataclassTransformer(left), Type::DataclassTransformer(right)) => { + left.bits().cmp(&right.bits()) + } + (Type::DataclassTransformer(_), _) => Ordering::Less, + (_, Type::DataclassTransformer(_)) => Ordering::Greater, + (Type::Callable(left), Type::Callable(right)) => { debug_assert_eq!(*left, left.normalized(db)); debug_assert_eq!(*right, right.normalized(db));