[red-knot] Handle generic constructors of generic classes (#17552)

We now handle generic constructor methods on generic classes correctly:

```py
class C[T]:
    def __init__[S](self, t: T, s: S): ...

x = C(1, "str")
```

Here, constructing `C` requires us to infer a specialization for the
generic contexts of `C` and `__init__` at the same time.

At first I thought I would need to track the full stack of nested
generic contexts here (since the `[S]` context is nested within the
`[T]` context). But I think this is the only way that we might need to
specialize more than one generic context at once — in all other cases, a
containing generic context must be specialized before we get to a nested
one, and so we can just special-case this.

While we're here, we also construct the generic context for a generic
function lazily, when its signature is accessed, instead of eagerly when
inferring the function body.
This commit is contained in:
Douglas Creager 2025-04-23 15:06:18 -04:00 committed by GitHub
parent 61e73481fe
commit 9db63fc58c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 101 additions and 59 deletions

View File

@ -232,21 +232,11 @@ TODO: These do not currently work yet, because we don't correctly model the nest
class C[T]: class C[T]:
def __init__[S](self, x: T, y: S) -> None: ... def __init__[S](self, x: T, y: S) -> None: ...
# TODO: no error reveal_type(C(1, 1)) # revealed: C[Literal[1]]
# TODO: revealed: C[Literal[1]] reveal_type(C(1, "string")) # revealed: C[Literal[1]]
# error: [invalid-argument-type] reveal_type(C(1, True)) # revealed: C[Literal[1]]
reveal_type(C(1, 1)) # revealed: C[Unknown]
# TODO: no error
# TODO: revealed: C[Literal[1]]
# error: [invalid-argument-type]
reveal_type(C(1, "string")) # revealed: C[Unknown]
# TODO: no error
# TODO: revealed: C[Literal[1]]
# error: [invalid-argument-type]
reveal_type(C(1, True)) # revealed: C[Unknown]
# TODO: [invalid-assignment] "Object of type `C[Literal["five"]]` is not assignable to `C[int]`" # error: [invalid-assignment] "Object of type `C[Literal["five"]]` is not assignable to `C[int]`"
# error: [invalid-argument-type] "Argument to this function is incorrect: Expected `S`, found `Literal[1]`"
wrong_innards: C[int] = C("five", 1) wrong_innards: C[int] = C("five", 1)
``` ```

View File

@ -4268,13 +4268,13 @@ impl<'db> Type<'db> {
.as_ref() .as_ref()
.and_then(Bindings::single_element) .and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload) .and_then(CallableBinding::matching_overload)
.and_then(|(_, binding)| binding.specialization()); .and_then(|(_, binding)| binding.inherited_specialization());
let init_specialization = init_call_outcome let init_specialization = init_call_outcome
.and_then(Result::ok) .and_then(Result::ok)
.as_ref() .as_ref()
.and_then(Bindings::single_element) .and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload) .and_then(CallableBinding::matching_overload)
.and_then(|(_, binding)| binding.specialization()); .and_then(|(_, binding)| binding.inherited_specialization());
let specialization = match (new_specialization, init_specialization) { let specialization = match (new_specialization, init_specialization) {
(None, None) => None, (None, None) => None,
(Some(specialization), None) | (None, Some(specialization)) => { (Some(specialization), None) | (None, Some(specialization)) => {
@ -5940,8 +5940,10 @@ pub struct FunctionType<'db> {
/// with `@dataclass_transformer(...)`. /// with `@dataclass_transformer(...)`.
dataclass_transformer_params: Option<DataclassTransformerParams>, dataclass_transformer_params: Option<DataclassTransformerParams>,
/// The generic context of a generic function. /// The inherited generic context, if this function is a class method being used to infer the
generic_context: Option<GenericContext<'db>>, /// specialization of its generic class. If the method is itself generic, this is in addition
/// to its own generic context.
inherited_generic_context: Option<GenericContext<'db>>,
/// A specialization that should be applied to the function's parameter and return types, /// A specialization that should be applied to the function's parameter and return types,
/// either because the function is itself generic, or because it appears in the body of a /// either because the function is itself generic, or because it appears in the body of a
@ -6007,11 +6009,7 @@ impl<'db> FunctionType<'db> {
/// would depend on the function's AST and rerun for every change in that file. /// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(return_ref)] #[salsa::tracked(return_ref)]
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut internal_signature = self.internal_signature(db); let internal_signature = self.internal_signature(db);
if let Some(specialization) = self.specialization(db) {
internal_signature = internal_signature.apply_specialization(db, specialization);
}
// The semantic model records a use for each function on the name node. This is used here // The semantic model records a use for each function on the name node. This is used here
// to get the previous function definition with the same name. // to get the previous function definition with the same name.
@ -6071,14 +6069,51 @@ impl<'db> FunctionType<'db> {
let scope = self.body_scope(db); let scope = self.body_scope(db);
let function_stmt_node = scope.node(db).expect_function(); let function_stmt_node = scope.node(db).expect_function();
let definition = self.definition(db); let definition = self.definition(db);
Signature::from_function(db, self.generic_context(db), definition, function_stmt_node) let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
let index = semantic_index(db, scope.file(db));
GenericContext::from_type_params(db, index, type_params)
});
let mut signature = Signature::from_function(
db,
generic_context,
self.inherited_generic_context(db),
definition,
function_stmt_node,
);
if let Some(specialization) = self.specialization(db) {
signature = signature.apply_specialization(db, specialization);
}
signature
} }
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function) self.known(db) == Some(known_function)
} }
fn with_generic_context(self, db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { fn with_dataclass_transformer_params(
self,
db: &'db dyn Db,
params: DataclassTransformerParams,
) -> Self {
Self::new(
db,
self.name(db).clone(),
self.known(db),
self.body_scope(db),
self.decorators(db),
Some(params),
self.inherited_generic_context(db),
self.specialization(db),
)
}
fn with_inherited_generic_context(
self,
db: &'db dyn Db,
inherited_generic_context: GenericContext<'db>,
) -> Self {
// A function cannot inherit more than one generic context from its containing class.
debug_assert!(self.inherited_generic_context(db).is_none());
Self::new( Self::new(
db, db,
self.name(db).clone(), self.name(db).clone(),
@ -6086,7 +6121,7 @@ impl<'db> FunctionType<'db> {
self.body_scope(db), self.body_scope(db),
self.decorators(db), self.decorators(db),
self.dataclass_transformer_params(db), self.dataclass_transformer_params(db),
Some(generic_context), Some(inherited_generic_context),
self.specialization(db), self.specialization(db),
) )
} }
@ -6103,7 +6138,7 @@ impl<'db> FunctionType<'db> {
self.body_scope(db), self.body_scope(db),
self.decorators(db), self.decorators(db),
self.dataclass_transformer_params(db), self.dataclass_transformer_params(db),
self.generic_context(db), self.inherited_generic_context(db),
Some(specialization), Some(specialization),
) )
} }

