From 36888198a61eeb86f2a3a416a99ac537676c6bbc Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Thu, 11 Sep 2025 15:19:12 -0400 Subject: [PATCH] [ty] Integrate type context for bidirectional inference (#20337) ## Summary Adds the infrastructure necessary to perform bidirectional type inference (https://github.com/astral-sh/ty/issues/168) without any typing changes. --- crates/ty_python_semantic/src/dunder_all.rs | 5 +- .../reachability_constraints.rs | 21 +- crates/ty_python_semantic/src/types.rs | 67 +--- crates/ty_python_semantic/src/types/class.rs | 16 +- crates/ty_python_semantic/src/types/infer.rs | 122 ++++++- .../src/types/infer/builder.rs | 334 ++++++++++-------- .../infer/builder/annotation_expression.rs | 8 +- .../types/infer/builder/type_expression.rs | 46 +-- .../src/types/infer/tests.rs | 248 +++++++++---- crates/ty_python_semantic/src/types/narrow.rs | 14 +- .../ty_python_semantic/src/types/unpacker.rs | 7 +- 11 files changed, 548 insertions(+), 340 deletions(-) diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 3827d9762a..10eab9321a 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -7,7 +7,7 @@ use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt}; use ruff_python_ast::{self as ast}; use crate::semantic_index::{SemanticIndex, semantic_index}; -use crate::types::{Truthiness, Type, infer_expression_types}; +use crate::types::{Truthiness, Type, TypeContext, infer_expression_types}; use crate::{Db, ModuleName, resolve_module}; #[allow(clippy::ref_option)] @@ -182,7 +182,8 @@ impl<'db> DunderAllNamesCollector<'db> { /// /// This function panics if `expr` was not marked as a standalone expression during semantic indexing. fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { - infer_expression_types(self.db, self.index.expression(expr)).expression_type(expr) + infer_expression_types(self.db, self.index.expression(expr), TypeContext::default()) + .expression_type(expr) } /// Evaluate the given expression and return its truthiness. diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 0d646ca626..2ef1fe4c6a 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -208,8 +208,8 @@ use crate::semantic_index::predicate::{ Predicates, ScopedPredicateId, }; use crate::types::{ - IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type, - static_expression_truthiness, + IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType, + infer_expression_type, static_expression_truthiness, }; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula @@ -328,10 +328,12 @@ fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> { match kind { PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton), - PatternPredicateKind::Value(value) => infer_expression_type(db, *value), + PatternPredicateKind::Value(value) => { + infer_expression_type(db, *value, TypeContext::default()) + } PatternPredicateKind::Class(class_expr, kind) => { if kind.is_irrefutable() { - infer_expression_type(db, *class_expr) + infer_expression_type(db, *class_expr, TypeContext::default()) .to_instance(db) .unwrap_or(Type::Never) } else { @@ -718,7 +720,7 @@ impl ReachabilityConstraints { ) -> Truthiness { match predicate_kind { PatternPredicateKind::Value(value) => { - let value_ty = infer_expression_type(db, *value); + let value_ty = infer_expression_type(db, *value, TypeContext::default()); if subject_ty.is_single_valued(db) { Truthiness::from(subject_ty.is_equivalent_to(db, value_ty)) @@ -769,7 +771,8 @@ impl ReachabilityConstraints { truthiness } PatternPredicateKind::Class(class_expr, kind) => { - let class_ty = infer_expression_type(db, *class_expr).to_instance(db); + let class_ty = + infer_expression_type(db, *class_expr, TypeContext::default()).to_instance(db); class_ty.map_or(Truthiness::Ambiguous, |class_ty| { if subject_ty.is_subtype_of(db, class_ty) { @@ -797,7 +800,7 @@ impl ReachabilityConstraints { } fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { - let subject_ty = infer_expression_type(db, predicate.subject(db)); + let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default()); let narrowed_subject_ty = IntersectionBuilder::new(db) .add_positive(subject_ty) @@ -837,7 +840,7 @@ impl ReachabilityConstraints { // selection algorithm). // Avoiding this on the happy-path is important because these constraints can be // very large in number, since we add them on all statement level function calls. - let ty = infer_expression_type(db, callable); + let ty = infer_expression_type(db, callable, TypeContext::default()); // Short-circuit for well known types that are known not to return `Never` when called. // Without the short-circuit, we've seen that threads keep blocking each other @@ -875,7 +878,7 @@ impl ReachabilityConstraints { } else if all_overloads_return_never { Truthiness::AlwaysTrue } else { - let call_expr_ty = infer_expression_type(db, call_expr); + let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default()); if call_expr_ty.is_equivalent_to(db, Type::Never) { Truthiness::AlwaysTrue } else { diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 42e42eb5fd..f462a4021b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -23,8 +23,8 @@ pub(crate) use self::cyclic::{CycleDetector, PairVisitor, TypeTransformer}; pub use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::diagnostic::register_lints; pub(crate) use self::infer::{ - infer_deferred_types, infer_definition_types, infer_expression_type, infer_expression_types, - infer_scope_types, static_expression_truthiness, + TypeContext, infer_deferred_types, infer_definition_types, infer_expression_type, + infer_expression_types, infer_scope_types, static_expression_truthiness, }; pub(crate) use self::signatures::{CallableSignature, Signature}; pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType}; @@ -10824,12 +10824,10 @@ static_assertions::assert_eq_size!(Type, [u8; 16]); pub(crate) mod tests { use super::*; use crate::db::tests::{TestDbBuilder, setup_db}; - use crate::place::{global_symbol, typing_extensions_symbol, typing_symbol}; + use crate::place::{typing_extensions_symbol, typing_symbol}; use crate::semantic_index::FileScopeId; use ruff_db::files::system_path_to_file; - use ruff_db::parsed::parsed_module; use ruff_db::system::DbWithWritableSystem as _; - use ruff_db::testing::assert_function_query_was_not_run; use ruff_python_ast::PythonVersion; use test_case::test_case; @@ -10868,65 +10866,6 @@ pub(crate) mod tests { ); } - /// Inferring the result of a call-expression shouldn't need to re-run after - /// a trivial change to the function's file (e.g. by adding a docstring to the function). - #[test] - fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_dedented( - "src/foo.py", - r#" - def foo() -> int: - return 5 - "#, - )?; - db.write_dedented( - "src/bar.py", - r#" - from foo import foo - - a = foo() - "#, - )?; - - let bar = system_path_to_file(&db, "src/bar.py")?; - let a = global_symbol(&db, bar, "a").place; - - assert_eq!( - a.expect_type(), - UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)]) - ); - - // Add a docstring to foo to trigger a re-run. - // The bar-call site of foo should not be re-run because of that - db.write_dedented( - "src/foo.py", - r#" - def foo() -> int: - "Computes a value" - return 5 - "#, - )?; - db.clear_salsa_events(); - - let a = global_symbol(&db, bar, "a").place; - - assert_eq!( - a.expect_type(), - UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)]) - ); - let events = db.take_salsa_events(); - - let module = parsed_module(&db, bar).load(&db); - let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; - let foo_call = semantic_index(&db, bar).expression(call); - - assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events); - - Ok(()) - } - /// All other tests also make sure that `Type::Todo` works as expected. This particular /// test makes sure that we handle `Todo` types correctly, even if they originate from /// different sources. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 9a43c78c73..726cb1104e 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -28,9 +28,9 @@ use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, - NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, - TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, - UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, + NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext, + TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, + TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -2926,7 +2926,11 @@ impl<'db> ClassLiteral<'db> { // `self.SOME_CONSTANT: Final = 1`, infer the type from the value // on the right-hand side. - let inferred_ty = infer_expression_type(db, index.expression(value)); + let inferred_ty = infer_expression_type( + db, + index.expression(value), + TypeContext::default(), + ); return Place::bound(inferred_ty).with_qualifiers(all_qualifiers); } @@ -3014,6 +3018,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = infer_expression_type( db, index.expression(assign.value(&module)), + TypeContext::default(), ); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); @@ -3041,6 +3046,7 @@ impl<'db> ClassLiteral<'db> { let iterable_ty = infer_expression_type( db, index.expression(for_stmt.iterable(&module)), + TypeContext::default(), ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = @@ -3071,6 +3077,7 @@ impl<'db> ClassLiteral<'db> { let context_ty = infer_expression_type( db, index.expression(with_item.context_expr(&module)), + TypeContext::default(), ); let inferred_ty = if with_item.is_async() { context_ty.aenter(db) @@ -3104,6 +3111,7 @@ impl<'db> ClassLiteral<'db> { let iterable_ty = infer_expression_type( db, index.expression(comprehension.iterable(&module)), + TypeContext::default(), ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index a13f36e35a..3ccba3064d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -171,11 +171,21 @@ fn deferred_cycle_initial<'db>( /// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an /// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a /// type narrowing guard expression (e.g. if statement test node). -#[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn infer_expression_types<'db>( db: &'db dyn Db, expression: Expression<'db>, + tcx: TypeContext<'db>, +) -> &'db ExpressionInference<'db> { + infer_expression_types_impl(db, InferExpression::new(db, expression, tcx)) +} + +#[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] +fn infer_expression_types_impl<'db>( + db: &'db dyn Db, + input: InferExpression<'db>, ) -> ExpressionInference<'db> { + let (expression, tcx) = (input.expression(db), input.tcx(db)); + let file = expression.file(db); let module = parsed_module(db, file).load(db); let _span = tracing::trace_span!( @@ -188,8 +198,13 @@ pub(crate) fn infer_expression_types<'db>( let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index, &module) - .finish_expression() + TypeInferenceBuilder::new( + db, + InferenceRegion::Expression(expression, tcx), + index, + &module, + ) + .finish_expression() } /// How many fixpoint iterations to allow before falling back to Divergent type. @@ -199,11 +214,11 @@ fn expression_cycle_recover<'db>( db: &'db dyn Db, _value: &ExpressionInference<'db>, count: u32, - expression: Expression<'db>, + input: InferExpression<'db>, ) -> salsa::CycleRecoveryAction> { if count == ITERATIONS_BEFORE_FALLBACK { salsa::CycleRecoveryAction::Fallback(ExpressionInference::cycle_fallback( - expression.scope(db), + input.expression(db).scope(db), )) } else { salsa::CycleRecoveryAction::Iterate @@ -212,9 +227,9 @@ fn expression_cycle_recover<'db>( fn expression_cycle_initial<'db>( db: &'db dyn Db, - expression: Expression<'db>, + input: InferExpression<'db>, ) -> ExpressionInference<'db> { - ExpressionInference::cycle_initial(expression.scope(db)) + ExpressionInference::cycle_initial(input.expression(db).scope(db)) } /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query. @@ -225,9 +240,10 @@ fn expression_cycle_initial<'db>( pub(super) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, + tcx: TypeContext<'db>, parsed: &ParsedModuleRef, ) -> Type<'db> { - let inference = infer_expression_types(db, expression); + let inference = infer_expression_types(db, expression, tcx); inference.expression_type(expression.node_ref(db, parsed)) } @@ -238,34 +254,108 @@ pub(super) fn infer_same_file_expression_type<'db>( /// /// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same /// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`. -#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn infer_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, + tcx: TypeContext<'db>, ) -> Type<'db> { - let file = expression.file(db); + infer_expression_type_impl(db, InferExpression::new(db, expression, tcx)) +} + +#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] +fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> { + let file = input.expression(db).file(db); let module = parsed_module(db, file).load(db); // It's okay to call the "same file" version here because we're inside a salsa query. - infer_same_file_expression_type(db, expression, &module) + let inference = infer_expression_types_impl(db, input); + inference.expression_type(input.expression(db).node_ref(db, &module)) } fn single_expression_cycle_recover<'db>( _db: &'db dyn Db, _value: &Type<'db>, _count: u32, - _expression: Expression<'db>, + _input: InferExpression<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } fn single_expression_cycle_initial<'db>( _db: &'db dyn Db, - _expression: Expression<'db>, + _input: InferExpression<'db>, ) -> Type<'db> { Type::Never } +/// An `Expression` with an optional `TypeContext`. +/// +/// This is a Salsa supertype used as the input to `infer_expression_types` to avoid +/// interning an `ExpressionWithContext` unnecessarily when no type context is provided. +#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)] +enum InferExpression<'db> { + Bare(Expression<'db>), + WithContext(ExpressionWithContext<'db>), +} + +impl<'db> InferExpression<'db> { + fn new( + db: &'db dyn Db, + expression: Expression<'db>, + tcx: TypeContext<'db>, + ) -> InferExpression<'db> { + if tcx.annotation.is_some() { + InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx)) + } else { + // Drop the empty `TypeContext` to avoid the interning cost. + InferExpression::Bare(expression) + } + } + + fn expression(self, db: &'db dyn Db) -> Expression<'db> { + match self { + InferExpression::Bare(expression) => expression, + InferExpression::WithContext(expression_with_context) => { + expression_with_context.expression(db) + } + } + } + + fn tcx(self, db: &'db dyn Db) -> TypeContext<'db> { + match self { + InferExpression::Bare(_) => TypeContext::default(), + InferExpression::WithContext(expression_with_context) => { + expression_with_context.tcx(db) + } + } + } +} + +/// An `Expression` with a `TypeContext`. +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +struct ExpressionWithContext<'db> { + expression: Expression<'db>, + tcx: TypeContext<'db>, +} + +/// The type context for a given expression, namely the type annotation +/// in an annotated assignment. +/// +/// Knowing the outer type context when inferring an expression can enable +/// more precise inference results, aka "bidirectional type inference". +#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)] +pub(crate) struct TypeContext<'db> { + annotation: Option>, +} + +impl<'db> TypeContext<'db> { + pub(crate) fn new(annotation: Type<'db>) -> Self { + Self { + annotation: Some(annotation), + } + } +} + /// Returns the statically-known truthiness of a given expression. /// /// Returns [`Truthiness::Ambiguous`] in case any non-definitely bound places @@ -275,7 +365,7 @@ pub(crate) fn static_expression_truthiness<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> Truthiness { - let inference = infer_expression_types(db, expression); + let inference = infer_expression_types_impl(db, InferExpression::Bare(expression)); if !inference.all_places_definitely_bound() { return Truthiness::Ambiguous; @@ -366,7 +456,7 @@ pub(crate) fn nearest_enclosing_class<'db>( #[derive(Copy, Clone, Debug)] pub(crate) enum InferenceRegion<'db> { /// infer types for a standalone [`Expression`] - Expression(Expression<'db>), + Expression(Expression<'db>, TypeContext<'db>), /// infer types for a [`Definition`] Definition(Definition<'db>), /// infer deferred types for a [`Definition`] @@ -378,7 +468,7 @@ pub(crate) enum InferenceRegion<'db> { impl<'db> InferenceRegion<'db> { fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { match self { - InferenceRegion::Expression(expression) => expression.scope(db), + InferenceRegion::Expression(expression, _) => expression.scope(db), InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition) => { definition.scope(db) } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 0a5561dc6d..1f01021651 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -90,8 +90,8 @@ use crate::types::{ IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, - TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, - TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, + TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, + TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -440,7 +440,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { InferenceRegion::Scope(scope) => self.infer_region_scope(scope), InferenceRegion::Definition(definition) => self.infer_region_definition(definition), InferenceRegion::Deferred(definition) => self.infer_region_deferred(definition), - InferenceRegion::Expression(expression) => self.infer_region_expression(expression), + InferenceRegion::Expression(expression, tcx) => { + self.infer_region_expression(expression, tcx); + } } } @@ -1221,10 +1223,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - fn infer_region_expression(&mut self, expression: Expression<'db>) { + fn infer_region_expression(&mut self, expression: Expression<'db>, tcx: TypeContext<'db>) { match expression.kind(self.db()) { ExpressionKind::Normal => { - self.infer_expression_impl(expression.node_ref(self.db(), self.module())); + self.infer_expression_impl(expression.node_ref(self.db(), self.module()), tcx); } ExpressionKind::TypeExpression => { self.infer_type_expression(expression.node_ref(self.db(), self.module())); @@ -1435,7 +1437,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let declared_ty = if resolved_place.is_unbound() && !place_table.place(place_id).is_symbol() { if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { - let value_type = self.infer_maybe_standalone_expression(value); + let value_type = + self.infer_maybe_standalone_expression(value, TypeContext::default()); if let Place::Type(ty, Boundness::Bound) = value_type.member(db, attr).place { // TODO: also consider qualifiers on the attribute ty @@ -1448,8 +1451,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }, ) = node { - let value_ty = self.infer_expression(value); - let slice_ty = self.infer_expression(slice); + let value_ty = self.infer_expression(value, TypeContext::default()); + let slice_ty = self.infer_expression(slice, TypeContext::default()); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx) } else { unwrap_declared_ty() @@ -1517,9 +1520,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // In the following cases, the bound type may not be the same as the RHS value type. if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { - let value_ty = self - .try_expression_type(value) - .unwrap_or_else(|| self.infer_maybe_standalone_expression(value)); + let value_ty = self.try_expression_type(value).unwrap_or_else(|| { + self.infer_maybe_standalone_expression(value, TypeContext::default()) + }); // If the member is a data descriptor, the RHS value may differ from the value actually assigned. if value_ty .class_member(db, attr.id.clone()) @@ -1532,7 +1535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node { let value_ty = self .try_expression_type(value) - .unwrap_or_else(|| self.infer_expression(value)); + .unwrap_or_else(|| self.infer_expression(value, TypeContext::default())); if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) { bound_ty = declared_ty; @@ -1719,7 +1722,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { std::mem::replace(&mut self.deferred_state, in_stub.into()); let mut call_arguments = CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| { - let ty = self.infer_expression(splatted_value); + let ty = self.infer_expression(splatted_value, TypeContext::default()); self.store_expression_type(argument, ty); ty }); @@ -1988,7 +1991,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }) => { // If this is a call expression, we would have added a `ReturnsNever` constraint, // meaning this will be a standalone expression. - self.infer_maybe_standalone_expression(value); + self.infer_maybe_standalone_expression(value, TypeContext::default()); } ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement), ast::Stmt::Try(try_statement) => self.infer_try_statement(try_statement), @@ -2085,7 +2088,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .iter_non_variadic_params() .filter_map(|param| param.default.as_deref()) { - self.infer_expression(default); + self.infer_expression(default, TypeContext::default()); } // If there are type params, parameters and returns are evaluated in that scope, that is, in @@ -2517,7 +2520,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // and we don't need to run inference here if type_params.is_none() { for keyword in class_node.keywords() { - self.infer_expression(&keyword.value); + self.infer_expression(&keyword.value, TypeContext::default()); } // Inference of bases deferred in stubs, or if any are string literals. @@ -2527,7 +2530,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); for base in class_node.bases() { - self.infer_expression(base); + self.infer_expression(base, TypeContext::default()); } self.typevar_binding_context = previous_typevar_binding_context; } @@ -2552,9 +2555,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); for base in class.bases() { if self.in_stub() { - self.infer_expression_with_state(base, DeferredExpressionState::Deferred); + self.infer_expression_with_state( + base, + TypeContext::default(), + DeferredExpressionState::Deferred, + ); } else { - self.infer_expression(base); + self.infer_expression(base, TypeContext::default()); } } self.typevar_binding_context = previous_typevar_binding_context; @@ -2565,7 +2572,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { type_alias: &ast::StmtTypeAlias, definition: Definition<'db>, ) { - self.infer_expression(&type_alias.name); + self.infer_expression(&type_alias.name, TypeContext::default()); let rhs_scope = self .index @@ -2597,7 +2604,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { elif_else_clauses, } = if_statement; - let test_ty = self.infer_standalone_expression(test); + let test_ty = self.infer_standalone_expression(test, TypeContext::default()); if let Err(err) = test_ty.try_bool(self.db()) { err.report_diagnostic(&self.context, &**test); @@ -2614,7 +2621,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = clause; if let Some(test) = &test { - let test_ty = self.infer_standalone_expression(test); + let test_ty = self.infer_standalone_expression(test, TypeContext::default()); if let Err(err) = test_ty.try_bool(self.db()) { err.report_diagnostic(&self.context, test); @@ -2681,15 +2688,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `with not_context_manager as a.x: ... builder - .infer_standalone_expression(context_expr) + .infer_standalone_expression(context_expr, TypeContext::default()) .enter(builder.db()) }); } else { // Call into the context expression inference to validate that it evaluates // to a valid context manager. - let context_expression_ty = self.infer_expression(&item.context_expr); + let context_expression_ty = + self.infer_expression(&item.context_expr, TypeContext::default()); self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async); - self.infer_optional_expression(target); + self.infer_optional_expression(target, TypeContext::default()); } } @@ -2713,7 +2721,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { unpacked.expression_type(target) } TargetKind::Single => { - let context_expr_ty = self.infer_standalone_expression(context_expr); + let context_expr_ty = + self.infer_standalone_expression(context_expr, TypeContext::default()); self.infer_context_expression(context_expr, context_expr_ty, with_item.is_async()) } }; @@ -2755,7 +2764,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_exception(&mut self, node: Option<&ast::Expr>, is_star: bool) -> Type<'db> { // If there is no handled exception, it's invalid syntax; // a diagnostic will have already been emitted - let node_ty = node.map_or(Type::unknown(), |ty| self.infer_expression(ty)); + let node_ty = node.map_or(Type::unknown(), |ty| { + self.infer_expression(ty, TypeContext::default()) + }); let type_base_exception = KnownClass::BaseException.to_subclass_of(self.db()); // If it's an `except*` handler, this won't actually be the type of the bound symbol; @@ -2947,7 +2958,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { name: _, default, } = node; - self.infer_optional_expression(default.as_deref()); + self.infer_optional_expression(default.as_deref(), TypeContext::default()); let pep_695_todo = Type::Dynamic(DynamicType::TodoPEP695ParamSpec); self.add_declaration_with_binding( node.into(), @@ -2967,7 +2978,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { name: _, default, } = node; - self.infer_optional_expression(default.as_deref()); + self.infer_optional_expression(default.as_deref(), TypeContext::default()); let pep_695_todo = todo_type!("PEP-695 TypeVarTuple definition types"); self.add_declaration_with_binding( node.into(), @@ -2984,7 +2995,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { cases, } = match_statement; - self.infer_standalone_expression(subject); + self.infer_standalone_expression(subject, TypeContext::default()); for case in cases { let ast::MatchCase { @@ -2997,7 +3008,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_match_pattern(pattern); if let Some(guard) = guard.as_deref() { - let guard_ty = self.infer_standalone_expression(guard); + let guard_ty = self.infer_standalone_expression(guard, TypeContext::default()); if let Err(err) = guard_ty.try_bool(self.db()) { err.report_diagnostic(&self.context, guard); @@ -3052,7 +3063,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 match pattern { ast::Pattern::MatchValue(match_value) => { - self.infer_standalone_expression(&match_value.value); + self.infer_standalone_expression(&match_value.value, TypeContext::default()); } ast::Pattern::MatchClass(match_class) => { let ast::PatternMatchClass { @@ -3067,7 +3078,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for keyword in &arguments.keywords { self.infer_nested_match_pattern(&keyword.pattern); } - self.infer_standalone_expression(cls); + self.infer_standalone_expression(cls, TypeContext::default()); } ast::Pattern::MatchOr(match_or) => { for pattern in &match_or.patterns { @@ -3083,7 +3094,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) { match pattern { ast::Pattern::MatchValue(match_value) => { - self.infer_maybe_standalone_expression(&match_value.value); + self.infer_maybe_standalone_expression(&match_value.value, TypeContext::default()); } ast::Pattern::MatchSequence(match_sequence) => { for pattern in &match_sequence.patterns { @@ -3099,7 +3110,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { rest: _, } = match_mapping; for key in keys { - self.infer_expression(key); + self.infer_expression(key, TypeContext::default()); } for pattern in patterns { self.infer_nested_match_pattern(pattern); @@ -3118,7 +3129,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for keyword in &arguments.keywords { self.infer_nested_match_pattern(&keyword.pattern); } - self.infer_maybe_standalone_expression(cls); + self.infer_maybe_standalone_expression(cls, TypeContext::default()); } ast::Pattern::MatchAs(match_as) => { if let Some(pattern) = &match_as.pattern { @@ -3144,7 +3155,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for target in targets { self.infer_target(target, value, |builder, value_expr| { - builder.infer_standalone_expression(value_expr) + builder.infer_standalone_expression(value_expr, TypeContext::default()) }); } } @@ -3184,8 +3195,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ctx: _, } = target; - let value_ty = self.infer_expression(value); - let slice_ty = self.infer_expression(slice); + let value_ty = self.infer_expression(value, TypeContext::default()); + let slice_ty = self.infer_expression(slice, TypeContext::default()); let db = self.db(); let context = &self.context; @@ -3878,7 +3889,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) => { self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown())); - let object_ty = self.infer_expression(object); + let object_ty = self.infer_expression(object, TypeContext::default()); if let Some(assigned_ty) = assigned_ty { self.validate_attribute_assignment( @@ -3899,7 +3910,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } _ => { // TODO: Remove this once we handle all possible assignment targets. - self.infer_expression(target); + self.infer_expression(target, TypeContext::default()); } } } @@ -3924,7 +3935,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { unpacked.expression_type(target) } TargetKind::Single => { - let value_ty = self.infer_standalone_expression(value); + let value_ty = self.infer_standalone_expression(value, TypeContext::default()); // `TYPE_CHECKING` is a special variable that should only be assigned `False` // at runtime, but is always considered `True` in type checking. @@ -3988,12 +3999,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if let Some(value) = value { - self.infer_maybe_standalone_expression(value); + self.infer_maybe_standalone_expression( + value, + TypeContext::new(annotated.inner_type()), + ); } // If we have an annotated assignment like `self.attr: int = 1`, we still need to // do type inference on the `self.attr` target to get types for all sub-expressions. - self.infer_expression(target); + self.infer_expression(target, TypeContext::default()); // But here we explicitly overwrite the type for the overall `self.attr` node with // the annotated type. We do no use `store_expression_type` here, because it checks @@ -4080,7 +4094,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { debug_assert!(PlaceExpr::try_from_expr(target).is_some()); if let Some(value) = value { - let inferred_ty = self.infer_maybe_standalone_expression(value); + let inferred_ty = self + .infer_maybe_standalone_expression(value, TypeContext::new(declared.inner_type())); let mut inferred_ty = if target .as_name_expr() .is_some_and(|name| &name.id == "TYPE_CHECKING") @@ -4236,9 +4251,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.store_expression_type(target, previous_value); previous_value } - _ => self.infer_expression(target), + _ => self.infer_expression(target, TypeContext::default()), }; - let value_type = self.infer_expression(value); + let value_type = self.infer_expression(value, TypeContext::default()); self.infer_augmented_op(assignment, target_type, value_type) } @@ -4263,7 +4278,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `for a.x in not_iterable: ... builder - .infer_standalone_expression(iter_expr) + .infer_standalone_expression(iter_expr, TypeContext::default()) .iterate(builder.db()) .homogeneous_element_type(builder.db()) }); @@ -4290,7 +4305,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { unpacked.expression_type(target) } TargetKind::Single => { - let iterable_type = self.infer_standalone_expression(iterable); + let iterable_type = + self.infer_standalone_expression(iterable, TypeContext::default()); iterable_type .try_iterate_with_mode( @@ -4318,7 +4334,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { orelse, } = while_statement; - let test_ty = self.infer_standalone_expression(test); + let test_ty = self.infer_standalone_expression(test, TypeContext::default()); if let Err(err) = test_ty.try_bool(self.db()) { err.report_diagnostic(&self.context, &**test); @@ -4500,13 +4516,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { msg, } = assert; - let test_ty = self.infer_standalone_expression(test); + let test_ty = self.infer_standalone_expression(test, TypeContext::default()); if let Err(err) = test_ty.try_bool(self.db()) { err.report_diagnostic(&self.context, &**test); } - self.infer_optional_expression(msg.as_deref()); + self.infer_optional_expression(msg.as_deref(), TypeContext::default()); } fn infer_raise_statement(&mut self, raise: &ast::StmtRaise) { @@ -4526,7 +4542,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { UnionType::from_elements(self.db(), [can_be_raised, Type::none(self.db())]); if let Some(raised) = exc { - let raised_type = self.infer_expression(raised); + let raised_type = self.infer_expression(raised, TypeContext::default()); if !raised_type.is_assignable_to(self.db(), can_be_raised) { report_invalid_exception_raised(&self.context, raised, raised_type); @@ -4534,7 +4550,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if let Some(cause) = cause { - let cause_type = self.infer_expression(cause); + let cause_type = self.infer_expression(cause, TypeContext::default()); if !cause_type.is_assignable_to(self.db(), can_be_exception_cause) { report_invalid_exception_cause(&self.context, cause, cause_type); @@ -4740,7 +4756,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { - if let Some(ty) = self.infer_optional_expression(ret.value.as_deref()) { + if let Some(ty) = + self.infer_optional_expression(ret.value.as_deref(), TypeContext::default()) + { let range = ret .value .as_ref() @@ -4758,7 +4776,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { targets, } = delete; for target in targets { - self.infer_expression(target); + self.infer_expression(target, TypeContext::default()); } } @@ -4898,7 +4916,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { expression, } = decorator; - self.infer_expression(expression) + self.infer_expression(expression, TypeContext::default()) } fn infer_argument_types<'a>( @@ -4920,7 +4938,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::ArgOrKeyword::Arg(arg) => arg, ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, }; - let ty = self.infer_argument_type(argument, form); + let ty = self.infer_argument_type(argument, form, TypeContext::default()); *argument_type = Some(ty); } } @@ -4929,58 +4947,73 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &mut self, ast_argument: &ast::Expr, form: Option, + tcx: TypeContext<'db>, ) -> Type<'db> { match form { - None | Some(ParameterForm::Value) => self.infer_expression(ast_argument), + None | Some(ParameterForm::Value) => self.infer_expression(ast_argument, tcx), Some(ParameterForm::Type) => self.infer_type_expression(ast_argument), } } - fn infer_optional_expression(&mut self, expression: Option<&ast::Expr>) -> Option> { - expression.map(|expr| self.infer_expression(expr)) + fn infer_optional_expression( + &mut self, + expression: Option<&ast::Expr>, + tcx: TypeContext<'db>, + ) -> Option> { + expression.map(|expr| self.infer_expression(expr, tcx)) } #[track_caller] - fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_expression(&mut self, expression: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> { debug_assert!( !self.index.is_standalone_expression(expression), "Calling `self.infer_expression` on a standalone-expression is not allowed because it can lead to double-inference. Use `self.infer_standalone_expression` instead." ); - self.infer_expression_impl(expression) + self.infer_expression_impl(expression, tcx) } fn infer_expression_with_state( &mut self, expression: &ast::Expr, + tcx: TypeContext<'db>, state: DeferredExpressionState, ) -> Type<'db> { let previous_deferred_state = std::mem::replace(&mut self.deferred_state, state); - let ty = self.infer_expression(expression); + let ty = self.infer_expression(expression, tcx); self.deferred_state = previous_deferred_state; ty } - fn infer_maybe_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_maybe_standalone_expression( + &mut self, + expression: &ast::Expr, + tcx: TypeContext<'db>, + ) -> Type<'db> { if let Some(standalone_expression) = self.index.try_expression(expression) { - self.infer_standalone_expression_impl(expression, standalone_expression) + self.infer_standalone_expression_impl(expression, standalone_expression, tcx) } else { - self.infer_expression(expression) + self.infer_expression(expression, tcx) } } #[track_caller] - fn infer_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_standalone_expression( + &mut self, + expression: &ast::Expr, + tcx: TypeContext<'db>, + ) -> Type<'db> { let standalone_expression = self.index.expression(expression); - self.infer_standalone_expression_impl(expression, standalone_expression) + self.infer_standalone_expression_impl(expression, standalone_expression, tcx) } fn infer_standalone_expression_impl( &mut self, expression: &ast::Expr, standalone_expression: Expression<'db>, + tcx: TypeContext<'db>, ) -> Type<'db> { - let types = infer_expression_types(self.db(), standalone_expression); + let types = infer_expression_types(self.db(), standalone_expression, tcx); self.extend_expression(types); // Instead of calling `self.expression_type(expr)` after extending here, we get @@ -4990,7 +5023,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { types.expression_type(expression) } - fn infer_expression_impl(&mut self, expression: &ast::Expr) -> Type<'db> { + fn infer_expression_impl( + &mut self, + expression: &ast::Expr, + tcx: TypeContext<'db>, + ) -> Type<'db> { let ty = match expression { ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _, @@ -5005,10 +5042,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::FString(fstring) => self.infer_fstring_expression(fstring), ast::Expr::TString(tstring) => self.infer_tstring_expression(tstring), ast::Expr::EllipsisLiteral(literal) => self.infer_ellipsis_literal_expression(literal), - ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple), - ast::Expr::List(list) => self.infer_list_expression(list), - ast::Expr::Set(set) => self.infer_set_expression(set), - ast::Expr::Dict(dict) => self.infer_dict_expression(dict), + ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple, tcx), + ast::Expr::List(list) => self.infer_list_expression(list, tcx), + ast::Expr::Set(set) => self.infer_set_expression(set, tcx), + ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx), ast::Expr::Generator(generator) => self.infer_generator_expression(generator), ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp), ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp), @@ -5024,7 +5061,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::Named(named) => self.infer_named_expression(named), ast::Expr::If(if_expression) => self.infer_if_expression(if_expression), ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression), - ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression), + ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx), ast::Expr::Starred(starred) => self.infer_starred_expression(starred), ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression), ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from), @@ -5038,7 +5075,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ty } - fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { if self.deferred_state.in_string_annotation() { // Avoid storing the type of expressions that are part of a string annotation because @@ -5120,11 +5156,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { conversion, format_spec, } = expression; - let ty = self.infer_expression(expression); + let ty = self.infer_expression(expression, TypeContext::default()); if let Some(format_spec) = format_spec { for element in format_spec.elements.interpolations() { - self.infer_expression(&element.expression); + self.infer_expression( + &element.expression, + TypeContext::default(), + ); } } @@ -5166,10 +5205,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { format_spec, .. } = tstring_interpolation_element; - self.infer_expression(expression); + self.infer_expression(expression, TypeContext::default()); if let Some(format_spec) = format_spec { for element in format_spec.elements.interpolations() { - self.infer_expression(&element.expression); + self.infer_expression(&element.expression, TypeContext::default()); } } } @@ -5187,7 +5226,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { KnownClass::EllipsisType.to_instance(self.db()) } - fn infer_tuple_expression(&mut self, tuple: &ast::ExprTuple) -> Type<'db> { + fn infer_tuple_expression( + &mut self, + tuple: &ast::ExprTuple, + _tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprTuple { range: _, node_index: _, @@ -5199,7 +5242,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let db = self.db(); let divergent = Type::divergent(self.scope()); let element_types = elts.iter().map(|element| { - let element_type = self.infer_expression(element); + // TODO: Use the type context for more precise inference. + let element_type = self.infer_expression(element, TypeContext::default()); if element_type.has_divergent_type(self.db(), divergent) { divergent } else { @@ -5210,7 +5254,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::heterogeneous_tuple(db, element_types) } - fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> { + fn infer_list_expression(&mut self, list: &ast::ExprList, _tcx: TypeContext<'db>) -> Type<'db> { let ast::ExprList { range: _, node_index: _, @@ -5218,38 +5262,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ctx: _, } = list; + // TODO: Use the type context for more precise inference. for elt in elts { - self.infer_expression(elt); + self.infer_expression(elt, TypeContext::default()); } KnownClass::List .to_specialized_instance(self.db(), [todo_type!("list literal element type")]) } - fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { + fn infer_set_expression(&mut self, set: &ast::ExprSet, _tcx: TypeContext<'db>) -> Type<'db> { let ast::ExprSet { range: _, node_index: _, elts, } = set; + // TODO: Use the type context for more precise inference. for elt in elts { - self.infer_expression(elt); + self.infer_expression(elt, TypeContext::default()); } KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")]) } - fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { + fn infer_dict_expression(&mut self, dict: &ast::ExprDict, _tcx: TypeContext<'db>) -> Type<'db> { let ast::ExprDict { range: _, node_index: _, items, } = dict; + // TODO: Use the type context for more precise inference. for item in items { - self.infer_optional_expression(item.key.as_ref()); - self.infer_expression(&item.value); + self.infer_optional_expression(item.key.as_ref(), TypeContext::default()); + self.infer_expression(&item.value, TypeContext::default()); } KnownClass::Dict.to_specialized_instance( @@ -5260,14 +5307,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ], ) } - /// Infer the type of the `iter` expression of the first comprehension. fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) { let mut comprehensions_iter = comprehensions.iter(); let Some(first_comprehension) = comprehensions_iter.next() else { unreachable!("Comprehension must contain at least one generator"); }; - self.infer_standalone_expression(&first_comprehension.iter); + self.infer_standalone_expression(&first_comprehension.iter, TypeContext::default()); } fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> { @@ -5348,7 +5394,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { parenthesized: _, } = generator; - self.infer_expression(elt); + self.infer_expression(elt, TypeContext::default()); self.infer_comprehensions(generators); } @@ -5360,7 +5406,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { generators, } = listcomp; - self.infer_expression(elt); + self.infer_expression(elt, TypeContext::default()); self.infer_comprehensions(generators); } @@ -5373,8 +5419,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { generators, } = dictcomp; - self.infer_expression(key); - self.infer_expression(value); + self.infer_expression(key, TypeContext::default()); + self.infer_expression(value, TypeContext::default()); self.infer_comprehensions(generators); } @@ -5386,7 +5432,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { generators, } = setcomp; - self.infer_expression(elt); + self.infer_expression(elt, TypeContext::default()); self.infer_comprehensions(generators); } @@ -5419,16 +5465,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { infer_same_file_expression_type( builder.db(), builder.index.expression(iter_expr), + TypeContext::default(), builder.module(), ) } else { - builder.infer_standalone_expression(iter_expr) + builder.infer_standalone_expression(iter_expr, TypeContext::default()) } .iterate(builder.db()) .homogeneous_element_type(builder.db()) }); for expr in ifs { - self.infer_standalone_expression(expr); + self.infer_standalone_expression(expr, TypeContext::default()); } } @@ -5442,7 +5489,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut infer_iterable_type = || { let expression = self.index.expression(iterable); - let result = infer_expression_types(self.db(), expression); + let result = infer_expression_types(self.db(), expression, TypeContext::default()); // Two things are different if it's the first comprehension: // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, @@ -5496,8 +5543,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { result.binding_type(definition) } else { // For syntactically invalid targets, we still need to run type inference: - self.infer_expression(&named.target); - self.infer_expression(&named.value); + self.infer_expression(&named.target, TypeContext::default()); + self.infer_expression(&named.value, TypeContext::default()); Type::unknown() } } @@ -5514,8 +5561,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { value, } = named; - let value_ty = self.infer_expression(value); - self.infer_expression(target); + let value_ty = self.infer_expression(value, TypeContext::default()); + self.infer_expression(target, TypeContext::default()); self.add_binding(named.into(), definition, value_ty); @@ -5531,9 +5578,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { orelse, } = if_expression; - let test_ty = self.infer_standalone_expression(test); - let body_ty = self.infer_expression(body); - let orelse_ty = self.infer_expression(orelse); + let test_ty = self.infer_standalone_expression(test, TypeContext::default()); + let body_ty = self.infer_expression(body, TypeContext::default()); + let orelse_ty = self.infer_expression(orelse, TypeContext::default()); match test_ty.try_bool(self.db()).unwrap_or_else(|err| { err.report_diagnostic(&self.context, &**test); @@ -5546,7 +5593,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) { - self.infer_expression(&lambda_expression.body); + self.infer_expression(&lambda_expression.body, TypeContext::default()); } fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> { @@ -5564,7 +5611,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .map(|param| { let mut parameter = Parameter::positional_only(Some(param.name().id.clone())); if let Some(default) = param.default() { - parameter = parameter.with_default_type(self.infer_expression(default)); + parameter = parameter.with_default_type( + self.infer_expression(default, TypeContext::default()), + ); } parameter }) @@ -5575,7 +5624,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .map(|param| { let mut parameter = Parameter::positional_or_keyword(param.name().id.clone()); if let Some(default) = param.default() { - parameter = parameter.with_default_type(self.infer_expression(default)); + parameter = parameter.with_default_type( + self.infer_expression(default, TypeContext::default()), + ); } parameter }) @@ -5590,7 +5641,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .map(|param| { let mut parameter = Parameter::keyword_only(param.name().id.clone()); if let Some(default) = param.default() { - parameter = parameter.with_default_type(self.infer_expression(default)); + parameter = parameter.with_default_type( + self.infer_expression(default, TypeContext::default()), + ); } parameter }) @@ -5618,7 +5671,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { CallableType::function_like(self.db(), Signature::new(parameters, Some(Type::unknown()))) } - fn infer_call_expression(&mut self, call_expression: &ast::ExprCall) -> Type<'db> { + fn infer_call_expression( + &mut self, + call_expression: &ast::ExprCall, + _tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprCall { range: _, node_index: _, @@ -5631,12 +5688,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // are assignable to any parameter annotations. let mut call_arguments = CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| { - let ty = self.infer_expression(splatted_value); + let ty = self.infer_expression(splatted_value, TypeContext::default()); self.store_expression_type(argument, ty); ty }); - let callable_type = self.infer_maybe_standalone_expression(func); + // TODO: Use the type context for more precise inference. + let callable_type = self.infer_maybe_standalone_expression(func, TypeContext::default()); // Special handling for `TypedDict` method calls if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() { @@ -5881,7 +5939,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ctx: _, } = starred; - let iterable_type = self.infer_expression(value); + let iterable_type = self.infer_expression(value, TypeContext::default()); iterable_type .try_iterate(self.db()) .map(|tuple| tuple.homogeneous_element_type(self.db())) @@ -5900,7 +5958,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { node_index: _, value, } = yield_expression; - self.infer_optional_expression(value.as_deref()); + self.infer_optional_expression(value.as_deref(), TypeContext::default()); todo_type!("yield expressions") } @@ -5911,7 +5969,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { value, } = yield_from; - let iterable_type = self.infer_expression(value); + let iterable_type = self.infer_expression(value, TypeContext::default()); iterable_type .try_iterate(self.db()) .map(|tuple| tuple.homogeneous_element_type(self.db())) @@ -5931,7 +5989,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { node_index: _, value, } = await_expression; - let expr_type = self.infer_expression(value); + let expr_type = self.infer_expression(value, TypeContext::default()); expr_type.try_await(self.db()).unwrap_or_else(|err| { err.report_diagnostic(&self.context, expr_type, value.as_ref().into()); Type::unknown() @@ -6576,7 +6634,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { let ast::ExprAttribute { value, attr, .. } = attribute; - let value_type = self.infer_maybe_standalone_expression(value); + let value_type = self.infer_maybe_standalone_expression(value, TypeContext::default()); let db = self.db(); let mut constraint_keys = vec![]; @@ -6687,7 +6745,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match ctx { ExprContext::Load => self.infer_attribute_load(attribute), ExprContext::Store => { - self.infer_expression(value); + self.infer_expression(value, TypeContext::default()); Type::Never } ExprContext::Del => { @@ -6695,7 +6753,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::Never } ExprContext::Invalid => { - self.infer_expression(value); + self.infer_expression(value, TypeContext::default()); Type::unknown() } } @@ -6709,7 +6767,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { operand, } = unary; - let operand_type = self.infer_expression(operand); + let operand_type = self.infer_expression(operand, TypeContext::default()); self.infer_unary_expression_type(*op, operand_type, unary) } @@ -6830,8 +6888,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { node_index: _, } = binary; - let left_ty = self.infer_expression(left); - let right_ty = self.infer_expression(right); + let left_ty = self.infer_expression(left, TypeContext::default()); + let right_ty = self.infer_expression(right, TypeContext::default()); self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op) .unwrap_or_else(|| { @@ -7276,9 +7334,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { values.iter().enumerate(), |builder, (index, value)| { let ty = if index == values.len() - 1 { - builder.infer_expression(value) + builder.infer_expression(value, TypeContext::default()) } else { - builder.infer_standalone_expression(value) + builder.infer_standalone_expression(value, TypeContext::default()) }; (ty, value.range()) @@ -7359,7 +7417,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { comparators, } = compare; - self.infer_expression(left); + self.infer_expression(left, TypeContext::default()); // https://docs.python.org/3/reference/expressions.html#comparisons // > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison @@ -7376,7 +7434,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .zip(ops), |builder, ((left, right), op)| { let left_ty = builder.expression_type(left); - let right_ty = builder.infer_expression(right); + let right_ty = builder.infer_expression(right, TypeContext::default()); let range = TextRange::new(left.start(), right.end()); @@ -8143,8 +8201,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match ctx { ExprContext::Load => self.infer_subscript_load(subscript), ExprContext::Store => { - let value_ty = self.infer_expression(value); - let slice_ty = self.infer_expression(slice); + let value_ty = self.infer_expression(value, TypeContext::default()); + let slice_ty = self.infer_expression(slice, TypeContext::default()); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); Type::Never } @@ -8153,8 +8211,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::Never } ExprContext::Invalid => { - let value_ty = self.infer_expression(value); - let slice_ty = self.infer_expression(slice); + let value_ty = self.infer_expression(value, TypeContext::default()); + let slice_ty = self.infer_expression(slice, TypeContext::default()); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); Type::unknown() } @@ -8169,7 +8227,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { slice, ctx, } = subscript; - let value_ty = self.infer_expression(value); + let value_ty = self.infer_expression(value, TypeContext::default()); let mut constraint_keys = vec![]; // If `value` is a valid reference, we attempt type narrowing by assignment. @@ -8183,7 +8241,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Place::Type(ty, Boundness::Bound) = place.place { // Even if we can obtain the subscript type based on the assignments, we still perform default type inference // (to store the expression type and to report errors). - let slice_ty = self.infer_expression(slice); + let slice_ty = self.infer_expression(slice, TypeContext::default()); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); return ty; } @@ -8228,7 +8286,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let slice_ty = self.infer_expression(slice); + let slice_ty = self.infer_expression(slice, TypeContext::default()); let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys) } @@ -8767,9 +8825,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { step, } = slice; - let ty_lower = self.infer_optional_expression(lower.as_deref()); - let ty_upper = self.infer_optional_expression(upper.as_deref()); - let ty_step = self.infer_optional_expression(step.as_deref()); + let ty_lower = self.infer_optional_expression(lower.as_deref(), TypeContext::default()); + let ty_upper = self.infer_optional_expression(upper.as_deref(), TypeContext::default()); + let ty_step = self.infer_optional_expression(step.as_deref(), TypeContext::default()); let type_to_slice_argument = |ty: Option>| match ty { Some(ty @ (Type::IntLiteral(_) | Type::BooleanLiteral(_))) => SliceArg::Arg(ty), diff --git a/crates/ty_python_semantic/src/types/infer/builder/annotation_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/annotation_expression.rs index facf8c8fa6..3316808579 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/annotation_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/annotation_expression.rs @@ -6,7 +6,7 @@ use crate::types::string_annotation::{ BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation, }; use crate::types::{ - KnownClass, SpecialFormType, Type, TypeAndQualifiers, TypeQualifiers, todo_type, + KnownClass, SpecialFormType, Type, TypeAndQualifiers, TypeContext, TypeQualifiers, todo_type, }; /// Annotation expressions. @@ -122,7 +122,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { }, ast::Expr::Subscript(subscript @ ast::ExprSubscript { value, slice, .. }) => { - let value_ty = self.infer_expression(value); + let value_ty = self.infer_expression(value, TypeContext::default()); let slice = &**slice; @@ -141,7 +141,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { if let [inner_annotation, metadata @ ..] = &arguments[..] { for element in metadata { - self.infer_expression(element); + self.infer_expression(element, TypeContext::default()); } let inner_annotation_ty = @@ -151,7 +151,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { inner_annotation_ty } else { for argument in arguments { - self.infer_expression(argument); + self.infer_expression(argument, TypeContext::default()); } self.store_expression_type(slice, Type::unknown()); TypeAndQualifiers::unknown() diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 7ff6fb2279..71976256a5 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -14,7 +14,7 @@ use crate::types::visitor::any_over_type; use crate::types::{ CallableType, DynamicType, IntersectionBuilder, KnownClass, KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType, Type, - TypeAliasType, TypeIsType, UnionBuilder, UnionType, todo_type, + TypeAliasType, TypeContext, TypeIsType, UnionBuilder, UnionType, todo_type, }; /// Type expressions @@ -114,7 +114,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { node_index: _, } = subscript; - let value_ty = self.infer_expression(value); + let value_ty = self.infer_expression(value, TypeContext::default()); self.infer_subscript_type_expression_no_store(subscript, slice, value_ty) } @@ -324,7 +324,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::Dict(dict) => { - self.infer_dict_expression(dict); + self.infer_dict_expression(dict, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("Dict literals are not allowed in type expressions"), @@ -333,7 +333,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::Set(set) => { - self.infer_set_expression(set); + self.infer_set_expression(set, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("Set literals are not allowed in type expressions"), @@ -414,7 +414,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::Call(call_expr) => { - self.infer_call_expression(call_expr); + self.infer_call_expression(call_expr, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("Function calls are not allowed in type expressions"), @@ -544,7 +544,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { let value_ty = if builder.deferred_state.in_string_annotation() { // Using `.expression_type` does not work in string annotations, because // we do not store types for sub-expressions. Re-infer the type here. - builder.infer_expression(value) + builder.infer_expression(value, TypeContext::default()) } else { builder.expression_type(value) }; @@ -559,7 +559,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { match tuple_slice { ast::Expr::Tuple(elements) => { if let [element, ellipsis @ ast::Expr::EllipsisLiteral(_)] = &*elements.elts { - self.infer_expression(ellipsis); + self.infer_expression(ellipsis, TypeContext::default()); let result = TupleType::homogeneous(self.db(), self.infer_type_expression(element)); self.store_expression_type(tuple_slice, Type::tuple(Some(result))); @@ -617,7 +617,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { fn infer_subclass_of_type_expression(&mut self, slice: &ast::Expr) -> Type<'db> { match slice { ast::Expr::Name(_) | ast::Expr::Attribute(_) => { - let name_ty = self.infer_expression(slice); + let name_ty = self.infer_expression(slice, TypeContext::default()); match name_ty { Type::ClassLiteral(class_literal) => { if class_literal.is_protocol(self.db()) { @@ -663,7 +663,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { slice: parameters, .. }) => { - let parameters_ty = match self.infer_expression(value) { + let parameters_ty = match self.infer_expression(value, TypeContext::default()) { Type::SpecialForm(SpecialFormType::Union) => match &**parameters { ast::Expr::Tuple(tuple) => { let ty = UnionType::from_elements_leave_aliases( @@ -713,7 +713,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // `infer_expression` (instead of `infer_type_expression`) here to avoid // false-positive `invalid-type-form` diagnostics (`1` is not a valid type // expression). - self.infer_expression(&subscript.slice); + self.infer_expression(&subscript.slice, TypeContext::default()); Type::unknown() } Type::SpecialForm(special_form) => { @@ -912,14 +912,14 @@ impl<'db> TypeInferenceBuilder<'db, '_> { let [type_expr, metadata @ ..] = &arguments[..] else { for argument in arguments { - self.infer_expression(argument); + self.infer_expression(argument, TypeContext::default()); } self.store_expression_type(arguments_slice, Type::unknown()); return Type::unknown(); }; for element in metadata { - self.infer_expression(element); + self.infer_expression(element, TypeContext::default()); } let ty = self.infer_type_expression(type_expr); @@ -1107,7 +1107,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { let num_arguments = arguments.len(); let type_of_type = if num_arguments == 1 { // N.B. This uses `infer_expression` rather than `infer_type_expression` - self.infer_expression(&arguments[0]) + self.infer_expression(&arguments[0], TypeContext::default()) } else { for argument in arguments { self.infer_type_expression(argument); @@ -1137,7 +1137,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { if num_arguments != 1 { for argument in arguments { - self.infer_expression(argument); + self.infer_expression(argument, TypeContext::default()); } report_invalid_argument_number_to_special_form( &self.context, @@ -1152,7 +1152,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { return Type::unknown(); } - let argument_type = self.infer_expression(&arguments[0]); + let argument_type = self.infer_expression(&arguments[0], TypeContext::default()); let bindings = argument_type.bindings(db); // SAFETY: This is enforced by the constructor methods on `Bindings` even in @@ -1362,7 +1362,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { Type::tuple(self.infer_tuple_type_expression(arguments_slice)) } SpecialFormType::Generic | SpecialFormType::Protocol => { - self.infer_expression(arguments_slice); + self.infer_expression(arguments_slice, TypeContext::default()); if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) { builder.into_diagnostic(format_args!( "`{special_form}` is not allowed in type expressions", @@ -1380,7 +1380,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { Ok(match parameters { // TODO handle type aliases ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => { - let value_ty = self.infer_expression(value); + let value_ty = self.infer_expression(value, TypeContext::default()); if matches!(value_ty, Type::SpecialForm(SpecialFormType::Literal)) { let ty = self.infer_literal_parameter_type(slice)?; @@ -1389,7 +1389,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { self.store_expression_type(parameters, ty); ty } else { - self.infer_expression(slice); + self.infer_expression(slice, TypeContext::default()); self.store_expression_type(parameters, Type::unknown()); return Err(vec![parameters]); @@ -1426,13 +1426,13 @@ impl<'db> TypeInferenceBuilder<'db, '_> { literal @ (ast::Expr::StringLiteral(_) | ast::Expr::BytesLiteral(_) | ast::Expr::BooleanLiteral(_) - | ast::Expr::NoneLiteral(_)) => self.infer_expression(literal), + | ast::Expr::NoneLiteral(_)) => self.infer_expression(literal, TypeContext::default()), literal @ ast::Expr::NumberLiteral(number) if number.value.is_int() => { - self.infer_expression(literal) + self.infer_expression(literal, TypeContext::default()) } // For enum values ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { - let value_ty = self.infer_expression(value); + let value_ty = self.infer_expression(value, TypeContext::default()); if is_enum_class(self.db(), value_ty) { let ty = value_ty @@ -1461,7 +1461,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { ty } _ => { - self.infer_expression(parameters); + self.infer_expression(parameters, TypeContext::default()); return Err(vec![parameters]); } }) @@ -1507,7 +1507,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { }); } ast::Expr::Subscript(subscript) => { - let value_ty = self.infer_expression(&subscript.value); + let value_ty = self.infer_expression(&subscript.value, TypeContext::default()); self.infer_subscript_type_expression(subscript, value_ty); // TODO: Support `Concatenate[...]` return Some(Parameters::todo()); diff --git a/crates/ty_python_semantic/src/types/infer/tests.rs b/crates/ty_python_semantic/src/types/infer/tests.rs index e5ca28a4b8..ec4a0d46c8 100644 --- a/crates/ty_python_semantic/src/types/infer/tests.rs +++ b/crates/ty_python_semantic/src/types/infer/tests.rs @@ -5,7 +5,7 @@ use crate::place::{ConsideredDefinitions, Place, global_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map}; -use crate::types::{KnownInstanceType, check_types}; +use crate::types::{KnownClass, KnownInstanceType, UnionType, check_types}; use ruff_db::diagnostic::Diagnostic; use ruff_db::files::{File, system_path_to_file}; use ruff_db::system::DbWithWritableSystem as _; @@ -409,17 +409,17 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { db.write_dedented( "/src/mod.py", r#" - class C: - def f(self): - self.attr: int | None = None - "#, + class C: + def f(self): + self.attr: int | None = None + "#, )?; db.write_dedented( "/src/main.py", r#" - from mod import C - x = C().attr - "#, + from mod import C + x = C().attr + "#, )?; let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); @@ -430,10 +430,10 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { db.write_dedented( "/src/mod.py", r#" - class C: - def f(self): - self.attr: str | None = None - "#, + class C: + def f(self): + self.attr: str | None = None + "#, )?; let events = { @@ -442,17 +442,22 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); + assert_function_query_was_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(x_rhs_expression(&db)), + &events, + ); // Add a comment; this should not trigger the type of `x` to be re-inferred db.write_dedented( "/src/mod.py", r#" - class C: - def f(self): - # a comment! - self.attr: str | None = None - "#, + class C: + def f(self): + # a comment! + self.attr: str | None = None + "#, )?; let events = { @@ -462,7 +467,12 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { db.take_salsa_events() }; - assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events); + assert_function_query_was_not_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(x_rhs_expression(&db)), + &events, + ); Ok(()) } @@ -487,19 +497,19 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { db.write_dedented( "/src/mod.py", r#" - class C: - if random.choice([True, False]): - attr: int = 42 - else: - attr: None = None - "#, + class C: + if random.choice([True, False]): + attr: int = 42 + else: + attr: None = None + "#, )?; db.write_dedented( "/src/main.py", r#" - from mod import C - x = C().attr - "#, + from mod import C + x = C().attr + "#, )?; let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); @@ -510,12 +520,12 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { db.write_dedented( "/src/mod.py", r#" - class C: - if random.choice([True, False]): - attr: str = "42" - else: - attr: None = None - "#, + class C: + if random.choice([True, False]): + attr: str = "42" + else: + attr: None = None + "#, )?; let events = { @@ -524,19 +534,24 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); + assert_function_query_was_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(x_rhs_expression(&db)), + &events, + ); // Add a comment; this should not trigger the type of `x` to be re-inferred db.write_dedented( "/src/mod.py", r#" - class C: - # comment - if random.choice([True, False]): - attr: str = "42" - else: - attr: None = None - "#, + class C: + # comment + if random.choice([True, False]): + attr: str = "42" + else: + attr: None = None + "#, )?; let events = { @@ -546,7 +561,12 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { db.take_salsa_events() }; - assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events); + assert_function_query_was_not_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(x_rhs_expression(&db)), + &events, + ); Ok(()) } @@ -569,22 +589,22 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { db.write_dedented( "/src/mod.py", r#" - class C: - def __init__(self): - self.instance_attr: str = "24" + class C: + def __init__(self): + self.instance_attr: str = "24" - @classmethod - def method(cls): - cls.class_attr: int = 42 - "#, + @classmethod + def method(cls): + cls.class_attr: int = 42 + "#, )?; db.write_dedented( "/src/main.py", r#" - from mod import C - C.method() - x = C().class_attr - "#, + from mod import C + C.method() + x = C().class_attr + "#, )?; let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); @@ -595,14 +615,14 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { db.write_dedented( "/src/mod.py", r#" - class C: - def __init__(self): - self.instance_attr: str = "24" + class C: + def __init__(self): + self.instance_attr: str = "24" - @classmethod - def method(cls): - cls.class_attr: str = "42" - "#, + @classmethod + def method(cls): + cls.class_attr: str = "42" + "#, )?; let events = { @@ -611,21 +631,26 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; - assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); + assert_function_query_was_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(x_rhs_expression(&db)), + &events, + ); // Add a comment; this should not trigger the type of `x` to be re-inferred db.write_dedented( "/src/mod.py", r#" - class C: - def __init__(self): - self.instance_attr: str = "24" + class C: + def __init__(self): + self.instance_attr: str = "24" - @classmethod - def method(cls): - # comment - cls.class_attr: str = "42" - "#, + @classmethod + def method(cls): + # comment + cls.class_attr: str = "42" + "#, )?; let events = { @@ -635,7 +660,88 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { db.take_salsa_events() }; - assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events); + assert_function_query_was_not_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(x_rhs_expression(&db)), + &events, + ); + + Ok(()) +} + +/// Inferring the result of a call-expression shouldn't need to re-run after +/// a trivial change to the function's file (e.g. by adding a docstring to the function). +#[test] +fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/foo.py", + r#" + def foo() -> int: + return 5 + "#, + )?; + db.write_dedented( + "src/bar.py", + r#" + from foo import foo + + a = foo() + "#, + )?; + + let bar = system_path_to_file(&db, "src/bar.py")?; + let a = global_symbol(&db, bar, "a").place; + + assert_eq!( + a.expect_type(), + UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)]) + ); + let events = db.take_salsa_events(); + + let module = parsed_module(&db, bar).load(&db); + let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; + let foo_call = semantic_index(&db, bar).expression(call); + + assert_function_query_was_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(foo_call), + &events, + ); + + // Add a docstring to foo to trigger a re-run. + // The bar-call site of foo should not be re-run because of that + db.write_dedented( + "src/foo.py", + r#" + def foo() -> int: + "Computes a value" + return 5 + "#, + )?; + db.clear_salsa_events(); + + let a = global_symbol(&db, bar, "a").place; + + assert_eq!( + a.expect_type(), + UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)]) + ); + let events = db.take_salsa_events(); + + let module = parsed_module(&db, bar).load(&db); + let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; + let foo_call = semantic_index(&db, bar).expression(call); + + assert_function_query_was_not_run( + &db, + infer_expression_types_impl, + InferExpression::Bare(foo_call), + &events, + ); Ok(()) } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index b1ba138c65..2fb6157acb 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -12,7 +12,7 @@ use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; use crate::types::{ ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType, - Truthiness, Type, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, + Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, }; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; @@ -773,7 +773,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let inference = infer_expression_types(self.db, expression); + let inference = infer_expression_types(self.db, expression, TypeContext::default()); let comparator_tuples = std::iter::once(&**left) .chain(comparators) @@ -863,7 +863,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let inference = infer_expression_types(self.db, expression); + let inference = infer_expression_types(self.db, expression, TypeContext::default()); let callable_ty = inference.expression_type(&*expr_call.func); @@ -983,7 +983,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?; + let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) + .to_instance(self.db)?; Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -996,7 +997,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let ty = infer_same_file_expression_type(self.db, value, self.module); + let ty = + infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -1025,7 +1027,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let inference = infer_expression_types(self.db, expression); + let inference = infer_expression_types(self.db, expression, TypeContext::default()); let mut sub_constraints = expr_bool_op .values .iter() diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index 7197d57e03..b448053295 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -9,7 +9,7 @@ use crate::Db; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::scope::ScopeId; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker}; -use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types}; +use crate::types::{Type, TypeCheckDiagnostics, TypeContext, infer_expression_types}; use crate::unpack::{UnpackKind, UnpackValue}; use super::context::InferContext; @@ -48,8 +48,9 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { "Unpacking target must be a list or tuple expression" ); - let value_type = infer_expression_types(self.db(), value.expression()) - .expression_type(value.expression().node_ref(self.db(), self.module())); + let value_type = + infer_expression_types(self.db(), value.expression(), TypeContext::default()) + .expression_type(value.expression().node_ref(self.db(), self.module())); let value_type = match value.kind() { UnpackKind::Assign => {