mirror of https://github.com/astral-sh/ruff
infer arguments of generic calls with declared type context
This commit is contained in:
parent
64ab79e572
commit
72acb7e14e
|
|
@ -476,10 +476,8 @@ def _(i: int):
|
||||||
b: list[int | None] | None = id([i])
|
b: list[int | None] | None = id([i])
|
||||||
c: list[int | None] | int | None = id([i])
|
c: list[int | None] | int | None = id([i])
|
||||||
reveal_type(a) # revealed: list[int | None]
|
reveal_type(a) # revealed: list[int | None]
|
||||||
# TODO: these should reveal `list[int | None]`
|
reveal_type(b) # revealed: list[int | None]
|
||||||
# we currently do not use the call expression annotation as type context for argument inference
|
reveal_type(c) # revealed: list[int | None]
|
||||||
reveal_type(b) # revealed: list[Unknown | int]
|
|
||||||
reveal_type(c) # revealed: list[Unknown | int]
|
|
||||||
|
|
||||||
a: list[int | None] | None = [i]
|
a: list[int | None] | None = [i]
|
||||||
b: list[int | None] | None = lst(i)
|
b: list[int | None] | None = lst(i)
|
||||||
|
|
@ -495,3 +493,26 @@ def _(i: int):
|
||||||
reveal_type(b) # revealed: list[Unknown]
|
reveal_type(b) # revealed: list[Unknown]
|
||||||
reveal_type(c) # revealed: list[Unknown]
|
reveal_type(c) # revealed: list[Unknown]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The function arguments are inferred using the type context:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
class TD(TypedDict):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
def f[T](x: list[T]) -> T:
|
||||||
|
return x[0]
|
||||||
|
|
||||||
|
a: TD = f([{"x": 0}, {"x": 1}])
|
||||||
|
reveal_type(a) # revealed: TD
|
||||||
|
|
||||||
|
b: TD | None = f([{"x": 0}, {"x": 1}])
|
||||||
|
# TODO: Narrow away the `None` here.
|
||||||
|
reveal_type(b) # revealed: TD | None
|
||||||
|
|
||||||
|
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
|
||||||
|
# error: [invalid-key] "Invalid key access on TypedDict `TD`: Unknown key "y""
|
||||||
|
c: TD | None = f([{"y": 0}, {"x": 1}])
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -5535,6 +5535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
ast_arguments: &ast::Arguments,
|
ast_arguments: &ast::Arguments,
|
||||||
arguments: &mut CallArguments<'a, 'db>,
|
arguments: &mut CallArguments<'a, 'db>,
|
||||||
bindings: &Bindings<'db>,
|
bindings: &Bindings<'db>,
|
||||||
|
call_expression_tcx: TypeContext<'db>,
|
||||||
) {
|
) {
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
ast_arguments.len() == arguments.len()
|
ast_arguments.len() == arguments.len()
|
||||||
|
|
@ -5603,10 +5604,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
return None;
|
return None;
|
||||||
};
|
};
|
||||||
|
|
||||||
let parameter_type =
|
let mut parameter_type =
|
||||||
overload.signature.parameters()[*parameter_index].annotated_type()?;
|
overload.signature.parameters()[*parameter_index].annotated_type()?;
|
||||||
|
|
||||||
// TODO: For now, skip any parameter annotations that mention any typevars. There
|
// If this is a generic call, attempt to specialize the parameter type using the
|
||||||
|
// declared type context, if provided.
|
||||||
|
if let Some(generic_context) = overload.signature.generic_context
|
||||||
|
&& let Some(return_ty) = overload.signature.return_ty
|
||||||
|
&& let Some(declared_return_ty) = call_expression_tcx.annotation
|
||||||
|
{
|
||||||
|
let mut builder =
|
||||||
|
SpecializationBuilder::new(db, generic_context.inferable_typevars(db));
|
||||||
|
|
||||||
|
let _ = builder.infer(return_ty, declared_return_ty);
|
||||||
|
let specialization = builder.build(generic_context, call_expression_tcx);
|
||||||
|
|
||||||
|
// Note that we are not necessarily "preferring the declared type" here, as the
|
||||||
|
// type context will only be preferred during the inference of this expression
|
||||||
|
// by the same heuristics we use for the inference of the outer generic call.
|
||||||
|
parameter_type = parameter_type.apply_specialization(db, specialization);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: For now, skip any parameter annotations that still mention any typevars. There
|
||||||
// are two issues:
|
// are two issues:
|
||||||
//
|
//
|
||||||
// First, if we include those typevars in the type context that we use to infer the
|
// First, if we include those typevars in the type context that we use to infer the
|
||||||
|
|
@ -6820,7 +6839,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
|
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
|
||||||
if let Some(bindings) = bindings {
|
if let Some(bindings) = bindings {
|
||||||
let bindings = bindings.match_parameters(self.db(), &call_arguments);
|
let bindings = bindings.match_parameters(self.db(), &call_arguments);
|
||||||
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
|
self.infer_all_argument_types(
|
||||||
|
arguments,
|
||||||
|
&mut call_arguments,
|
||||||
|
&bindings,
|
||||||
|
tcx,
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
|
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
|
||||||
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
|
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
|
||||||
|
|
@ -6841,7 +6865,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
let bindings = callable_type
|
let bindings = callable_type
|
||||||
.bindings(self.db())
|
.bindings(self.db())
|
||||||
.match_parameters(self.db(), &call_arguments);
|
.match_parameters(self.db(), &call_arguments);
|
||||||
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
|
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings, tcx);
|
||||||
|
|
||||||
// Validate `TypedDict` constructor calls after argument type inference
|
// Validate `TypedDict` constructor calls after argument type inference
|
||||||
if let Some(class_literal) = callable_type.as_class_literal() {
|
if let Some(class_literal) = callable_type.as_class_literal() {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue