diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index adf0de358d..dcbcb19f97 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -476,10 +476,8 @@ def _(i: int): b: list[int | None] | None = id([i]) c: list[int | None] | int | None = id([i]) reveal_type(a) # revealed: list[int | None] - # TODO: these should reveal `list[int | None]` - # we currently do not use the call expression annotation as type context for argument inference - reveal_type(b) # revealed: list[Unknown | int] - reveal_type(c) # revealed: list[Unknown | int] + reveal_type(b) # revealed: list[int | None] + reveal_type(c) # revealed: list[int | None] a: list[int | None] | None = [i] b: list[int | None] | None = lst(i) @@ -495,3 +493,26 @@ def _(i: int): reveal_type(b) # revealed: list[Unknown] reveal_type(c) # revealed: list[Unknown] ``` + +The function arguments are inferred using the type context: + +```py +from typing import TypedDict + +class TD(TypedDict): + x: int + +def f[T](x: list[T]) -> T: + return x[0] + +a: TD = f([{"x": 0}, {"x": 1}]) +reveal_type(a) # revealed: TD + +b: TD | None = f([{"x": 0}, {"x": 1}]) +# TODO: Narrow away the `None` here. +reveal_type(b) # revealed: TD | None + +# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor" +# error: [invalid-key] "Invalid key access on TypedDict `TD`: Unknown key "y"" +c: TD | None = f([{"y": 0}, {"x": 1}]) +``` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 9a5091f837..59bed62217 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5535,6 +5535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast_arguments: &ast::Arguments, arguments: &mut CallArguments<'a, 'db>, bindings: &Bindings<'db>, + call_expression_tcx: TypeContext<'db>, ) { debug_assert!( ast_arguments.len() == arguments.len() @@ -5603,10 +5604,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return None; }; - let parameter_type = + let mut parameter_type = overload.signature.parameters()[*parameter_index].annotated_type()?; - // TODO: For now, skip any parameter annotations that mention any typevars. There + // If this is a generic call, attempt to specialize the parameter type using the + // declared type context, if provided. + if let Some(generic_context) = overload.signature.generic_context + && let Some(return_ty) = overload.signature.return_ty + && let Some(declared_return_ty) = call_expression_tcx.annotation + { + let mut builder = + SpecializationBuilder::new(db, generic_context.inferable_typevars(db)); + + let _ = builder.infer(return_ty, declared_return_ty); + let specialization = builder.build(generic_context, call_expression_tcx); + + // Note that we are not necessarily "preferring the declared type" here, as the + // type context will only be preferred during the inference of this expression + // by the same heuristics we use for the inference of the outer generic call. + parameter_type = parameter_type.apply_specialization(db, specialization); + } + + // TODO: For now, skip any parameter annotations that still mention any typevars. There // are two issues: // // First, if we include those typevars in the type context that we use to infer the @@ -6820,7 +6839,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let infer_call_arguments = |bindings: Option>| { if let Some(bindings) = bindings { let bindings = bindings.match_parameters(self.db(), &call_arguments); - self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); + self.infer_all_argument_types( + arguments, + &mut call_arguments, + &bindings, + tcx, + ); } else { let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); @@ -6841,7 +6865,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let bindings = callable_type .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); + self.infer_all_argument_types(arguments, &mut call_arguments, &bindings, tcx); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.as_class_literal() {