diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index 79b3e3c49a..99f75f6517 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -478,11 +478,8 @@ class Person: name: str = fancy_field() age: int | None = fancy_field(kw_only=True) -# TODO: Should be `(self: Person, name: str, *, age: int | None) -> None` -reveal_type(Person.__init__) # revealed: (self: Person, id: int = Any, name: str = Any, age: int | None = Any) -> None +reveal_type(Person.__init__) # revealed: (self: Person, name: str = Unknown, *, age: int | None = Unknown) -> None -# TODO: No error here -# error: [invalid-argument-type] alice = Person("Alice", age=30) reveal_type(alice.id) # revealed: int diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index d6dba6f194..aa3fa82fc4 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -619,8 +619,8 @@ impl<'db> PropertyInstanceType<'db> { } bitflags! { - /// Used for the return type of `dataclass(…)` calls. Keeps track of the arguments - /// that were passed in. For the precise meaning of the fields, see [1]. + /// Used to store metadata about a dataclass or dataclass-like class. + /// For the precise meaning of the fields, see [1]. /// /// [1]: https://docs.python.org/3/library/dataclasses.html #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -671,11 +671,16 @@ impl From for DataclassFlags { } } +/// Metadata for a dataclass. Stored inside a `Type::DataclassDecorator(…)` +/// instance that we use as the return type of a `dataclasses.dataclass` and +/// dataclass-transformer decorator calls. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] #[derive(PartialOrd, Ord)] pub struct DataclassParams<'db> { flags: DataclassFlags, - field_specifiers: Type<'db>, + + #[returns(deref)] + field_specifiers: Box<[Type<'db>]>, } impl get_size2::GetSize for DataclassParams<'_> {} @@ -691,7 +696,7 @@ impl<'db> DataclassParams<'db> { .ignore_possibly_unbound() .unwrap_or_else(|| Type::none(db)); - Self::new(db, flags, dataclasses_field) + Self::new(db, flags, vec![dataclasses_field].into_boxed_slice()) } fn from_transformer_params(db: &'db dyn Db, params: DataclassTransformerParams<'db>) -> Self { @@ -5464,7 +5469,7 @@ impl<'db> Type<'db> { ) -> Result, CallError<'db>> { self.bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types, &TypeContext::default()) + .check_types(db, argument_types, &TypeContext::default(), &[]) } /// Look up a dunder method on the meta-type of `self` and call it. @@ -5516,7 +5521,7 @@ impl<'db> Type<'db> { let bindings = dunder_callable .bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types, &tcx)?; + .check_types(db, argument_types, &tcx, &[])?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index b78c776a43..e04bc30348 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -136,6 +136,7 @@ impl<'db> Bindings<'db> { db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>, call_expression_tcx: &TypeContext<'db>, + dataclass_field_specifiers: &[Type<'db>], ) -> Result> { for element in &mut self.elements { if let Some(mut updated_argument_forms) = @@ -148,7 +149,7 @@ impl<'db> Bindings<'db> { } } - self.evaluate_known_cases(db); + self.evaluate_known_cases(db, dataclass_field_specifiers); // In order of precedence: // @@ -270,7 +271,7 @@ impl<'db> Bindings<'db> { /// Evaluates the return type of certain known callables, where we have special-case logic to /// determine the return type in a way that isn't directly expressible in the type system. - fn evaluate_known_cases(&mut self, db: &'db dyn Db) { + fn evaluate_known_cases(&mut self, db: &'db dyn Db, dataclass_field_specifiers: &[Type<'db>]) { let to_bool = |ty: &Option>, default: bool| -> bool { if let Some(Type::BooleanLiteral(value)) = ty { *value @@ -597,6 +598,48 @@ impl<'db> Bindings<'db> { } } + function @ Type::FunctionLiteral(function_type) + if dataclass_field_specifiers.contains(&function) + || function_type.is_known(db, KnownFunction::Field) => + { + let default = overload.parameter_type_by_name("default").unwrap_or(None); + let default_factory = overload + .parameter_type_by_name("default_factory") + .unwrap_or(None); + let init = overload.parameter_type_by_name("init").unwrap_or(None); + let kw_only = overload.parameter_type_by_name("kw_only").unwrap_or(None); + + let default_ty = match (default, default_factory) { + (Some(default_ty), _) => default_ty, + (_, Some(default_factory_ty)) => default_factory_ty + .try_call(db, &CallArguments::none()) + .map_or(Type::unknown(), |binding| binding.return_type(db)), + _ => Type::unknown(), + }; + + let init = init + .map(|init| !init.bool(db).is_always_false()) + .unwrap_or(true); + + let kw_only = if Program::get(db).python_version(db) >= PythonVersion::PY310 + { + kw_only.map(|kw_only| !kw_only.bool(db).is_always_false()) + } else { + None + }; + + // `typeshed` pretends that `dataclasses.field()` returns the type of the + // default value directly. At runtime, however, this function returns an + // instance of `dataclasses.Field`. We also model it this way and return + // a known-instance type with information about the field. The drawback + // of this approach is that we need to pretend that instances of `Field` + // are assignable to `T` if the default type of the field is assignable + // to `T`. Otherwise, we would error on `name: str = field(default="")`. + overload.set_return_type(Type::KnownInstance(KnownInstanceType::Field( + FieldInstance::new(db, default_ty, init, kw_only), + ))); + } + Type::FunctionLiteral(function_type) => match function_type.known(db) { Some(KnownFunction::IsEquivalentTo) => { if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { @@ -956,59 +999,24 @@ impl<'db> Bindings<'db> { flags |= DataclassTransformerFlags::FROZEN_DEFAULT; } - let params = DataclassTransformerParams::new( - db, - flags, - field_specifiers.unwrap_or(Type::none(db)), - ); + let field_specifiers: Box<[Type<'db>]> = field_specifiers + .map(|tuple_type| { + tuple_type + .exact_tuple_instance_spec(db) + .iter() + .flat_map(|tuple_spec| tuple_spec.fixed_elements()) + .copied() + .collect() + }) + .unwrap_or_default(); + + let params = + DataclassTransformerParams::new(db, flags, field_specifiers); overload.set_return_type(Type::DataclassTransformer(params)); } } - Some(KnownFunction::Field) => { - let default = - overload.parameter_type_by_name("default").unwrap_or(None); - let default_factory = overload - .parameter_type_by_name("default_factory") - .unwrap_or(None); - let init = overload.parameter_type_by_name("init").unwrap_or(None); - let kw_only = - overload.parameter_type_by_name("kw_only").unwrap_or(None); - - let default_ty = match (default, default_factory) { - (Some(default_ty), _) => default_ty, - (_, Some(default_factory_ty)) => default_factory_ty - .try_call(db, &CallArguments::none()) - .map_or(Type::unknown(), |binding| binding.return_type(db)), - _ => Type::unknown(), - }; - - let init = init - .map(|init| !init.bool(db).is_always_false()) - .unwrap_or(true); - - let kw_only = - if Program::get(db).python_version(db) >= PythonVersion::PY310 { - kw_only.map(|kw_only| !kw_only.bool(db).is_always_false()) - } else { - None - }; - - // `typeshed` pretends that `dataclasses.field()` returns the type of the - // default value directly. At runtime, however, this function returns an - // instance of `dataclasses.Field`. We also model it this way and return - // a known-instance type with information about the field. The drawback - // of this approach is that we need to pretend that instances of `Field` - // are assignable to `T` if the default type of the field is assignable - // to `T`. Otherwise, we would error on `name: str = field(default="")`. - overload.set_return_type(Type::KnownInstance( - KnownInstanceType::Field(FieldInstance::new( - db, default_ty, init, kw_only, - )), - )); - } - _ => { // Ideally, either the implementation, or exactly one of the overloads // of the function can have the dataclass_transform decorator applied. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index be9f311616..ab12da49ce 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -2900,7 +2900,7 @@ impl<'db> ClassLiteral<'db> { default_ty = Some(field.default_type(db)); if self .dataclass_params(db) - .map(|params| params.field_specifiers(db).is_none(db)) + .map(|params| params.field_specifiers(db).is_empty()) .unwrap_or(false) { // This happens when constructing a `dataclass` with a `dataclass_transform` diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index dadafeede3..1a3a667656 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -169,11 +169,15 @@ impl Default for DataclassTransformerFlags { } } +/// Metadata for a dataclass-transformer. Stored inside a `Type::DataclassTransformer(…)` +/// instance that we use as the return type for `dataclass_transform(…)` calls. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] #[derive(PartialOrd, Ord)] pub struct DataclassTransformerParams<'db> { pub flags: DataclassTransformerFlags, - pub field_specifiers: Type<'db>, + + #[returns(deref)] + pub field_specifiers: Box<[Type<'db>]>, } impl get_size2::GetSize for DataclassTransformerParams<'_> {} diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index fd274da3ca..a5588d7140 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -9,6 +9,7 @@ use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, PythonVersion}; use ruff_python_stdlib::builtins::version_builtin_was_added; use ruff_text_size::{Ranged, TextRange}; use rustc_hash::{FxHashMap, FxHashSet}; +use smallvec::SmallVec; use super::{ CycleRecovery, DefinitionInference, DefinitionInferenceExtra, ExpressionInference, @@ -272,6 +273,10 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, + + /// A list of `dataclass_transform` field specifiers that are "active" (when inferring + /// the right hand side of an annotated assignment in a class that is a dataclass). + dataclass_field_specifiers: SmallVec<[Type<'db>; 2]>, } impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { @@ -307,6 +312,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { undecorated_type: None, cycle_recovery: None, all_definitely_bound: true, + dataclass_field_specifiers: SmallVec::new(), } } @@ -4512,10 +4518,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { debug_assert!(PlaceExpr::try_from_expr(target).is_some()); if let Some(value) = value { + fn field_specifiers<'db>( + db: &'db dyn Db, + index: &'db SemanticIndex<'db>, + scope: ScopeId<'db>, + ) -> Option; 2]>> { + let enclosing_scope = index.scope(scope.file_scope_id(db)); + + if enclosing_scope.kind() != ScopeKind::Class { + return None; + } + + let class_node = enclosing_scope.node().as_class()?; + + let class_definition = index.expect_single_definition(class_node); + infer_definition_types(db, class_definition) + .declaration_type(class_definition) + .inner_type() + .as_class_literal()? + .dataclass_params(db) + .map(|params| SmallVec::from(params.field_specifiers(db))) + } + + if let Some(specifiers) = field_specifiers(self.db(), self.index, self.scope()) { + self.dataclass_field_specifiers = specifiers; + } + let inferred_ty = self.infer_maybe_standalone_expression( value, TypeContext::new(Some(declared.inner_type())), ); + + self.dataclass_field_specifiers.clear(); + let inferred_ty = if target .as_name_expr() .is_some_and(|name| &name.id == "TYPE_CHECKING") @@ -6631,7 +6666,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) { + let mut bindings = match bindings.check_types( + self.db(), + &call_arguments, + &tcx, + &self.dataclass_field_specifiers[..], + ) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => { bindings.report_diagnostics(&self.context, call_expression.into()); @@ -9218,8 +9258,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let binding = Binding::single(value_ty, generic_context.signature(self.db())); let bindings = match Bindings::from(binding) .match_parameters(self.db(), &call_argument_types) - .check_types(self.db(), &call_argument_types, &TypeContext::default()) - { + .check_types( + self.db(), + &call_argument_types, + &TypeContext::default(), + &self.dataclass_field_specifiers[..], + ) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => { bindings.report_diagnostics(&self.context, subscript.into()); @@ -9751,6 +9795,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { deferred, cycle_recovery, all_definitely_bound, + dataclass_field_specifiers: _, // Ignored; only relevant to definition regions undecorated_type: _, @@ -9817,8 +9862,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { deferred, cycle_recovery, undecorated_type, - all_definitely_bound: _, // builder only state + dataclass_field_specifiers: _, + all_definitely_bound: _, typevar_binding_context: _, deferred_state: _, multi_inference_state: _, @@ -9885,12 +9931,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { deferred: _, bindings: _, declarations: _, - all_definitely_bound: _, // Ignored; only relevant to definition regions undecorated_type: _, // Builder only state + dataclass_field_specifiers: _, + all_definitely_bound: _, typevar_binding_context: _, deferred_state: _, multi_inference_state: _,