Recognize custom field-specifier functions

This commit is contained in:
David Peter 2025-10-15 15:18:33 +02:00
parent 4dc88a0a5f
commit 23543194fc
6 changed files with 128 additions and 67 deletions

View File

@ -478,11 +478,8 @@ class Person:
name: str = fancy_field()
age: int | None = fancy_field(kw_only=True)
# TODO: Should be `(self: Person, name: str, *, age: int | None) -> None`
reveal_type(Person.__init__) # revealed: (self: Person, id: int = Any, name: str = Any, age: int | None = Any) -> None
reveal_type(Person.__init__) # revealed: (self: Person, name: str = Unknown, *, age: int | None = Unknown) -> None
# TODO: No error here
# error: [invalid-argument-type]
alice = Person("Alice", age=30)
reveal_type(alice.id) # revealed: int

View File

@ -619,8 +619,8 @@ impl<'db> PropertyInstanceType<'db> {
}
bitflags! {
/// Used for the return type of `dataclass(…)` calls. Keeps track of the arguments
/// that were passed in. For the precise meaning of the fields, see [1].
/// Used to store metadata about a dataclass or dataclass-like class.
/// For the precise meaning of the fields, see [1].
///
/// [1]: https://docs.python.org/3/library/dataclasses.html
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
@ -671,11 +671,16 @@ impl From<DataclassTransformerFlags> for DataclassFlags {
}
}
/// Metadata for a dataclass. Stored inside a `Type::DataclassDecorator(…)`
/// instance that we use as the return type of a `dataclasses.dataclass` and
/// dataclass-transformer decorator calls.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub struct DataclassParams<'db> {
flags: DataclassFlags,
field_specifiers: Type<'db>,
#[returns(deref)]
field_specifiers: Box<[Type<'db>]>,
}
impl get_size2::GetSize for DataclassParams<'_> {}
@ -691,7 +696,7 @@ impl<'db> DataclassParams<'db> {
.ignore_possibly_unbound()
.unwrap_or_else(|| Type::none(db));
Self::new(db, flags, dataclasses_field)
Self::new(db, flags, vec![dataclasses_field].into_boxed_slice())
}
fn from_transformer_params(db: &'db dyn Db, params: DataclassTransformerParams<'db>) -> Self {
@ -5464,7 +5469,7 @@ impl<'db> Type<'db> {
) -> Result<Bindings<'db>, CallError<'db>> {
self.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types, &TypeContext::default())
.check_types(db, argument_types, &TypeContext::default(), &[])
}
/// Look up a dunder method on the meta-type of `self` and call it.
@ -5516,7 +5521,7 @@ impl<'db> Type<'db> {
let bindings = dunder_callable
.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types, &tcx)?;
.check_types(db, argument_types, &tcx, &[])?;
if boundness == Boundness::PossiblyUnbound {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
}

View File

