[ty] Use key and value parameter types as type context for __setitem__ dunder calls (#22148)

## Summary

Resolves https://github.com/astral-sh/ty/issues/2136.
This commit is contained in:
Ibraheem Ahmed
2026-01-12 16:05:05 -05:00
committed by GitHub
parent 4abc5fe2f1
commit 8ac5f9d8bc
6 changed files with 413 additions and 231 deletions

1
Cargo.lock generated
View File

@@ -3305,7 +3305,6 @@ dependencies = [
"compact_str",
"get-size2",
"is-macro",
"itertools 0.14.0",
"memchr",
"ruff_cache",
"ruff_macros",

View File

@@ -28,7 +28,6 @@ bitflags = { workspace = true }
compact_str = { workspace = true }
get-size2 = { workspace = true, optional = true }
is-macro = { workspace = true }
itertools = { workspace = true }
memchr = { workspace = true }
rustc-hash = { workspace = true }
salsa = { workspace = true, optional = true }

View File

@@ -14,7 +14,6 @@ use std::slice::{Iter, IterMut};
use std::sync::OnceLock;
use bitflags::bitflags;
use itertools::Itertools;
use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
@@ -3380,10 +3379,13 @@ impl Arguments {
/// 2
/// {'4': 5}
/// ```
pub fn arguments_source_order(&self) -> impl Iterator<Item = ArgOrKeyword<'_>> {
let args = self.args.iter().map(ArgOrKeyword::Arg);
let keywords = self.keywords.iter().map(ArgOrKeyword::Keyword);
args.merge_by(keywords, |left, right| left.start() <= right.start())
pub fn arguments_source_order(&self) -> ArgumentsSourceOrder<'_> {
ArgumentsSourceOrder {
args: &self.args,
keywords: &self.keywords,
next_arg: 0,
next_keyword: 0,
}
}
pub fn inner_range(&self) -> TextRange {
@@ -3399,6 +3401,38 @@ impl Arguments {
}
}
/// The iterator returned by [`Arguments::arguments_source_order`].
#[derive(Clone)]
pub struct ArgumentsSourceOrder<'a> {
args: &'a [Expr],
keywords: &'a [Keyword],
next_arg: usize,
next_keyword: usize,
}
impl<'a> Iterator for ArgumentsSourceOrder<'a> {
type Item = ArgOrKeyword<'a>;
fn next(&mut self) -> Option<Self::Item> {
let arg = self.args.get(self.next_arg);
let keyword = self.keywords.get(self.next_keyword);
if let Some(arg) = arg
&& keyword.is_none_or(|keyword| arg.start() <= keyword.start())
{
self.next_arg += 1;
Some(ArgOrKeyword::Arg(arg))
} else if let Some(keyword) = keyword {
self.next_keyword += 1;
Some(ArgOrKeyword::Keyword(keyword))
} else {
None
}
}
}
impl FusedIterator for ArgumentsSourceOrder<'_> {}
/// An AST node used to represent a sequence of type parameters.
///
/// For example, given:

View File

@@ -297,6 +297,33 @@ def _(flag: bool):
reveal_type(x2) # revealed: list[int | None]
```
## Dunder Calls
The key and value parameters types are used as type context for `__setitem__` dunder calls:
```py
from typing import TypedDict
class Bar(TypedDict):
baz: float
def _(x: dict[str, Bar]):
x["foo"] = reveal_type({"baz": 2}) # revealed: Bar
class X:
def __setitem__(self, key: Bar, value: Bar): ...
def _(x: X):
# revealed: Bar
x[reveal_type({"baz": 1})] = reveal_type({"baz": 2}) # revealed: Bar
# TODO: Support type context with union subscripting.
def _(x: X | dict[Bar, Bar]):
# error: [invalid-assignment]
# error: [invalid-assignment]
x[{"baz": 1}] = {"baz": 2}
```
## Multi-inference diagnostics
```toml

View File

@@ -7,7 +7,8 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_db::source::source_text;
use ruff_python_ast::visitor::{Visitor, walk_expr};
use ruff_python_ast::{
self as ast, AnyNodeRef, ExprContext, HasNodeIndex, NodeIndex, PythonVersion,
self as ast, AnyNodeRef, ArgOrKeyword, ArgumentsSourceOrder, ExprContext, HasNodeIndex,
NodeIndex, PythonVersion,
};
use ruff_python_stdlib::builtins::version_builtin_was_added;
use ruff_text_size::{Ranged, TextRange};
@@ -3203,7 +3204,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for item in items {
let target = item.optional_vars.as_deref();
if let Some(target) = target {
self.infer_target(target, &item.context_expr, |builder, tcx| {
self.infer_target(target, &item.context_expr, &|builder, tcx| {
// TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `with not_context_manager as a.x: ...
@@ -3932,7 +3933,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = assignment;
for target in targets {
self.infer_target(target, value, |builder, tcx| {
self.infer_target(target, value, &|builder, tcx| {
builder.infer_standalone_expression(value, tcx)
});
}
@@ -3947,20 +3948,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// The `infer_value_expr` function is used to infer the type of the `value` expression which
/// are not `Name` expressions. The returned type is the one that is eventually assigned to the
/// `target`.
fn infer_target<F>(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F)
where
F: Fn(&mut Self, TypeContext<'db>) -> Type<'db>,
{
fn infer_target(
&mut self,
target: &ast::Expr,
value: &ast::Expr,
infer_value_expr: &dyn Fn(&mut Self, TypeContext<'db>) -> Type<'db>,
) {
match target {
ast::Expr::Name(_) => {
self.infer_target_impl(target, value, None);
}
_ => self.infer_target_impl(
target,
value,
Some(&|builder, tcx| infer_value_expr(builder, tcx)),
),
_ => self.infer_target_impl(target, value, Some(&infer_value_expr)),
}
}
@@ -3969,7 +3968,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&mut self,
target: &ast::ExprSubscript,
rhs_value: &ast::Expr,
rhs_value_ty: Type<'db>,
infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
) -> bool {
let ast::ExprSubscript {
range: _,
@@ -3980,28 +3979,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = target;
let object_ty = self.infer_expression(object, TypeContext::default());
let slice_ty = self.infer_expression(slice, TypeContext::default());
let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx);
self.validate_subscript_assignment_impl(
target,
None,
object_ty,
slice_ty,
&mut infer_slice_ty,
rhs_value,
rhs_value_ty,
infer_rhs_value,
true,
)
}
#[expect(clippy::too_many_arguments)]
fn validate_subscript_assignment_impl(
&self,
target: &'ast ast::ExprSubscript,
&mut self,
target: &ast::ExprSubscript,
full_object_ty: Option<Type<'db>>,
object_ty: Type<'db>,
slice_ty: Type<'db>,
rhs_value_node: &'ast ast::Expr,
rhs_value_ty: Type<'db>,
infer_slice_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
rhs_value_node: &ast::Expr,
infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
emit_diagnostic: bool,
) -> bool {
/// Given a string literal or a union of string literals, return an iterator over the contained
@@ -4037,6 +4036,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match object_ty {
Type::Union(union) => {
// TODO: Perform multi-inference here.
let slice_ty = infer_slice_ty(self, TypeContext::default());
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
// Note that we use a loop here instead of .all(…) to avoid short-circuiting.
// We need to keep iterating to emit all diagnostics.
let mut valid = true;
@@ -4045,9 +4048,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target,
full_object_ty.or(Some(object_ty)),
*element_ty,
slice_ty,
&mut |_, _| slice_ty,
rhs_value_node,
rhs_value_ty,
&mut |_, _| rhs_value_ty,
emit_diagnostic,
);
}
@@ -4055,16 +4058,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
Type::Intersection(intersection) => {
let check_positive_elements = |emit_diagnostic_and_short_circuit| {
// TODO: Perform multi-inference here.
let slice_ty = infer_slice_ty(self, TypeContext::default());
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
let mut check_positive_elements = |emit_diagnostic_and_short_circuit| {
let mut valid = false;
for element_ty in intersection.positive(db) {
valid |= self.validate_subscript_assignment_impl(
target,
full_object_ty.or(Some(object_ty)),
*element_ty,
slice_ty,
&mut |_, _| slice_ty,
rhs_value_node,
rhs_value_ty,
&mut |_, _| rhs_value_ty,
emit_diagnostic_and_short_circuit,
);
@@ -4092,6 +4099,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// As an optimization, prevent calling `__setitem__` on (unions of) large `TypedDict`s, and
// validate the assignment ourselves. This also allows us to emit better diagnostics.
// TODO: Use type context here.
let slice_ty = infer_slice_ty(self, TypeContext::default());
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
let mut valid = true;
let Some(keys) = key_literals(db, slice_ty) else {
// Check if the key has a valid type. We only allow string literals, a union of string literals,
@@ -4155,168 +4166,190 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
_ => {
match object_ty.try_call_dunder(
db,
"__setitem__",
CallArguments::positional([slice_ty, rhs_value_ty]),
TypeContext::default(),
) {
Ok(_) => true,
Err(err) => match err {
CallDunderError::PossiblyUnbound { .. } => {
if emit_diagnostic
&& let Some(builder) = self
.context
.report_lint(&POSSIBLY_MISSING_IMPLICIT_CALL, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Method `__setitem__` of type `{}` may be missing",
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
}
false
}
CallDunderError::CallError(call_error_kind, bindings) => {
match call_error_kind {
CallErrorKind::NotCallable => {
if emit_diagnostic
&& let Some(builder) =
self.context.report_lint(&CALL_NON_CALLABLE, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Method `__setitem__` of type `{}` is not callable \
on object of type `{}`",
bindings.callable_type().display(db),
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
}
}
CallErrorKind::BindingError => {
if let Some(typed_dict) = object_ty.as_typed_dict() {
if let Some(key) = slice_ty.as_string_literal() {
let key = key.value(db);
validate_typed_dict_key_assignment(
&self.context,
typed_dict,
full_object_ty,
key,
rhs_value_ty,
target.value.as_ref(),
target.slice.as_ref(),
rhs_value_node,
TypedDictAssignmentKind::Subscript,
true,
);
}
} else {
if emit_diagnostic
&& let Some(builder) = self.context.report_lint(
&INVALID_ASSIGNMENT,
target.range.cover(rhs_value_node.range()),
)
{
let assigned_d = rhs_value_ty.display(db);
let object_d = object_ty.display(db);
let ast_arguments = [
ArgOrKeyword::Arg(&target.slice),
ArgOrKeyword::Arg(rhs_value_node),
];
let mut diagnostic = builder.into_diagnostic(format_args!(
let mut call_arguments =
CallArguments::positional([Type::unknown(), Type::unknown()]);
let mut infer_argument_ty =
|builder: &mut Self, (argument_index, _, tcx): ArgExpr<'db, '_>| {
match argument_index {
0 => infer_slice_ty(builder, tcx),
1 => infer_rhs_value(builder, tcx),
_ => unreachable!(),
}
};
let Err(call_dunder_err) = self.infer_and_try_call_dunder(
db,
object_ty,
"__setitem__",
ArgumentsIter::synthesized(&ast_arguments),
&mut call_arguments,
&mut infer_argument_ty,
TypeContext::default(),
) else {
return true;
};
let [Some(slice_ty), Some(rhs_value_ty)] = call_arguments.types() else {
unreachable!();
};
match call_dunder_err {
CallDunderError::PossiblyUnbound { .. } => {
if emit_diagnostic
&& let Some(builder) = self
.context
.report_lint(&POSSIBLY_MISSING_IMPLICIT_CALL, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Method `__setitem__` of type `{}` may be missing",
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
}
false
}
CallDunderError::CallError(call_error_kind, bindings) => {
match call_error_kind {
CallErrorKind::NotCallable => {
if emit_diagnostic
&& let Some(builder) =
self.context.report_lint(&CALL_NON_CALLABLE, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Method `__setitem__` of type `{}` is not callable \
on object of type `{}`",
bindings.callable_type().display(db),
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
}
}
CallErrorKind::BindingError => {
if let Some(typed_dict) = object_ty.as_typed_dict() {
if let Some(key) = slice_ty.as_string_literal() {
let key = key.value(db);
validate_typed_dict_key_assignment(
&self.context,
typed_dict,
full_object_ty,
key,
*rhs_value_ty,
target.value.as_ref(),
target.slice.as_ref(),
rhs_value_node,
TypedDictAssignmentKind::Subscript,
true,
);
}
} else {
if emit_diagnostic
&& let Some(builder) = self.context.report_lint(
&INVALID_ASSIGNMENT,
target.range.cover(rhs_value_node.range()),
)
{
let assigned_d = rhs_value_ty.display(db);
let object_d = object_ty.display(db);
let mut diagnostic = builder.into_diagnostic(format_args!(
"Invalid subscript assignment with key of type `{}` and value of \
type `{assigned_d}` on object of type `{object_d}`",
slice_ty.display(db),
));
// Special diagnostic for dictionaries
if let Some([expected_key_ty, expected_value_ty]) =
object_ty
.known_specialization(db, KnownClass::Dict)
.map(|s| s.types(db))
{
if !slice_ty.is_assignable_to(db, *expected_key_ty)
{
diagnostic.annotate(
self.context
.secondary(target.slice.as_ref())
.message(format_args!(
"Expected key of type `{}`, got `{}`",
expected_key_ty.display(db),
slice_ty.display(db),
)),
);
}
if !rhs_value_ty
.is_assignable_to(db, *expected_value_ty)
{
diagnostic.annotate(
self.context
.secondary(rhs_value_node)
.message(format_args!(
"Expected value of type `{}`, got `{}`",
expected_value_ty.display(db),
rhs_value_ty.display(db),
)),
);
}
// Special diagnostic for dictionaries
if let Some([expected_key_ty, expected_value_ty]) =
object_ty
.known_specialization(db, KnownClass::Dict)
.map(|s| s.types(db))
{
if !slice_ty.is_assignable_to(db, *expected_key_ty) {
diagnostic.annotate(
self.context
.secondary(target.slice.as_ref())
.message(format_args!(
"Expected key of type `{}`, got `{}`",
expected_key_ty.display(db),
slice_ty.display(db),
)),
);
}
attach_original_type_info(&mut diagnostic);
if !rhs_value_ty
.is_assignable_to(db, *expected_value_ty)
{
diagnostic.annotate(
self.context.secondary(rhs_value_node).message(
format_args!(
"Expected value of type `{}`, got `{}`",
expected_value_ty.display(db),
rhs_value_ty.display(db),
),
),
);
}
}
}
}
CallErrorKind::PossiblyNotCallable => {
if emit_diagnostic
&& let Some(builder) =
self.context.report_lint(&CALL_NON_CALLABLE, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Method `__setitem__` of type `{}` may not be callable on object of type `{}`",
bindings.callable_type().display(db),
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
}
}
}
false
}
CallDunderError::MethodNotAvailable => {
if emit_diagnostic
&& let Some(builder) =
self.context.report_lint(&INVALID_ASSIGNMENT, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Cannot assign to a subscript on an object of type `{}`",
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
// If it's a user-defined class, suggest adding a `__setitem__` method.
if object_ty
.as_nominal_instance()
.and_then(|instance| {
instance.class(db).static_class_literal(db)
})
.and_then(|(class_literal, _)| {
file_to_module(db, class_literal.file(db))
})
.and_then(|module| module.search_path(db))
.is_some_and(ty_module_resolver::SearchPath::is_first_party)
CallErrorKind::PossiblyNotCallable => {
if emit_diagnostic
&& let Some(builder) =
self.context.report_lint(&CALL_NON_CALLABLE, target)
{
diagnostic.help(format_args!(
"Consider adding a `__setitem__` method to `{}`.",
object_ty.display(db),
));
} else {
diagnostic.info(format_args!(
"`{}` does not have a `__setitem__` method.",
object_ty.display(db),
));
let mut diagnostic = builder.into_diagnostic(format_args!(
"Method `__setitem__` of type `{}` may not be callable on object of type `{}`",
bindings.callable_type().display(db),
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
}
}
false
}
},
false
}
CallDunderError::MethodNotAvailable => {
if emit_diagnostic
&& let Some(builder) =
self.context.report_lint(&INVALID_ASSIGNMENT, target)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Cannot assign to a subscript on an object of type `{}`",
object_ty.display(db),
));
attach_original_type_info(&mut diagnostic);
// If it's a user-defined class, suggest adding a `__setitem__` method.
if object_ty
.as_nominal_instance()
.and_then(|instance| instance.class(db).static_class_literal(db))
.and_then(|(class_literal, _)| {
file_to_module(db, class_literal.file(db))
})
.and_then(|module| module.search_path(db))
.is_some_and(ty_module_resolver::SearchPath::is_first_party)
{
diagnostic.help(format_args!(
"Consider adding a `__setitem__` method to `{}`.",
object_ty.display(db),
));
} else {
diagnostic.info(format_args!(
"`{}` does not have a `__setitem__` method.",
object_ty.display(db),
));
}
}
false
}
}
}
}
@@ -5315,11 +5348,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
ast::Expr::Subscript(subscript_expr) => {
let assigned_ty = infer_assigned_ty.map(|f| f(self, TypeContext::default()));
self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown()));
if let Some(infer_assigned_ty) = infer_assigned_ty {
let infer_assigned_ty = &mut |builder: &mut Self, tcx| {
let assigned_ty = infer_assigned_ty(builder, tcx);
builder.store_expression_type(target, assigned_ty);
assigned_ty
};
if let Some(assigned_ty) = assigned_ty {
self.validate_subscript_assignment(subscript_expr, value, assigned_ty);
self.validate_subscript_assignment(subscript_expr, value, infer_assigned_ty);
}
}
@@ -6638,7 +6674,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
is_async: _,
} = for_statement;
self.infer_target(target, iter, |builder, tcx| {
self.infer_target(target, iter, &|builder, tcx| {
// TODO: `infer_for_statement_definition` reports a diagnostic if `iter_ty` isn't iterable
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `for a.x in not_iterable: ...
@@ -7511,10 +7547,54 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
#[expect(clippy::too_many_arguments)]
fn infer_and_try_call_dunder(
&mut self,
db: &'db dyn Db,
object: Type<'db>,
name: &str,
ast_arguments: ArgumentsIter<'_>,
argument_types: &mut CallArguments<'_, 'db>,
infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>,
call_expression_tcx: TypeContext<'db>,
) -> Result<Bindings<'db>, CallDunderError<'db>> {
match object
.member_lookup_with_policy(db, name.into(), MemberLookupPolicy::NO_INSTANCE_FALLBACK)
.place
{
Place::Defined(DefinedPlace {
ty: dunder_callable,
definedness: boundness,
..
}) => {
let mut bindings = dunder_callable
.bindings(db)
.match_parameters(db, argument_types);
if let Err(call_error) = self.infer_and_check_argument_types(
ast_arguments,
argument_types,
infer_argument_ty,
&mut bindings,
call_expression_tcx,
) {
return Err(CallDunderError::CallError(call_error, Box::new(bindings)));
}
if boundness == Definedness::PossiblyUndefined {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
}
Ok(bindings)
}
Place::Undefined => Err(CallDunderError::MethodNotAvailable),
}
}
fn infer_and_check_argument_types(
&mut self,
ast_arguments: &ast::Arguments,
ast_arguments: ArgumentsIter<'_>,
argument_types: &mut CallArguments<'_, 'db>,
infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>,
bindings: &mut Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
) -> Result<(), CallErrorKind> {
@@ -7546,8 +7626,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Attempt to infer the argument types using the narrowed type context.
self.infer_all_argument_types(
ast_arguments,
ast_arguments.clone(),
argument_types,
infer_argument_ty,
bindings,
narrowed_tcx,
MultiInferenceState::Ignore,
@@ -7586,8 +7667,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.context.set_multi_inference(was_in_multi_inference);
self.infer_all_argument_types(
ast_arguments,
ast_arguments.clone(),
argument_types,
infer_argument_ty,
bindings,
narrowed_tcx,
MultiInferenceState::Intersect,
@@ -7631,6 +7713,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_all_argument_types(
ast_arguments,
argument_types,
infer_argument_ty,
bindings,
call_expression_tcx,
MultiInferenceState::Intersect,
@@ -7651,13 +7734,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// behavior.
fn infer_all_argument_types(
&mut self,
ast_arguments: &ast::Arguments,
ast_arguments: ArgumentsIter<'_>,
arguments_types: &mut CallArguments<'_, 'db>,
infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>,
bindings: &Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
multi_inference_state: MultiInferenceState,
) {
debug_assert_eq!(ast_arguments.len(), arguments_types.len());
debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len());
let db = self.db();
@@ -7665,7 +7748,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
0..,
arguments_types.iter_mut(),
bindings.argument_forms().iter().copied(),
ast_arguments.arguments_source_order()
ast_arguments
);
let overloads_with_binding = bindings
@@ -7774,14 +7857,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// If there is only a single binding and overload, we can infer the argument directly with
// the unique parameter type annotation.
if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() {
*argument_type = Some(self.infer_expression(
ast_argument,
TypeContext::new(parameter_type(overload, binding)),
));
let tcx = TypeContext::new(parameter_type(overload, binding));
*argument_type = Some(infer_argument_ty(self, (argument_index, ast_argument, tcx)));
} else {
// We perform inference once without any type context, emitting any diagnostics that are unrelated
// to bidirectional type inference.
*argument_type = Some(self.infer_expression(ast_argument, TypeContext::default()));
*argument_type = Some(infer_argument_ty(
self,
(argument_index, ast_argument, TypeContext::default()),
));
// We then silence any diagnostics emitted during multi-inference, as the type context is only
// used as a hint to infer a more assignable argument type, and should not lead to diagnostics
@@ -7799,8 +7883,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if !seen.insert(parameter_type) {
continue;
}
let inferred_ty =
self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type)));
let tcx = TypeContext::new(Some(parameter_type));
let inferred_ty = infer_argument_ty(self, (argument_index, ast_argument, tcx));
// Ensure the inferred type is assignable to the declared type.
//
@@ -8244,9 +8329,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _,
} = list;
let elts = elts.iter().map(|elt| [Some(elt)]);
let infer_elt_ty = |builder: &mut Self, elt, tcx| builder.infer_expression(elt, tcx);
self.infer_collection_literal(elts, tcx, infer_elt_ty, KnownClass::List)
let mut elts = elts.iter().map(|elt| [Some(elt)]);
let infer_elt_ty =
&mut |builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(&mut elts, tcx, infer_elt_ty, KnownClass::List)
.unwrap_or_else(|| {
KnownClass::List.to_specialized_instance(self.db(), &[Type::unknown()])
})
@@ -8259,9 +8345,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
elts,
} = set;
let elts = elts.iter().map(|elt| [Some(elt)]);
let infer_elt_ty = |builder: &mut Self, elt, tcx| builder.infer_expression(elt, tcx);
self.infer_collection_literal(elts, tcx, infer_elt_ty, KnownClass::Set)
let mut elts = elts.iter().map(|elt| [Some(elt)]);
let infer_elt_ty =
&mut |builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(&mut elts, tcx, infer_elt_ty, KnownClass::Set)
.unwrap_or_else(|| {
KnownClass::Set.to_specialized_instance(self.db(), &[Type::unknown()])
})
@@ -8293,21 +8380,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()]);
}
let items = items
let mut items = items
.iter()
.map(|item| [item.key.as_ref(), Some(&item.value)]);
// Avoid inferring the items multiple times if we already attempted to infer the
// dictionary literal as a `TypedDict`. This also allows us to infer using the
// type context of the expected `TypedDict` field.
let infer_elt_ty = |builder: &mut Self, elt: &ast::Expr, tcx| {
let infer_elt_ty = &mut |builder: &mut Self, (_, elt, tcx): ArgExpr<'db, '_>| {
item_types
.get(&elt.node_index().load())
.copied()
.unwrap_or_else(|| builder.infer_expression(elt, tcx))
};
self.infer_collection_literal(items, tcx, infer_elt_ty, KnownClass::Dict)
self.infer_collection_literal(&mut items, tcx, infer_elt_ty, KnownClass::Dict)
.unwrap_or_else(|| {
KnownClass::Dict
.to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()])
@@ -8356,17 +8443,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// Infer the type of a collection literal expression.
fn infer_collection_literal<'expr, const N: usize, F, I>(
fn infer_collection_literal<'expr, const N: usize>(
&mut self,
elts: I,
elts: &mut dyn Iterator<Item = [Option<&'expr ast::Expr>; N]>,
tcx: TypeContext<'db>,
mut infer_elt_expression: F,
infer_elt_expression: &mut dyn FnMut(&mut Self, ArgExpr<'db, 'expr>) -> Type<'db>,
collection_class: KnownClass,
) -> Option<Type<'db>>
where
I: Iterator<Item = [Option<&'expr ast::Expr>; N]>,
F: FnMut(&mut Self, &'expr ast::Expr, TypeContext<'db>) -> Type<'db>,
{
) -> Option<Type<'db>> {
// Extract the type variable `T` from `list[T]` in typeshed.
let elt_tys = |collection_class: KnownClass| {
let collection_alias = collection_class
@@ -8388,8 +8471,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let Some((collection_alias, generic_context, elt_tys)) = elt_tys(collection_class) else {
// Infer the element types without type context, and fallback to unknown for
// custom typesheds.
for elt in elts.flatten().flatten() {
infer_elt_expression(self, elt, TypeContext::default());
for (i, elt) in elts.flatten().flatten().enumerate() {
infer_elt_expression(self, (i, elt, TypeContext::default()));
}
return None;
@@ -8483,7 +8566,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for elts in elts {
// An unpacking expression for a dictionary.
if let &[None, Some(value)] = elts.as_slice() {
let inferred_value_ty = infer_elt_expression(self, value, TypeContext::default());
let inferred_value_ty =
infer_elt_expression(self, (1, value, TypeContext::default()));
// Merge the inferred type of the nested dictionary.
if let Some(specialization) =
@@ -8502,7 +8586,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// The inferred type of each element acts as an additional constraint on `T`.
for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) {
for (i, elt, elt_ty) in itertools::izip!(0.., elts, elt_tys.clone()) {
let Some(elt) = elt else { continue };
// Note that unlike when preferring the declared type, we use covariant type
@@ -8511,7 +8595,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let elt_ty_identity = elt_ty.identity(self.db());
let elt_tcx = elt_tcx_constraints.get(&elt_ty_identity).copied();
let inferred_elt_ty = infer_elt_expression(self, elt, TypeContext::new(elt_tcx));
let inferred_elt_ty =
infer_elt_expression(self, (i, elt, TypeContext::new(elt_tcx)));
// Simplify the inference based on a non-covariant declared type.
if let Some(elt_tcx) =
@@ -8789,7 +8874,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
is_async: _,
} = comprehension;
self.infer_target(target, iter, |builder, tcx| {
self.infer_target(target, iter, &|builder, tcx| {
// TODO: `infer_comprehension_definition` reports a diagnostic if `iter_ty` isn't iterable
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `[... for a.x in not_iterable]
@@ -9033,7 +9118,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&mut self,
call_expression: &ast::ExprCall,
callable_type: Type<'db>,
tcx: TypeContext<'db>,
call_expression_tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprCall {
range: _,
@@ -9044,7 +9129,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Fast-path dict(...) in TypedDict context: infer keyword values against fields,
// then validate and return the TypedDict type.
if let Some(tcx) = tcx.annotation
if let Some(tcx) = call_expression_tcx.annotation
&& let Some(typed_dict) = tcx
.filter_union(self.db(), Type::is_typed_dict)
.as_typed_dict()
@@ -9281,10 +9366,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(bindings) = bindings {
let bindings = bindings.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(
arguments,
ArgumentsIter::from_ast(arguments),
&mut call_arguments,
&mut |builder, (_, expr, tcx)| builder.infer_expression(expr, tcx),
&bindings,
tcx,
call_expression_tcx,
MultiInferenceState::Intersect,
);
} else {
@@ -9296,7 +9382,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
return callable_type
.try_call_constructor(db, infer_call_arguments, tcx)
.try_call_constructor(db, infer_call_arguments, call_expression_tcx)
.unwrap_or_else(|err| {
err.report_diagnostic(&self.context, callable_type, call_expression.into());
err.return_type()
@@ -9308,8 +9394,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.bindings(self.db())
.match_parameters(self.db(), &call_arguments);
let bindings_result =
self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx);
let bindings_result = self.infer_and_check_argument_types(
ArgumentsIter::from_ast(arguments),
&mut call_arguments,
&mut |builder, (_, expr, tcx)| builder.infer_expression(expr, tcx),
&mut bindings,
call_expression_tcx,
);
// Validate `TypedDict` constructor calls after argument type inference
if let Some(class_literal) = callable_type.as_class_literal() {
@@ -12541,7 +12632,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generic_context: GenericContext<'db>,
) -> Type<'db> {
let db = self.db();
let specialize = |types: &[Option<Type<'db>>]| {
let specialize = &|types: &[Option<Type<'db>>]| {
Type::from(generic_class.apply_specialization(db, |_| {
generic_context.specialize_partial(db, types.iter().copied())
}))
@@ -12563,7 +12654,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generic_context: GenericContext<'db>,
) -> Type<'db> {
let db = self.db();
let specialize = |types: &[Option<Type<'db>>]| {
let specialize = &|types: &[Option<Type<'db>>]| {
let type_alias = generic_type_alias.apply_specialization(db, |_| {
generic_context.specialize_partial(db, types.iter().copied())
});
@@ -12584,7 +12675,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
subscript: &ast::ExprSubscript,
value_ty: Type<'db>,
generic_context: GenericContext<'db>,
specialize: impl FnOnce(&[Option<Type<'db>>]) -> Type<'db>,
specialize: &dyn Fn(&[Option<Type<'db>>]) -> Type<'db>,
) -> Type<'db> {
enum ExplicitSpecializationError {
InvalidParamSpec,
@@ -13577,6 +13668,38 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
/// An expression representing the function argument at the given index, along with its type
/// context.
type ArgExpr<'db, 'ast> = (usize, &'ast ast::Expr, TypeContext<'db>);
/// An iterator over arguments to a functional call.
#[derive(Clone)]
enum ArgumentsIter<'a> {
FromAst(ArgumentsSourceOrder<'a>),
Synthesized(std::slice::Iter<'a, ArgOrKeyword<'a>>),
}
impl<'a> ArgumentsIter<'a> {
fn from_ast(arguments: &'a ast::Arguments) -> Self {
Self::FromAst(arguments.arguments_source_order())
}
fn synthesized(arguments: &'a [ArgOrKeyword<'a>]) -> Self {
Self::Synthesized(arguments.iter())
}
}
impl<'a> Iterator for ArgumentsIter<'a> {
type Item = ArgOrKeyword<'a>;
fn next(&mut self) -> Option<Self::Item> {
match self {
ArgumentsIter::FromAst(args) => args.next(),
ArgumentsIter::Synthesized(args) => args.next().copied(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GenericContextError {
/// It's invalid to subscript `Generic` or `Protocol` with this type

View File

@@ -710,7 +710,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
match class_literal.generic_context(self.db()) {
Some(generic_context) => {
let db = self.db();
let specialize = |types: &[Option<Type<'db>>]| {
let specialize = &|types: &[Option<Type<'db>>]| {
SubclassOfType::from(
db,
class_literal.apply_specialization(db, |_| {
@@ -805,7 +805,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
};
}
let specialize = |types: &[Option<Type<'db>>]| {
let specialize = &|types: &[Option<Type<'db>>]| {
let specialized = value_ty.apply_specialization(
db,
generic_context.specialize_partial(db, types.iter().copied()),