infer arguments of generic calls with declared type context

This commit is contained in:
Ibraheem Ahmed 2025-10-25 00:51:24 -04:00
parent 64ab79e572
commit 72acb7e14e
2 changed files with 53 additions and 8 deletions

View File

@ -476,10 +476,8 @@ def _(i: int):
b: list[int | None] | None = id([i]) b: list[int | None] | None = id([i])
c: list[int | None] | int | None = id([i]) c: list[int | None] | int | None = id([i])
reveal_type(a) # revealed: list[int | None] reveal_type(a) # revealed: list[int | None]
# TODO: these should reveal `list[int | None]` reveal_type(b) # revealed: list[int | None]
# we currently do not use the call expression annotation as type context for argument inference reveal_type(c) # revealed: list[int | None]
reveal_type(b) # revealed: list[Unknown | int]
reveal_type(c) # revealed: list[Unknown | int]
a: list[int | None] | None = [i] a: list[int | None] | None = [i]
b: list[int | None] | None = lst(i) b: list[int | None] | None = lst(i)
@ -495,3 +493,26 @@ def _(i: int):
reveal_type(b) # revealed: list[Unknown] reveal_type(b) # revealed: list[Unknown]
reveal_type(c) # revealed: list[Unknown] reveal_type(c) # revealed: list[Unknown]
``` ```
The function arguments are inferred using the type context:
```py
from typing import TypedDict
class TD(TypedDict):
x: int
def f[T](x: list[T]) -> T:
return x[0]
a: TD = f([{"x": 0}, {"x": 1}])
reveal_type(a) # revealed: TD
b: TD | None = f([{"x": 0}, {"x": 1}])
# TODO: Narrow away the `None` here.
reveal_type(b) # revealed: TD | None
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
# error: [invalid-key] "Invalid key access on TypedDict `TD`: Unknown key "y""
c: TD | None = f([{"y": 0}, {"x": 1}])
```

View File

@ -5535,6 +5535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast_arguments: &ast::Arguments, ast_arguments: &ast::Arguments,
arguments: &mut CallArguments<'a, 'db>, arguments: &mut CallArguments<'a, 'db>,
bindings: &Bindings<'db>, bindings: &Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
) { ) {
debug_assert!( debug_assert!(
ast_arguments.len() == arguments.len() ast_arguments.len() == arguments.len()
@ -5603,10 +5604,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return None; return None;
}; };
let parameter_type = let mut parameter_type =
overload.signature.parameters()[*parameter_index].annotated_type()?; overload.signature.parameters()[*parameter_index].annotated_type()?;
// TODO: For now, skip any parameter annotations that mention any typevars. There // If this is a generic call, attempt to specialize the parameter type using the
// declared type context, if provided.
if let Some(generic_context) = overload.signature.generic_context
&& let Some(return_ty) = overload.signature.return_ty
&& let Some(declared_return_ty) = call_expression_tcx.annotation
{
let mut builder =
SpecializationBuilder::new(db, generic_context.inferable_typevars(db));
let _ = builder.infer(return_ty, declared_return_ty);
let specialization = builder.build(generic_context, call_expression_tcx);
// Note that we are not necessarily "preferring the declared type" here, as the
// type context will only be preferred during the inference of this expression
// by the same heuristics we use for the inference of the outer generic call.
parameter_type = parameter_type.apply_specialization(db, specialization);
}
// TODO: For now, skip any parameter annotations that still mention any typevars. There
// are two issues: // are two issues:
// //
// First, if we include those typevars in the type context that we use to infer the // First, if we include those typevars in the type context that we use to infer the
@ -6820,7 +6839,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let infer_call_arguments = |bindings: Option<Bindings<'db>>| { let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
if let Some(bindings) = bindings { if let Some(bindings) = bindings {
let bindings = bindings.match_parameters(self.db(), &call_arguments); let bindings = bindings.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); self.infer_all_argument_types(
arguments,
&mut call_arguments,
&bindings,
tcx,
);
} else { } else {
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
@ -6841,7 +6865,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let bindings = callable_type let bindings = callable_type
.bindings(self.db()) .bindings(self.db())
.match_parameters(self.db(), &call_arguments); .match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); self.infer_all_argument_types(arguments, &mut call_arguments, &bindings, tcx);
// Validate `TypedDict` constructor calls after argument type inference // Validate `TypedDict` constructor calls after argument type inference
if let Some(class_literal) = callable_type.as_class_literal() { if let Some(class_literal) = callable_type.as_class_literal() {