From 8ac5f9d8bcd34738a2c93d20b218ca894a93e197 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Mon, 12 Jan 2026 16:05:05 -0500 Subject: [PATCH] [ty] Use key and value parameter types as type context for `__setitem__` dunder calls (#22148) ## Summary Resolves https://github.com/astral-sh/ty/issues/2136. --- Cargo.lock | 1 - crates/ruff_python_ast/Cargo.toml | 1 - crates/ruff_python_ast/src/nodes.rs | 44 +- .../resources/mdtest/bidirectional.md | 27 + .../src/types/infer/builder.rs | 567 +++++++++++------- .../types/infer/builder/type_expression.rs | 4 +- 6 files changed, 413 insertions(+), 231 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b7093422be..edd3008780 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3305,7 +3305,6 @@ dependencies = [ "compact_str", "get-size2", "is-macro", - "itertools 0.14.0", "memchr", "ruff_cache", "ruff_macros", diff --git a/crates/ruff_python_ast/Cargo.toml b/crates/ruff_python_ast/Cargo.toml index b6f0433a56..9a6ea65412 100644 --- a/crates/ruff_python_ast/Cargo.toml +++ b/crates/ruff_python_ast/Cargo.toml @@ -28,7 +28,6 @@ bitflags = { workspace = true } compact_str = { workspace = true } get-size2 = { workspace = true, optional = true } is-macro = { workspace = true } -itertools = { workspace = true } memchr = { workspace = true } rustc-hash = { workspace = true } salsa = { workspace = true, optional = true } diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index 2ea2701f82..f0f7b54cc4 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -14,7 +14,6 @@ use std::slice::{Iter, IterMut}; use std::sync::OnceLock; use bitflags::bitflags; -use itertools::Itertools; use ruff_text_size::{Ranged, TextLen, TextRange, TextSize}; @@ -3380,10 +3379,13 @@ impl Arguments { /// 2 /// {'4': 5} /// ``` - pub fn arguments_source_order(&self) -> impl Iterator> { - let args = self.args.iter().map(ArgOrKeyword::Arg); - let keywords = self.keywords.iter().map(ArgOrKeyword::Keyword); - args.merge_by(keywords, |left, right| left.start() <= right.start()) + pub fn arguments_source_order(&self) -> ArgumentsSourceOrder<'_> { + ArgumentsSourceOrder { + args: &self.args, + keywords: &self.keywords, + next_arg: 0, + next_keyword: 0, + } } pub fn inner_range(&self) -> TextRange { @@ -3399,6 +3401,38 @@ impl Arguments { } } +/// The iterator returned by [`Arguments::arguments_source_order`]. +#[derive(Clone)] +pub struct ArgumentsSourceOrder<'a> { + args: &'a [Expr], + keywords: &'a [Keyword], + next_arg: usize, + next_keyword: usize, +} + +impl<'a> Iterator for ArgumentsSourceOrder<'a> { + type Item = ArgOrKeyword<'a>; + + fn next(&mut self) -> Option { + let arg = self.args.get(self.next_arg); + let keyword = self.keywords.get(self.next_keyword); + + if let Some(arg) = arg + && keyword.is_none_or(|keyword| arg.start() <= keyword.start()) + { + self.next_arg += 1; + Some(ArgOrKeyword::Arg(arg)) + } else if let Some(keyword) = keyword { + self.next_keyword += 1; + Some(ArgOrKeyword::Keyword(keyword)) + } else { + None + } + } +} + +impl FusedIterator for ArgumentsSourceOrder<'_> {} + /// An AST node used to represent a sequence of type parameters. /// /// For example, given: diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 17d31794ab..bc4b08b3ae 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -297,6 +297,33 @@ def _(flag: bool): reveal_type(x2) # revealed: list[int | None] ``` +## Dunder Calls + +The key and value parameters types are used as type context for `__setitem__` dunder calls: + +```py +from typing import TypedDict + +class Bar(TypedDict): + baz: float + +def _(x: dict[str, Bar]): + x["foo"] = reveal_type({"baz": 2}) # revealed: Bar + +class X: + def __setitem__(self, key: Bar, value: Bar): ... + +def _(x: X): + # revealed: Bar + x[reveal_type({"baz": 1})] = reveal_type({"baz": 2}) # revealed: Bar + +# TODO: Support type context with union subscripting. +def _(x: X | dict[Bar, Bar]): + # error: [invalid-assignment] + # error: [invalid-assignment] + x[{"baz": 1}] = {"baz": 2} +``` + ## Multi-inference diagnostics ```toml diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 9abed64628..d819726c18 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -7,7 +7,8 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_db::source::source_text; use ruff_python_ast::visitor::{Visitor, walk_expr}; use ruff_python_ast::{ - self as ast, AnyNodeRef, ExprContext, HasNodeIndex, NodeIndex, PythonVersion, + self as ast, AnyNodeRef, ArgOrKeyword, ArgumentsSourceOrder, ExprContext, HasNodeIndex, + NodeIndex, PythonVersion, }; use ruff_python_stdlib::builtins::version_builtin_was_added; use ruff_text_size::{Ranged, TextRange}; @@ -3203,7 +3204,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for item in items { let target = item.optional_vars.as_deref(); if let Some(target) = target { - self.infer_target(target, &item.context_expr, |builder, tcx| { + self.infer_target(target, &item.context_expr, &|builder, tcx| { // TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `with not_context_manager as a.x: ... @@ -3932,7 +3933,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = assignment; for target in targets { - self.infer_target(target, value, |builder, tcx| { + self.infer_target(target, value, &|builder, tcx| { builder.infer_standalone_expression(value, tcx) }); } @@ -3947,20 +3948,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// The `infer_value_expr` function is used to infer the type of the `value` expression which /// are not `Name` expressions. The returned type is the one that is eventually assigned to the /// `target`. - fn infer_target(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F) - where - F: Fn(&mut Self, TypeContext<'db>) -> Type<'db>, - { + fn infer_target( + &mut self, + target: &ast::Expr, + value: &ast::Expr, + infer_value_expr: &dyn Fn(&mut Self, TypeContext<'db>) -> Type<'db>, + ) { match target { ast::Expr::Name(_) => { self.infer_target_impl(target, value, None); } - _ => self.infer_target_impl( - target, - value, - Some(&|builder, tcx| infer_value_expr(builder, tcx)), - ), + _ => self.infer_target_impl(target, value, Some(&infer_value_expr)), } } @@ -3969,7 +3968,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &mut self, target: &ast::ExprSubscript, rhs_value: &ast::Expr, - rhs_value_ty: Type<'db>, + infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, ) -> bool { let ast::ExprSubscript { range: _, @@ -3980,28 +3979,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = target; let object_ty = self.infer_expression(object, TypeContext::default()); - let slice_ty = self.infer_expression(slice, TypeContext::default()); + let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx); self.validate_subscript_assignment_impl( target, None, object_ty, - slice_ty, + &mut infer_slice_ty, rhs_value, - rhs_value_ty, + infer_rhs_value, true, ) } #[expect(clippy::too_many_arguments)] fn validate_subscript_assignment_impl( - &self, - target: &'ast ast::ExprSubscript, + &mut self, + target: &ast::ExprSubscript, full_object_ty: Option>, object_ty: Type<'db>, - slice_ty: Type<'db>, - rhs_value_node: &'ast ast::Expr, - rhs_value_ty: Type<'db>, + infer_slice_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, + rhs_value_node: &ast::Expr, + infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, emit_diagnostic: bool, ) -> bool { /// Given a string literal or a union of string literals, return an iterator over the contained @@ -4037,6 +4036,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match object_ty { Type::Union(union) => { + // TODO: Perform multi-inference here. + let slice_ty = infer_slice_ty(self, TypeContext::default()); + let rhs_value_ty = infer_rhs_value(self, TypeContext::default()); + // Note that we use a loop here instead of .all(…) to avoid short-circuiting. // We need to keep iterating to emit all diagnostics. let mut valid = true; @@ -4045,9 +4048,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { target, full_object_ty.or(Some(object_ty)), *element_ty, - slice_ty, + &mut |_, _| slice_ty, rhs_value_node, - rhs_value_ty, + &mut |_, _| rhs_value_ty, emit_diagnostic, ); } @@ -4055,16 +4058,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } Type::Intersection(intersection) => { - let check_positive_elements = |emit_diagnostic_and_short_circuit| { + // TODO: Perform multi-inference here. + let slice_ty = infer_slice_ty(self, TypeContext::default()); + let rhs_value_ty = infer_rhs_value(self, TypeContext::default()); + + let mut check_positive_elements = |emit_diagnostic_and_short_circuit| { let mut valid = false; for element_ty in intersection.positive(db) { valid |= self.validate_subscript_assignment_impl( target, full_object_ty.or(Some(object_ty)), *element_ty, - slice_ty, + &mut |_, _| slice_ty, rhs_value_node, - rhs_value_ty, + &mut |_, _| rhs_value_ty, emit_diagnostic_and_short_circuit, ); @@ -4092,6 +4099,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // As an optimization, prevent calling `__setitem__` on (unions of) large `TypedDict`s, and // validate the assignment ourselves. This also allows us to emit better diagnostics. + // TODO: Use type context here. + let slice_ty = infer_slice_ty(self, TypeContext::default()); + let rhs_value_ty = infer_rhs_value(self, TypeContext::default()); + let mut valid = true; let Some(keys) = key_literals(db, slice_ty) else { // Check if the key has a valid type. We only allow string literals, a union of string literals, @@ -4155,168 +4166,190 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } _ => { - match object_ty.try_call_dunder( - db, - "__setitem__", - CallArguments::positional([slice_ty, rhs_value_ty]), - TypeContext::default(), - ) { - Ok(_) => true, - Err(err) => match err { - CallDunderError::PossiblyUnbound { .. } => { - if emit_diagnostic - && let Some(builder) = self - .context - .report_lint(&POSSIBLY_MISSING_IMPLICIT_CALL, target) - { - let mut diagnostic = builder.into_diagnostic(format_args!( - "Method `__setitem__` of type `{}` may be missing", - object_ty.display(db), - )); - attach_original_type_info(&mut diagnostic); - } - false - } - CallDunderError::CallError(call_error_kind, bindings) => { - match call_error_kind { - CallErrorKind::NotCallable => { - if emit_diagnostic - && let Some(builder) = - self.context.report_lint(&CALL_NON_CALLABLE, target) - { - let mut diagnostic = builder.into_diagnostic(format_args!( - "Method `__setitem__` of type `{}` is not callable \ - on object of type `{}`", - bindings.callable_type().display(db), - object_ty.display(db), - )); - attach_original_type_info(&mut diagnostic); - } - } - CallErrorKind::BindingError => { - if let Some(typed_dict) = object_ty.as_typed_dict() { - if let Some(key) = slice_ty.as_string_literal() { - let key = key.value(db); - validate_typed_dict_key_assignment( - &self.context, - typed_dict, - full_object_ty, - key, - rhs_value_ty, - target.value.as_ref(), - target.slice.as_ref(), - rhs_value_node, - TypedDictAssignmentKind::Subscript, - true, - ); - } - } else { - if emit_diagnostic - && let Some(builder) = self.context.report_lint( - &INVALID_ASSIGNMENT, - target.range.cover(rhs_value_node.range()), - ) - { - let assigned_d = rhs_value_ty.display(db); - let object_d = object_ty.display(db); + let ast_arguments = [ + ArgOrKeyword::Arg(&target.slice), + ArgOrKeyword::Arg(rhs_value_node), + ]; - let mut diagnostic = builder.into_diagnostic(format_args!( + let mut call_arguments = + CallArguments::positional([Type::unknown(), Type::unknown()]); + + let mut infer_argument_ty = + |builder: &mut Self, (argument_index, _, tcx): ArgExpr<'db, '_>| { + match argument_index { + 0 => infer_slice_ty(builder, tcx), + 1 => infer_rhs_value(builder, tcx), + _ => unreachable!(), + } + }; + + let Err(call_dunder_err) = self.infer_and_try_call_dunder( + db, + object_ty, + "__setitem__", + ArgumentsIter::synthesized(&ast_arguments), + &mut call_arguments, + &mut infer_argument_ty, + TypeContext::default(), + ) else { + return true; + }; + + let [Some(slice_ty), Some(rhs_value_ty)] = call_arguments.types() else { + unreachable!(); + }; + + match call_dunder_err { + CallDunderError::PossiblyUnbound { .. } => { + if emit_diagnostic + && let Some(builder) = self + .context + .report_lint(&POSSIBLY_MISSING_IMPLICIT_CALL, target) + { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Method `__setitem__` of type `{}` may be missing", + object_ty.display(db), + )); + attach_original_type_info(&mut diagnostic); + } + false + } + CallDunderError::CallError(call_error_kind, bindings) => { + match call_error_kind { + CallErrorKind::NotCallable => { + if emit_diagnostic + && let Some(builder) = + self.context.report_lint(&CALL_NON_CALLABLE, target) + { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Method `__setitem__` of type `{}` is not callable \ + on object of type `{}`", + bindings.callable_type().display(db), + object_ty.display(db), + )); + attach_original_type_info(&mut diagnostic); + } + } + CallErrorKind::BindingError => { + if let Some(typed_dict) = object_ty.as_typed_dict() { + if let Some(key) = slice_ty.as_string_literal() { + let key = key.value(db); + validate_typed_dict_key_assignment( + &self.context, + typed_dict, + full_object_ty, + key, + *rhs_value_ty, + target.value.as_ref(), + target.slice.as_ref(), + rhs_value_node, + TypedDictAssignmentKind::Subscript, + true, + ); + } + } else { + if emit_diagnostic + && let Some(builder) = self.context.report_lint( + &INVALID_ASSIGNMENT, + target.range.cover(rhs_value_node.range()), + ) + { + let assigned_d = rhs_value_ty.display(db); + let object_d = object_ty.display(db); + + let mut diagnostic = builder.into_diagnostic(format_args!( "Invalid subscript assignment with key of type `{}` and value of \ type `{assigned_d}` on object of type `{object_d}`", slice_ty.display(db), )); - // Special diagnostic for dictionaries - if let Some([expected_key_ty, expected_value_ty]) = - object_ty - .known_specialization(db, KnownClass::Dict) - .map(|s| s.types(db)) - { - if !slice_ty.is_assignable_to(db, *expected_key_ty) - { - diagnostic.annotate( - self.context - .secondary(target.slice.as_ref()) - .message(format_args!( - "Expected key of type `{}`, got `{}`", - expected_key_ty.display(db), - slice_ty.display(db), - )), - ); - } - - if !rhs_value_ty - .is_assignable_to(db, *expected_value_ty) - { - diagnostic.annotate( - self.context - .secondary(rhs_value_node) - .message(format_args!( - "Expected value of type `{}`, got `{}`", - expected_value_ty.display(db), - rhs_value_ty.display(db), - )), - ); - } + // Special diagnostic for dictionaries + if let Some([expected_key_ty, expected_value_ty]) = + object_ty + .known_specialization(db, KnownClass::Dict) + .map(|s| s.types(db)) + { + if !slice_ty.is_assignable_to(db, *expected_key_ty) { + diagnostic.annotate( + self.context + .secondary(target.slice.as_ref()) + .message(format_args!( + "Expected key of type `{}`, got `{}`", + expected_key_ty.display(db), + slice_ty.display(db), + )), + ); } - attach_original_type_info(&mut diagnostic); + if !rhs_value_ty + .is_assignable_to(db, *expected_value_ty) + { + diagnostic.annotate( + self.context.secondary(rhs_value_node).message( + format_args!( + "Expected value of type `{}`, got `{}`", + expected_value_ty.display(db), + rhs_value_ty.display(db), + ), + ), + ); + } } - } - } - CallErrorKind::PossiblyNotCallable => { - if emit_diagnostic - && let Some(builder) = - self.context.report_lint(&CALL_NON_CALLABLE, target) - { - let mut diagnostic = builder.into_diagnostic(format_args!( - "Method `__setitem__` of type `{}` may not be callable on object of type `{}`", - bindings.callable_type().display(db), - object_ty.display(db), - )); + attach_original_type_info(&mut diagnostic); } } } - false - } - CallDunderError::MethodNotAvailable => { - if emit_diagnostic - && let Some(builder) = - self.context.report_lint(&INVALID_ASSIGNMENT, target) - { - let mut diagnostic = builder.into_diagnostic(format_args!( - "Cannot assign to a subscript on an object of type `{}`", - object_ty.display(db), - )); - attach_original_type_info(&mut diagnostic); - - // If it's a user-defined class, suggest adding a `__setitem__` method. - if object_ty - .as_nominal_instance() - .and_then(|instance| { - instance.class(db).static_class_literal(db) - }) - .and_then(|(class_literal, _)| { - file_to_module(db, class_literal.file(db)) - }) - .and_then(|module| module.search_path(db)) - .is_some_and(ty_module_resolver::SearchPath::is_first_party) + CallErrorKind::PossiblyNotCallable => { + if emit_diagnostic + && let Some(builder) = + self.context.report_lint(&CALL_NON_CALLABLE, target) { - diagnostic.help(format_args!( - "Consider adding a `__setitem__` method to `{}`.", - object_ty.display(db), - )); - } else { - diagnostic.info(format_args!( - "`{}` does not have a `__setitem__` method.", - object_ty.display(db), - )); + let mut diagnostic = builder.into_diagnostic(format_args!( + "Method `__setitem__` of type `{}` may not be callable on object of type `{}`", + bindings.callable_type().display(db), + object_ty.display(db), + )); + attach_original_type_info(&mut diagnostic); } } - false } - }, + false + } + CallDunderError::MethodNotAvailable => { + if emit_diagnostic + && let Some(builder) = + self.context.report_lint(&INVALID_ASSIGNMENT, target) + { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Cannot assign to a subscript on an object of type `{}`", + object_ty.display(db), + )); + attach_original_type_info(&mut diagnostic); + + // If it's a user-defined class, suggest adding a `__setitem__` method. + if object_ty + .as_nominal_instance() + .and_then(|instance| instance.class(db).static_class_literal(db)) + .and_then(|(class_literal, _)| { + file_to_module(db, class_literal.file(db)) + }) + .and_then(|module| module.search_path(db)) + .is_some_and(ty_module_resolver::SearchPath::is_first_party) + { + diagnostic.help(format_args!( + "Consider adding a `__setitem__` method to `{}`.", + object_ty.display(db), + )); + } else { + diagnostic.info(format_args!( + "`{}` does not have a `__setitem__` method.", + object_ty.display(db), + )); + } + } + false + } } } } @@ -5315,11 +5348,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } ast::Expr::Subscript(subscript_expr) => { - let assigned_ty = infer_assigned_ty.map(|f| f(self, TypeContext::default())); - self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown())); + if let Some(infer_assigned_ty) = infer_assigned_ty { + let infer_assigned_ty = &mut |builder: &mut Self, tcx| { + let assigned_ty = infer_assigned_ty(builder, tcx); + builder.store_expression_type(target, assigned_ty); + assigned_ty + }; - if let Some(assigned_ty) = assigned_ty { - self.validate_subscript_assignment(subscript_expr, value, assigned_ty); + self.validate_subscript_assignment(subscript_expr, value, infer_assigned_ty); } } @@ -6638,7 +6674,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { is_async: _, } = for_statement; - self.infer_target(target, iter, |builder, tcx| { + self.infer_target(target, iter, &|builder, tcx| { // TODO: `infer_for_statement_definition` reports a diagnostic if `iter_ty` isn't iterable // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `for a.x in not_iterable: ... @@ -7511,10 +7547,54 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } + #[expect(clippy::too_many_arguments)] + fn infer_and_try_call_dunder( + &mut self, + db: &'db dyn Db, + object: Type<'db>, + name: &str, + ast_arguments: ArgumentsIter<'_>, + argument_types: &mut CallArguments<'_, 'db>, + infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>, + call_expression_tcx: TypeContext<'db>, + ) -> Result, CallDunderError<'db>> { + match object + .member_lookup_with_policy(db, name.into(), MemberLookupPolicy::NO_INSTANCE_FALLBACK) + .place + { + Place::Defined(DefinedPlace { + ty: dunder_callable, + definedness: boundness, + .. + }) => { + let mut bindings = dunder_callable + .bindings(db) + .match_parameters(db, argument_types); + + if let Err(call_error) = self.infer_and_check_argument_types( + ast_arguments, + argument_types, + infer_argument_ty, + &mut bindings, + call_expression_tcx, + ) { + return Err(CallDunderError::CallError(call_error, Box::new(bindings))); + } + + if boundness == Definedness::PossiblyUndefined { + return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); + } + Ok(bindings) + } + Place::Undefined => Err(CallDunderError::MethodNotAvailable), + } + } + fn infer_and_check_argument_types( &mut self, - ast_arguments: &ast::Arguments, + ast_arguments: ArgumentsIter<'_>, argument_types: &mut CallArguments<'_, 'db>, + infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>, bindings: &mut Bindings<'db>, call_expression_tcx: TypeContext<'db>, ) -> Result<(), CallErrorKind> { @@ -7546,8 +7626,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Attempt to infer the argument types using the narrowed type context. self.infer_all_argument_types( - ast_arguments, + ast_arguments.clone(), argument_types, + infer_argument_ty, bindings, narrowed_tcx, MultiInferenceState::Ignore, @@ -7586,8 +7667,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.context.set_multi_inference(was_in_multi_inference); self.infer_all_argument_types( - ast_arguments, + ast_arguments.clone(), argument_types, + infer_argument_ty, bindings, narrowed_tcx, MultiInferenceState::Intersect, @@ -7631,6 +7713,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_all_argument_types( ast_arguments, argument_types, + infer_argument_ty, bindings, call_expression_tcx, MultiInferenceState::Intersect, @@ -7651,13 +7734,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// behavior. fn infer_all_argument_types( &mut self, - ast_arguments: &ast::Arguments, + ast_arguments: ArgumentsIter<'_>, arguments_types: &mut CallArguments<'_, 'db>, + infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>, bindings: &Bindings<'db>, call_expression_tcx: TypeContext<'db>, multi_inference_state: MultiInferenceState, ) { - debug_assert_eq!(ast_arguments.len(), arguments_types.len()); debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len()); let db = self.db(); @@ -7665,7 +7748,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { 0.., arguments_types.iter_mut(), bindings.argument_forms().iter().copied(), - ast_arguments.arguments_source_order() + ast_arguments ); let overloads_with_binding = bindings @@ -7774,14 +7857,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // If there is only a single binding and overload, we can infer the argument directly with // the unique parameter type annotation. if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() { - *argument_type = Some(self.infer_expression( - ast_argument, - TypeContext::new(parameter_type(overload, binding)), - )); + let tcx = TypeContext::new(parameter_type(overload, binding)); + *argument_type = Some(infer_argument_ty(self, (argument_index, ast_argument, tcx))); } else { // We perform inference once without any type context, emitting any diagnostics that are unrelated // to bidirectional type inference. - *argument_type = Some(self.infer_expression(ast_argument, TypeContext::default())); + *argument_type = Some(infer_argument_ty( + self, + (argument_index, ast_argument, TypeContext::default()), + )); // We then silence any diagnostics emitted during multi-inference, as the type context is only // used as a hint to infer a more assignable argument type, and should not lead to diagnostics @@ -7799,8 +7883,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if !seen.insert(parameter_type) { continue; } - let inferred_ty = - self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type))); + + let tcx = TypeContext::new(Some(parameter_type)); + let inferred_ty = infer_argument_ty(self, (argument_index, ast_argument, tcx)); // Ensure the inferred type is assignable to the declared type. // @@ -8244,9 +8329,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ctx: _, } = list; - let elts = elts.iter().map(|elt| [Some(elt)]); - let infer_elt_ty = |builder: &mut Self, elt, tcx| builder.infer_expression(elt, tcx); - self.infer_collection_literal(elts, tcx, infer_elt_ty, KnownClass::List) + let mut elts = elts.iter().map(|elt| [Some(elt)]); + let infer_elt_ty = + &mut |builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx); + self.infer_collection_literal(&mut elts, tcx, infer_elt_ty, KnownClass::List) .unwrap_or_else(|| { KnownClass::List.to_specialized_instance(self.db(), &[Type::unknown()]) }) @@ -8259,9 +8345,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { elts, } = set; - let elts = elts.iter().map(|elt| [Some(elt)]); - let infer_elt_ty = |builder: &mut Self, elt, tcx| builder.infer_expression(elt, tcx); - self.infer_collection_literal(elts, tcx, infer_elt_ty, KnownClass::Set) + let mut elts = elts.iter().map(|elt| [Some(elt)]); + let infer_elt_ty = + &mut |builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx); + self.infer_collection_literal(&mut elts, tcx, infer_elt_ty, KnownClass::Set) .unwrap_or_else(|| { KnownClass::Set.to_specialized_instance(self.db(), &[Type::unknown()]) }) @@ -8293,21 +8380,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()]); } - let items = items + let mut items = items .iter() .map(|item| [item.key.as_ref(), Some(&item.value)]); // Avoid inferring the items multiple times if we already attempted to infer the // dictionary literal as a `TypedDict`. This also allows us to infer using the // type context of the expected `TypedDict` field. - let infer_elt_ty = |builder: &mut Self, elt: &ast::Expr, tcx| { + let infer_elt_ty = &mut |builder: &mut Self, (_, elt, tcx): ArgExpr<'db, '_>| { item_types .get(&elt.node_index().load()) .copied() .unwrap_or_else(|| builder.infer_expression(elt, tcx)) }; - self.infer_collection_literal(items, tcx, infer_elt_ty, KnownClass::Dict) + self.infer_collection_literal(&mut items, tcx, infer_elt_ty, KnownClass::Dict) .unwrap_or_else(|| { KnownClass::Dict .to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()]) @@ -8356,17 +8443,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // Infer the type of a collection literal expression. - fn infer_collection_literal<'expr, const N: usize, F, I>( + fn infer_collection_literal<'expr, const N: usize>( &mut self, - elts: I, + elts: &mut dyn Iterator; N]>, tcx: TypeContext<'db>, - mut infer_elt_expression: F, + infer_elt_expression: &mut dyn FnMut(&mut Self, ArgExpr<'db, 'expr>) -> Type<'db>, collection_class: KnownClass, - ) -> Option> - where - I: Iterator; N]>, - F: FnMut(&mut Self, &'expr ast::Expr, TypeContext<'db>) -> Type<'db>, - { + ) -> Option> { // Extract the type variable `T` from `list[T]` in typeshed. let elt_tys = |collection_class: KnownClass| { let collection_alias = collection_class @@ -8388,8 +8471,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let Some((collection_alias, generic_context, elt_tys)) = elt_tys(collection_class) else { // Infer the element types without type context, and fallback to unknown for // custom typesheds. - for elt in elts.flatten().flatten() { - infer_elt_expression(self, elt, TypeContext::default()); + for (i, elt) in elts.flatten().flatten().enumerate() { + infer_elt_expression(self, (i, elt, TypeContext::default())); } return None; @@ -8483,7 +8566,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for elts in elts { // An unpacking expression for a dictionary. if let &[None, Some(value)] = elts.as_slice() { - let inferred_value_ty = infer_elt_expression(self, value, TypeContext::default()); + let inferred_value_ty = + infer_elt_expression(self, (1, value, TypeContext::default())); // Merge the inferred type of the nested dictionary. if let Some(specialization) = @@ -8502,7 +8586,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // The inferred type of each element acts as an additional constraint on `T`. - for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) { + for (i, elt, elt_ty) in itertools::izip!(0.., elts, elt_tys.clone()) { let Some(elt) = elt else { continue }; // Note that unlike when preferring the declared type, we use covariant type @@ -8511,7 +8595,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let elt_ty_identity = elt_ty.identity(self.db()); let elt_tcx = elt_tcx_constraints.get(&elt_ty_identity).copied(); - let inferred_elt_ty = infer_elt_expression(self, elt, TypeContext::new(elt_tcx)); + let inferred_elt_ty = + infer_elt_expression(self, (i, elt, TypeContext::new(elt_tcx))); // Simplify the inference based on a non-covariant declared type. if let Some(elt_tcx) = @@ -8789,7 +8874,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { is_async: _, } = comprehension; - self.infer_target(target, iter, |builder, tcx| { + self.infer_target(target, iter, &|builder, tcx| { // TODO: `infer_comprehension_definition` reports a diagnostic if `iter_ty` isn't iterable // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `[... for a.x in not_iterable] @@ -9033,7 +9118,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &mut self, call_expression: &ast::ExprCall, callable_type: Type<'db>, - tcx: TypeContext<'db>, + call_expression_tcx: TypeContext<'db>, ) -> Type<'db> { let ast::ExprCall { range: _, @@ -9044,7 +9129,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Fast-path dict(...) in TypedDict context: infer keyword values against fields, // then validate and return the TypedDict type. - if let Some(tcx) = tcx.annotation + if let Some(tcx) = call_expression_tcx.annotation && let Some(typed_dict) = tcx .filter_union(self.db(), Type::is_typed_dict) .as_typed_dict() @@ -9281,10 +9366,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(bindings) = bindings { let bindings = bindings.match_parameters(self.db(), &call_arguments); self.infer_all_argument_types( - arguments, + ArgumentsIter::from_ast(arguments), &mut call_arguments, + &mut |builder, (_, expr, tcx)| builder.infer_expression(expr, tcx), &bindings, - tcx, + call_expression_tcx, MultiInferenceState::Intersect, ); } else { @@ -9296,7 +9382,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; return callable_type - .try_call_constructor(db, infer_call_arguments, tcx) + .try_call_constructor(db, infer_call_arguments, call_expression_tcx) .unwrap_or_else(|err| { err.report_diagnostic(&self.context, callable_type, call_expression.into()); err.return_type() @@ -9308,8 +9394,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - let bindings_result = - self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx); + let bindings_result = self.infer_and_check_argument_types( + ArgumentsIter::from_ast(arguments), + &mut call_arguments, + &mut |builder, (_, expr, tcx)| builder.infer_expression(expr, tcx), + &mut bindings, + call_expression_tcx, + ); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.as_class_literal() { @@ -12541,7 +12632,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { generic_context: GenericContext<'db>, ) -> Type<'db> { let db = self.db(); - let specialize = |types: &[Option>]| { + let specialize = &|types: &[Option>]| { Type::from(generic_class.apply_specialization(db, |_| { generic_context.specialize_partial(db, types.iter().copied()) })) @@ -12563,7 +12654,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { generic_context: GenericContext<'db>, ) -> Type<'db> { let db = self.db(); - let specialize = |types: &[Option>]| { + let specialize = &|types: &[Option>]| { let type_alias = generic_type_alias.apply_specialization(db, |_| { generic_context.specialize_partial(db, types.iter().copied()) }); @@ -12584,7 +12675,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { subscript: &ast::ExprSubscript, value_ty: Type<'db>, generic_context: GenericContext<'db>, - specialize: impl FnOnce(&[Option>]) -> Type<'db>, + specialize: &dyn Fn(&[Option>]) -> Type<'db>, ) -> Type<'db> { enum ExplicitSpecializationError { InvalidParamSpec, @@ -13577,6 +13668,38 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } +/// An expression representing the function argument at the given index, along with its type +/// context. +type ArgExpr<'db, 'ast> = (usize, &'ast ast::Expr, TypeContext<'db>); + +/// An iterator over arguments to a functional call. +#[derive(Clone)] +enum ArgumentsIter<'a> { + FromAst(ArgumentsSourceOrder<'a>), + Synthesized(std::slice::Iter<'a, ArgOrKeyword<'a>>), +} + +impl<'a> ArgumentsIter<'a> { + fn from_ast(arguments: &'a ast::Arguments) -> Self { + Self::FromAst(arguments.arguments_source_order()) + } + + fn synthesized(arguments: &'a [ArgOrKeyword<'a>]) -> Self { + Self::Synthesized(arguments.iter()) + } +} + +impl<'a> Iterator for ArgumentsIter<'a> { + type Item = ArgOrKeyword<'a>; + + fn next(&mut self) -> Option { + match self { + ArgumentsIter::FromAst(args) => args.next(), + ArgumentsIter::Synthesized(args) => args.next().copied(), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum GenericContextError { /// It's invalid to subscript `Generic` or `Protocol` with this type diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index c7a59e0986..a63f94443d 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -710,7 +710,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { match class_literal.generic_context(self.db()) { Some(generic_context) => { let db = self.db(); - let specialize = |types: &[Option>]| { + let specialize = &|types: &[Option>]| { SubclassOfType::from( db, class_literal.apply_specialization(db, |_| { @@ -805,7 +805,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { }; } - let specialize = |types: &[Option>]| { + let specialize = &|types: &[Option>]| { let specialized = value_ty.apply_specialization( db, generic_context.specialize_partial(db, types.iter().copied()),