From 3dedd70a92ae4e156df6b26c52628f7254b28a1d Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Bodas <55339528+abhijeetbodas2001@users.noreply.github.com> Date: Wed, 7 May 2025 19:21:13 +0530 Subject: [PATCH] [ty] Detect overloads decorated with `@dataclass_transform` (#17835) ## Summary Fixes #17541 Before this change, in the case of overloaded functions, `@dataclass_transform` was detected only when applied to the implementation, not the overloads. However, the spec also allows this decorator to be applied to any of the overloads as well. With this PR, we start handling `@dataclass_transform`s applied to overloads. ## Test Plan Fixed existing TODOs in the test suite. --- .../resources/mdtest/dataclass_transform.md | 10 ++- .../ty_python_semantic/src/types/call/bind.rs | 63 ++++++++++++------- crates/ty_python_semantic/src/types/infer.rs | 11 ++++ 3 files changed, 57 insertions(+), 27 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclass_transform.md index a57fcd612e..879557ad3c 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclass_transform.md @@ -281,13 +281,11 @@ class D1: class D2: x: str -# TODO: these should not be errors -D1("a") # error: [too-many-positional-arguments] -D2("a") # error: [too-many-positional-arguments] +D1("a") +D2("a") -# TODO: these should be invalid-argument-type errors -D1(1.2) # error: [too-many-positional-arguments] -D2(1.2) # error: [too-many-positional-arguments] +D1(1.2) # error: [invalid-argument-type] +D2(1.2) # error: [invalid-argument-type] ``` [`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 2bb1bc0730..8f28f5f141 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder, Specializati use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ todo_type, BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, - KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, - TupleType, UnionType, WrapperDescriptorKind, + FunctionType, KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, + PropertyInstanceType, TupleType, UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, SubDiagnostic}; use ruff_python_ast as ast; @@ -770,29 +770,50 @@ impl<'db> Bindings<'db> { } _ => { - if let Some(params) = function_type.dataclass_transformer_params(db) { - // This is a call to a custom function that was decorated with `@dataclass_transformer`. - // If this function was called with a keyword argument like `order=False`, we extract - // the argument type and overwrite the corresponding flag in `dataclass_params` after - // constructing them from the `dataclass_transformer`-parameter defaults. + let mut handle_dataclass_transformer_params = + |function_type: &FunctionType| { + if let Some(params) = + function_type.dataclass_transformer_params(db) + { + // This is a call to a custom function that was decorated with `@dataclass_transformer`. + // If this function was called with a keyword argument like `order=False`, we extract + // the argument type and overwrite the corresponding flag in `dataclass_params` after + // constructing them from the `dataclass_transformer`-parameter defaults. - let mut dataclass_params = DataclassParams::from(params); + let mut dataclass_params = DataclassParams::from(params); - if let Some(Some(Type::BooleanLiteral(order))) = callable_signature + if let Some(Some(Type::BooleanLiteral(order))) = + callable_signature.iter().nth(overload_index).and_then( + |signature| { + let (idx, _) = signature + .parameters() + .keyword_by_name("order")?; + overload.parameter_types().get(idx) + }, + ) + { + dataclass_params.set(DataclassParams::ORDER, *order); + } + + overload.set_return_type(Type::DataclassDecorator( + dataclass_params, + )); + } + }; + + // Ideally, either the implementation, or exactly one of the overloads + // of the function can have the dataclass_transform decorator applied. + // However, we do not yet enforce this, and in the case of multiple + // applications of the decorator, we will only consider the last one + // for the return value, since the prior ones will be over-written. + if let Some(overloaded) = function_type.to_overloaded(db) { + overloaded + .overloads .iter() - .nth(overload_index) - .and_then(|signature| { - let (idx, _) = - signature.parameters().keyword_by_name("order")?; - overload.parameter_types().get(idx) - }) - { - dataclass_params.set(DataclassParams::ORDER, *order); - } - - overload - .set_return_type(Type::DataclassDecorator(dataclass_params)); + .for_each(&mut handle_dataclass_transformer_params); } + + handle_dataclass_transformer_params(&function_type); } }, diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b2451b0ab3..c31e3d3a30 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -2121,6 +2121,17 @@ impl<'db> TypeInferenceBuilder<'db> { } if let Type::FunctionLiteral(f) = decorator_ty { + // We do not yet detect or flag `@dataclass_transform` applied to more than one + // overload, or an overload and the implementation both. Nevertheless, this is not + // allowed. We do not try to treat the offenders intelligently -- just use the + // params of the last seen usage of `@dataclass_transform` + if let Some(overloaded) = f.to_overloaded(self.db()) { + overloaded.overloads.iter().for_each(|overload| { + if let Some(params) = overload.dataclass_transformer_params(self.db()) { + dataclass_params = Some(params.into()); + } + }); + } if let Some(params) = f.dataclass_transformer_params(self.db()) { dataclass_params = Some(params.into()); continue;