diff --git a/crates/ty_python_semantic/resources/mdtest/call/function.md b/crates/ty_python_semantic/resources/mdtest/call/function.md index de4bbe2e7f..069558ad77 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/function.md +++ b/crates/ty_python_semantic/resources/mdtest/call/function.md @@ -159,6 +159,8 @@ def _(args: list[int]) -> None: takes_zero(*args) takes_one(*args) takes_two(*args) + takes_two(*b"ab") + takes_two(*b"abc") # error: [too-many-positional-arguments] takes_two_positional_only(*args) takes_two_different(*args) # error: [invalid-argument-type] takes_two_different_positional_only(*args) # error: [invalid-argument-type] diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index c769f904d6..fcb23dc077 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -931,6 +931,134 @@ def _(t: tuple[int, str] | tuple[int, str, int]) -> None: f(*t) # error: [no-matching-overload] ``` +## Filtering based on variaidic arguments + +This is step 4 of the overload call evaluation algorithm which specifies that: + +> If the argument list is compatible with two or more overloads, determine whether one or more of +> the overloads has a variadic parameter (either `*args` or `**kwargs`) that maps to a corresponding +> argument that supplies an indeterminate number of positional or keyword arguments. If so, +> eliminate overloads that do not have a variadic parameter. + +This is only performed if the previous step resulted in more than one matching overload. + +### Simple `*args` + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f(x1: int) -> tuple[int]: ... +@overload +def f(x1: int, x2: int) -> tuple[int, int]: ... +@overload +def f(*args: int) -> int: ... +``` + +```py +from overloaded import f + +def _(x1: int, x2: int, args: list[int]): + reveal_type(f(x1)) # revealed: tuple[int] + reveal_type(f(x1, x2)) # revealed: tuple[int, int] + reveal_type(f(*(x1, x2))) # revealed: tuple[int, int] + + # Step 4 should filter out all but the last overload. + reveal_type(f(*args)) # revealed: int +``` + +### Variable `*args` + +```toml +[environment] +python-version = "3.11" +``` + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f(x1: int) -> tuple[int]: ... +@overload +def f(x1: int, x2: int) -> tuple[int, int]: ... +@overload +def f(x1: int, *args: int) -> tuple[int, ...]: ... +``` + +```py +from overloaded import f + +def _(x1: int, x2: int, args1: list[int], args2: tuple[int, *tuple[int, ...]]): + reveal_type(f(x1, x2)) # revealed: tuple[int, int] + reveal_type(f(*(x1, x2))) # revealed: tuple[int, int] + + # Step 4 should filter out all but the last overload. + reveal_type(f(x1, *args1)) # revealed: tuple[int, ...] + reveal_type(f(*args2)) # revealed: tuple[int, ...] +``` + +### Simple `**kwargs` + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f(*, x1: int) -> int: ... +@overload +def f(*, x1: int, x2: int) -> tuple[int, int]: ... +@overload +def f(**kwargs: int) -> int: ... +``` + +```py +from overloaded import f + +def _(x1: int, x2: int, kwargs: dict[str, int]): + reveal_type(f(x1=x1)) # revealed: int + reveal_type(f(x1=x1, x2=x2)) # revealed: tuple[int, int] + + # Step 4 should filter out all but the last overload. + reveal_type(f(**{"x1": x1, "x2": x2})) # revealed: int + reveal_type(f(**kwargs)) # revealed: int +``` + +### `TypedDict` + +The keys in a `TypedDict` are static so there's no variable part to it, so step 4 shouldn't filter +out any overloads. + +`overloaded.pyi`: + +```pyi +from typing import TypedDict, overload + +@overload +def f(*, x: int) -> int: ... +@overload +def f(*, x: int, y: int) -> tuple[int, int]: ... +@overload +def f(**kwargs: int) -> tuple[int, ...]: ... +``` + +```py +from typing import TypedDict +from overloaded import f + +class Foo(TypedDict): + x: int + y: int + +def _(foo: Foo, kwargs: dict[str, int]): + reveal_type(f(**foo)) # revealed: tuple[int, int] + reveal_type(f(**kwargs)) # revealed: tuple[int, ...] +``` + ## Filtering based on `Any` / `Unknown` This is the step 5 of the overload call evaluation algorithm which specifies that: diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 46628983e2..7182694780 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1333,10 +1333,30 @@ impl<'db> CallableBinding<'db> { } MatchingOverloadIndex::Multiple(indexes) => { // If two or more candidate overloads remain, proceed to step 4. - // TODO: Step 4 + self.filter_overloads_containing_variadic(&indexes); - // Step 5 - self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes); + match self.matching_overload_index() { + MatchingOverloadIndex::None => { + // This shouldn't be possible because step 4 can only filter out overloads + // when there _is_ a matching variadic argument. + tracing::debug!("All overloads have been filtered out in step 4"); + return None; + } + MatchingOverloadIndex::Single(index) => { + // If only one candidate overload remains, it is the winning match. + // Evaluate it as a regular (non-overloaded) call. + self.matching_overload_index = Some(index); + return None; + } + MatchingOverloadIndex::Multiple(indexes) => { + // If two or more candidate overloads remain, proceed to step 5. + self.filter_overloads_using_any_or_unknown( + db, + argument_types.as_ref(), + &indexes, + ); + } + } // This shouldn't lead to argument type expansion. return None; @@ -1446,15 +1466,28 @@ impl<'db> CallableBinding<'db> { Some(self.overloads[index].return_type()) } MatchingOverloadIndex::Multiple(matching_overload_indexes) => { - // TODO: Step 4 + self.filter_overloads_containing_variadic(&matching_overload_indexes); - self.filter_overloads_using_any_or_unknown( - db, - expanded_arguments, - &matching_overload_indexes, - ); - - Some(self.return_type()) + match self.matching_overload_index() { + MatchingOverloadIndex::None => { + tracing::debug!( + "All overloads have been filtered out in step 4 during argument type expansion" + ); + None + } + MatchingOverloadIndex::Single(index) => { + self.matching_overload_index = Some(index); + Some(self.return_type()) + } + MatchingOverloadIndex::Multiple(indexes) => { + self.filter_overloads_using_any_or_unknown( + db, + expanded_arguments, + &indexes, + ); + Some(self.return_type()) + } + } } }; @@ -1511,6 +1544,32 @@ impl<'db> CallableBinding<'db> { None } + /// Filter overloads based on variadic argument to variadic parameter match. + /// + /// This is the step 4 of the [overload call evaluation algorithm][1]. + /// + /// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation + fn filter_overloads_containing_variadic(&mut self, matching_overload_indexes: &[usize]) { + let variadic_matching_overloads = matching_overload_indexes + .iter() + .filter(|&&overload_index| { + self.overloads[overload_index].variadic_argument_matched_to_variadic_parameter + }) + .collect::>(); + + if variadic_matching_overloads.is_empty() + || variadic_matching_overloads.len() == matching_overload_indexes.len() + { + return; + } + + for overload_index in matching_overload_indexes { + if !variadic_matching_overloads.contains(overload_index) { + self.overloads[*overload_index].mark_as_unmatched_overload(); + } + } + } + /// Filter overloads based on [`Any`] or [`Unknown`] argument types. /// /// This is the step 5 of the [overload call evaluation algorithm][1]. @@ -1995,6 +2054,7 @@ struct ArgumentMatcher<'a, 'db> { next_positional: usize, first_excess_positional: Option, num_synthetic_args: usize, + variadic_argument_matched_to_variadic_parameter: bool, } impl<'a, 'db> ArgumentMatcher<'a, 'db> { @@ -2014,6 +2074,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { next_positional: 0, first_excess_positional: None, num_synthetic_args: 0, + variadic_argument_matched_to_variadic_parameter: false, } } @@ -2029,6 +2090,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { } } + #[expect(clippy::too_many_arguments)] fn assign_argument( &mut self, argument_index: usize, @@ -2037,6 +2099,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { parameter_index: usize, parameter: &Parameter<'db>, positional: bool, + variable_argument_length: bool, ) { if !matches!(argument, Argument::Synthetic) { let adjusted_argument_index = argument_index - self.num_synthetic_args; @@ -2057,6 +2120,15 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { }); } } + if variable_argument_length + && matches!( + (argument, parameter.kind()), + (Argument::Variadic, ParameterKind::Variadic { .. }) + | (Argument::Keywords, ParameterKind::KeywordVariadic { .. }) + ) + { + self.variadic_argument_matched_to_variadic_parameter = true; + } let matched_argument = &mut self.argument_matches[argument_index]; matched_argument.parameters.push(parameter_index); matched_argument.types.push(argument_type); @@ -2069,6 +2141,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { argument_index: usize, argument: Argument<'a>, argument_type: Option>, + variable_argument_length: bool, ) -> Result<(), ()> { if matches!(argument, Argument::Synthetic) { self.num_synthetic_args += 1; @@ -2091,6 +2164,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { parameter_index, parameter, !parameter.is_variadic(), + variable_argument_length, ); Ok(()) } @@ -2131,6 +2205,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { parameter_index, parameter, false, + false, ); Ok(()) } @@ -2157,6 +2232,8 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { ), }; + let is_variable = length.is_variable(); + // We must be able to match up the fixed-length portion of the argument with positional // parameters, so we pass on any errors that occur. for _ in 0..length.minimum() { @@ -2164,12 +2241,13 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { argument_index, argument, argument_types.next().or(variable_element), + is_variable, )?; } // If the tuple is variable-length, we assume that it will soak up all remaining positional // parameters. - if length.is_variable() { + if is_variable { while self .parameters .get_positional(self.next_positional) @@ -2179,6 +2257,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { argument_index, argument, argument_types.next().or(variable_element), + is_variable, )?; } } @@ -2189,9 +2268,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { // raise a false positive as "too many arguments". if self.parameters.variadic().is_some() { if let Some(argument_type) = argument_types.next().or(variable_element) { - self.match_positional(argument_index, argument, Some(argument_type))?; + self.match_positional(argument_index, argument, Some(argument_type), is_variable)?; for argument_type in argument_types { - self.match_positional(argument_index, argument, Some(argument_type))?; + self.match_positional( + argument_index, + argument, + Some(argument_type), + is_variable, + )?; } } } @@ -2248,6 +2332,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { parameter_index, parameter, false, + true, ); } } @@ -2670,6 +2755,10 @@ pub(crate) struct Binding<'db> { /// order. argument_matches: Box<[MatchedArgument<'db>]>, + /// Whether an argument that supplies an indeterminate number of positional or keyword + /// arguments is mapped to a variadic parameter (`*args` or `**kwargs`). + variadic_argument_matched_to_variadic_parameter: bool, + /// Bound types for parameters, in parameter source order, or `None` if no argument was matched /// to that parameter. parameter_tys: Box<[Option>]>, @@ -2688,6 +2777,7 @@ impl<'db> Binding<'db> { specialization: None, inherited_specialization: None, argument_matches: Box::from([]), + variadic_argument_matched_to_variadic_parameter: false, parameter_tys: Box::from([]), errors: vec![], } @@ -2712,7 +2802,7 @@ impl<'db> Binding<'db> { for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() { match argument { Argument::Positional | Argument::Synthetic => { - let _ = matcher.match_positional(argument_index, argument, None); + let _ = matcher.match_positional(argument_index, argument, None, false); } Argument::Keyword(name) => { let _ = matcher.match_keyword(argument_index, argument, None, name); @@ -2730,6 +2820,8 @@ impl<'db> Binding<'db> { } self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown()); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); + self.variadic_argument_matched_to_variadic_parameter = + matcher.variadic_argument_matched_to_variadic_parameter; self.argument_matches = matcher.finish(); } diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 1b0ab92fb5..db7c4d26b7 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -44,7 +44,7 @@ impl TupleLength { TupleLength::Variable(0, 0) } - pub(crate) fn is_variable(self) -> bool { + pub(crate) const fn is_variable(self) -> bool { matches!(self, TupleLength::Variable(_, _)) }