From bccc33c6fe31acc2894def097cf9a2fc9b83d03d Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 14 Aug 2024 17:02:17 -0700 Subject: [PATCH] [WIP] working version of SSA-style use-def map --- Cargo.lock | 1 + crates/red_knot_python_semantic/Cargo.toml | 1 + .../src/semantic_index.rs | 111 ++-- .../src/semantic_index/builder.rs | 114 ++-- .../src/semantic_index/definition.rs | 59 +- .../src/semantic_index/symbol.rs | 6 +- .../src/semantic_index/use_def.rs | 512 +++++++++--------- .../src/semantic_model.rs | 6 +- crates/red_knot_python_semantic/src/types.rs | 78 +-- .../src/types/infer.rs | 183 ++++--- 10 files changed, 580 insertions(+), 491 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4cc92e1c76..149a45c841 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1904,6 +1904,7 @@ dependencies = [ "ruff_text_size", "rustc-hash 2.0.0", "salsa", + "smallvec", "tempfile", "tracing", "walkdir", diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index 1019ce9434..70b3252756 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -26,6 +26,7 @@ countme = { workspace = true } once_cell = { workspace = true } ordermap = { workspace = true } salsa = { workspace = true } +smallvec = { workspace = true } tracing = { workspace = true } rustc-hash = { workspace = true } hashbrown = { workspace = true } diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index a3626c0bdc..6804e96e37 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -126,7 +126,7 @@ impl<'db> SemanticIndex<'db> { /// /// Use the Salsa cached [`use_def_map()`] query if you only need the /// use-def map for a single scope. - pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc { + pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc> { self.use_def_maps[scope_id].clone() } @@ -311,7 +311,7 @@ mod tests { use crate::db::tests::TestDb; use crate::semantic_index::ast_ids::HasScopedUseId; - use crate::semantic_index::definition::DefinitionKind; + use crate::semantic_index::definition::{DefinitionKind, DefinitionNode}; use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable}; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; use crate::Db; @@ -374,10 +374,11 @@ mod tests { let foo = global_table.symbol_id_by_name("foo").unwrap(); let use_def = use_def_map(&db, scope); - let [definition] = use_def.public_definitions(foo) else { - panic!("expected one definition"); - }; - assert!(matches!(definition.node(&db), DefinitionKind::Import(_))); + let definition = use_def.public_definition(foo).unwrap(); + assert!(matches!( + definition.kind(&db), + DefinitionKind::Node(DefinitionNode::Import(_)) + )); } #[test] @@ -411,16 +412,16 @@ mod tests { ); let use_def = use_def_map(&db, scope); - let [definition] = use_def.public_definitions( - global_table - .symbol_id_by_name("foo") - .expect("symbol to exist"), - ) else { - panic!("expected one definition"); - }; + let definition = use_def + .public_definition( + global_table + .symbol_id_by_name("foo") + .expect("symbol to exist"), + ) + .unwrap(); assert!(matches!( - definition.node(&db), - DefinitionKind::ImportFrom(_) + definition.kind(&db), + DefinitionKind::Node(DefinitionNode::ImportFrom(_)) )); } @@ -438,14 +439,12 @@ mod tests { "a symbol used but not defined in a scope should have only the used flag" ); let use_def = use_def_map(&db, scope); - let [definition] = - use_def.public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists")) - else { - panic!("expected one definition"); - }; + let definition = use_def + .public_definition(global_table.symbol_id_by_name("x").expect("symbol exists")) + .unwrap(); assert!(matches!( - definition.node(&db), - DefinitionKind::Assignment(_) + definition.kind(&db), + DefinitionKind::Node(DefinitionNode::Assignment(_)) )); } @@ -477,14 +476,12 @@ y = 2 assert_eq!(names(&class_table), vec!["x"]); let use_def = index.use_def_map(class_scope_id); - let [definition] = - use_def.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists")) - else { - panic!("expected one definition"); - }; + let definition = use_def + .public_definition(class_table.symbol_id_by_name("x").expect("symbol exists")) + .unwrap(); assert!(matches!( - definition.node(&db), - DefinitionKind::Assignment(_) + definition.kind(&db), + DefinitionKind::Node(DefinitionNode::Assignment(_)) )); } @@ -515,16 +512,16 @@ y = 2 assert_eq!(names(&function_table), vec!["x"]); let use_def = index.use_def_map(function_scope_id); - let [definition] = use_def.public_definitions( - function_table - .symbol_id_by_name("x") - .expect("symbol exists"), - ) else { - panic!("expected one definition"); - }; + let definition = use_def + .public_definition( + function_table + .symbol_id_by_name("x") + .expect("symbol exists"), + ) + .unwrap(); assert!(matches!( - definition.node(&db), - DefinitionKind::Assignment(_) + definition.kind(&db), + DefinitionKind::Node(DefinitionNode::Assignment(_)) )); } @@ -594,10 +591,10 @@ y = 2 let element_use_id = element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file)); - let [definition] = use_def.use_definitions(element_use_id) else { - panic!("expected one definition") - }; - let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else { + let definition = use_def.definition_for_use(element_use_id).unwrap(); + let DefinitionKind::Node(DefinitionNode::Comprehension(comprehension)) = + definition.kind(&db) + else { panic!("expected generator definition") }; let ast::Comprehension { target, .. } = comprehension.node(); @@ -611,7 +608,7 @@ y = 2 /// the outer comprehension scope and the variables are correctly defined in the respective /// scopes. #[test] - fn nested_generators() { + fn nested_comprehensions() { let TestCase { db, file } = test_case( " [{x for x in iter2} for y in iter1] @@ -644,7 +641,7 @@ y = 2 .child_scopes(comprehension_scope_id) .collect::>()[..] else { - panic!("expected one inner generator scope") + panic!("expected one inner comprehension scope") }; assert_eq!(inner_comprehension_scope.kind(), ScopeKind::Comprehension); @@ -693,14 +690,17 @@ def func(): assert_eq!(names(&func2_table), vec!["y"]); let use_def = index.use_def_map(FileScopeId::global()); - let [definition] = use_def.public_definitions( - global_table - .symbol_id_by_name("func") - .expect("symbol exists"), - ) else { - panic!("expected one definition"); - }; - assert!(matches!(definition.node(&db), DefinitionKind::Function(_))); + let definition = use_def + .public_definition( + global_table + .symbol_id_by_name("func") + .expect("symbol exists"), + ) + .unwrap(); + assert!(matches!( + definition.kind(&db), + DefinitionKind::Node(DefinitionNode::Function(_)) + )); } #[test] @@ -800,10 +800,9 @@ class C[T]: }; let x_use_id = x_use_expr_name.scoped_use_id(&db, scope); let use_def = use_def_map(&db, scope); - let [definition] = use_def.use_definitions(x_use_id) else { - panic!("expected one definition"); - }; - let DefinitionKind::Assignment(assignment) = definition.node(&db) else { + let definition = use_def.definition_for_use(x_use_id).unwrap(); + let DefinitionKind::Node(DefinitionNode::Assignment(assignment)) = definition.kind(&db) + else { panic!("should be an assignment definition") }; let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index ee17e228d9..8c548974d2 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -13,15 +13,15 @@ use crate::ast_node_ref::AstNodeRef; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIdsBuilder; use crate::semantic_index::definition::{ - AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey, - DefinitionNodeRef, ImportFromDefinitionNodeRef, + AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionKind, + DefinitionNodeKey, DefinitionNodeRef, ImportFromDefinitionNodeRef, }; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, }; -use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; +use crate::semantic_index::use_def::{BasicBlockId, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::Db; @@ -33,8 +33,8 @@ pub(super) struct SemanticIndexBuilder<'db> { scope_stack: Vec, /// The assignment we're currently visiting. current_assignment: Option>, - /// Flow states at each `break` in the current loop. - loop_break_states: Vec, + /// Basic block ending at each `break` in the current loop. + loop_breaks: Vec, // Semantic Index fields scopes: IndexVec, @@ -56,7 +56,7 @@ impl<'db> SemanticIndexBuilder<'db> { module: parsed, scope_stack: Vec::new(), current_assignment: None, - loop_break_states: vec![], + loop_breaks: vec![], scopes: IndexVec::new(), symbol_tables: IndexVec::new(), @@ -98,7 +98,8 @@ impl<'db> SemanticIndexBuilder<'db> { let file_scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::new()); - self.use_def_maps.push(UseDefMapBuilder::new()); + self.use_def_maps + .push(UseDefMapBuilder::new(self.db, self.file, file_scope_id)); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); #[allow(unsafe_code)] @@ -132,41 +133,50 @@ impl<'db> SemanticIndexBuilder<'db> { &mut self.symbol_tables[scope_id] } - fn current_use_def_map_mut(&mut self) -> &mut UseDefMapBuilder<'db> { + fn current_use_def_map(&mut self) -> &mut UseDefMapBuilder<'db> { let scope_id = self.current_scope(); &mut self.use_def_maps[scope_id] } - fn current_use_def_map(&self) -> &UseDefMapBuilder<'db> { - let scope_id = self.current_scope(); - &self.use_def_maps[scope_id] - } - fn current_ast_ids(&mut self) -> &mut AstIdsBuilder { let scope_id = self.current_scope(); &mut self.ast_ids[scope_id] } - fn flow_snapshot(&self) -> FlowSnapshot { - self.current_use_def_map().snapshot() + /// Start a new basic block and return the previous block's ID. + fn next_block(&mut self) -> BasicBlockId { + self.current_use_def_map().next_block(/* sealed */ true) } - fn flow_restore(&mut self, state: FlowSnapshot) { - self.current_use_def_map_mut().restore(state); + /// Start a new unsealed basic block and return the previous block's ID. + fn next_block_unsealed(&mut self) -> BasicBlockId { + self.current_use_def_map().next_block(/* sealed */ false) } - fn flow_merge(&mut self, state: &FlowSnapshot) { - self.current_use_def_map_mut().merge(state); + /// Seal an unsealed basic block. + fn seal_block(&mut self) { + self.current_use_def_map().seal_current_block(); + } + + /// Start a new basic block with the given block as predecessor. + fn new_block_from(&mut self, predecessor: BasicBlockId) { + self.current_use_def_map() + .new_block_from(predecessor, /* sealed */ true); + } + + /// Add a predecessor to the current block. + fn merge_block(&mut self, predecessor: BasicBlockId) { + self.current_use_def_map().merge_block(predecessor); + } + + /// Add predecessors to the current block. + fn merge_blocks(&mut self, predecessors: Vec) { + self.current_use_def_map().merge_blocks(predecessors); } fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); - let (symbol_id, added) = symbol_table.add_or_update_symbol(name, flags); - if added { - let use_def_map = self.current_use_def_map_mut(); - use_def_map.add_symbol(symbol_id); - } - symbol_id + symbol_table.add_or_update_symbol(name, flags) } fn add_definition<'a>( @@ -181,15 +191,13 @@ impl<'db> SemanticIndexBuilder<'db> { self.current_scope(), symbol, #[allow(unsafe_code)] - unsafe { - definition_node.into_owned(self.module.clone()) - }, + DefinitionKind::Node(unsafe { definition_node.into_owned(self.module.clone()) }), countme::Count::default(), ); self.definitions_by_node .insert(definition_node.key(), definition); - self.current_use_def_map_mut() + self.current_use_def_map() .record_definition(symbol, definition); definition @@ -455,21 +463,19 @@ where } ast::Stmt::If(node) => { self.visit_expr(&node.test); - let pre_if = self.flow_snapshot(); + let pre_if = self.next_block(); self.visit_body(&node.body); - let mut post_clauses: Vec = vec![]; + let mut post_clauses: Vec = vec![]; for clause in &node.elif_else_clauses { // snapshot after every block except the last; the last one will just become // the state that we merge the other snapshots into - post_clauses.push(self.flow_snapshot()); + post_clauses.push(self.next_block()); // we can only take an elif/else branch if none of the previous ones were // taken, so the block entry state is always `pre_if` - self.flow_restore(pre_if.clone()); + self.new_block_from(pre_if); self.visit_elif_else_clause(clause); } - for post_clause_state in post_clauses { - self.flow_merge(&post_clause_state); - } + self.next_block_unsealed(); let has_else = node .elif_else_clauses .last() @@ -477,35 +483,39 @@ where if !has_else { // if there's no else clause, then it's possible we took none of the branches, // and the pre_if state can reach here - self.flow_merge(&pre_if); + self.merge_block(pre_if); } + self.merge_blocks(post_clauses); + self.seal_block(); } ast::Stmt::While(node) => { self.visit_expr(&node.test); - let pre_loop = self.flow_snapshot(); + let pre_loop = self.next_block(); // Save aside any break states from an outer loop - let saved_break_states = std::mem::take(&mut self.loop_break_states); + let saved_break_states = std::mem::take(&mut self.loop_breaks); self.visit_body(&node.body); // Get the break states from the body of this loop, and restore the saved outer // ones. - let break_states = - std::mem::replace(&mut self.loop_break_states, saved_break_states); + let break_states = std::mem::replace(&mut self.loop_breaks, saved_break_states); // We may execute the `else` clause without ever executing the body, so merge in // the pre-loop state before visiting `else`. - self.flow_merge(&pre_loop); + self.next_block_unsealed(); + self.merge_block(pre_loop); + self.seal_block(); self.visit_body(&node.orelse); // Breaking out of a while loop bypasses the `else` clause, so merge in the break // states after visiting `else`. - for break_state in break_states { - self.flow_merge(&break_state); - } + self.next_block_unsealed(); + self.merge_blocks(break_states); + self.seal_block(); } ast::Stmt::Break(_) => { - self.loop_break_states.push(self.flow_snapshot()); + let block_id = self.next_block(); + self.loop_breaks.push(block_id); } _ => { walk_stmt(self, stmt); @@ -559,7 +569,7 @@ where if flags.contains(SymbolFlags::IS_USED) { let use_id = self.current_ast_ids().record_use(expr); - self.current_use_def_map_mut().record_use(symbol, use_id); + self.current_use_def_map().record_use(symbol, use_id); } walk_expr(self, expr); @@ -586,12 +596,14 @@ where // AST inspection, so we can't simplify here, need to record test expression for // later checking) self.visit_expr(test); - let pre_if = self.flow_snapshot(); + let pre_if = self.next_block(); self.visit_expr(body); - let post_body = self.flow_snapshot(); - self.flow_restore(pre_if); + let post_body = self.next_block(); + self.new_block_from(pre_if); self.visit_expr(orelse); - self.flow_merge(&post_body); + self.next_block_unsealed(); + self.merge_block(post_body); + self.seal_block(); } ast::Expr::ListComp( list_comprehension @ ast::ExprListComp { diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 0c4c9f39fe..90d289ff07 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -1,5 +1,6 @@ use ruff_db::files::File; use ruff_db::parsed::ParsedModule; +use ruff_index::newtype_index; use ruff_python_ast as ast; use crate::ast_node_ref::AstNodeRef; @@ -8,7 +9,7 @@ use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId}; use crate::Db; #[salsa::tracked] -pub struct Definition<'db> { +pub(crate) struct Definition<'db> { /// The file in which the definition occurs. #[id] pub(crate) file: File, @@ -23,7 +24,7 @@ pub struct Definition<'db> { #[no_eq] #[return_ref] - pub(crate) node: DefinitionKind, + pub(crate) kind: DefinitionKind, #[no_eq] count: countme::Count>, @@ -35,6 +36,22 @@ impl<'db> Definition<'db> { } } +#[derive(Clone, Debug)] +pub(crate) enum DefinitionKind { + /// Inserted at control-flow merge points, if multiple definitions can reach the merge point. + /// + /// Operands are not kept inline, since it's not possible to construct cyclically-referential + /// Salsa tracked structs; they are kept instead in the + /// [`UseDefMap`](super::use_def::UseDefMap). + Phi(ScopedPhiId), + + /// An assignment to the symbol. + Node(DefinitionNode), +} + +#[newtype_index] +pub(crate) struct ScopedPhiId; + #[derive(Copy, Clone, Debug)] pub(crate) enum DefinitionNodeRef<'a> { Import(&'a ast::Alias), @@ -115,37 +132,37 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> { impl DefinitionNodeRef<'_> { #[allow(unsafe_code)] - pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind { + pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionNode { match self { DefinitionNodeRef::Import(alias) => { - DefinitionKind::Import(AstNodeRef::new(parsed, alias)) + DefinitionNode::Import(AstNodeRef::new(parsed, alias)) } DefinitionNodeRef::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => { - DefinitionKind::ImportFrom(ImportFromDefinitionKind { + DefinitionNode::ImportFrom(ImportFromDefinitionNode { node: AstNodeRef::new(parsed, node), alias_index, }) } DefinitionNodeRef::Function(function) => { - DefinitionKind::Function(AstNodeRef::new(parsed, function)) + DefinitionNode::Function(AstNodeRef::new(parsed, function)) } DefinitionNodeRef::Class(class) => { - DefinitionKind::Class(AstNodeRef::new(parsed, class)) + DefinitionNode::Class(AstNodeRef::new(parsed, class)) } DefinitionNodeRef::NamedExpression(named) => { - DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named)) + DefinitionNode::NamedExpression(AstNodeRef::new(parsed, named)) } DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => { - DefinitionKind::Assignment(AssignmentDefinitionKind { + DefinitionNode::Assignment(AssignmentDefinitionNode { assignment: AstNodeRef::new(parsed.clone(), assignment), target: AstNodeRef::new(parsed, target), }) } DefinitionNodeRef::AnnotatedAssignment(assign) => { - DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) + DefinitionNode::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) } DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => { - DefinitionKind::Comprehension(ComprehensionDefinitionKind { + DefinitionNode::Comprehension(ComprehensionDefinitionNode { node: AstNodeRef::new(parsed, node), first, }) @@ -173,24 +190,24 @@ impl DefinitionNodeRef<'_> { } #[derive(Clone, Debug)] -pub enum DefinitionKind { +pub enum DefinitionNode { Import(AstNodeRef), - ImportFrom(ImportFromDefinitionKind), + ImportFrom(ImportFromDefinitionNode), Function(AstNodeRef), Class(AstNodeRef), NamedExpression(AstNodeRef), - Assignment(AssignmentDefinitionKind), + Assignment(AssignmentDefinitionNode), AnnotatedAssignment(AstNodeRef), - Comprehension(ComprehensionDefinitionKind), + Comprehension(ComprehensionDefinitionNode), } #[derive(Clone, Debug)] -pub struct ComprehensionDefinitionKind { +pub struct ComprehensionDefinitionNode { node: AstNodeRef, first: bool, } -impl ComprehensionDefinitionKind { +impl ComprehensionDefinitionNode { pub(crate) fn node(&self) -> &ast::Comprehension { self.node.node() } @@ -201,12 +218,12 @@ impl ComprehensionDefinitionKind { } #[derive(Clone, Debug)] -pub struct ImportFromDefinitionKind { +pub struct ImportFromDefinitionNode { node: AstNodeRef, alias_index: usize, } -impl ImportFromDefinitionKind { +impl ImportFromDefinitionNode { pub(crate) fn import(&self) -> &ast::StmtImportFrom { self.node.node() } @@ -218,12 +235,12 @@ impl ImportFromDefinitionKind { #[derive(Clone, Debug)] #[allow(dead_code)] -pub struct AssignmentDefinitionKind { +pub struct AssignmentDefinitionNode { assignment: AstNodeRef, target: AstNodeRef, } -impl AssignmentDefinitionKind { +impl AssignmentDefinitionNode { pub(crate) fn assignment(&self) -> &ast::StmtAssign { self.assignment.node() } diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index 44db9d0d42..4389f112f5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -272,7 +272,7 @@ impl SymbolTableBuilder { &mut self, name: Name, flags: SymbolFlags, - ) -> (ScopedSymbolId, bool) { + ) -> ScopedSymbolId { let hash = SymbolTable::hash_name(&name); let entry = self .table @@ -285,7 +285,7 @@ impl SymbolTableBuilder { let symbol = &mut self.table.symbols[*entry.key()]; symbol.insert_flags(flags); - (*entry.key(), false) + *entry.key() } RawEntryMut::Vacant(entry) => { let mut symbol = Symbol::new(name); @@ -295,7 +295,7 @@ impl SymbolTableBuilder { entry.insert_with_hasher(hash, id, (), |id| { SymbolTable::hash_name(self.table.symbols[*id].name().as_str()) }); - (id, true) + id } } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index f3e1afe982..f9d60d19d4 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -56,299 +56,323 @@ //! visible at the end of the scope. //! //! The data structure we build to answer these two questions is the `UseDefMap`. It has a -//! `definitions_by_use` vector indexed by [`ScopedUseId`] and a `public_definitions` vector -//! indexed by [`ScopedSymbolId`]. The values in each of these vectors are (in principle) a list of -//! visible definitions at that use, or at the end of the scope for that symbol. +//! `definitions_by_use` vector indexed by [`ScopedUseId`] and a `public_definitions` map +//! indexed by [`ScopedSymbolId`]. The values in each are the visible definition of a symbol at +//! that use, or at the end of the scope. //! -//! In order to avoid vectors-of-vectors and all the allocations that would entail, we don't -//! actually store these "list of visible definitions" as a vector of [`Definition`] IDs. Instead, -//! the values in `definitions_by_use` and `public_definitions` are a [`Definitions`] struct that -//! keeps a [`Range`] into a third vector of [`Definition`] IDs, `all_definitions`. The trick with -//! this representation is that it requires that the definitions visible at any given use of a -//! symbol are stored sequentially in `all_definitions`. -//! -//! There is another special kind of possible "definition" for a symbol: it might be unbound in the -//! scope. (This isn't equivalent to "zero visible definitions", since we may go through an `if` -//! that has a definition for the symbol, leaving us with one visible definition, but still also -//! the "unbound" possibility, since we might not have taken the `if` branch.) -//! -//! The simplest way to model "unbound" would be as an actual [`Definition`] itself: the initial -//! visible [`Definition`] for each symbol in a scope. But actually modeling it this way would -//! dramatically increase the number of [`Definition`] that Salsa must track. Since "unbound" is a -//! special definition in that all symbols share it, and it doesn't have any additional per-symbol -//! state, we can represent it more efficiently: we use the `may_be_unbound` boolean on the -//! [`Definitions`] struct. If this flag is `true`, it means the symbol/use really has one -//! additional visible "definition", which is the unbound state. If this flag is `false`, it means -//! we've eliminated the possibility of unbound: every path we've followed includes a definition -//! for this symbol. -//! -//! To build a [`UseDefMap`], the [`UseDefMapBuilder`] is notified of each new use and definition -//! as they are encountered by the -//! [`SemanticIndexBuilder`](crate::semantic_index::builder::SemanticIndexBuilder) AST visit. For -//! each symbol, the builder tracks the currently-visible definitions for that symbol. When we hit -//! a use of a symbol, it records the currently-visible definitions for that symbol as the visible -//! definitions for that use. When we reach the end of the scope, it records the currently-visible -//! definitions for each symbol as the public definitions of that symbol. -//! -//! Let's walk through the above example. Initially we record for `x` that it has no visible -//! definitions, and may be unbound. When we see `x = 1`, we record that as the sole visible -//! definition of `x`, and flip `may_be_unbound` to `false`. Then we see `x = 2`, and it replaces -//! `x = 1` as the sole visible definition of `x`. When we get to `y = x`, we record that the -//! visible definitions for that use of `x` are just the `x = 2` definition. -//! -//! Then we hit the `if` branch. We visit the `test` node (`flag` in this case), since that will -//! happen regardless. Then we take a pre-branch snapshot of the currently visible definitions for -//! all symbols, which we'll need later. Then we go ahead and visit the `if` body. When we see `x = -//! 3`, it replaces `x = 2` as the sole visible definition of `x`. At the end of the `if` body, we -//! take another snapshot of the currently-visible definitions; we'll call this the post-if-body -//! snapshot. -//! -//! Now we need to visit the `else` clause. The conditions when entering the `else` clause should -//! be the pre-if conditions; if we are entering the `else` clause, we know that the `if` test -//! failed and we didn't execute the `if` body. So we first reset the builder to the pre-if state, -//! using the snapshot we took previously (meaning we now have `x = 2` as the sole visible -//! definition for `x` again), then visit the `else` clause, where `x = 4` replaces `x = 2` as the -//! sole visible definition of `x`. -//! -//! Now we reach the end of the if/else, and want to visit the following code. The state here needs -//! to reflect that we might have gone through the `if` branch, or we might have gone through the -//! `else` branch, and we don't know which. So we need to "merge" our current builder state -//! (reflecting the end-of-else state, with `x = 4` as the only visible definition) with our -//! post-if-body snapshot (which has `x = 3` as the only visible definition). The result of this -//! merge is that we now have two visible definitions of `x`: `x = 3` and `x = 4`. -//! -//! The [`UseDefMapBuilder`] itself just exposes methods for taking a snapshot, resetting to a -//! snapshot, and merging a snapshot into the current state. The logic using these methods lives in -//! [`SemanticIndexBuilder`](crate::semantic_index::builder::SemanticIndexBuilder), e.g. where it -//! visits a `StmtIf` node. -//! -//! (In the future we may have some other questions we want to answer as well, such as "is this -//! definition used?", which will require tracking a bit more info in our map, e.g. a "used" bit -//! for each [`Definition`] which is flipped to true when we record that definition for a use.) +//! Rather than have multiple definitions, we use a Phi definition at control flow join points to +//! merge the visible definition in each path. This means at any given point we always have exactly +//! one definition for a symbol. (This is analogous to static-single-assignment, or SSA, form, and +//! in fact we use the algorithm from [Simple and efficient construction of static single +//! assignment form](https://dl.acm.org/doi/10.1007/978-3-642-37051-9_6) here.) use crate::semantic_index::ast_ids::ScopedUseId; -use crate::semantic_index::definition::Definition; -use crate::semantic_index::symbol::ScopedSymbolId; -use ruff_index::IndexVec; -use std::ops::Range; +use crate::semantic_index::definition::{Definition, DefinitionKind, ScopedPhiId}; +use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId}; +use crate::Db; +use ruff_db::files::File; +use ruff_index::{newtype_index, IndexVec}; +use rustc_hash::{FxHashMap, FxHashSet}; +use smallvec::{smallvec, SmallVec}; -/// All definitions that can reach a given use of a name. +/// Number of basic block predecessors we store inline. +const PREDECESSORS: usize = 2; + +/// Input operands (definitions) for a Phi definition. None means not defined. +// TODO would like to use SmallVec here but can't due to lifetime invariance issue. +type PhiOperands<'db> = Vec>>; + +/// Definition for each use of a name. #[derive(Debug, PartialEq, Eq)] pub(crate) struct UseDefMap<'db> { // TODO store constraints with definitions for type narrowing - /// Definition IDs array for `definitions_by_use` and `public_definitions` to slice into. - all_definitions: Vec>, + /// Definition that reaches each [`ScopedUseId`]. + definitions_by_use: IndexVec>>, - /// Definitions that can reach a [`ScopedUseId`]. - definitions_by_use: IndexVec, + /// Definition of each symbol visible at end of scope. + /// + /// Sparse, because it only includes symbols defined in the scope. + public_definitions: FxHashMap>, - /// Definitions of each symbol visible at end of scope. - public_definitions: IndexVec, + /// Operands for each Phi definition in this scope. + phi_operands: IndexVec>, } impl<'db> UseDefMap<'db> { - pub(crate) fn use_definitions(&self, use_id: ScopedUseId) -> &[Definition<'db>] { - &self.all_definitions[self.definitions_by_use[use_id].definitions_range.clone()] + /// Return the dominating definition for a given use of a name; None means not-defined. + pub(crate) fn definition_for_use(&self, use_id: ScopedUseId) -> Option> { + self.definitions_by_use[use_id] } - pub(crate) fn use_may_be_unbound(&self, use_id: ScopedUseId) -> bool { - self.definitions_by_use[use_id].may_be_unbound + /// Return the definition visible at end of scope for a symbol. + /// + /// Return None if the symbol is never defined in the scope. + pub(crate) fn public_definition(&self, symbol_id: ScopedSymbolId) -> Option> { + self.public_definitions.get(&symbol_id).copied() } - pub(crate) fn public_definitions(&self, symbol: ScopedSymbolId) -> &[Definition<'db>] { - &self.all_definitions[self.public_definitions[symbol].definitions_range.clone()] - } - - pub(crate) fn public_may_be_unbound(&self, symbol: ScopedSymbolId) -> bool { - self.public_definitions[symbol].may_be_unbound + /// Return the operands for a Phi in this scope; a None means not-defined. + pub(crate) fn phi_operands<'s>(&'s self, phi_id: ScopedPhiId) -> &'s [Option>] { + self.phi_operands[phi_id].as_slice() } } -/// Definitions visible for a symbol at a particular use (or end-of-scope). -#[derive(Clone, Debug, PartialEq, Eq)] -struct Definitions { - /// [`Range`] in `all_definitions` of the visible definition IDs. - definitions_range: Range, - /// Is the symbol possibly unbound at this point? - may_be_unbound: bool, -} +type PredecessorBlocks = SmallVec<[BasicBlockId; PREDECESSORS]>; -impl Definitions { - /// The default state of a symbol is "no definitions, may be unbound", aka definitely-unbound. - fn unbound() -> Self { - Self { - definitions_range: Range::default(), - may_be_unbound: true, - } - } -} +/// A basic block is a linear region of code (no branches.) +#[newtype_index] +pub(super) struct BasicBlockId; -impl Default for Definitions { - fn default() -> Self { - Definitions::unbound() - } -} - -/// A snapshot of the visible definitions for each symbol at a particular point in control flow. -#[derive(Clone, Debug)] -pub(super) struct FlowSnapshot { - definitions_by_symbol: IndexVec, -} - -#[derive(Debug)] pub(super) struct UseDefMapBuilder<'db> { - /// Definition IDs array for `definitions_by_use` and `definitions_by_symbol` to slice into. - all_definitions: Vec>, + db: &'db dyn Db, + file: File, + file_scope: FileScopeId, - /// Visible definitions at each so-far-recorded use. - definitions_by_use: IndexVec, + /// Predecessor blocks for each basic block. + /// + /// Entry block has none, all other blocks have at least one, blocks that join control flow can + /// have two or more. + predecessors: IndexVec, - /// Currently visible definitions for each symbol. - definitions_by_symbol: IndexVec, + /// The definition of each symbol which dominates each basic block. + /// + /// No entry means "lazily unfilled"; we haven't had to query for it yet, and we may never have + /// to, if the symbol isn't used in this block or any successor block. + /// + /// Each block has an [`FxHashMap`] of symbols instead of an [`IndexVec`] because it is lazy + /// and potentially sparse; it will only include a definition for a symbol that is actually + /// used in that block or a successor. An [`IndexVec`] would have to be eagerly filled with + /// placeholders. + definitions_per_block: + IndexVec>>>, + + /// Incomplete Phi definitions in each block. + /// + /// An incomplete Phi is used when we don't know, while processing a block's body, what new + /// predecessors it may later gain (that is, backward jumps.) + /// + /// Sparse, because relative few blocks (just loop headers) will have any incomplete Phis. + incomplete_phis: FxHashMap>>, + + /// Operands for each Phi definition in this scope. + phi_operands: IndexVec>, + + /// Are this block's predecessors fully populated? + /// + /// If not, it isn't safe to recurse to predecessors yet; we might miss a predecessor block. + sealed_blocks: IndexVec, + + /// Definition for each so-far-recorded use. + definitions_by_use: IndexVec>>, + + /// All symbols defined in this scope. + defined_symbols: FxHashSet, } impl<'db> UseDefMapBuilder<'db> { - pub(super) fn new() -> Self { - Self { - all_definitions: Vec::new(), + pub(super) fn new(db: &'db dyn Db, file: File, file_scope: FileScopeId) -> Self { + let mut new = Self { + db, + file, + file_scope, + predecessors: IndexVec::new(), + definitions_per_block: IndexVec::new(), + incomplete_phis: FxHashMap::default(), + sealed_blocks: IndexVec::new(), definitions_by_use: IndexVec::new(), - definitions_by_symbol: IndexVec::new(), - } - } - - pub(super) fn add_symbol(&mut self, symbol: ScopedSymbolId) { - let new_symbol = self.definitions_by_symbol.push(Definitions::unbound()); - debug_assert_eq!(symbol, new_symbol); + phi_operands: IndexVec::new(), + defined_symbols: FxHashSet::default(), + }; + + // create the entry basic block + new.predecessors.push(PredecessorBlocks::default()); + new.definitions_per_block.push(FxHashMap::default()); + new.sealed_blocks.push(true); + + new } + /// Record a definition for a symbol. pub(super) fn record_definition( &mut self, - symbol: ScopedSymbolId, + symbol_id: ScopedSymbolId, definition: Definition<'db>, ) { - // We have a new definition of a symbol; this replaces any previous definitions in this - // path. - let def_idx = self.all_definitions.len(); - self.all_definitions.push(definition); - self.definitions_by_symbol[symbol] = Definitions { - #[allow(clippy::range_plus_one)] - definitions_range: def_idx..(def_idx + 1), - may_be_unbound: false, - }; + self.memoize(self.current_block_id(), symbol_id, Some(definition)); + self.defined_symbols.insert(symbol_id); } - pub(super) fn record_use(&mut self, symbol: ScopedSymbolId, use_id: ScopedUseId) { - // We have a use of a symbol; clone the currently visible definitions for that symbol, and - // record them as the visible definitions for this use. - let new_use = self - .definitions_by_use - .push(self.definitions_by_symbol[symbol].clone()); + /// Record a use of a symbol. + pub(super) fn record_use(&mut self, symbol_id: ScopedSymbolId, use_id: ScopedUseId) { + let definition_id = self.lookup(symbol_id); + let new_use = self.definitions_by_use.push(definition_id); debug_assert_eq!(use_id, new_use); } - /// Take a snapshot of the current visible-symbols state. - pub(super) fn snapshot(&self) -> FlowSnapshot { - FlowSnapshot { - definitions_by_symbol: self.definitions_by_symbol.clone(), + /// Get the id of the current basic block. + pub(super) fn current_block_id(&self) -> BasicBlockId { + BasicBlockId::from(self.definitions_per_block.len() - 1) + } + + /// Push a new basic block, with given block as predecessor. + pub(super) fn new_block_from(&mut self, block_id: BasicBlockId, sealed: bool) { + self.new_block_with_predecessors(smallvec![block_id], sealed); + } + + /// Push a new basic block, with current block as predecessor; return the current block's ID. + pub(super) fn next_block(&mut self, sealed: bool) -> BasicBlockId { + let current_block_id = self.current_block_id(); + self.new_block_from(current_block_id, sealed); + current_block_id + } + + /// Add a predecessor to the current block. + pub(super) fn merge_block(&mut self, new_predecessor: BasicBlockId) { + let block_id = self.current_block_id(); + debug_assert!(!self.sealed_blocks[block_id]); + self.predecessors[block_id].push(new_predecessor); + } + + /// Add predecessors to the current block. + pub(super) fn merge_blocks(&mut self, new_predecessors: Vec) { + let block_id = self.current_block_id(); + debug_assert!(!self.sealed_blocks[block_id]); + self.predecessors[block_id].extend(new_predecessors); + } + + /// Mark the current block as sealed; it cannot have any more predecessors added. + pub(super) fn seal_current_block(&mut self) { + self.seal_block(self.current_block_id()); + } + + /// Mark a block as sealed; it cannot have any more predecessors added. + pub(super) fn seal_block(&mut self, block_id: BasicBlockId) { + debug_assert!(!self.sealed_blocks[block_id]); + if let Some(phis) = self.incomplete_phis.get(&block_id) { + for phi in phis.clone() { + self.add_phi_operands(block_id, phi); + } + self.incomplete_phis.remove(&block_id); + } + self.sealed_blocks[block_id] = true; + } + + pub(super) fn finish(mut self) -> UseDefMap<'db> { + debug_assert!(self.incomplete_phis.is_empty()); + debug_assert!(self.sealed_blocks.iter().all(|&b| b)); + self.definitions_by_use.shrink_to_fit(); + self.phi_operands.shrink_to_fit(); + + let mut public_definitions: FxHashMap> = + FxHashMap::default(); + + for symbol_id in self.defined_symbols.clone() { + // SAFETY: We are only looking up defined symbols here, can't get None. + public_definitions.insert(symbol_id, self.lookup(symbol_id).unwrap()); + } + + UseDefMap { + definitions_by_use: self.definitions_by_use, + public_definitions, + phi_operands: self.phi_operands, } } - /// Restore the current builder visible-definitions state to the given snapshot. - pub(super) fn restore(&mut self, snapshot: FlowSnapshot) { - // We never remove symbols from `definitions_by_symbol` (it's an IndexVec, and the symbol - // IDs must line up), so the current number of known symbols must always be equal to or - // greater than the number of known symbols in a previously-taken snapshot. - let num_symbols = self.definitions_by_symbol.len(); - debug_assert!(num_symbols >= snapshot.definitions_by_symbol.len()); + /// Push a new basic block (with given predecessors) and return its ID. + fn new_block_with_predecessors( + &mut self, + predecessors: PredecessorBlocks, + sealed: bool, + ) -> BasicBlockId { + let new_block_id = self.predecessors.push(predecessors); + self.definitions_per_block.push(FxHashMap::default()); + self.sealed_blocks.push(sealed); - // Restore the current visible-definitions state to the given snapshot. - self.definitions_by_symbol = snapshot.definitions_by_symbol; - - // If the snapshot we are restoring is missing some symbols we've recorded since, we need - // to fill them in so the symbol IDs continue to line up. Since they don't exist in the - // snapshot, the correct state to fill them in with is "unbound", the default. - self.definitions_by_symbol - .resize(num_symbols, Definitions::unbound()); + new_block_id } - /// Merge the given snapshot into the current state, reflecting that we might have taken either - /// path to get here. The new visible-definitions state for each symbol should include - /// definitions from both the prior state and the snapshot. - pub(super) fn merge(&mut self, snapshot: &FlowSnapshot) { - // The tricky thing about merging two Ranges pointing into `all_definitions` is that if the - // two Ranges aren't already adjacent in `all_definitions`, we will have to copy at least - // one or the other of the ranges to the end of `all_definitions` so as to make them - // adjacent. We can't ever move things around in `all_definitions` because previously - // recorded uses may still have ranges pointing to any part of it; all we can do is append. - // It's possible we may end up with some old entries in `all_definitions` that nobody is - // pointing to, but that's OK. + /// Look up the dominating definition for a symbol in the current block. + /// + /// If there isn't a local definition, recursively look up the symbol in predecessor blocks, + /// memoizing the found symbol in each block. + fn lookup(&mut self, symbol_id: ScopedSymbolId) -> Option> { + self.lookup_impl(self.current_block_id(), symbol_id) + } - // We never remove symbols from `definitions_by_symbol` (it's an IndexVec, and the symbol - // IDs must line up), so the current number of known symbols must always be equal to or - // greater than the number of known symbols in a previously-taken snapshot. - debug_assert!(self.definitions_by_symbol.len() >= snapshot.definitions_by_symbol.len()); - - for (symbol_id, current) in self.definitions_by_symbol.iter_mut_enumerated() { - let Some(snapshot) = snapshot.definitions_by_symbol.get(symbol_id) else { - // Symbol not present in snapshot, so it's unbound from that path. - current.may_be_unbound = true; - continue; - }; - - // If the symbol can be unbound in either predecessor, it can be unbound post-merge. - current.may_be_unbound |= snapshot.may_be_unbound; - - // Merge the definition ranges. - let current = &mut current.definitions_range; - let snapshot = &snapshot.definitions_range; - - // We never create reversed ranges. - debug_assert!(current.end >= current.start); - debug_assert!(snapshot.end >= snapshot.start); - - if current == snapshot { - // Ranges already identical, nothing to do. - } else if snapshot.is_empty() { - // Merging from an empty range; nothing to do. - } else if (*current).is_empty() { - // Merging to an empty range; just use the incoming range. - *current = snapshot.clone(); - } else if snapshot.end >= current.start && snapshot.start <= current.end { - // Ranges are adjacent or overlapping, merge them in-place. - *current = current.start.min(snapshot.start)..current.end.max(snapshot.end); - } else if current.end == self.all_definitions.len() { - // Ranges are not adjacent or overlapping, `current` is at the end of - // `all_definitions`, we need to copy `snapshot` to the end so they are adjacent - // and can be merged into one range. - self.all_definitions.extend_from_within(snapshot.clone()); - current.end = self.all_definitions.len(); - } else if snapshot.end == self.all_definitions.len() { - // Ranges are not adjacent or overlapping, `snapshot` is at the end of - // `all_definitions`, we need to copy `current` to the end so they are adjacent and - // can be merged into one range. - self.all_definitions.extend_from_within(current.clone()); - current.start = snapshot.start; - current.end = self.all_definitions.len(); - } else { - // Ranges are not adjacent and neither one is at the end of `all_definitions`, we - // have to copy both to the end so they are adjacent and we can merge them. - let start = self.all_definitions.len(); - self.all_definitions.extend_from_within(current.clone()); - self.all_definitions.extend_from_within(snapshot.clone()); - current.start = start; - current.end = self.all_definitions.len(); + fn lookup_impl( + &mut self, + block_id: BasicBlockId, + symbol_id: ScopedSymbolId, + ) -> Option> { + if let Some(local) = self.definitions_per_block[block_id].get(&symbol_id) { + return *local; + } + if !self.sealed_blocks[block_id] { + // we may still be missing predecessors; insert an incomplete Phi. + let definition = self.create_incomplete_phi(block_id, symbol_id); + self.incomplete_phis + .entry(block_id) + .or_default() + .push(definition); + return Some(definition); + } + match self.predecessors[block_id].as_slice() { + // entry block, no definition found: return None + [] => None, + // single predecessor, recurse + &[single_predecessor_id] => { + let definition = self.lookup_impl(single_predecessor_id, symbol_id); + self.memoize(block_id, symbol_id, definition); + definition + } + // multiple predecessors: create and memoize an incomplete Phi to break cycles, then + // recurse into predecessors and fill the Phi operands. + _ => { + let phi = self.create_incomplete_phi(block_id, symbol_id); + self.add_phi_operands(block_id, phi); + Some(phi) } } } - pub(super) fn finish(mut self) -> UseDefMap<'db> { - self.all_definitions.shrink_to_fit(); - self.definitions_by_symbol.shrink_to_fit(); - self.definitions_by_use.shrink_to_fit(); + /// Recurse into predecessors to add operands for an incomplete Phi. + fn add_phi_operands(&mut self, block_id: BasicBlockId, phi: Definition<'db>) { + let predecessors: PredecessorBlocks = self.predecessors[block_id].clone(); + let operands: PhiOperands = predecessors + .iter() + .map(|pred_id| self.lookup_impl(*pred_id, phi.symbol(self.db))) + .collect(); + let DefinitionKind::Phi(phi_id) = phi.kind(self.db) else { + unreachable!("add_phi_operands called with non-Phi"); + }; + self.phi_operands[*phi_id] = operands; + } - UseDefMap { - all_definitions: self.all_definitions, - definitions_by_use: self.definitions_by_use, - public_definitions: self.definitions_by_symbol, - } + /// Remember a given definition for a given symbol in the given block. + fn memoize( + &mut self, + block_id: BasicBlockId, + symbol_id: ScopedSymbolId, + definition_id: Option>, + ) { + self.definitions_per_block[block_id].insert(symbol_id, definition_id); + } + + /// Create an incomplete Phi for the given block and symbol, memoize it, and return its ID. + fn create_incomplete_phi( + &mut self, + block_id: BasicBlockId, + symbol_id: ScopedSymbolId, + ) -> Definition<'db> { + let phi_id = self.phi_operands.push(vec![]); + let definition = Definition::new( + self.db, + self.file, + self.file_scope, + symbol_id, + DefinitionKind::Phi(phi_id), + countme::Count::default(), + ); + self.memoize(block_id, symbol_id, Some(definition)); + definition } } diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 6b76b42b7c..f191f18cbb 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -151,7 +151,7 @@ impl HasTy for ast::StmtFunctionDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition = index.definition(self); - definition_ty(model.db, definition) + definition_ty(model.db, Some(definition)) } } @@ -159,7 +159,7 @@ impl HasTy for StmtClassDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition = index.definition(self); - definition_ty(model.db, definition) + definition_ty(model.db, Some(definition)) } } @@ -167,7 +167,7 @@ impl HasTy for ast::Alias { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition = index.definition(self); - definition_ty(model.db, definition) + definition_ty(model.db, Some(definition)) } } diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 37430d95c3..e4746192e9 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -23,13 +23,7 @@ pub(crate) fn symbol_ty<'db>( let _span = tracing::trace_span!("symbol_ty", ?symbol).entered(); let use_def = use_def_map(db, scope); - definitions_ty( - db, - use_def.public_definitions(symbol), - use_def - .public_may_be_unbound(symbol) - .then_some(Type::Unbound), - ) + definition_ty(db, use_def.public_definition(symbol)) } /// Shorthand for `symbol_ty` that takes a symbol name instead of an ID. @@ -60,49 +54,16 @@ pub(crate) fn builtins_symbol_ty_by_name<'db>(db: &'db dyn Db, name: &str) -> Ty } /// Infer the type of a [`Definition`]. -pub(crate) fn definition_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> { - let inference = infer_definition_types(db, definition); - inference.definition_ty(definition) -} - -/// Infer the combined type of an array of [`Definition`]s, plus one optional "unbound type". -/// -/// Will return a union if there is more than one definition, or at least one plus an unbound -/// type. -/// -/// The "unbound type" represents the type in case control flow may not have passed through any -/// definitions in this scope. If this isn't possible, then it will be `None`. If it is possible, -/// and the result in that case should be Unbound (e.g. an unbound function local), then it will be -/// `Some(Type::Unbound)`. If it is possible and the result should be something else (e.g. an -/// implicit global lookup), then `unbound_type` will be `Some(the_global_symbol_type)`. -/// -/// # Panics -/// Will panic if called with zero definitions and no `unbound_ty`. This is a logic error, -/// as any symbol with zero visible definitions clearly may be unbound, and the caller should -/// provide an `unbound_ty`. -pub(crate) fn definitions_ty<'db>( +pub(crate) fn definition_ty<'db>( db: &'db dyn Db, - definitions: &[Definition<'db>], - unbound_ty: Option>, + definition: Option>, ) -> Type<'db> { - let def_types = definitions.iter().map(|def| definition_ty(db, *def)); - let mut all_types = unbound_ty.into_iter().chain(def_types); - - let Some(first) = all_types.next() else { - panic!("definitions_ty should never be called with zero definitions and no unbound_ty.") - }; - - if let Some(second) = all_types.next() { - let mut builder = UnionBuilder::new(db); - builder = builder.add(first).add(second); - - for variant in all_types { - builder = builder.add(variant); + match definition { + Some(definition) => { + let inference = infer_definition_types(db, definition); + inference.definition_ty(definition) } - - builder.build() - } else { - first + None => Type::Unbound, } } @@ -145,8 +106,27 @@ impl<'db> Type<'db> { matches!(self, Type::Unbound) } - pub const fn is_unknown(&self) -> bool { - matches!(self, Type::Unknown) + pub fn may_be_unbound(&self, db: &'db dyn Db) -> bool { + match self { + Type::Unbound => true, + Type::Union(union) => union.contains(db, Type::Unbound), + _ => false, + } + } + + #[must_use] + pub fn replace_unbound_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> { + match self { + Type::Unbound => replacement, + Type::Union(union) => union + .elements(db) + .into_iter() + .fold(UnionBuilder::new(db), |builder, ty| { + builder.add(ty.replace_unbound_with(db, replacement)) + }) + .build(), + ty => *ty, + } } #[must_use] diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index ea39ee0725..8517a98b69 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -33,13 +33,17 @@ use crate::builtins::builtins_scope; use crate::module_name::ModuleName; use crate::module_resolver::resolve_module; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; -use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey}; +use crate::semantic_index::definition::{ + Definition, DefinitionKind, DefinitionNode, DefinitionNodeKey, ScopedPhiId, +}; use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; -use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId}; +use crate::semantic_index::symbol::{ + FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, Symbol, +}; use crate::semantic_index::SemanticIndex; use crate::types::{ - builtins_symbol_ty_by_name, definitions_ty, global_symbol_ty_by_name, ClassType, FunctionType, + builtins_symbol_ty_by_name, definition_ty, global_symbol_ty_by_name, ClassType, FunctionType, Name, Type, UnionBuilder, }; use crate::Db; @@ -276,37 +280,45 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_region_definition(&mut self, definition: Definition<'db>) { - match definition.node(self.db) { - DefinitionKind::Function(function) => { - self.infer_function_definition(function.node(), definition); - } - DefinitionKind::Class(class) => self.infer_class_definition(class.node(), definition), - DefinitionKind::Import(import) => { - self.infer_import_definition(import.node(), definition); - } - DefinitionKind::ImportFrom(import_from) => { - self.infer_import_from_definition( - import_from.import(), - import_from.alias(), - definition, - ); - } - DefinitionKind::Assignment(assignment) => { - self.infer_assignment_definition(assignment.assignment(), definition); - } - DefinitionKind::AnnotatedAssignment(annotated_assignment) => { - self.infer_annotated_assignment_definition(annotated_assignment.node(), definition); - } - DefinitionKind::NamedExpression(named_expression) => { - self.infer_named_expression_definition(named_expression.node(), definition); - } - DefinitionKind::Comprehension(comprehension) => { - self.infer_comprehension_definition( - comprehension.node(), - comprehension.is_first(), - definition, - ); - } + match definition.kind(self.db) { + DefinitionKind::Phi(phi_id) => self.infer_phi_definition(*phi_id, definition), + DefinitionKind::Node(node) => match node { + DefinitionNode::Function(function) => { + self.infer_function_definition(function.node(), definition); + } + DefinitionNode::Class(class) => { + self.infer_class_definition(class.node(), definition); + } + DefinitionNode::Import(import) => { + self.infer_import_definition(import.node(), definition); + } + DefinitionNode::ImportFrom(import_from) => { + self.infer_import_from_definition( + import_from.import(), + import_from.alias(), + definition, + ); + } + DefinitionNode::Assignment(assignment) => { + self.infer_assignment_definition(assignment.assignment(), definition); + } + DefinitionNode::AnnotatedAssignment(annotated_assignment) => { + self.infer_annotated_assignment_definition( + annotated_assignment.node(), + definition, + ); + } + DefinitionNode::NamedExpression(named_expression) => { + self.infer_named_expression_definition(named_expression.node(), definition); + } + DefinitionNode::Comprehension(comprehension) => { + self.infer_comprehension_definition( + comprehension.node(), + comprehension.is_first(), + definition, + ); + } + }, } } @@ -396,6 +408,18 @@ impl<'db> TypeInferenceBuilder<'db> { self.extend(result); } + fn infer_phi_definition(&mut self, phi_id: ScopedPhiId, definition: Definition<'db>) { + let file_scope_id = self.scope.file_scope_id(self.db); + let use_def = self.index.use_def_map(file_scope_id); + let ty = use_def + .phi_operands(phi_id) + .iter() + .map(|&definition| definition_ty(self.db, definition)) + .fold(UnionBuilder::new(self.db), UnionBuilder::add) + .build(); + self.types.definitions.insert(definition, ty); + } + fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { self.infer_definition(function); } @@ -1338,6 +1362,22 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown } + fn infer_global_name_reference(&self, symbol: &Symbol) -> Type<'db> { + let file_scope_id = self.scope.file_scope_id(self.db); + // implicit global + let mut ty = if file_scope_id == FileScopeId::global() { + Type::Unbound + } else { + global_symbol_ty_by_name(self.db, self.file, symbol.name()) + }; + // fallback to builtins + if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_scope(self.db) { + ty = ty + .replace_unbound_with(self.db, builtins_symbol_ty_by_name(self.db, symbol.name())); + } + ty + } + fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { let ast::ExprName { range: _, id, ctx } = name; @@ -1346,34 +1386,18 @@ impl<'db> TypeInferenceBuilder<'db> { let file_scope_id = self.scope.file_scope_id(self.db); let use_def = self.index.use_def_map(file_scope_id); let use_id = name.scoped_use_id(self.db, self.scope); - let may_be_unbound = use_def.use_may_be_unbound(use_id); - - let unbound_ty = if may_be_unbound { + let mut ty = definition_ty(self.db, use_def.definition_for_use(use_id)); + if ty.may_be_unbound(self.db) { let symbols = self.index.symbol_table(file_scope_id); - // SAFETY: the symbol table always creates a symbol for every Name node. let symbol = symbols.symbol_by_name(id).unwrap(); if !symbol.is_defined() || !self.scope.is_function_like(self.db) { - // implicit global - let mut unbound_ty = if file_scope_id == FileScopeId::global() { - Type::Unbound - } else { - global_symbol_ty_by_name(self.db, self.file, id) - }; - // fallback to builtins - if matches!(unbound_ty, Type::Unbound) - && Some(self.scope) != builtins_scope(self.db) - { - unbound_ty = builtins_symbol_ty_by_name(self.db, id); - } - Some(unbound_ty) - } else { - Some(Type::Unbound) + ty = ty.replace_unbound_with( + self.db, + self.infer_global_name_reference(symbol), + ); } - } else { - None - }; - - definitions_ty(self.db, use_def.use_definitions(use_id), unbound_ty) + } + ty } ExprContext::Store | ExprContext::Del => Type::None, ExprContext::Invalid => Type::Unknown, @@ -2163,6 +2187,38 @@ mod tests { Ok(()) } + #[test] + fn conditionally_global_or_builtin() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + if flag: + copyright = 1 + def f(): + y = copyright + ", + )?; + + let file = system_path_to_file(&db, "src/a.py").expect("Expected file to exist."); + let index = semantic_index(&db, file); + let function_scope = index + .child_scopes(FileScopeId::global()) + .next() + .unwrap() + .0 + .to_scope_id(&db, file); + let y_ty = symbol_ty_by_name(&db, function_scope, "y"); + + assert_eq!( + y_ty.display(&db).to_string(), + "Literal[1] | Literal[copyright]" + ); + + Ok(()) + } + /// Class name lookups do fall back to globals, but the public type never does. #[test] fn unbound_class_local() -> anyhow::Result<()> { @@ -2386,11 +2442,10 @@ mod tests { Ok(()) } - fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { + fn public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { let scope = global_scope(db, file); - *use_def_map(db, scope) - .public_definitions(symbol_table(db, scope).symbol_id_by_name(name).unwrap()) - .first() + use_def_map(db, scope) + .public_definition(symbol_table(db, scope).symbol_id_by_name(name).unwrap()) .unwrap() } @@ -2533,7 +2588,7 @@ mod tests { assert_function_query_was_not_run( &db, infer_definition_types, - first_public_def(&db, a, "x"), + public_def(&db, a, "x"), &events, ); @@ -2569,7 +2624,7 @@ mod tests { assert_function_query_was_not_run( &db, infer_definition_types, - first_public_def(&db, a, "x"), + public_def(&db, a, "x"), &events, ); Ok(())