mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 21:40:51 -05:00
[ty] Add basic support for overloads in ParamSpec (#21946)
## Summary fixes: https://github.com/astral-sh/ty/issues/1838 This PR adds basic support for overloaded function when used to specialize a `ParamSpec` type variable. Following cases are still remaining: 1. Updating the specialization with the matching overload after the paramspec sub-call logic 2. Updating the specialization with the matching overload using the return type Both of these cases are present in the mdtest file. ## Test Plan Update mdtest with new cases.
This commit is contained in:
@@ -681,11 +681,64 @@ reveal_type(change_return_type(int_str)) # revealed: Overload[(x: int) -> str,
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(change_return_type(str_str)) # revealed: (...) -> str
|
||||
|
||||
# TODO: Both of these shouldn't raise an error
|
||||
# error: [invalid-argument-type]
|
||||
# TODO: This should reveal the matching overload instead
|
||||
reveal_type(with_parameters(int_int, 1)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(with_parameters(int_int, "a")) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
|
||||
# error: [invalid-argument-type] "Argument to function `with_parameters` is incorrect: Expected `int`, found `None`"
|
||||
reveal_type(with_parameters(int_int, None)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
|
||||
def foo(int_or_str: int | str):
|
||||
# Argument type expansion leads to matching both overloads.
|
||||
# TODO: Should this be an error instead?
|
||||
reveal_type(with_parameters(int_int, int_or_str)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
|
||||
# Keyword argument matching should also work
|
||||
# TODO: This should reveal the matching overload instead
|
||||
reveal_type(with_parameters(int_int, x=1)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
reveal_type(with_parameters(int_int, x="a")) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
|
||||
# No matching overload should error
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(with_parameters(int_int, 1.5)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
|
||||
```
|
||||
|
||||
### Overloads with multiple parameters
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def multi(x: int, y: int) -> int: ...
|
||||
@overload
|
||||
def multi(x: str, y: str) -> str: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
from overloaded import multi
|
||||
|
||||
def run[**P, R](f: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# Both arguments match first overload
|
||||
# TODO: should reveal `int`
|
||||
reveal_type(run(multi, 1, 2)) # revealed: int | str
|
||||
|
||||
# Both arguments match second overload
|
||||
# TODO: should reveal `str`
|
||||
reveal_type(run(multi, "a", "b")) # revealed: int | str
|
||||
|
||||
# Mixed positional and keyword
|
||||
# TODO: both should reveal `int`
|
||||
reveal_type(run(multi, 1, y=2)) # revealed: int | str
|
||||
reveal_type(run(multi, x=1, y=2)) # revealed: int | str
|
||||
|
||||
# No matching overload (int, str doesn't match either overload of `multi`)
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(run(multi, 1, "b")) # revealed: int | str
|
||||
```
|
||||
|
||||
### Overloads with subtitution of `P.args` and `P.kwargs`
|
||||
|
||||
@@ -3458,8 +3458,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||
/// are passed.
|
||||
///
|
||||
/// This method returns `false` if the specialization does not contain a mapping for the given
|
||||
/// `paramspec`, contains an invalid mapping (i.e., not a `Callable` of kind `ParamSpecValue`)
|
||||
/// or if the value is an overloaded callable.
|
||||
/// `paramspec` or contains an invalid mapping (i.e., not a `Callable` of kind `ParamSpecValue`).
|
||||
///
|
||||
/// For more details, refer to [`Self::try_paramspec_evaluation_at`].
|
||||
fn evaluate_paramspec_sub_call(
|
||||
@@ -3478,10 +3477,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Support overloads?
|
||||
let [signature] = callable.signatures(self.db).overloads.as_slice() else {
|
||||
let signatures = &callable.signatures(self.db).overloads;
|
||||
if signatures.is_empty() {
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
let sub_arguments = if let Some(argument_index) = argument_index {
|
||||
self.arguments.start_from(argument_index)
|
||||
@@ -3489,21 +3488,61 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||
CallArguments::none()
|
||||
};
|
||||
|
||||
// TODO: What should be the `signature_type` here?
|
||||
let bindings = match Bindings::from(Binding::single(self.signature_type, signature.clone()))
|
||||
// Create Bindings with all overloads and perform full overload resolution
|
||||
let callable_binding =
|
||||
CallableBinding::from_overloads(self.signature_type, signatures.iter().cloned());
|
||||
let bindings = match Bindings::from(callable_binding)
|
||||
.match_parameters(self.db, &sub_arguments)
|
||||
.check_types(self.db, &sub_arguments, self.call_expression_tcx, &[])
|
||||
{
|
||||
Ok(bindings) => Box::new(bindings),
|
||||
Err(CallError(_, bindings)) => bindings,
|
||||
Ok(bindings) => bindings,
|
||||
Err(CallError(_, bindings)) => *bindings,
|
||||
};
|
||||
|
||||
// SAFETY: `bindings` was created from a single binding above.
|
||||
let [binding] = bindings.single_element().unwrap().overloads.as_slice() else {
|
||||
unreachable!("ParamSpec sub-call should only contain a single binding");
|
||||
};
|
||||
// SAFETY: `bindings` was created from a single `CallableBinding` above.
|
||||
let callable_binding = bindings
|
||||
.single_element()
|
||||
.expect("ParamSpec sub-call should only contain a single CallableBinding");
|
||||
|
||||
self.errors.extend(binding.errors.iter().cloned());
|
||||
match callable_binding.matching_overload_index() {
|
||||
MatchingOverloadIndex::None => {
|
||||
if let [binding] = callable_binding.overloads() {
|
||||
// This is not an overloaded function, so we can propagate its errors to the
|
||||
// outer bindings.
|
||||
self.errors.extend(binding.errors.iter().cloned());
|
||||
} else {
|
||||
let index = callable_binding
|
||||
.matching_overload_before_type_checking
|
||||
.unwrap_or(0);
|
||||
// TODO: We should also update the specialization for the `ParamSpec` to reflect
|
||||
// the matching overload here.
|
||||
self.errors
|
||||
.extend(callable_binding.overloads()[index].errors.iter().cloned());
|
||||
}
|
||||
}
|
||||
MatchingOverloadIndex::Single(index) => {
|
||||
// TODO: We should also update the specialization for the `ParamSpec` to reflect the
|
||||
// matching overload here.
|
||||
self.errors
|
||||
.extend(callable_binding.overloads()[index].errors.iter().cloned());
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(_) => {
|
||||
if !matches!(
|
||||
callable_binding.overload_call_return_type,
|
||||
Some(OverloadCallReturnType::ArgumentTypeExpansion(_))
|
||||
) {
|
||||
self.errors.extend(
|
||||
callable_binding
|
||||
.overloads()
|
||||
.first()
|
||||
.unwrap()
|
||||
.errors
|
||||
.iter()
|
||||
.cloned(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user