diff --git a/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md b/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md index 295b44827d..11d1d36b35 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md @@ -232,21 +232,11 @@ TODO: These do not currently work yet, because we don't correctly model the nest class C[T]: def __init__[S](self, x: T, y: S) -> None: ... -# TODO: no error -# TODO: revealed: C[Literal[1]] -# error: [invalid-argument-type] -reveal_type(C(1, 1)) # revealed: C[Unknown] -# TODO: no error -# TODO: revealed: C[Literal[1]] -# error: [invalid-argument-type] -reveal_type(C(1, "string")) # revealed: C[Unknown] -# TODO: no error -# TODO: revealed: C[Literal[1]] -# error: [invalid-argument-type] -reveal_type(C(1, True)) # revealed: C[Unknown] +reveal_type(C(1, 1)) # revealed: C[Literal[1]] +reveal_type(C(1, "string")) # revealed: C[Literal[1]] +reveal_type(C(1, True)) # revealed: C[Literal[1]] -# TODO: [invalid-assignment] "Object of type `C[Literal["five"]]` is not assignable to `C[int]`" -# error: [invalid-argument-type] "Argument to this function is incorrect: Expected `S`, found `Literal[1]`" +# error: [invalid-assignment] "Object of type `C[Literal["five"]]` is not assignable to `C[int]`" wrong_innards: C[int] = C("five", 1) ``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index ebb44ae9cc..9f45c4bb85 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -4268,13 +4268,13 @@ impl<'db> Type<'db> { .as_ref() .and_then(Bindings::single_element) .and_then(CallableBinding::matching_overload) - .and_then(|(_, binding)| binding.specialization()); + .and_then(|(_, binding)| binding.inherited_specialization()); let init_specialization = init_call_outcome .and_then(Result::ok) .as_ref() .and_then(Bindings::single_element) .and_then(CallableBinding::matching_overload) - .and_then(|(_, binding)| binding.specialization()); + .and_then(|(_, binding)| binding.inherited_specialization()); let specialization = match (new_specialization, init_specialization) { (None, None) => None, (Some(specialization), None) | (None, Some(specialization)) => { @@ -5940,8 +5940,10 @@ pub struct FunctionType<'db> { /// with `@dataclass_transformer(...)`. dataclass_transformer_params: Option, - /// The generic context of a generic function. - generic_context: Option>, + /// The inherited generic context, if this function is a class method being used to infer the + /// specialization of its generic class. If the method is itself generic, this is in addition + /// to its own generic context. + inherited_generic_context: Option>, /// A specialization that should be applied to the function's parameter and return types, /// either because the function is itself generic, or because it appears in the body of a @@ -6007,11 +6009,7 @@ impl<'db> FunctionType<'db> { /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(return_ref)] pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { - let mut internal_signature = self.internal_signature(db); - - if let Some(specialization) = self.specialization(db) { - internal_signature = internal_signature.apply_specialization(db, specialization); - } + let internal_signature = self.internal_signature(db); // The semantic model records a use for each function on the name node. This is used here // to get the previous function definition with the same name. @@ -6071,14 +6069,51 @@ impl<'db> FunctionType<'db> { let scope = self.body_scope(db); let function_stmt_node = scope.node(db).expect_function(); let definition = self.definition(db); - Signature::from_function(db, self.generic_context(db), definition, function_stmt_node) + let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { + let index = semantic_index(db, scope.file(db)); + GenericContext::from_type_params(db, index, type_params) + }); + let mut signature = Signature::from_function( + db, + generic_context, + self.inherited_generic_context(db), + definition, + function_stmt_node, + ); + if let Some(specialization) = self.specialization(db) { + signature = signature.apply_specialization(db, specialization); + } + signature } pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { self.known(db) == Some(known_function) } - fn with_generic_context(self, db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { + fn with_dataclass_transformer_params( + self, + db: &'db dyn Db, + params: DataclassTransformerParams, + ) -> Self { + Self::new( + db, + self.name(db).clone(), + self.known(db), + self.body_scope(db), + self.decorators(db), + Some(params), + self.inherited_generic_context(db), + self.specialization(db), + ) + } + + fn with_inherited_generic_context( + self, + db: &'db dyn Db, + inherited_generic_context: GenericContext<'db>, + ) -> Self { + // A function cannot inherit more than one generic context from its containing class. + debug_assert!(self.inherited_generic_context(db).is_none()); Self::new( db, self.name(db).clone(), @@ -6086,7 +6121,7 @@ impl<'db> FunctionType<'db> { self.body_scope(db), self.decorators(db), self.dataclass_transformer_params(db), - Some(generic_context), + Some(inherited_generic_context), self.specialization(db), ) } @@ -6103,7 +6138,7 @@ impl<'db> FunctionType<'db> { self.body_scope(db), self.decorators(db), self.dataclass_transformer_params(db), - self.generic_context(db), + self.inherited_generic_context(db), Some(specialization), ) } diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 6fb2505e87..4a2b5a7bfd 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -19,9 +19,9 @@ use crate::types::diagnostic::{ use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType, - KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, - UnionType, WrapperDescriptorKind, + BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass, + KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, + WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; use ruff_python_ast as ast; @@ -424,16 +424,9 @@ impl<'db> Bindings<'db> { Type::DataclassTransformer(params) => { if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() { - overload.set_return_type(Type::FunctionLiteral(FunctionType::new( - db, - function.name(db), - function.known(db), - function.body_scope(db), - function.decorators(db), - Some(params), - function.generic_context(db), - function.specialization(db), - ))); + overload.set_return_type(Type::FunctionLiteral( + function.with_dataclass_transformer_params(db, params), + )); } } @@ -961,6 +954,10 @@ pub(crate) struct Binding<'db> { /// The specialization that was inferred from the argument types, if the callable is generic. specialization: Option>, + /// The specialization that was inferred for a class method's containing generic class, if it + /// is being used to infer a specialization for the class. + inherited_specialization: Option>, + /// The formal parameter that each argument is matched with, in argument source order, or /// `None` if the argument was not matched to any parameter. argument_parameters: Box<[Option]>, @@ -1097,6 +1094,7 @@ impl<'db> Binding<'db> { Self { return_ty: signature.return_ty.unwrap_or(Type::unknown()), specialization: None, + inherited_specialization: None, argument_parameters: argument_parameters.into_boxed_slice(), parameter_tys: vec![None; parameters.len()].into_boxed_slice(), errors, @@ -1112,8 +1110,8 @@ impl<'db> Binding<'db> { // If this overload is generic, first see if we can infer a specialization of the function // from the arguments that were passed in. let parameters = signature.parameters(); - self.specialization = signature.generic_context.map(|generic_context| { - let mut builder = SpecializationBuilder::new(db, generic_context); + if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() { + let mut builder = SpecializationBuilder::new(db); for (argument_index, (_, argument_type)) in argument_types.iter().enumerate() { let Some(parameter_index) = self.argument_parameters[argument_index] else { // There was an error with argument when matching parameters, so don't bother @@ -1126,8 +1124,11 @@ impl<'db> Binding<'db> { }; builder.infer(expected_type, argument_type); } - builder.build() - }); + self.specialization = signature.generic_context.map(|gc| builder.build(gc)); + self.inherited_specialization = signature + .inherited_generic_context + .map(|gc| builder.build(gc)); + } let mut num_synthetic_args = 0; let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { @@ -1155,6 +1156,9 @@ impl<'db> Binding<'db> { if let Some(specialization) = self.specialization { expected_ty = expected_ty.apply_specialization(db, specialization); } + if let Some(inherited_specialization) = self.inherited_specialization { + expected_ty = expected_ty.apply_specialization(db, inherited_specialization); + } if !argument_type.is_assignable_to(db, expected_ty) { let positional = matches!(argument, Argument::Positional | Argument::Synthetic) && !parameter.is_variadic(); @@ -1180,6 +1184,11 @@ impl<'db> Binding<'db> { if let Some(specialization) = self.specialization { self.return_ty = self.return_ty.apply_specialization(db, specialization); } + if let Some(inherited_specialization) = self.inherited_specialization { + self.return_ty = self + .return_ty + .apply_specialization(db, inherited_specialization); + } } pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { @@ -1190,8 +1199,8 @@ impl<'db> Binding<'db> { self.return_ty } - pub(crate) fn specialization(&self) -> Option> { - self.specialization + pub(crate) fn inherited_specialization(&self) -> Option> { + self.inherited_specialization } pub(crate) fn parameter_types(&self) -> &[Option>] { diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index f35c6a3eb2..85aa50aa0d 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -1017,7 +1017,7 @@ impl<'db> ClassLiteralType<'db> { Some(_), "__new__" | "__init__", ) => Type::FunctionLiteral( - function.with_generic_context(db, origin.generic_context(db)), + function.with_inherited_generic_context(db, origin.generic_context(db)), ), _ => ty, } diff --git a/crates/red_knot_python_semantic/src/types/generics.rs b/crates/red_knot_python_semantic/src/types/generics.rs index ef5d66a772..66cd67f263 100644 --- a/crates/red_knot_python_semantic/src/types/generics.rs +++ b/crates/red_knot_python_semantic/src/types/generics.rs @@ -299,22 +299,19 @@ impl<'db> Specialization<'db> { /// specialization of a generic function. pub(crate) struct SpecializationBuilder<'db> { db: &'db dyn Db, - generic_context: GenericContext<'db>, types: FxHashMap, UnionBuilder<'db>>, } impl<'db> SpecializationBuilder<'db> { - pub(crate) fn new(db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { + pub(crate) fn new(db: &'db dyn Db) -> Self { Self { db, - generic_context, types: FxHashMap::default(), } } - pub(crate) fn build(mut self) -> Specialization<'db> { - let types: Box<[_]> = self - .generic_context + pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> { + let types: Box<[_]> = generic_context .variables(self.db) .iter() .map(|variable| { @@ -324,7 +321,7 @@ impl<'db> SpecializationBuilder<'db> { .unwrap_or(variable.default_ty(self.db).unwrap_or(Type::unknown())) }) .collect(); - Specialization::new(self.db, self.generic_context, types) + Specialization::new(self.db, generic_context, types) } fn add_type_mapping(&mut self, typevar: TypeVarInstance<'db>, ty: Type<'db>) { diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index e07f17d7ac..8dcfe14167 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1525,10 +1525,6 @@ impl<'db> TypeInferenceBuilder<'db> { } } - let generic_context = type_params.as_ref().map(|type_params| { - GenericContext::from_type_params(self.db(), self.index, type_params) - }); - let function_kind = KnownFunction::try_from_definition_and_name(self.db(), definition, name); @@ -1537,6 +1533,7 @@ impl<'db> TypeInferenceBuilder<'db> { .node_scope(NodeWithScopeRef::Function(function)) .to_scope_id(self.db(), self.file()); + let inherited_generic_context = None; let specialization = None; let mut inferred_ty = Type::FunctionLiteral(FunctionType::new( @@ -1546,7 +1543,7 @@ impl<'db> TypeInferenceBuilder<'db> { body_scope, function_decorators, dataclass_transformer_params, - generic_context, + inherited_generic_context, specialization, )); diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index c74554fe91..5e931c5d96 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -166,6 +166,7 @@ impl<'db> CallableSignature<'db> { pub(crate) fn dynamic(signature_type: Type<'db>) -> Self { let signature = Signature { generic_context: None, + inherited_generic_context: None, parameters: Parameters::gradual_form(), return_ty: Some(signature_type), }; @@ -178,6 +179,7 @@ impl<'db> CallableSignature<'db> { let signature_type = todo_type!(reason); let signature = Signature { generic_context: None, + inherited_generic_context: None, parameters: Parameters::todo(), return_ty: Some(signature_type), }; @@ -215,6 +217,11 @@ pub struct Signature<'db> { /// The generic context for this overload, if it is generic. pub(crate) generic_context: Option>, + /// The inherited generic context, if this function is a class method being used to infer the + /// specialization of its generic class. If the method is itself generic, this is in addition + /// to its own generic context. + pub(crate) inherited_generic_context: Option>, + /// Parameters, in source order. /// /// The ordering of parameters in a valid signature must be: first positional-only parameters, @@ -233,6 +240,7 @@ impl<'db> Signature<'db> { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option>) -> Self { Self { generic_context: None, + inherited_generic_context: None, parameters, return_ty, } @@ -245,6 +253,7 @@ impl<'db> Signature<'db> { ) -> Self { Self { generic_context, + inherited_generic_context: None, parameters, return_ty, } @@ -254,6 +263,7 @@ impl<'db> Signature<'db> { pub(super) fn from_function( db: &'db dyn Db, generic_context: Option>, + inherited_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, ) -> Self { @@ -267,6 +277,7 @@ impl<'db> Signature<'db> { Self { generic_context, + inherited_generic_context, parameters: Parameters::from_parameters( db, definition, @@ -279,6 +290,7 @@ impl<'db> Signature<'db> { pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self { Self { generic_context: self.generic_context, + inherited_generic_context: self.inherited_generic_context, parameters: self .parameters .iter() @@ -295,6 +307,7 @@ impl<'db> Signature<'db> { ) -> Self { Self { generic_context: self.generic_context, + inherited_generic_context: self.inherited_generic_context, parameters: self.parameters.apply_specialization(db, specialization), return_ty: self .return_ty @@ -310,6 +323,7 @@ impl<'db> Signature<'db> { pub(crate) fn bind_self(&self) -> Self { Self { generic_context: self.generic_context, + inherited_generic_context: self.inherited_generic_context, parameters: Parameters::new(self.parameters().iter().skip(1).cloned()), return_ty: self.return_ty, }