From 4e94b22815c7979003d07a4931c9174285617c90 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Thu, 2 Oct 2025 20:11:56 +0530 Subject: [PATCH] [ty] Support single-starred argument for overload call (#20223) ## Summary closes: https://github.com/astral-sh/ty/issues/247 This PR adds support for variadic arguments to overload call evaluation. This basically boils down to making sure that the overloads are not filtered out incorrectly during the step 5 in the overload call evaluation algorithm. For context, the step 5 tries to filter out the remaining overloads after finding an overload where the materialization of argument types are assignable to the parameter types. The issue with the previous implementation was that it wouldn't unpack the variadic argument and wouldn't consider the many-to-one (multiple arguments mapping to a single variadic parameter) correctly. This PR fixes that. ## Test Plan Update existing test cases and resolve the TODOs. --- .../resources/mdtest/call/overloads.md | 60 ++++------ .../ty_python_semantic/src/types/call/bind.rs | 110 +++++++++++++----- 2 files changed, 99 insertions(+), 71 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index fcb23dc077..bfb516f026 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -139,8 +139,7 @@ reveal_type(f(A())) # revealed: A reveal_type(f(*(A(),))) # revealed: A reveal_type(f(B())) # revealed: A -# TODO: revealed: A -reveal_type(f(*(B(),))) # revealed: Unknown +reveal_type(f(*(B(),))) # revealed: A # But, in this case, the arity check filters out the first overload, so we only have one match: reveal_type(f(B(), 1)) # revealed: B @@ -551,16 +550,13 @@ from overloaded import MyEnumSubclass, ActualEnum, f def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass): reveal_type(f(actual_enum)) # revealed: Both - # TODO: revealed: Both - reveal_type(f(*(actual_enum,))) # revealed: Unknown + reveal_type(f(*(actual_enum,))) # revealed: Both reveal_type(f(ActualEnum.A)) # revealed: OnlyA - # TODO: revealed: OnlyA - reveal_type(f(*(ActualEnum.A,))) # revealed: Unknown + reveal_type(f(*(ActualEnum.A,))) # revealed: OnlyA reveal_type(f(ActualEnum.B)) # revealed: OnlyB - # TODO: revealed: OnlyB - reveal_type(f(*(ActualEnum.B,))) # revealed: Unknown + reveal_type(f(*(ActualEnum.B,))) # revealed: OnlyB reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass reveal_type(f(*(my_enum_instance,))) # revealed: MyEnumSubclass @@ -1097,12 +1093,10 @@ reveal_type(f(*(1,))) # revealed: str def _(list_int: list[int], list_any: list[Any]): reveal_type(f(list_int)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(list_int,))) # revealed: Unknown + reveal_type(f(*(list_int,))) # revealed: int reveal_type(f(list_any)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(list_any,))) # revealed: Unknown + reveal_type(f(*(list_any,))) # revealed: int ``` ### Single list argument (ambiguous) @@ -1136,8 +1130,7 @@ def _(list_int: list[int], list_any: list[Any]): # All materializations of `list[int]` are assignable to `list[int]`, so it matches the first # overload. reveal_type(f(list_int)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(list_int,))) # revealed: Unknown + reveal_type(f(*(list_int,))) # revealed: int # All materializations of `list[Any]` are assignable to `list[int]` and `list[Any]`, but the # return type of first and second overloads are not equivalent, so the overload matching @@ -1170,25 +1163,21 @@ reveal_type(f("a")) # revealed: str reveal_type(f(*("a",))) # revealed: str reveal_type(f((1, "b"))) # revealed: int -# TODO: revealed: int -reveal_type(f(*((1, "b"),))) # revealed: Unknown +reveal_type(f(*((1, "b"),))) # revealed: int reveal_type(f((1, 2))) # revealed: int -# TODO: revealed: int -reveal_type(f(*((1, 2),))) # revealed: Unknown +reveal_type(f(*((1, 2),))) # revealed: int def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, Any]): # All materializations are assignable to first overload, so second and third overloads are # eliminated reveal_type(f(int_str)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(int_str,))) # revealed: Unknown + reveal_type(f(*(int_str,))) # revealed: int # All materializations are assignable to second overload, so the third overload is eliminated; # the return type of first and second overload is equivalent reveal_type(f(int_any)) # revealed: int - # TODO: revealed: int - reveal_type(f(*(int_any,))) # revealed: Unknown + reveal_type(f(*(int_any,))) # revealed: int # All materializations of `tuple[Any, Any]` are assignable to the parameters of all the # overloads, but the return types aren't equivalent, so the overload matching is ambiguous @@ -1266,26 +1255,22 @@ def _(list_int: list[int], list_any: list[Any], int_str: tuple[int, str], int_an # All materializations of both argument types are assignable to the first overload, so the # second and third overloads are filtered out reveal_type(f(list_int, int_str)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_int, int_str))) # revealed: Unknown + reveal_type(f(*(list_int, int_str))) # revealed: A # All materialization of first argument is assignable to first overload and for the second # argument, they're assignable to the second overload, so the third overload is filtered out reveal_type(f(list_int, int_any)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_int, int_any))) # revealed: Unknown + reveal_type(f(*(list_int, int_any))) # revealed: A # All materialization of first argument is assignable to second overload and for the second # argument, they're assignable to the first overload, so the third overload is filtered out reveal_type(f(list_any, int_str)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_any, int_str))) # revealed: Unknown + reveal_type(f(*(list_any, int_str))) # revealed: A # All materializations of both arguments are assignable to the second overload, so the third # overload is filtered out reveal_type(f(list_any, int_any)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_any, int_any))) # revealed: Unknown + reveal_type(f(*(list_any, int_any))) # revealed: A # All materializations of first argument is assignable to the second overload and for the second # argument, they're assignable to the third overload, so no overloads are filtered out; the @@ -1316,8 +1301,7 @@ from overloaded import f def _(literal: LiteralString, string: str, any: Any): reveal_type(f(literal)) # revealed: LiteralString - # TODO: revealed: LiteralString - reveal_type(f(*(literal,))) # revealed: Unknown + reveal_type(f(*(literal,))) # revealed: LiteralString reveal_type(f(string)) # revealed: str reveal_type(f(*(string,))) # revealed: str @@ -1355,12 +1339,10 @@ from overloaded import f def _(list_int: list[int], list_str: list[str], list_any: list[Any], any: Any): reveal_type(f(list_int)) # revealed: A - # TODO: revealed: A - reveal_type(f(*(list_int,))) # revealed: Unknown + reveal_type(f(*(list_int,))) # revealed: A reveal_type(f(list_str)) # revealed: str - # TODO: Should be `str` - reveal_type(f(*(list_str,))) # revealed: Unknown + reveal_type(f(*(list_str,))) # revealed: str reveal_type(f(list_any)) # revealed: Unknown reveal_type(f(*(list_any,))) # revealed: Unknown @@ -1561,12 +1543,10 @@ def _(any: Any): reveal_type(f(*(any,), flag=False)) # revealed: str def _(args: tuple[Any, Literal[True]]): - # TODO: revealed: int - reveal_type(f(*args)) # revealed: Unknown + reveal_type(f(*args)) # revealed: int def _(args: tuple[Any, Literal[False]]): - # TODO: revealed: str - reveal_type(f(*args)) # revealed: Unknown + reveal_type(f(*args)) # revealed: str ``` ### Argument type expansion diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 59a38e303b..841159221b 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, - TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType, + TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; @@ -1588,6 +1588,14 @@ impl<'db> CallableBinding<'db> { arguments: &CallArguments<'_, 'db>, matching_overload_indexes: &[usize], ) { + // The maximum number of parameters across all the overloads that are being considered + // for filtering. + let max_parameter_count = matching_overload_indexes + .iter() + .map(|&index| self.overloads[index].signature.parameters().len()) + .max() + .unwrap_or(0); + // These are the parameter indexes that matches the arguments that participate in the // filtering process. // @@ -1595,41 +1603,67 @@ impl<'db> CallableBinding<'db> { // gradual equivalent to the parameter types at the same index for other overloads. let mut participating_parameter_indexes = HashSet::new(); - // These only contain the top materialized argument types for the corresponding - // participating parameter indexes. - let mut top_materialized_argument_types = vec![]; + // The parameter types at each index for the first overload containing a parameter at + // that index. + let mut first_parameter_types: Vec>> = vec![None; max_parameter_count]; - for (argument_index, argument_type) in arguments.iter_types().enumerate() { - let mut first_parameter_type: Option> = None; - let mut participating_parameter_index = None; - - 'overload: for overload_index in matching_overload_indexes { + for argument_index in 0..arguments.len() { + for overload_index in matching_overload_indexes { let overload = &self.overloads[*overload_index]; - for parameter_index in &overload.argument_matches[argument_index].parameters { + for ¶meter_index in &overload.argument_matches[argument_index].parameters { // TODO: For an unannotated `self` / `cls` parameter, the type should be // `typing.Self` / `type[typing.Self]` - let current_parameter_type = overload.signature.parameters()[*parameter_index] + let current_parameter_type = overload.signature.parameters()[parameter_index] .annotated_type() .unwrap_or(Type::unknown()); + let first_parameter_type = &mut first_parameter_types[parameter_index]; if let Some(first_parameter_type) = first_parameter_type { if !first_parameter_type.is_equivalent_to(db, current_parameter_type) { - participating_parameter_index = Some(*parameter_index); - break 'overload; + participating_parameter_indexes.insert(parameter_index); } } else { - first_parameter_type = Some(current_parameter_type); + *first_parameter_type = Some(current_parameter_type); } } } + } - if let Some(parameter_index) = participating_parameter_index { - participating_parameter_indexes.insert(parameter_index); - top_materialized_argument_types.push(argument_type.top_materialization(db)); + let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db)) + .take(max_parameter_count) + .collect::>(); + + for (argument_index, argument_type) in arguments.iter_types().enumerate() { + for overload_index in matching_overload_indexes { + let overload = &self.overloads[*overload_index]; + for (parameter_index, variadic_argument_type) in + overload.argument_matches[argument_index].iter() + { + if !participating_parameter_indexes.contains(¶meter_index) { + continue; + } + union_argument_type_builders[parameter_index].add_in_place( + variadic_argument_type + .unwrap_or(argument_type) + .top_materialization(db), + ); + } } } - let top_materialized_argument_type = - Type::heterogeneous_tuple(db, top_materialized_argument_types); + // These only contain the top materialized argument types for the corresponding + // participating parameter indexes. + let top_materialized_argument_type = Type::heterogeneous_tuple( + db, + union_argument_type_builders + .into_iter() + .filter_map(|builder| { + if builder.is_empty() { + None + } else { + Some(builder.build()) + } + }), + ); // A flag to indicate whether we've found the overload that makes the remaining overloads // unmatched for the given argument types. @@ -1640,15 +1674,22 @@ impl<'db> CallableBinding<'db> { self.overloads[*current_index].mark_as_unmatched_overload(); continue; } - let mut parameter_types = Vec::with_capacity(arguments.len()); + + let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db)) + .take(max_parameter_count) + .collect::>(); + + // The number of parameters that have been skipped because they don't participate in + // the filtering process. This is used to make sure the types are added to the + // corresponding parameter index in `union_parameter_types`. + let mut skipped_parameters = 0; + for argument_index in 0..arguments.len() { - // The parameter types at the current argument index. - let mut current_parameter_types = vec![]; for overload_index in &matching_overload_indexes[..=upto] { let overload = &self.overloads[*overload_index]; for parameter_index in &overload.argument_matches[argument_index].parameters { if !participating_parameter_indexes.contains(parameter_index) { - // This parameter doesn't participate in the filtering process. + skipped_parameters += 1; continue; } // TODO: For an unannotated `self` / `cls` parameter, the type should be @@ -1664,17 +1705,24 @@ impl<'db> CallableBinding<'db> { parameter_type = parameter_type.apply_specialization(db, inherited_specialization); } - current_parameter_types.push(parameter_type); + union_parameter_types[parameter_index.saturating_sub(skipped_parameters)] + .add_in_place(parameter_type); } } - if current_parameter_types.is_empty() { - continue; - } - parameter_types.push(UnionType::from_elements(db, current_parameter_types)); } - if top_materialized_argument_type - .is_assignable_to(db, Type::heterogeneous_tuple(db, parameter_types)) - { + + let parameter_types = Type::heterogeneous_tuple( + db, + union_parameter_types.into_iter().filter_map(|builder| { + if builder.is_empty() { + None + } else { + Some(builder.build()) + } + }), + ); + + if top_materialized_argument_type.is_assignable_to(db, parameter_types) { filter_remaining_overloads = true; } }