[ty] Fix overload filtering to prefer more "precise" match (#21859)

## Summary

fixes: https://github.com/astral-sh/ty/issues/1809

I took this chance to add some debug level tracing logs for overload
call evaluation similar to Doug's implementation in `constraints.rs`.

## Test Plan

- Add new mdtests
- Tested it against `sqlalchemy.select` in pyx which results in the
correct overload being matched
This commit is contained in:
Dhruv Manilawala
2025-12-09 20:29:34 +05:30
committed by GitHub
parent 426125f5c0
commit c35bf8f441
3 changed files with 164 additions and 2 deletions

View File

@@ -228,7 +228,7 @@ impl<'a, 'db> CallArguments<'a, 'db> {
if expansion_size > MAX_EXPANSIONS {
tracing::debug!(
"Skipping argument type expansion as it would exceed the \
maximum number of expansions ({MAX_EXPANSIONS})"
maximum number of expansions ({MAX_EXPANSIONS})"
);
return Some(State::LimitReached(index));
}

View File

@@ -2,6 +2,13 @@
//! arguments against the parameters of the callable. Like with
//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a
//! union of types, each of which might contain multiple overloads.
//!
//! ### Tracing
//!
//! This module is instrumented with debug-level `tracing` messages. You can set the `TY_LOG`
//! environment variable to see this output when testing locally. `tracing` log messages typically
//! have a `target` field, which is the name of the module the message appears in — in this case,
//! `ty_python_semantic::types::call::bind`.
use std::borrow::Cow;
use std::collections::HashSet;
@@ -1582,6 +1589,13 @@ impl<'db> CallableBinding<'db> {
// before checking.
let argument_types = argument_types.with_self(self.bound_type);
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 1",
);
// Step 1: Check the result of the arity check which is done by `match_parameters`
let matching_overload_indexes = match self.matching_overload_index() {
MatchingOverloadIndex::None => {
@@ -1612,6 +1626,13 @@ impl<'db> CallableBinding<'db> {
overload.check_types(db, argument_types.as_ref(), call_expression_tcx);
}
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 2",
);
match self.matching_overload_index() {
MatchingOverloadIndex::None => {
// If all overloads result in errors, proceed to step 3.
@@ -1624,6 +1645,13 @@ impl<'db> CallableBinding<'db> {
// If two or more candidate overloads remain, proceed to step 4.
self.filter_overloads_containing_variadic(&indexes);
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 4",
);
match self.matching_overload_index() {
MatchingOverloadIndex::None => {
// This shouldn't be possible because step 4 can only filter out overloads
@@ -1642,6 +1670,13 @@ impl<'db> CallableBinding<'db> {
argument_types.as_ref(),
&indexes,
);
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 5",
);
}
}
@@ -1744,12 +1779,26 @@ impl<'db> CallableBinding<'db> {
overload.match_parameters(db, expanded_arguments, &mut argument_forms);
}
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 1",
);
merged_argument_forms.merge(&argument_forms);
for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, expanded_arguments, call_expression_tcx);
}
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 2",
);
let return_type = match self.matching_overload_index() {
MatchingOverloadIndex::None => None,
MatchingOverloadIndex::Single(index) => {
@@ -1758,6 +1807,13 @@ impl<'db> CallableBinding<'db> {
MatchingOverloadIndex::Multiple(matching_overload_indexes) => {
self.filter_overloads_containing_variadic(&matching_overload_indexes);
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 4",
);
match self.matching_overload_index() {
MatchingOverloadIndex::None => {
tracing::debug!(
@@ -1772,6 +1828,14 @@ impl<'db> CallableBinding<'db> {
expanded_arguments,
&indexes,
);
tracing::debug!(
target: "ty_python_semantic::types::call::bind",
matching_overload_index = ?self.matching_overload_index(),
signature = %self.signature_type.display(db),
"after step 5",
);
Some(self.return_type())
}
}
@@ -1926,12 +1990,37 @@ impl<'db> CallableBinding<'db> {
.take(max_parameter_count)
.collect::<Vec<_>>();
// The following loop is trying to construct a tuple of argument types that correspond to
// the participating parameter indexes. Considering the following example:
//
// ```python
// @overload
// def f(x: Literal[1], y: Literal[2]) -> tuple[int, int]: ...
// @overload
// def f(*args: Any) -> tuple[Any, ...]: ...
//
// f(1, 2)
// ```
//
// Here, only the first parameter participates in the filtering process because only one
// overload has the second parameter. So, while going through the argument types, the
// second argument needs to be skipped but for the second overload both arguments map to
// the first parameter and that parameter is considered for the filtering process. This
// flag is to handle that special case of many-to-one mapping from arguments to parameters.
let mut variadic_parameter_handled = false;
for (argument_index, argument_type) in arguments.iter_types().enumerate() {
if variadic_parameter_handled {
continue;
}
for overload_index in matching_overload_indexes {
let overload = &self.overloads[*overload_index];
for (parameter_index, variadic_argument_type) in
overload.argument_matches[argument_index].iter()
{
if overload.signature.parameters()[parameter_index].is_variadic() {
variadic_parameter_handled = true;
}
if !participating_parameter_indexes.contains(&parameter_index) {
continue;
}