View File

@ -19,9 +19,9 @@ use crate::types::diagnostic::{
use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::generics::{Specialization, SpecializationBuilder};
use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{ use crate::types::{
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType, BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType,
UnionType, WrapperDescriptorKind, WrapperDescriptorKind,
}; };
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
use ruff_python_ast as ast; use ruff_python_ast as ast;
@ -424,16 +424,9 @@ impl<'db> Bindings<'db> {
Type::DataclassTransformer(params) => { Type::DataclassTransformer(params) => {
if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() { if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() {
overload.set_return_type(Type::FunctionLiteral(FunctionType::new( overload.set_return_type(Type::FunctionLiteral(
db, function.with_dataclass_transformer_params(db, params),
function.name(db), ));
function.known(db),
function.body_scope(db),
function.decorators(db),
Some(params),
function.generic_context(db),
function.specialization(db),
)));
} }
} }
@ -961,6 +954,10 @@ pub(crate) struct Binding<'db> {
/// The specialization that was inferred from the argument types, if the callable is generic. /// The specialization that was inferred from the argument types, if the callable is generic.
specialization: Option<Specialization<'db>>, specialization: Option<Specialization<'db>>,
/// The specialization that was inferred for a class method's containing generic class, if it
/// is being used to infer a specialization for the class.
inherited_specialization: Option<Specialization<'db>>,
/// The formal parameter that each argument is matched with, in argument source order, or /// The formal parameter that each argument is matched with, in argument source order, or
/// `None` if the argument was not matched to any parameter. /// `None` if the argument was not matched to any parameter.
argument_parameters: Box<[Option<usize>]>, argument_parameters: Box<[Option<usize>]>,
@ -1097,6 +1094,7 @@ impl<'db> Binding<'db> {
Self { Self {
return_ty: signature.return_ty.unwrap_or(Type::unknown()), return_ty: signature.return_ty.unwrap_or(Type::unknown()),
specialization: None, specialization: None,
inherited_specialization: None,
argument_parameters: argument_parameters.into_boxed_slice(), argument_parameters: argument_parameters.into_boxed_slice(),
parameter_tys: vec![None; parameters.len()].into_boxed_slice(), parameter_tys: vec![None; parameters.len()].into_boxed_slice(),
errors, errors,
@ -1112,8 +1110,8 @@ impl<'db> Binding<'db> {
// If this overload is generic, first see if we can infer a specialization of the function // If this overload is generic, first see if we can infer a specialization of the function
// from the arguments that were passed in. // from the arguments that were passed in.
let parameters = signature.parameters(); let parameters = signature.parameters();
self.specialization = signature.generic_context.map(|generic_context| { if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() {
let mut builder = SpecializationBuilder::new(db, generic_context); let mut builder = SpecializationBuilder::new(db);
for (argument_index, (_, argument_type)) in argument_types.iter().enumerate() { for (argument_index, (_, argument_type)) in argument_types.iter().enumerate() {
let Some(parameter_index) = self.argument_parameters[argument_index] else { let Some(parameter_index) = self.argument_parameters[argument_index] else {
// There was an error with argument when matching parameters, so don't bother // There was an error with argument when matching parameters, so don't bother
@ -1126,8 +1124,11 @@ impl<'db> Binding<'db> {
}; };
builder.infer(expected_type, argument_type); builder.infer(expected_type, argument_type);
} }
builder.build() self.specialization = signature.generic_context.map(|gc| builder.build(gc));
}); self.inherited_specialization = signature
.inherited_generic_context
.map(|gc| builder.build(gc));
}
let mut num_synthetic_args = 0; let mut num_synthetic_args = 0;
let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { let get_argument_index = |argument_index: usize, num_synthetic_args: usize| {
@ -1155,6 +1156,9 @@ impl<'db> Binding<'db> {
if let Some(specialization) = self.specialization { if let Some(specialization) = self.specialization {
expected_ty = expected_ty.apply_specialization(db, specialization); expected_ty = expected_ty.apply_specialization(db, specialization);
} }
if let Some(inherited_specialization) = self.inherited_specialization {
expected_ty = expected_ty.apply_specialization(db, inherited_specialization);
}
if !argument_type.is_assignable_to(db, expected_ty) { if !argument_type.is_assignable_to(db, expected_ty) {
let positional = matches!(argument, Argument::Positional | Argument::Synthetic) let positional = matches!(argument, Argument::Positional | Argument::Synthetic)
&& !parameter.is_variadic(); && !parameter.is_variadic();
@ -1180,6 +1184,11 @@ impl<'db> Binding<'db> {
if let Some(specialization) = self.specialization { if let Some(specialization) = self.specialization {
self.return_ty = self.return_ty.apply_specialization(db, specialization); self.return_ty = self.return_ty.apply_specialization(db, specialization);
} }
if let Some(inherited_specialization) = self.inherited_specialization {
self.return_ty = self
.return_ty
.apply_specialization(db, inherited_specialization);
}
} }
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {
@ -1190,8 +1199,8 @@ impl<'db> Binding<'db> {
self.return_ty self.return_ty
} }
pub(crate) fn specialization(&self) -> Option<Specialization<'db>> { pub(crate) fn inherited_specialization(&self) -> Option<Specialization<'db>> {
self.specialization self.inherited_specialization
} }
pub(crate) fn parameter_types(&self) -> &[Option<Type<'db>>] { pub(crate) fn parameter_types(&self) -> &[Option<Type<'db>>] {

View File

@ -1017,7 +1017,7 @@ impl<'db> ClassLiteralType<'db> {
Some(_), Some(_),
"__new__" | "__init__", "__new__" | "__init__",
) => Type::FunctionLiteral( ) => Type::FunctionLiteral(
function.with_generic_context(db, origin.generic_context(db)), function.with_inherited_generic_context(db, origin.generic_context(db)),
), ),
_ => ty, _ => ty,
} }

View File

@ -299,22 +299,19 @@ impl<'db> Specialization<'db> {
/// specialization of a generic function. /// specialization of a generic function.
pub(crate) struct SpecializationBuilder<'db> { pub(crate) struct SpecializationBuilder<'db> {
db: &'db dyn Db, db: &'db dyn Db,
generic_context: GenericContext<'db>,
types: FxHashMap<TypeVarInstance<'db>, UnionBuilder<'db>>, types: FxHashMap<TypeVarInstance<'db>, UnionBuilder<'db>>,
} }
impl<'db> SpecializationBuilder<'db> { impl<'db> SpecializationBuilder<'db> {
pub(crate) fn new(db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { pub(crate) fn new(db: &'db dyn Db) -> Self {
Self { Self {
db, db,
generic_context,
types: FxHashMap::default(), types: FxHashMap::default(),
} }
} }
pub(crate) fn build(mut self) -> Specialization<'db> { pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
let types: Box<[_]> = self let types: Box<[_]> = generic_context
.generic_context
.variables(self.db) .variables(self.db)
.iter() .iter()
.map(|variable| { .map(|variable| {
@ -324,7 +321,7 @@ impl<'db> SpecializationBuilder<'db> {
.unwrap_or(variable.default_ty(self.db).unwrap_or(Type::unknown())) .unwrap_or(variable.default_ty(self.db).unwrap_or(Type::unknown()))
}) })
.collect(); .collect();
Specialization::new(self.db, self.generic_context, types) Specialization::new(self.db, generic_context, types)
} }
fn add_type_mapping(&mut self, typevar: TypeVarInstance<'db>, ty: Type<'db>) { fn add_type_mapping(&mut self, typevar: TypeVarInstance<'db>, ty: Type<'db>) {

View File

@ -1525,10 +1525,6 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
} }
let generic_context = type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(self.db(), self.index, type_params)
});
let function_kind = let function_kind =
KnownFunction::try_from_definition_and_name(self.db(), definition, name); KnownFunction::try_from_definition_and_name(self.db(), definition, name);
@ -1537,6 +1533,7 @@ impl<'db> TypeInferenceBuilder<'db> {
.node_scope(NodeWithScopeRef::Function(function)) .node_scope(NodeWithScopeRef::Function(function))
.to_scope_id(self.db(), self.file()); .to_scope_id(self.db(), self.file());
let inherited_generic_context = None;
let specialization = None; let specialization = None;
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new( let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(
@ -1546,7 +1543,7 @@ impl<'db> TypeInferenceBuilder<'db> {
body_scope, body_scope,
function_decorators, function_decorators,
dataclass_transformer_params, dataclass_transformer_params,
generic_context, inherited_generic_context,
specialization, specialization,
)); ));

View File

@ -166,6 +166,7 @@ impl<'db> CallableSignature<'db> {
pub(crate) fn dynamic(signature_type: Type<'db>) -> Self { pub(crate) fn dynamic(signature_type: Type<'db>) -> Self {
let signature = Signature { let signature = Signature {
generic_context: None, generic_context: None,
inherited_generic_context: None,
parameters: Parameters::gradual_form(), parameters: Parameters::gradual_form(),
return_ty: Some(signature_type), return_ty: Some(signature_type),
}; };
@ -178,6 +179,7 @@ impl<'db> CallableSignature<'db> {
let signature_type = todo_type!(reason); let signature_type = todo_type!(reason);
let signature = Signature { let signature = Signature {
generic_context: None, generic_context: None,
inherited_generic_context: None,
parameters: Parameters::todo(), parameters: Parameters::todo(),
return_ty: Some(signature_type), return_ty: Some(signature_type),
}; };
@ -215,6 +217,11 @@ pub struct Signature<'db> {
/// The generic context for this overload, if it is generic. /// The generic context for this overload, if it is generic.
pub(crate) generic_context: Option<GenericContext<'db>>, pub(crate) generic_context: Option<GenericContext<'db>>,
/// The inherited generic context, if this function is a class method being used to infer the
/// specialization of its generic class. If the method is itself generic, this is in addition
/// to its own generic context.
pub(crate) inherited_generic_context: Option<GenericContext<'db>>,
/// Parameters, in source order. /// Parameters, in source order.
/// ///
/// The ordering of parameters in a valid signature must be: first positional-only parameters, /// The ordering of parameters in a valid signature must be: first positional-only parameters,
@ -233,6 +240,7 @@ impl<'db> Signature<'db> {
pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option<Type<'db>>) -> Self { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option<Type<'db>>) -> Self {
Self { Self {
generic_context: None, generic_context: None,
inherited_generic_context: None,
parameters, parameters,
return_ty, return_ty,
} }
@ -245,6 +253,7 @@ impl<'db> Signature<'db> {
) -> Self { ) -> Self {
Self { Self {
generic_context, generic_context,
inherited_generic_context: None,
parameters, parameters,
return_ty, return_ty,
} }
@ -254,6 +263,7 @@ impl<'db> Signature<'db> {
pub(super) fn from_function( pub(super) fn from_function(
db: &'db dyn Db, db: &'db dyn Db,
generic_context: Option<GenericContext<'db>>, generic_context: Option<GenericContext<'db>>,
inherited_generic_context: Option<GenericContext<'db>>,
definition: Definition<'db>, definition: Definition<'db>,
function_node: &ast::StmtFunctionDef, function_node: &ast::StmtFunctionDef,
) -> Self { ) -> Self {
@ -267,6 +277,7 @@ impl<'db> Signature<'db> {
Self { Self {
generic_context, generic_context,
inherited_generic_context,
parameters: Parameters::from_parameters( parameters: Parameters::from_parameters(
db, db,
definition, definition,
@ -279,6 +290,7 @@ impl<'db> Signature<'db> {
pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self { pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self {
Self { Self {
generic_context: self.generic_context, generic_context: self.generic_context,
inherited_generic_context: self.inherited_generic_context,
parameters: self parameters: self
.parameters .parameters
.iter() .iter()
@ -295,6 +307,7 @@ impl<'db> Signature<'db> {
) -> Self { ) -> Self {
Self { Self {
generic_context: self.generic_context, generic_context: self.generic_context,
inherited_generic_context: self.inherited_generic_context,
parameters: self.parameters.apply_specialization(db, specialization), parameters: self.parameters.apply_specialization(db, specialization),
return_ty: self return_ty: self
.return_ty .return_ty
@ -310,6 +323,7 @@ impl<'db> Signature<'db> {
pub(crate) fn bind_self(&self) -> Self { pub(crate) fn bind_self(&self) -> Self {
Self { Self {
generic_context: self.generic_context, generic_context: self.generic_context,
inherited_generic_context: self.inherited_generic_context,
parameters: Parameters::new(self.parameters().iter().skip(1).cloned()), parameters: Parameters::new(self.parameters().iter().skip(1).cloned()),
return_ty: self.return_ty, return_ty: self.return_ty,
} }