@ -136,6 +136,7 @@ impl<'db> Bindings<'db> {
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
dataclass_field_specifiers: &[Type<'db>],
) -> Result<Self, CallError<'db>> {
for element in &mut self.elements {
if let Some(mut updated_argument_forms) =
@ -148,7 +149,7 @@ impl<'db> Bindings<'db> {
}
}
self.evaluate_known_cases(db);
self.evaluate_known_cases(db, dataclass_field_specifiers);
// In order of precedence:
//
@ -270,7 +271,7 @@ impl<'db> Bindings<'db> {
/// Evaluates the return type of certain known callables, where we have special-case logic to
/// determine the return type in a way that isn't directly expressible in the type system.
fn evaluate_known_cases(&mut self, db: &'db dyn Db) {
fn evaluate_known_cases(&mut self, db: &'db dyn Db, dataclass_field_specifiers: &[Type<'db>]) {
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
if let Some(Type::BooleanLiteral(value)) = ty {
*value
@ -597,6 +598,48 @@ impl<'db> Bindings<'db> {
}
}
function @ Type::FunctionLiteral(function_type)
if dataclass_field_specifiers.contains(&function)
|| function_type.is_known(db, KnownFunction::Field) =>
{
let default = overload.parameter_type_by_name("default").unwrap_or(None);
let default_factory = overload
.parameter_type_by_name("default_factory")
.unwrap_or(None);
let init = overload.parameter_type_by_name("init").unwrap_or(None);
let kw_only = overload.parameter_type_by_name("kw_only").unwrap_or(None);
let default_ty = match (default, default_factory) {
(Some(default_ty), _) => default_ty,
(_, Some(default_factory_ty)) => default_factory_ty
.try_call(db, &CallArguments::none())
.map_or(Type::unknown(), |binding| binding.return_type(db)),
_ => Type::unknown(),
};
let init = init
.map(|init| !init.bool(db).is_always_false())
.unwrap_or(true);
let kw_only = if Program::get(db).python_version(db) >= PythonVersion::PY310
{
kw_only.map(|kw_only| !kw_only.bool(db).is_always_false())
} else {
None
};
// `typeshed` pretends that `dataclasses.field()` returns the type of the
// default value directly. At runtime, however, this function returns an
// instance of `dataclasses.Field`. We also model it this way and return
// a known-instance type with information about the field. The drawback
// of this approach is that we need to pretend that instances of `Field`
// are assignable to `T` if the default type of the field is assignable
// to `T`. Otherwise, we would error on `name: str = field(default="")`.
overload.set_return_type(Type::KnownInstance(KnownInstanceType::Field(
FieldInstance::new(db, default_ty, init, kw_only),
)));
}
Type::FunctionLiteral(function_type) => match function_type.known(db) {
Some(KnownFunction::IsEquivalentTo) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
@ -956,59 +999,24 @@ impl<'db> Bindings<'db> {
flags |= DataclassTransformerFlags::FROZEN_DEFAULT;
}
let params = DataclassTransformerParams::new(
db,
flags,
field_specifiers.unwrap_or(Type::none(db)),
);
let field_specifiers: Box<[Type<'db>]> = field_specifiers
.map(|tuple_type| {
tuple_type
.exact_tuple_instance_spec(db)
.iter()
.flat_map(|tuple_spec| tuple_spec.fixed_elements())
.copied()
.collect()
})
.unwrap_or_default();
let params =
DataclassTransformerParams::new(db, flags, field_specifiers);
overload.set_return_type(Type::DataclassTransformer(params));
}
}
Some(KnownFunction::Field) => {
let default =
overload.parameter_type_by_name("default").unwrap_or(None);
let default_factory = overload
.parameter_type_by_name("default_factory")
.unwrap_or(None);
let init = overload.parameter_type_by_name("init").unwrap_or(None);
let kw_only =
overload.parameter_type_by_name("kw_only").unwrap_or(None);
let default_ty = match (default, default_factory) {
(Some(default_ty), _) => default_ty,
(_, Some(default_factory_ty)) => default_factory_ty
.try_call(db, &CallArguments::none())
.map_or(Type::unknown(), |binding| binding.return_type(db)),
_ => Type::unknown(),
};
let init = init
.map(|init| !init.bool(db).is_always_false())
.unwrap_or(true);
let kw_only =
if Program::get(db).python_version(db) >= PythonVersion::PY310 {
kw_only.map(|kw_only| !kw_only.bool(db).is_always_false())
} else {
None
};
// `typeshed` pretends that `dataclasses.field()` returns the type of the
// default value directly. At runtime, however, this function returns an
// instance of `dataclasses.Field`. We also model it this way and return
// a known-instance type with information about the field. The drawback
// of this approach is that we need to pretend that instances of `Field`
// are assignable to `T` if the default type of the field is assignable
// to `T`. Otherwise, we would error on `name: str = field(default="")`.
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::Field(FieldInstance::new(
db, default_ty, init, kw_only,
)),
));
}
_ => {
// Ideally, either the implementation, or exactly one of the overloads
// of the function can have the dataclass_transform decorator applied.

View File

@ -2900,7 +2900,7 @@ impl<'db> ClassLiteral<'db> {
default_ty = Some(field.default_type(db));
if self
.dataclass_params(db)
.map(|params| params.field_specifiers(db).is_none(db))
.map(|params| params.field_specifiers(db).is_empty())
.unwrap_or(false)
{
// This happens when constructing a `dataclass` with a `dataclass_transform`

View File

@ -169,11 +169,15 @@ impl Default for DataclassTransformerFlags {
}
}
/// Metadata for a dataclass-transformer. Stored inside a `Type::DataclassTransformer(…)`
/// instance that we use as the return type for `dataclass_transform(…)` calls.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub struct DataclassTransformerParams<'db> {
pub flags: DataclassTransformerFlags,
pub field_specifiers: Type<'db>,
#[returns(deref)]
pub field_specifiers: Box<[Type<'db>]>,
}
impl get_size2::GetSize for DataclassTransformerParams<'_> {}

View File

@ -9,6 +9,7 @@ use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, PythonVersion};
use ruff_python_stdlib::builtins::version_builtin_was_added;
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use super::{
CycleRecovery, DefinitionInference, DefinitionInferenceExtra, ExpressionInference,
@ -272,6 +273,10 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// `true` if all places in this expression are definitely bound
all_definitely_bound: bool,
/// A list of `dataclass_transform` field specifiers that are "active" (when inferring
/// the right hand side of an annotated assignment in a class that is a dataclass).
dataclass_field_specifiers: SmallVec<[Type<'db>; 2]>,
}
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
@ -307,6 +312,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
undecorated_type: None,
cycle_recovery: None,
all_definitely_bound: true,
dataclass_field_specifiers: SmallVec::new(),
}
}
@ -4512,10 +4518,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
debug_assert!(PlaceExpr::try_from_expr(target).is_some());
if let Some(value) = value {
fn field_specifiers<'db>(
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
scope: ScopeId<'db>,
) -> Option<SmallVec<[Type<'db>; 2]>> {
let enclosing_scope = index.scope(scope.file_scope_id(db));
if enclosing_scope.kind() != ScopeKind::Class {
return None;
}
let class_node = enclosing_scope.node().as_class()?;
let class_definition = index.expect_single_definition(class_node);
infer_definition_types(db, class_definition)
.declaration_type(class_definition)
.inner_type()
.as_class_literal()?
.dataclass_params(db)
.map(|params| SmallVec::from(params.field_specifiers(db)))
}
if let Some(specifiers) = field_specifiers(self.db(), self.index, self.scope()) {
self.dataclass_field_specifiers = specifiers;
}
let inferred_ty = self.infer_maybe_standalone_expression(
value,
TypeContext::new(Some(declared.inner_type())),
);
self.dataclass_field_specifiers.clear();
let inferred_ty = if target
.as_name_expr()
.is_some_and(|name| &name.id == "TYPE_CHECKING")
@ -6631,7 +6666,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
let mut bindings = match bindings.check_types(
self.db(),
&call_arguments,
&tcx,
&self.dataclass_field_specifiers[..],
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, call_expression.into());
@ -9218,8 +9258,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
let bindings = match Bindings::from(binding)
.match_parameters(self.db(), &call_argument_types)
.check_types(self.db(), &call_argument_types, &TypeContext::default())
{
.check_types(
self.db(),
&call_argument_types,
&TypeContext::default(),
&self.dataclass_field_specifiers[..],
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, subscript.into());
@ -9751,6 +9795,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred,
cycle_recovery,
all_definitely_bound,
dataclass_field_specifiers: _,
// Ignored; only relevant to definition regions
undecorated_type: _,
@ -9817,8 +9862,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred,
cycle_recovery,
undecorated_type,
all_definitely_bound: _,
// builder only state
dataclass_field_specifiers: _,
all_definitely_bound: _,
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,
@ -9885,12 +9931,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred: _,
bindings: _,
declarations: _,
all_definitely_bound: _,
// Ignored; only relevant to definition regions
undecorated_type: _,
// Builder only state
dataclass_field_specifiers: _,
all_definitely_bound: _,
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,