diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 1622f96178..5f72c71684 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -84,7 +84,7 @@ fn infer_definition_types_cycle_recovery<'db>( input: Definition<'db>, ) -> TypeInference<'db> { tracing::trace!("infer_definition_types_cycle_recovery"); - let mut inference = TypeInference::default(); + let mut inference = TypeInference::empty(input.scope(db)); let category = input.category(db); if category.is_declaration() { inference.declarations.insert(input, Type::Unknown); @@ -172,7 +172,7 @@ pub(crate) enum InferenceRegion<'db> { } /// The inferred types for a single region. -#[derive(Debug, Eq, PartialEq, Default)] +#[derive(Debug, Eq, PartialEq)] pub(crate) struct TypeInference<'db> { /// The types of every expression in this region. expressions: FxHashMap>, @@ -188,9 +188,23 @@ pub(crate) struct TypeInference<'db> { /// Are there deferred type expressions in this region? has_deferred: bool, + + /// The scope belong to this region. + scope: ScopeId<'db>, } impl<'db> TypeInference<'db> { + pub(crate) fn empty(scope: ScopeId<'db>) -> Self { + Self { + expressions: FxHashMap::default(), + bindings: FxHashMap::default(), + declarations: FxHashMap::default(), + diagnostics: TypeCheckDiagnostics::default(), + has_deferred: false, + scope, + } + } + pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { self.expressions[&expression] } @@ -272,7 +286,6 @@ pub(super) struct TypeInferenceBuilder<'db> { // Cached lookups file: File, - scope: ScopeId<'db>, /// The type inference results types: TypeInference<'db>, @@ -305,13 +318,14 @@ impl<'db> TypeInferenceBuilder<'db> { region, file, - scope, - types: TypeInference::default(), + types: TypeInference::empty(scope), } } fn extend(&mut self, inference: &TypeInference<'db>) { + debug_assert_eq!(self.types.scope, inference.scope); + self.types.bindings.extend(inference.bindings.iter()); self.types .declarations @@ -321,6 +335,10 @@ impl<'db> TypeInferenceBuilder<'db> { self.types.has_deferred |= inference.has_deferred; } + fn scope(&self) -> ScopeId<'db> { + self.types.scope + } + /// Are we currently inferring types in file with deferred types? /// This is true for stub files and files with `__future__.annotations` fn are_all_types_deferred(&self) -> bool { @@ -337,7 +355,7 @@ impl<'db> TypeInferenceBuilder<'db> { /// PANIC if no type has been inferred for this node. fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> { self.types - .expression_ty(expr.scoped_ast_id(self.db, self.scope)) + .expression_ty(expr.scoped_ast_id(self.db, self.scope())) } /// Infers types in the given [`InferenceRegion`]. @@ -799,8 +817,6 @@ impl<'db> TypeInferenceBuilder<'db> { } = parameter_with_default; self.infer_optional_expression(parameter.annotation.as_deref()); - - self.infer_definition(parameter_with_default); } fn infer_parameter(&mut self, parameter: &ast::Parameter) { @@ -811,8 +827,6 @@ impl<'db> TypeInferenceBuilder<'db> { } = parameter; self.infer_optional_expression(annotation.as_deref()); - - self.infer_definition(parameter); } fn infer_parameter_with_default_definition( @@ -1008,7 +1022,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.types .expressions - .insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty); + .insert(target.scoped_ast_id(self.db, self.scope()), context_expr_ty); self.add_binding(target.into(), definition, context_expr_ty); } @@ -1199,7 +1213,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.add_binding(name.into(), definition, target_ty); self.types .expressions - .insert(name.scoped_ast_id(self.db, self.scope), target_ty); + .insert(name.scoped_ast_id(self.db, self.scope()), target_ty); } fn infer_sequence_unpacking( @@ -1513,9 +1527,10 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_with_diagnostic(iterable.into(), self) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty); + self.types.expressions.insert( + target.scoped_ast_id(self.db, self.scope()), + loop_var_value_ty, + ); self.add_binding(target.into(), definition, loop_var_value_ty); } @@ -1842,7 +1857,7 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); + let expr_id = expression.scoped_ast_id(self.db, self.scope()); let previous = self.types.expressions.insert(expr_id, ty); assert_eq!(previous, None); @@ -2161,13 +2176,13 @@ impl<'db> TypeInferenceBuilder<'db> { let iterable_ty = if is_first { let lookup_scope = self .index - .parent_scope_id(self.scope.file_scope_id(self.db)) + .parent_scope_id(self.scope().file_scope_id(self.db)) .expect("A comprehension should never be the top-level scope") .to_scope_id(self.db, self.file); result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope)) } else { self.extend(result); - result.expression_ty(iterable.scoped_ast_id(self.db, self.scope)) + result.expression_ty(iterable.scoped_ast_id(self.db, self.scope())) }; let target_ty = if is_async { @@ -2181,7 +2196,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.types .expressions - .insert(target.scoped_ast_id(self.db, self.scope), target_ty); + .insert(target.scoped_ast_id(self.db, self.scope()), target_ty); self.add_binding(target.into(), definition, target_ty); } @@ -2318,7 +2333,7 @@ impl<'db> TypeInferenceBuilder<'db> { /// Look up a name reference that isn't bound in the local scope. fn lookup_name(&mut self, name_node: &ast::ExprName) -> Type<'db> { let ast::ExprName { id: name, .. } = name_node; - let file_scope_id = self.scope.file_scope_id(self.db); + let file_scope_id = self.scope().file_scope_id(self.db); let is_bound = self .index .symbol_table(file_scope_id) @@ -2329,7 +2344,7 @@ impl<'db> TypeInferenceBuilder<'db> { // In function-like scopes, any local variable (symbol that is bound in this scope) can // only have a definition in this scope, or error; it never references another scope. // (At runtime, it would use the `LOAD_FAST` opcode.) - if !is_bound || !self.scope.is_function_like(self.db) { + if !is_bound || !self.scope().is_function_like(self.db) { // Walk up parent scopes looking for a possible enclosing scope that may have a // definition of this name visible to us (would be `LOAD_DEREF` at runtime.) for (enclosing_scope_file_id, _) in self.index.ancestor_scopes(file_scope_id) { @@ -2361,7 +2376,7 @@ impl<'db> TypeInferenceBuilder<'db> { global_symbol_ty(self.db, self.file, name) }; // Fallback to builtins (without infinite recursion if we're already in builtins.) - if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_module_scope(self.db) { + if ty.may_be_unbound(self.db) && Some(self.scope()) != builtins_module_scope(self.db) { let mut builtin_ty = builtins_symbol_ty(self.db, name); if builtin_ty.is_unbound() && name == "reveal_type" { self.add_diagnostic( @@ -2383,7 +2398,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { let ast::ExprName { range: _, id, ctx } = name; - let file_scope_id = self.scope.file_scope_id(self.db); + let file_scope_id = self.scope().file_scope_id(self.db); match ctx { ExprContext::Load => { @@ -2400,7 +2415,7 @@ impl<'db> TypeInferenceBuilder<'db> { use_def.public_may_be_unbound(symbol), ) } else { - let use_id = name.scoped_use_id(self.db, self.scope); + let use_id = name.scoped_use_id(self.db, self.scope()); ( use_def.bindings_at_use(use_id), use_def.use_may_be_unbound(use_id), @@ -3388,7 +3403,7 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); + let expr_id = expression.scoped_ast_id(self.db, self.scope()); let previous = self.types.expressions.insert(expr_id, ty); assert!(previous.is_none());