diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index c74611d0b2..66bff3d7e0 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -117,7 +117,7 @@ static COLOUR_SCIENCE: std::sync::LazyLock> = std::sync::Lazy max_dep_date: "2025-06-17", python_version: PythonVersion::PY310, }, - 500, + 600, ) }); diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index bfb516f026..05b3d7ffed 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -1662,3 +1662,67 @@ def _(arg: tuple[A | B, Any]): reveal_type(f(arg)) # revealed: Unknown reveal_type(f(*(arg,))) # revealed: Unknown ``` + +## Bidirectional Type Inference + +```toml +[environment] +python-version = "3.12" +``` + +Type inference accounts for parameter type annotations across all overloads. + +```py +from typing import TypedDict, overload + +class T(TypedDict): + x: int + +@overload +def f(a: list[T], b: int) -> int: ... +@overload +def f(a: list[dict[str, int]], b: str) -> str: ... +def f(a: list[dict[str, int]] | list[T], b: int | str) -> int | str: + return 1 + +def int_or_str() -> int | str: + return 1 + +x = f([{"x": 1}], int_or_str()) +reveal_type(x) # revealed: int | str + +# TODO: error: [no-matching-overload] "No overload of function `f` matches arguments" +# we currently incorrectly consider `list[dict[str, int]]` a subtype of `list[T]` +f([{"y": 1}], int_or_str()) +``` + +Non-matching overloads do not produce diagnostics: + +```py +from typing import TypedDict, overload + +class T(TypedDict): + x: int + +@overload +def f(a: T, b: int) -> int: ... +@overload +def f(a: dict[str, int], b: str) -> str: ... +def f(a: T | dict[str, int], b: int | str) -> int | str: + return 1 + +x = f({"y": 1}, "a") +reveal_type(x) # revealed: str +``` + +```py +from typing import SupportsRound, overload + +@overload +def takes_str_or_float(x: str): ... +@overload +def takes_str_or_float(x: float): ... +def takes_str_or_float(x: float | str): ... + +takes_str_or_float(round(1.0)) +``` diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index b7df404383..c950f7f482 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -251,3 +251,59 @@ from ty_extensions import Intersection, Not def _(x: Union[Intersection[Any, Not[int]], Intersection[Any, Not[int]]]): reveal_type(x) # revealed: Any & ~int ``` + +## Bidirectional Type Inference + +```toml +[environment] +python-version = "3.12" +``` + +Type inference accounts for parameter type annotations across all signatures in a union. + +```py +from typing import TypedDict, overload + +class T(TypedDict): + x: int + +def _(flag: bool): + if flag: + def f(x: T) -> int: + return 1 + else: + def f(x: dict[str, int]) -> int: + return 1 + x = f({"x": 1}) + reveal_type(x) # revealed: int + + # TODO: error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `T`, found `dict[str, int]`" + # we currently consider `TypedDict` instances to be subtypes of `dict` + f({"y": 1}) +``` + +Diagnostics unrelated to the type-context are only reported once: + +```py +def f[T](x: T) -> list[T]: + return [x] + +def a(x: list[bool], y: list[bool]): ... +def b(x: list[int], y: list[int]): ... +def c(x: list[int], y: list[int]): ... +def _(x: int): + if x == 0: + y = a + elif x == 1: + y = b + else: + y = c + + if x == 0: + z = True + + y(f(True), [True]) + + # error: [possibly-unresolved-reference] "Name `z` used when possibly not defined" + y(f(True), [z]) +``` diff --git a/crates/ty_python_semantic/resources/mdtest/directives/assert_type.md b/crates/ty_python_semantic/resources/mdtest/directives/assert_type.md index 07ad5d555b..fd7c4d1649 100644 --- a/crates/ty_python_semantic/resources/mdtest/directives/assert_type.md +++ b/crates/ty_python_semantic/resources/mdtest/directives/assert_type.md @@ -10,6 +10,7 @@ from typing_extensions import assert_type def _(x: int): assert_type(x, int) # fine assert_type(x, str) # error: [type-assertion-failure] + assert_type(assert_type(x, int), int) ``` ## Narrowing diff --git a/crates/ty_python_semantic/resources/mdtest/snapshots/assert_type.md_-_`assert_type`_-_Basic_(c507788da2659ec9).snap b/crates/ty_python_semantic/resources/mdtest/snapshots/assert_type.md_-_`assert_type`_-_Basic_(c507788da2659ec9).snap index 4b9da52c7b..42b1ad283f 100644 --- a/crates/ty_python_semantic/resources/mdtest/snapshots/assert_type.md_-_`assert_type`_-_Basic_(c507788da2659ec9).snap +++ b/crates/ty_python_semantic/resources/mdtest/snapshots/assert_type.md_-_`assert_type`_-_Basic_(c507788da2659ec9).snap @@ -17,6 +17,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/directives/assert_type.m 3 | def _(x: int): 4 | assert_type(x, int) # fine 5 | assert_type(x, str) # error: [type-assertion-failure] +6 | assert_type(assert_type(x, int), int) ``` # Diagnostics @@ -31,6 +32,7 @@ error[type-assertion-failure]: Argument does not have asserted type `str` | ^^^^^^^^^^^^-^^^^^^ | | | Inferred type of argument is `int` +6 | assert_type(assert_type(x, int), int) | info: `str` and `int` are not equivalent types info: rule `type-assertion-failure` is enabled by default diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index b96b953e20..dc3734f368 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -152,7 +152,7 @@ Person(name="Alice") # error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor" Person({"name": "Alice"}) -# TODO: this should be an error, similar to the above +# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor" accepts_person({"name": "Alice"}) # TODO: this should be an error, similar to the above house.owner = {"name": "Alice"} @@ -171,7 +171,7 @@ Person(name=None, age=30) # error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`" Person({"name": None, "age": 30}) -# TODO: this should be an error, similar to the above +# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`" accepts_person({"name": None, "age": 30}) # TODO: this should be an error, similar to the above house.owner = {"name": None, "age": 30} @@ -190,7 +190,7 @@ Person(name="Alice", age=30, extra=True) # error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra"" Person({"name": "Alice", "age": 30, "extra": True}) -# TODO: this should be an error +# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra"" accepts_person({"name": "Alice", "age": 30, "extra": True}) # TODO: this should be an error house.owner = {"name": "Alice", "age": 30, "extra": True} diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index bb359c0b54..9d0e04bbba 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4194,20 +4194,26 @@ impl<'db> Type<'db> { .into() } - Some(KnownFunction::AssertType) => Binding::single( - self, - Signature::new( - Parameters::new([ - Parameter::positional_only(Some(Name::new_static("value"))) - .with_annotated_type(Type::any()), - Parameter::positional_only(Some(Name::new_static("type"))) - .type_form() - .with_annotated_type(Type::any()), - ]), - Some(Type::none(db)), - ), - ) - .into(), + Some(KnownFunction::AssertType) => { + let val_ty = + BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Invariant); + + Binding::single( + self, + Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [val_ty])), + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("value"))) + .with_annotated_type(Type::TypeVar(val_ty)), + Parameter::positional_only(Some(Name::new_static("type"))) + .type_form() + .with_annotated_type(Type::any()), + ]), + Some(Type::TypeVar(val_ty)), + ), + ) + .into() + } Some(KnownFunction::AssertNever) => { Binding::single( diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 0451ec8a09..9d101c40b2 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -1077,7 +1077,11 @@ impl<'db> InnerIntersectionBuilder<'db> { // don't need to worry about finding any particular constraint more than once. let constraints = constraints.elements(db); let mut positive_constraint_count = 0; - for positive in &self.positive { + for (i, positive) in self.positive.iter().enumerate() { + if i == typevar_index { + continue; + } + // This linear search should be fine as long as we don't encounter typevars with // thousands of constraints. positive_constraint_count += constraints diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 0f2853a5be..60a0394ec8 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -33,10 +33,10 @@ use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType, - WrapperDescriptorKind, enums, ide_support, todo_type, + WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; -use ruff_python_ast::{self as ast, PythonVersion}; +use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; /// Binding information for a possible union of callables. At a call site, the arguments must be /// compatible with _all_ of the types in the union for the call to be valid. @@ -1776,7 +1776,7 @@ impl<'db> CallableBinding<'db> { } /// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`]. - fn matching_overload_index(&self) -> MatchingOverloadIndex { + pub(crate) fn matching_overload_index(&self) -> MatchingOverloadIndex { let mut matching_overloads = self.matching_overloads(); match matching_overloads.next() { None => MatchingOverloadIndex::None, @@ -1794,8 +1794,15 @@ impl<'db> CallableBinding<'db> { } } + /// Returns all overloads for this call binding, including overloads that did not match. + pub(crate) fn overloads(&self) -> &[Binding<'db>] { + self.overloads.as_slice() + } + /// Returns an iterator over all the overloads that matched for this call binding. - pub(crate) fn matching_overloads(&self) -> impl Iterator)> { + pub(crate) fn matching_overloads( + &self, + ) -> impl Iterator)> + Clone { self.overloads .iter() .enumerate() @@ -2026,7 +2033,7 @@ enum OverloadCallReturnType<'db> { } #[derive(Debug)] -enum MatchingOverloadIndex { +pub(crate) enum MatchingOverloadIndex { /// No matching overloads found. None, @@ -2504,9 +2511,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { if let Some(return_ty) = self.signature.return_ty && let Some(call_expression_tcx) = self.call_expression_tcx.annotation { - // Ignore any specialization errors here, because the type context is only used to - // optionally widen the return type. - let _ = builder.infer(return_ty, call_expression_tcx); + match call_expression_tcx { + // A type variable is not a useful type-context for expression inference, and applying it + // to the return type can lead to confusing unions in nested generic calls. + Type::TypeVar(_) => {} + + _ => { + // Ignore any specialization errors here, because the type context is only used as a hint + // to infer a more assignable return type. + let _ = builder.infer(return_ty, call_expression_tcx); + } + } } let parameters = self.signature.parameters(); @@ -3289,6 +3304,23 @@ impl<'db> BindingError<'db> { return; }; + // Re-infer the argument type of call expressions, ignoring the type context for more + // precise error messages. + let provided_ty = match Self::get_argument_node(node, *argument_index) { + None => *provided_ty, + + // Ignore starred arguments, as those are difficult to re-infer. + Some( + ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) + | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }), + ) => *provided_ty, + + Some( + ast::ArgOrKeyword::Arg(value) + | ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }), + ) => infer_isolated_expression(context.db(), context.scope(), value), + }; + let provided_ty_display = provided_ty.display(context.db()); let expected_ty_display = expected_ty.display(context.db()); @@ -3624,22 +3656,29 @@ impl<'db> BindingError<'db> { } } - fn get_node(node: ast::AnyNodeRef, argument_index: Option) -> ast::AnyNodeRef { + fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option) -> ast::AnyNodeRef<'_> { // If we have a Call node and an argument index, report the diagnostic on the correct // argument node; otherwise, report it on the entire provided node. + match Self::get_argument_node(node, argument_index) { + Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(), + Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(), + None => node, + } + } + + fn get_argument_node( + node: ast::AnyNodeRef<'_>, + argument_index: Option, + ) -> Option> { match (node, argument_index) { - (ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => { - match call_node + (ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => Some( + call_node .arguments .arguments_source_order() .nth(argument_index) - .expect("argument index should not be out of range") - { - ast::ArgOrKeyword::Arg(expr) => expr.into(), - ast::ArgOrKeyword::Keyword(keyword) => keyword.into(), - } - } - _ => node, + .expect("argument index should not be out of range"), + ), + _ => None, } } } diff --git a/crates/ty_python_semantic/src/types/context.rs b/crates/ty_python_semantic/src/types/context.rs index d13d99a75e..c3901ecbad 100644 --- a/crates/ty_python_semantic/src/types/context.rs +++ b/crates/ty_python_semantic/src/types/context.rs @@ -40,6 +40,7 @@ pub(crate) struct InferContext<'db, 'ast> { module: &'ast ParsedModuleRef, diagnostics: std::cell::RefCell, no_type_check: InNoTypeCheck, + multi_inference: bool, bomb: DebugDropBomb, } @@ -50,6 +51,7 @@ impl<'db, 'ast> InferContext<'db, 'ast> { scope, module, file: scope.file(db), + multi_inference: false, diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()), no_type_check: InNoTypeCheck::default(), bomb: DebugDropBomb::new( @@ -156,6 +158,18 @@ impl<'db, 'ast> InferContext<'db, 'ast> { DiagnosticGuardBuilder::new(self, id, severity) } + /// Returns `true` if the current expression is being inferred for a second + /// (or subsequent) time, with a potentially different bidirectional type + /// context. + pub(super) fn is_in_multi_inference(&self) -> bool { + self.multi_inference + } + + /// Set the multi-inference state, returning the previous value. + pub(super) fn set_multi_inference(&mut self, multi_inference: bool) -> bool { + std::mem::replace(&mut self.multi_inference, multi_inference) + } + pub(super) fn set_in_no_type_check(&mut self, no_type_check: InNoTypeCheck) { self.no_type_check = no_type_check; } @@ -410,6 +424,11 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> { if ctx.is_in_no_type_check() { return None; } + // If this lint is being reported as part of multi-inference of a given expression, + // silence it to avoid duplicated diagnostics. + if ctx.is_in_multi_inference() { + return None; + } let id = DiagnosticId::Lint(lint.name()); let suppressions = suppressions(ctx.db(), ctx.file()); @@ -575,6 +594,11 @@ impl<'db, 'ctx> DiagnosticGuardBuilder<'db, 'ctx> { if !ctx.db.should_check_file(ctx.file) { return None; } + // If this lint is being reported as part of multi-inference of a given expression, + // silence it to avoid duplicated diagnostics. + if ctx.is_in_multi_inference() { + return None; + } Some(DiagnosticGuardBuilder { ctx, id, severity }) } diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index ea7d54841c..ebdd72b2b1 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -1975,7 +1975,7 @@ pub(super) fn report_invalid_assignment<'db>( if let DefinitionKind::AnnotatedAssignment(annotated_assignment) = definition.kind(context.db()) && let Some(value) = annotated_assignment.value(context.module()) { - // Re-infer the RHS of the annotated assignment, ignoring the type context, for more precise + // Re-infer the RHS of the annotated assignment, ignoring the type context for more precise // error messages. source_ty = infer_isolated_expression(context.db(), definition.scope(context.db()), value); } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index b4fe2b5a23..94a823c8c1 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -1,4 +1,4 @@ -use std::iter; +use std::{iter, mem}; use itertools::{Either, Itertools}; use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity}; @@ -44,6 +44,7 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, }; +use crate::types::call::bind::MatchingOverloadIndex; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; use crate::types::context::{InNoTypeCheck, InferContext}; @@ -88,12 +89,13 @@ use crate::types::typed_dict::{ }; use crate::types::visitor::any_over_type; use crate::types::{ - CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, - IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, - MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, - SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, - TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, - TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, + CallDunderError, CallableBinding, CallableType, ClassLiteral, ClassType, DataclassParams, + DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, + MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, + Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, + TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, + TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, + TypedDictType, UnionBuilder, UnionType, binding_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -257,6 +259,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// is a stub file but we're still in a non-deferred region. deferred_state: DeferredExpressionState, + multi_inference_state: MultiInferenceState, + /// For function definitions, the undecorated type of the function. undecorated_type: Option>, @@ -287,10 +291,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { context: InferContext::new(db, scope, module), index, region, + scope, return_types_and_ranges: vec![], called_functions: FxHashSet::default(), deferred_state: DeferredExpressionState::None, - scope, + multi_inference_state: MultiInferenceState::Panic, expressions: FxHashMap::default(), bindings: VecMap::default(), declarations: VecMap::default(), @@ -4911,6 +4916,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_expression(expression, TypeContext::default()) } + /// Infer the argument types for a single binding. fn infer_argument_types<'a>( &mut self, ast_arguments: &ast::Arguments, @@ -4920,22 +4926,155 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { debug_assert!( ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len() ); - let iter = (arguments.iter_mut()) - .zip(argument_forms.iter().copied()) - .zip(ast_arguments.arguments_source_order()); - for (((_, argument_type), form), arg_or_keyword) in iter { - let argument = match arg_or_keyword { - // We already inferred the type of splatted arguments. + + let iter = itertools::izip!( + arguments.iter_mut(), + argument_forms.iter().copied(), + ast_arguments.arguments_source_order() + ); + + for ((_, argument_type), argument_form, ast_argument) in iter { + let argument = match ast_argument { + // Splatted arguments are inferred before parameter matching to + // determine their length. ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue, + ast::ArgOrKeyword::Arg(arg) => arg, ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, }; - let ty = self.infer_argument_type(argument, form, TypeContext::default()); + + let ty = self.infer_argument_type(argument, argument_form, TypeContext::default()); *argument_type = Some(ty); } } + /// Infer the argument types for multiple potential bindings and overloads. + fn infer_all_argument_types<'a>( + &mut self, + ast_arguments: &ast::Arguments, + arguments: &mut CallArguments<'a, 'db>, + bindings: &Bindings<'db>, + ) { + debug_assert!( + ast_arguments.len() == arguments.len() + && arguments.len() == bindings.argument_forms().len() + ); + + let iter = itertools::izip!( + 0.., + arguments.iter_mut(), + bindings.argument_forms().iter().copied(), + ast_arguments.arguments_source_order() + ); + + let overloads_with_binding = bindings + .into_iter() + .filter_map(|binding| { + match binding.matching_overload_index() { + MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => { + let overloads = binding + .matching_overloads() + .map(move |(_, overload)| (overload, binding)); + + Some(Either::Right(overloads)) + } + + // If there is a single overload that does not match, we still infer the argument + // types for better diagnostics. + MatchingOverloadIndex::None => match binding.overloads() { + [overload] => Some(Either::Left(std::iter::once((overload, binding)))), + _ => None, + }, + } + }) + .flatten(); + + for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { + let ast_argument = match ast_argument { + // Splatted arguments are inferred before parameter matching to + // determine their length. + // + // TODO: Re-infer splatted arguments with their type context. + ast::ArgOrKeyword::Arg(ast::Expr::Starred(_)) + | ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue, + + ast::ArgOrKeyword::Arg(arg) => arg, + ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, + }; + + // Type-form arguments are inferred without type context, so we can infer the argument type directly. + if let Some(ParameterForm::Type) = argument_form { + *argument_type = Some(self.infer_type_expression(ast_argument)); + continue; + } + + // Retrieve the parameter type for the current argument in a given overload and its binding. + let parameter_type = |overload: &Binding<'db>, binding: &CallableBinding<'db>| { + let argument_index = if binding.bound_type.is_some() { + argument_index + 1 + } else { + argument_index + }; + + let argument_matches = &overload.argument_matches()[argument_index]; + let [parameter_index] = argument_matches.parameters.as_slice() else { + return None; + }; + + overload.signature.parameters()[*parameter_index].annotated_type() + }; + + // If there is only a single binding and overload, we can infer the argument directly with + // the unique parameter type annotation. + if let Ok((overload, binding)) = overloads_with_binding.clone().exactly_one() { + self.infer_expression_impl( + ast_argument, + TypeContext::new(parameter_type(overload, binding)), + ); + } else { + // Otherwise, each type is a valid independent inference of the given argument, and we may + // require different permutations of argument types to correctly perform argument expansion + // during overload evaluation, so we take the intersection of all the types we inferred for + // each argument. + // + // Note that this applies to all nested expressions within each argument. + let old_multi_inference_state = mem::replace( + &mut self.multi_inference_state, + MultiInferenceState::Intersect, + ); + + // We perform inference once without any type context, emitting any diagnostics that are unrelated + // to bidirectional type inference. + self.infer_expression_impl(ast_argument, TypeContext::default()); + + // We then silence any diagnostics emitted during multi-inference, as the type context is only + // used as a hint to infer a more assignable argument type, and should not lead to diagnostics + // for non-matching overloads. + let was_in_multi_inference = self.context.set_multi_inference(true); + + // Infer the type of each argument once with each distinct parameter type as type context. + let parameter_types = overloads_with_binding + .clone() + .filter_map(|(overload, binding)| parameter_type(overload, binding)) + .collect::>(); + + for parameter_type in parameter_types { + self.infer_expression_impl( + ast_argument, + TypeContext::new(Some(parameter_type)), + ); + } + + // Restore the multi-inference state. + self.multi_inference_state = old_multi_inference_state; + self.context.set_multi_inference(was_in_multi_inference); + } + + *argument_type = self.try_expression_type(ast_argument); + } + } + fn infer_argument_type( &mut self, ast_argument: &ast::Expr, @@ -4956,6 +5095,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { expression.map(|expr| self.infer_expression(expr, tcx)) } + fn get_or_infer_expression( + &mut self, + expression: &ast::Expr, + tcx: TypeContext<'db>, + ) -> Type<'db> { + self.try_expression_type(expression) + .unwrap_or_else(|| self.infer_expression(expression, tcx)) + } + #[track_caller] fn infer_expression(&mut self, expression: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> { debug_assert!( @@ -5016,6 +5164,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { types.expression_type(expression) } + /// Infer the type of an expression. fn infer_expression_impl( &mut self, expression: &ast::Expr, @@ -5051,7 +5200,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::Compare(compare) => self.infer_compare_expression(compare), ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript), ast::Expr::Slice(slice) => self.infer_slice_expression(slice), - ast::Expr::Named(named) => self.infer_named_expression(named), ast::Expr::If(if_expression) => self.infer_if_expression(if_expression), ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression), ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx), @@ -5059,6 +5207,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression), ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from), ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression), + ast::Expr::Named(named) => { + // Definitions must be unique, so we bypass multi-inference for named expressions. + if !self.multi_inference_state.is_panic() + && let Some(ty) = self.expressions.get(&expression.into()) + { + return *ty; + } + + self.infer_named_expression(named) + } ast::Expr::IpyEscapeCommand(_) => { todo_type!("Ipy escape command support") } @@ -5068,6 +5226,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ty } + fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { if self.deferred_state.in_string_annotation() { // Avoid storing the type of expressions that are part of a string annotation because @@ -5075,8 +5234,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // on the string expression itself that represents the annotation. return; } - let previous = self.expressions.insert(expression.into(), ty); - assert_eq!(previous, None); + + let db = self.db(); + + match self.multi_inference_state { + MultiInferenceState::Panic => { + let previous = self.expressions.insert(expression.into(), ty); + assert_eq!(previous, None); + } + + MultiInferenceState::Intersect => { + self.expressions + .entry(expression.into()) + .and_modify(|current| { + *current = IntersectionType::from_elements(db, [*current, ty]); + }) + .or_insert(ty); + } + } } fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> { @@ -5297,31 +5472,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = dict; // Validate `TypedDict` dictionary literal assignments. - if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict) { - let typed_dict_items = typed_dict.items(self.db()); - - for item in items { - self.infer_optional_expression(item.key.as_ref(), TypeContext::default()); - - if let Some(ast::Expr::StringLiteral(ref key)) = item.key - && let Some(key) = key.as_single_part_string() - && let Some(field) = typed_dict_items.get(key.as_str()) - { - self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty))); - } else { - self.infer_expression(&item.value, TypeContext::default()); - } - } - - validate_typed_dict_dict_literal( - &self.context, - typed_dict, - dict, - dict.into(), - |expr| self.expression_type(expr), - ); - - return Type::TypedDict(typed_dict); + if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict) + && let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict) + { + return ty; } // Avoid false positives for the functional `TypedDict` form, which is currently @@ -5342,6 +5496,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }) } + fn infer_typed_dict_expression( + &mut self, + dict: &ast::ExprDict, + typed_dict: TypedDictType<'db>, + ) -> Option> { + let ast::ExprDict { + range: _, + node_index: _, + items, + } = dict; + + let typed_dict_items = typed_dict.items(self.db()); + + for item in items { + self.infer_optional_expression(item.key.as_ref(), TypeContext::default()); + + if let Some(ast::Expr::StringLiteral(ref key)) = item.key + && let Some(key) = key.as_single_part_string() + && let Some(field) = typed_dict_items.get(key.as_str()) + { + self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty))); + } else { + self.infer_expression(&item.value, TypeContext::default()); + } + } + + validate_typed_dict_dict_literal(&self.context, typed_dict, dict, dict.into(), |expr| { + self.expression_type(expr) + }) + .ok() + .map(|_| Type::TypedDict(typed_dict)) + } + // Infer the type of a collection literal expression. fn infer_collection_literal<'expr, const N: usize>( &mut self, @@ -5399,7 +5586,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for elts in elts { // An unpacking expression for a dictionary. if let &[None, Some(value)] = elts.as_slice() { - let inferred_value_ty = self.infer_expression(value, TypeContext::default()); + let inferred_value_ty = self.get_or_infer_expression(value, TypeContext::default()); // Merge the inferred type of the nested dictionary. if let Some(specialization) = @@ -5420,9 +5607,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // The inferred type of each element acts as an additional constraint on `T`. for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone()) { - let Some(inferred_elt_ty) = self.infer_optional_expression(elt, elt_tcx) else { - continue; - }; + let Some(elt) = elt else { continue }; + + let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx); // Convert any element literals to their promoted type form to avoid excessively large // unions for large nested list literals, which the constraint solver struggles with. @@ -5967,7 +6154,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let bindings = callable_type .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms()); + self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.into_class_literal() { @@ -9087,6 +9274,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // builder only state typevar_binding_context: _, deferred_state: _, + multi_inference_state: _, called_functions: _, index: _, region: _, @@ -9149,6 +9337,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // builder only state typevar_binding_context: _, deferred_state: _, + multi_inference_state: _, called_functions: _, index: _, region: _, @@ -9220,6 +9409,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Builder only state typevar_binding_context: _, deferred_state: _, + multi_inference_state: _, called_functions: _, index: _, region: _, @@ -9265,6 +9455,26 @@ impl GenericContextError { } } +/// Dictates the behavior when an expression is inferred multiple times. +#[derive(Default, Debug, Clone, Copy)] +enum MultiInferenceState { + /// Panic if the expression has already been inferred. + #[default] + Panic, + + /// Store the intersection of all types inferred for the expression. + Intersect, +} + +impl MultiInferenceState { + fn is_panic(self) -> bool { + match self { + MultiInferenceState::Panic => true, + MultiInferenceState::Intersect => false, + } + } +} + /// The deferred state of a specific expression in an inference region. #[derive(Default, Debug, Clone, Copy)] enum DeferredExpressionState { @@ -9538,7 +9748,7 @@ impl Default for VecMap { /// Set based on a `Vec`. It doesn't enforce /// uniqueness on insertion. Instead, it relies on the caller -/// that elements are uniuqe. For example, the way we visit definitions +/// that elements are unique. For example, the way we visit definitions /// in the `TypeInference` builder make already implicitly guarantees that each definition /// is only visited once. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/crates/ty_python_semantic/src/types/typed_dict.rs b/crates/ty_python_semantic/src/types/typed_dict.rs index c1b241093b..3cfa861849 100644 --- a/crates/ty_python_semantic/src/types/typed_dict.rs +++ b/crates/ty_python_semantic/src/types/typed_dict.rs @@ -132,7 +132,8 @@ impl TypedDictAssignmentKind { } /// Validates assignment of a value to a specific key on a `TypedDict`. -/// Returns true if the assignment is valid, false otherwise. +/// +/// Returns true if the assignment is valid, or false otherwise. #[allow(clippy::too_many_arguments)] pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( context: &InferContext<'db, 'ast>, @@ -157,6 +158,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( Type::string_literal(db, key), &items, ); + return false; }; @@ -240,13 +242,16 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( } /// Validates that all required keys are provided in a `TypedDict` construction. +/// /// Reports errors for any keys that are required but not provided. +/// +/// Returns true if the assignment is valid, or false otherwise. pub(super) fn validate_typed_dict_required_keys<'db, 'ast>( context: &InferContext<'db, 'ast>, typed_dict: TypedDictType<'db>, provided_keys: &OrderSet<&str>, error_node: AnyNodeRef<'ast>, -) { +) -> bool { let db = context.db(); let items = typed_dict.items(db); @@ -255,7 +260,12 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>( .filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str())) .collect(); - for missing_key in required_keys.difference(provided_keys) { + let missing_keys = required_keys.difference(provided_keys); + + let mut has_missing_key = false; + for missing_key in missing_keys { + has_missing_key = true; + report_missing_typed_dict_key( context, error_node, @@ -263,6 +273,8 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>( missing_key, ); } + + !has_missing_key } pub(super) fn validate_typed_dict_constructor<'db, 'ast>( @@ -373,7 +385,7 @@ fn validate_from_keywords<'db, 'ast>( provided_keys } -/// Validates a `TypedDict` dictionary literal assignment +/// Validates a `TypedDict` dictionary literal assignment, /// e.g. `person: Person = {"name": "Alice", "age": 30}` pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( context: &InferContext<'db, 'ast>, @@ -381,7 +393,8 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( dict_expr: &'ast ast::ExprDict, error_node: AnyNodeRef<'ast>, expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>, -) -> OrderSet<&'ast str> { +) -> Result, OrderSet<&'ast str>> { + let mut valid = true; let mut provided_keys = OrderSet::new(); // Validate each key-value pair in the dictionary literal @@ -392,7 +405,8 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( provided_keys.insert(key_str); let value_type = expression_type_fn(&item.value); - validate_typed_dict_key_assignment( + + valid &= validate_typed_dict_key_assignment( context, typed_dict, key_str, @@ -406,7 +420,11 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( } } - validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node); + valid &= validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node); - provided_keys + if valid { + Ok(provided_keys) + } else { + Err(provided_keys) + } }