From caf3c916e85e545c357fc55d95a090d85051fc7d Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Fri, 27 Jun 2025 17:01:52 -0400 Subject: [PATCH] [ty] Refactor argument matching / type checking in call binding (#18997) This PR extracts a lot of the complex logic in the `match_parameters` and `check_types` methods of our call binding machinery into separate helper types. This is setup for #18996, which will update this logic to handle variadic arguments. To do so, it is helpful to have the per-argument logic extracted into a method that we can call repeatedly for each _element_ of a variadic argument. This should be a pure refactoring, with no behavioral changes. --- .../ty_python_semantic/src/types/call/bind.rs | 545 +++++++++++------- 1 file changed, 352 insertions(+), 193 deletions(-) diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 6d0810bab6..5498635041 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -26,7 +26,7 @@ use crate::types::function::{ DataclassTransformerParams, FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral, }; use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError}; -use crate::types::signatures::{Parameter, ParameterForm}; +use crate::types::signatures::{Parameter, ParameterForm, Parameters}; use crate::types::tuple::TupleType; use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType, @@ -1754,6 +1754,334 @@ enum MatchingOverloadIndex { Multiple(Vec), } +struct ArgumentMatcher<'a, 'db> { + parameters: &'a Parameters<'db>, + argument_forms: &'a mut [Option], + conflicting_forms: &'a mut [bool], + errors: &'a mut Vec>, + + /// The parameter that each argument is matched with. + argument_parameters: Vec>, + /// Whether each parameter has been matched with an argument. + parameter_matched: Vec, + next_positional: usize, + first_excess_positional: Option, + num_synthetic_args: usize, +} + +impl<'a, 'db> ArgumentMatcher<'a, 'db> { + fn new( + arguments: &CallArguments, + parameters: &'a Parameters<'db>, + argument_forms: &'a mut [Option], + conflicting_forms: &'a mut [bool], + errors: &'a mut Vec>, + ) -> Self { + Self { + parameters, + argument_forms, + conflicting_forms, + errors, + argument_parameters: vec![None; arguments.len()], + parameter_matched: vec![false; parameters.len()], + next_positional: 0, + first_excess_positional: None, + num_synthetic_args: 0, + } + } + + fn get_argument_index(&self, argument_index: usize) -> Option { + if argument_index >= self.num_synthetic_args { + // Adjust the argument index to skip synthetic args, which don't appear at the call + // site and thus won't be in the Call node arguments list. + Some(argument_index - self.num_synthetic_args) + } else { + // we are erroring on a synthetic argument, we'll just emit the diagnostic on the + // entire Call node, since there's no argument node for this argument at the call site + None + } + } + + fn assign_argument( + &mut self, + argument_index: usize, + argument: Argument<'a>, + parameter_index: usize, + parameter: &Parameter<'db>, + positional: bool, + ) { + if !matches!(argument, Argument::Synthetic) { + if let Some(existing) = self.argument_forms[argument_index - self.num_synthetic_args] + .replace(parameter.form) + { + if existing != parameter.form { + self.conflicting_forms[argument_index - self.num_synthetic_args] = true; + } + } + } + if self.parameter_matched[parameter_index] { + if !parameter.is_variadic() && !parameter.is_keyword_variadic() { + self.errors.push(BindingError::ParameterAlreadyAssigned { + argument_index: self.get_argument_index(argument_index), + parameter: ParameterContext::new(parameter, parameter_index, positional), + }); + } + } + self.argument_parameters[argument_index] = Some(parameter_index); + self.parameter_matched[parameter_index] = true; + } + + fn match_positional( + &mut self, + argument_index: usize, + argument: Argument<'a>, + ) -> Result<(), ()> { + if matches!(argument, Argument::Synthetic) { + self.num_synthetic_args += 1; + } + let Some((parameter_index, parameter)) = self + .parameters + .get_positional(self.next_positional) + .map(|param| (self.next_positional, param)) + .or_else(|| self.parameters.variadic()) + else { + self.first_excess_positional.get_or_insert(argument_index); + self.next_positional += 1; + return Err(()); + }; + self.next_positional += 1; + self.assign_argument( + argument_index, + argument, + parameter_index, + parameter, + !parameter.is_variadic(), + ); + Ok(()) + } + + fn match_keyword( + &mut self, + argument_index: usize, + argument: Argument<'a>, + name: &str, + ) -> Result<(), ()> { + let Some((parameter_index, parameter)) = self + .parameters + .keyword_by_name(name) + .or_else(|| self.parameters.keyword_variadic()) + else { + self.errors.push(BindingError::UnknownArgument { + argument_name: ast::name::Name::new(name), + argument_index: self.get_argument_index(argument_index), + }); + return Err(()); + }; + self.assign_argument(argument_index, argument, parameter_index, parameter, false); + Ok(()) + } + + fn finish(self) -> Box<[Option]> { + if let Some(first_excess_argument_index) = self.first_excess_positional { + self.errors.push(BindingError::TooManyPositionalArguments { + first_excess_argument_index: self.get_argument_index(first_excess_argument_index), + expected_positional_count: self.parameters.positional().count(), + provided_positional_count: self.next_positional, + }); + } + + let mut missing = vec![]; + for (index, matched) in self.parameter_matched.iter().copied().enumerate() { + if !matched { + let param = &self.parameters[index]; + if param.is_variadic() + || param.is_keyword_variadic() + || param.default_type().is_some() + { + // variadic/keywords and defaulted arguments are not required + continue; + } + missing.push(ParameterContext::new(param, index, false)); + } + } + if !missing.is_empty() { + self.errors.push(BindingError::MissingArguments { + parameters: ParameterContexts(missing), + }); + } + + self.argument_parameters.into_boxed_slice() + } +} + +struct ArgumentTypeChecker<'a, 'db> { + db: &'db dyn Db, + signature: &'a Signature<'db>, + arguments: &'a CallArguments<'a>, + argument_types: &'a [Type<'db>], + argument_parameters: &'a [Option], + parameter_tys: &'a mut [Option>], + errors: &'a mut Vec>, + + specialization: Option>, + inherited_specialization: Option>, +} + +impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { + fn new( + db: &'db dyn Db, + signature: &'a Signature<'db>, + arguments: &'a CallArguments<'a>, + argument_types: &'a [Type<'db>], + argument_parameters: &'a [Option], + parameter_tys: &'a mut [Option>], + errors: &'a mut Vec>, + ) -> Self { + Self { + db, + signature, + arguments, + argument_types, + argument_parameters, + parameter_tys, + errors, + specialization: None, + inherited_specialization: None, + } + } + + fn enumerate_argument_types( + &self, + ) -> impl Iterator, Argument<'a>, Type<'db>)> + 'a { + let mut iter = (self.arguments.iter()) + .zip(self.argument_types.iter().copied()) + .enumerate(); + let mut num_synthetic_args = 0; + std::iter::from_fn(move || { + let (argument_index, (argument, argument_type)) = iter.next()?; + let adjusted_argument_index = if matches!(argument, Argument::Synthetic) { + // If we are erroring on a synthetic argument, we'll just emit the + // diagnostic on the entire Call node, since there's no argument node for + // this argument at the call site + num_synthetic_args += 1; + None + } else { + // Adjust the argument index to skip synthetic args, which don't appear at + // the call site and thus won't be in the Call node arguments list. + Some(argument_index - num_synthetic_args) + }; + Some(( + argument_index, + adjusted_argument_index, + argument, + argument_type, + )) + }) + } + + fn infer_specialization(&mut self) { + if self.signature.generic_context.is_none() + && self.signature.inherited_generic_context.is_none() + { + return; + } + + let parameters = self.signature.parameters(); + let mut builder = SpecializationBuilder::new(self.db); + for (argument_index, adjusted_argument_index, _, argument_type) in + self.enumerate_argument_types() + { + let Some(parameter_index) = self.argument_parameters[argument_index] else { + // There was an error with argument when matching parameters, so don't bother + // type-checking it. + continue; + }; + let parameter = ¶meters[parameter_index]; + let Some(expected_type) = parameter.annotated_type() else { + continue; + }; + if let Err(error) = builder.infer(expected_type, argument_type) { + self.errors.push(BindingError::SpecializationError { + error, + argument_index: adjusted_argument_index, + }); + } + } + self.specialization = self.signature.generic_context.map(|gc| builder.build(gc)); + self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| { + // The inherited generic context is used when inferring the specialization of a generic + // class from a constructor call. In this case (only), we promote any typevars that are + // inferred as a literal to the corresponding instance type. + builder + .build(gc) + .apply_type_mapping(self.db, &TypeMapping::PromoteLiterals) + }); + } + + fn check_argument_type( + &mut self, + argument_index: usize, + adjusted_argument_index: Option, + argument: Argument<'a>, + mut argument_type: Type<'db>, + ) { + let Some(parameter_index) = self.argument_parameters[argument_index] else { + // There was an error with argument when matching parameters, so don't bother + // type-checking it. + return; + }; + let parameters = self.signature.parameters(); + let parameter = ¶meters[parameter_index]; + if let Some(mut expected_ty) = parameter.annotated_type() { + if let Some(specialization) = self.specialization { + argument_type = argument_type.apply_specialization(self.db, specialization); + expected_ty = expected_ty.apply_specialization(self.db, specialization); + } + if let Some(inherited_specialization) = self.inherited_specialization { + argument_type = + argument_type.apply_specialization(self.db, inherited_specialization); + expected_ty = expected_ty.apply_specialization(self.db, inherited_specialization); + } + if !argument_type.is_assignable_to(self.db, expected_ty) { + let positional = matches!(argument, Argument::Positional | Argument::Synthetic) + && !parameter.is_variadic(); + self.errors.push(BindingError::InvalidArgumentType { + parameter: ParameterContext::new(parameter, parameter_index, positional), + argument_index: adjusted_argument_index, + expected_ty, + provided_ty: argument_type, + }); + } + } + // We still update the actual type of the parameter in this binding to match the + // argument, even if the argument type is not assignable to the expected parameter + // type. + if let Some(existing) = self.parameter_tys[parameter_index].replace(argument_type) { + // We already verified in `match_parameters` that we only match multiple arguments + // with variadic parameters. + let union = UnionType::from_elements(self.db, [existing, argument_type]); + self.parameter_tys[parameter_index] = Some(union); + } + } + + fn check_argument_types(&mut self) { + for (argument_index, adjusted_argument_index, argument, argument_type) in + self.enumerate_argument_types() + { + self.check_argument_type( + argument_index, + adjusted_argument_index, + argument, + argument_type, + ); + } + } + + fn finish(self) -> (Option>, Option>) { + (self.specialization, self.inherited_specialization) + } +} + /// Binding information for one of the overloads of a callable. #[derive(Debug)] pub(crate) struct Binding<'db> { @@ -1817,115 +2145,30 @@ impl<'db> Binding<'db> { conflicting_forms: &mut [bool], ) { let parameters = self.signature.parameters(); - // The parameter that each argument is matched with. - let mut argument_parameters = vec![None; arguments.len()]; - // Whether each parameter has been matched with an argument. - let mut parameter_matched = vec![false; parameters.len()]; - let mut next_positional = 0; - let mut first_excess_positional = None; - let mut num_synthetic_args = 0; - let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { - if argument_index >= num_synthetic_args { - // Adjust the argument index to skip synthetic args, which don't appear at the call - // site and thus won't be in the Call node arguments list. - Some(argument_index - num_synthetic_args) - } else { - // we are erroring on a synthetic argument, we'll just emit the diagnostic on the - // entire Call node, since there's no argument node for this argument at the call site - None - } - }; + let mut matcher = ArgumentMatcher::new( + arguments, + parameters, + argument_forms, + conflicting_forms, + &mut self.errors, + ); for (argument_index, argument) in arguments.iter().enumerate() { - let (index, parameter, positional) = match argument { + match argument { Argument::Positional | Argument::Synthetic => { - if matches!(argument, Argument::Synthetic) { - num_synthetic_args += 1; - } - let Some((index, parameter)) = parameters - .get_positional(next_positional) - .map(|param| (next_positional, param)) - .or_else(|| parameters.variadic()) - else { - first_excess_positional.get_or_insert(argument_index); - next_positional += 1; - continue; - }; - next_positional += 1; - (index, parameter, !parameter.is_variadic()) + let _ = matcher.match_positional(argument_index, argument); } Argument::Keyword(name) => { - let Some((index, parameter)) = parameters - .keyword_by_name(name) - .or_else(|| parameters.keyword_variadic()) - else { - self.errors.push(BindingError::UnknownArgument { - argument_name: ast::name::Name::new(name), - argument_index: get_argument_index(argument_index, num_synthetic_args), - }); - continue; - }; - (index, parameter, false) + let _ = matcher.match_keyword(argument_index, argument, name); } - Argument::Variadic | Argument::Keywords => { // TODO continue; } - }; - if !matches!(argument, Argument::Synthetic) { - if let Some(existing) = - argument_forms[argument_index - num_synthetic_args].replace(parameter.form) - { - if existing != parameter.form { - conflicting_forms[argument_index - num_synthetic_args] = true; - } - } - } - if parameter_matched[index] { - if !parameter.is_variadic() && !parameter.is_keyword_variadic() { - self.errors.push(BindingError::ParameterAlreadyAssigned { - argument_index: get_argument_index(argument_index, num_synthetic_args), - parameter: ParameterContext::new(parameter, index, positional), - }); - } - } - argument_parameters[argument_index] = Some(index); - parameter_matched[index] = true; - } - if let Some(first_excess_argument_index) = first_excess_positional { - self.errors.push(BindingError::TooManyPositionalArguments { - first_excess_argument_index: get_argument_index( - first_excess_argument_index, - num_synthetic_args, - ), - expected_positional_count: parameters.positional().count(), - provided_positional_count: next_positional, - }); - } - let mut missing = vec![]; - for (index, matched) in parameter_matched.iter().copied().enumerate() { - if !matched { - let param = ¶meters[index]; - if param.is_variadic() - || param.is_keyword_variadic() - || param.default_type().is_some() - { - // variadic/keywords and defaulted arguments are not required - continue; - } - missing.push(ParameterContext::new(param, index, false)); } } - - if !missing.is_empty() { - self.errors.push(BindingError::MissingArguments { - parameters: ParameterContexts(missing), - }); - } - self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown()); - self.argument_parameters = argument_parameters.into_boxed_slice(); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); + self.argument_parameters = matcher.finish(); } fn check_types( @@ -1934,106 +2177,22 @@ impl<'db> Binding<'db> { arguments: &CallArguments<'_>, argument_types: &[Type<'db>], ) { - let mut num_synthetic_args = 0; - let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { - if argument_index >= num_synthetic_args { - // Adjust the argument index to skip synthetic args, which don't appear at the call - // site and thus won't be in the Call node arguments list. - Some(argument_index - num_synthetic_args) - } else { - // we are erroring on a synthetic argument, we'll just emit the diagnostic on the - // entire Call node, since there's no argument node for this argument at the call site - None - } - }; - - let enumerate_argument_types = || { - arguments - .iter() - .zip(argument_types.iter().copied()) - .enumerate() - }; + let mut checker = ArgumentTypeChecker::new( + db, + &self.signature, + arguments, + argument_types, + &self.argument_parameters, + &mut self.parameter_tys, + &mut self.errors, + ); // If this overload is generic, first see if we can infer a specialization of the function // from the arguments that were passed in. - let signature = &self.signature; - let parameters = signature.parameters(); - if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() { - let mut builder = SpecializationBuilder::new(db); - for (argument_index, (argument, argument_type)) in enumerate_argument_types() { - if matches!(argument, Argument::Synthetic) { - num_synthetic_args += 1; - } - let Some(parameter_index) = self.argument_parameters[argument_index] else { - // There was an error with argument when matching parameters, so don't bother - // type-checking it. - continue; - }; - let parameter = ¶meters[parameter_index]; - let Some(expected_type) = parameter.annotated_type() else { - continue; - }; - if let Err(error) = builder.infer(expected_type, argument_type) { - self.errors.push(BindingError::SpecializationError { - error, - argument_index: get_argument_index(argument_index, num_synthetic_args), - }); - } - } - self.specialization = signature.generic_context.map(|gc| builder.build(gc)); - self.inherited_specialization = signature.inherited_generic_context.map(|gc| { - // The inherited generic context is used when inferring the specialization of a - // generic class from a constructor call. In this case (only), we promote any - // typevars that are inferred as a literal to the corresponding instance type. - builder - .build(gc) - .apply_type_mapping(db, &TypeMapping::PromoteLiterals) - }); - } - - num_synthetic_args = 0; - for (argument_index, (argument, mut argument_type)) in enumerate_argument_types() { - if matches!(argument, Argument::Synthetic) { - num_synthetic_args += 1; - } - let Some(parameter_index) = self.argument_parameters[argument_index] else { - // There was an error with argument when matching parameters, so don't bother - // type-checking it. - continue; - }; - let parameter = ¶meters[parameter_index]; - if let Some(mut expected_ty) = parameter.annotated_type() { - if let Some(specialization) = self.specialization { - argument_type = argument_type.apply_specialization(db, specialization); - expected_ty = expected_ty.apply_specialization(db, specialization); - } - if let Some(inherited_specialization) = self.inherited_specialization { - argument_type = - argument_type.apply_specialization(db, inherited_specialization); - expected_ty = expected_ty.apply_specialization(db, inherited_specialization); - } - if !argument_type.is_assignable_to(db, expected_ty) { - let positional = matches!(argument, Argument::Positional | Argument::Synthetic) - && !parameter.is_variadic(); - self.errors.push(BindingError::InvalidArgumentType { - parameter: ParameterContext::new(parameter, parameter_index, positional), - argument_index: get_argument_index(argument_index, num_synthetic_args), - expected_ty, - provided_ty: argument_type, - }); - } - } - // We still update the actual type of the parameter in this binding to match the - // argument, even if the argument type is not assignable to the expected parameter - // type. - if let Some(existing) = self.parameter_tys[parameter_index].replace(argument_type) { - // We already verified in `match_parameters` that we only match multiple arguments - // with variadic parameters. - let union = UnionType::from_elements(db, [existing, argument_type]); - self.parameter_tys[parameter_index] = Some(union); - } - } + checker.infer_specialization(); + checker.check_argument_types(); + (self.specialization, self.inherited_specialization) = checker.finish(); if let Some(specialization) = self.specialization { self.return_ty = self.return_ty.apply_specialization(db, specialization); }