diff --git a/crates/red_knot/src/db.rs b/crates/red_knot/src/db.rs index 9cea81808c..a7bf50b868 100644 --- a/crates/red_knot/src/db.rs +++ b/crates/red_knot/src/db.rs @@ -32,13 +32,13 @@ pub trait SemanticDb: SourceDb { fn symbol_table(&self, file_id: FileId) -> Arc; + fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type; + // mutations fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)>; fn set_module_search_paths(&mut self, paths: Vec); - - fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type; } pub trait Db: SemanticDb {} @@ -155,6 +155,10 @@ pub(crate) mod tests { file_to_module(self, file_id) } + fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type { + infer_symbol_type(self, file_id, symbol_id) + } + fn path_to_module(&self, path: &Path) -> Option { path_to_module(self, path) } @@ -170,9 +174,5 @@ pub(crate) mod tests { fn set_module_search_paths(&mut self, paths: Vec) { set_module_search_paths(self, paths); } - - fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type { - infer_symbol_type(self, file_id, symbol_id) - } } } diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs index 0070d53086..205f628be7 100644 --- a/crates/red_knot/src/program/mod.rs +++ b/crates/red_knot/src/program/mod.rs @@ -111,6 +111,10 @@ impl SemanticDb for Program { symbol_table(self, file_id) } + fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type { + infer_symbol_type(self, file_id, symbol_id) + } + // Mutations fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)> { @@ -120,10 +124,6 @@ impl SemanticDb for Program { fn set_module_search_paths(&mut self, paths: Vec) { set_module_search_paths(self, paths); } - - fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type { - infer_symbol_type(self, file_id, symbol_id) - } } impl Db for Program {} diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index fc414da192..7e12ef1213 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -1,8 +1,9 @@ #![allow(dead_code)] use crate::ast_ids::NodeKey; use crate::files::FileId; +use crate::module::ModuleName; use crate::symbols::SymbolId; -use crate::{FxDashMap, FxIndexSet, Name}; +use crate::{FxDashMap, FxHashSet, FxIndexSet, Name}; use ruff_index::{newtype_index, IndexVec}; use rustc_hash::FxHashMap; @@ -49,17 +50,17 @@ pub struct TypeStore { } impl TypeStore { - pub fn remove_module(&mut self, file_id: FileId) { + pub fn remove_module(&self, file_id: FileId) { self.modules.remove(&file_id); } - pub fn cache_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId, ty: Type) { + pub fn cache_symbol_type(&self, file_id: FileId, symbol_id: SymbolId, ty: Type) { self.add_or_get_module(file_id) .symbol_types .insert(symbol_id, ty); } - pub fn cache_node_type(&mut self, file_id: FileId, node_key: NodeKey, ty: Type) { + pub fn cache_node_type(&self, file_id: FileId, node_key: NodeKey, ty: Type) { self.add_or_get_module(file_id) .node_types .insert(node_key, ty); @@ -79,7 +80,7 @@ impl TypeStore { .copied() } - fn add_or_get_module(&mut self, file_id: FileId) -> ModuleStoreRefMut { + fn add_or_get_module(&self, file_id: FileId) -> ModuleStoreRefMut { self.modules .entry(file_id) .or_insert_with(|| ModuleTypeStore::new(file_id)) @@ -93,20 +94,20 @@ impl TypeStore { self.modules.get(&file_id) } - fn add_function(&mut self, file_id: FileId, name: &str) -> FunctionTypeId { + fn add_function(&self, file_id: FileId, name: &str) -> FunctionTypeId { self.add_or_get_module(file_id).add_function(name) } - fn add_class(&mut self, file_id: FileId, name: &str) -> ClassTypeId { + fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId { self.add_or_get_module(file_id).add_class(name) } - fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId { + fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId { self.add_or_get_module(file_id).add_union(elems) } fn add_intersection( - &mut self, + &self, file_id: FileId, positive: &[Type], negative: &[Type], @@ -142,6 +143,24 @@ impl TypeStore { intersection_id: id.intersection_id, } } + + fn record_symbol_dependency(&self, from: (FileId, SymbolId), to: (FileId, SymbolId)) { + let (from_file_id, from_symbol_id) = from; + self.add_or_get_module(from_file_id) + .symbol_dependencies + .entry(from_symbol_id) + .or_default() + .insert(to); + } + + fn record_module_dependency(&self, from: (FileId, SymbolId), to: ModuleName) { + let (from_file_id, from_symbol_id) = from; + self.add_or_get_module(from_file_id) + .module_dependencies + .entry(from_symbol_id) + .or_default() + .insert(to); + } } type ModuleStoreRef<'a> = dashmap::mapref::one::Ref< @@ -265,6 +284,12 @@ struct ModuleTypeStore { symbol_types: FxHashMap, /// cached types of AST nodes in this module node_types: FxHashMap, + // the inferred type for symbol K depends on the type of symbols in V + symbol_dependencies: FxHashMap>, + // the inferred type for symbol K depends on the modules in V; this type of dependency is + // recorded when e.g. the target symbol doesn't exist in the module, so we can't record a + // dependency on a symbol, but if the module changes it could still change our resolution) + module_dependencies: FxHashMap>, } impl ModuleTypeStore { @@ -277,6 +302,8 @@ impl ModuleTypeStore { intersections: IndexVec::default(), symbol_types: FxHashMap::default(), node_types: FxHashMap::default(), + symbol_dependencies: FxHashMap::default(), + module_dependencies: FxHashMap::default(), } } @@ -462,7 +489,7 @@ mod tests { #[test] fn add_class() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let id = store.add_class(file_id, "C"); @@ -473,7 +500,7 @@ mod tests { #[test] fn add_function() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let id = store.add_function(file_id, "func"); @@ -484,7 +511,7 @@ mod tests { #[test] fn add_union() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let c1 = store.add_class(file_id, "C1"); @@ -501,7 +528,7 @@ mod tests { #[test] fn add_intersection() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let c1 = store.add_class(file_id, "C1"); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index 87b6cb8e13..7876fa086e 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -9,7 +9,7 @@ use ruff_python_ast::AstNode; // TODO this should not take a &mut db, it should be a query, not a mutation. This means we'll need // to use interior mutability in TypeStore instead, and avoid races in populating the cache. #[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_type(db: &mut Db, file_id: FileId, symbol_id: SymbolId) -> Type +pub fn infer_symbol_type(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type where Db: SemanticDb + HasJar, { @@ -36,15 +36,27 @@ where // TODO relative imports assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - if let Some(module) = db.resolve_module(module_name) { + if let Some(module) = db.resolve_module(module_name.clone()) { let remote_file_id = module.path(db).file(); let remote_symbols = db.symbol_table(remote_file_id); if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { + // TODO integrate this into module and symbol-resolution APIs (requiring a + // "requester" argument) so that it doesn't have to be remembered + db.jar().type_store.record_symbol_dependency( + (file_id, symbol_id), + (remote_file_id, remote_symbol_id), + ); db.infer_symbol_type(remote_file_id, remote_symbol_id) } else { + db.jar() + .type_store + .record_module_dependency((file_id, symbol_id), module_name); Type::Unknown } } else { + db.jar() + .type_store + .record_module_dependency((file_id, symbol_id), module_name); Type::Unknown } } @@ -60,7 +72,7 @@ where let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - let store = &mut db.jar_mut().type_store; + let store = &db.jar().type_store; let ty = Type::Class(store.add_class(file_id, &node.name.id)); store.cache_node_type(file_id, *node_key.erased(), ty); ty @@ -69,10 +81,9 @@ where _ => todo!("other kinds of definitions"), }; - db.jar_mut() + db.jar() .type_store .cache_symbol_type(file_id, symbol_id, ty); - // TODO record dependencies ty } @@ -112,7 +123,7 @@ mod tests { fn follow_import_to_class() -> std::io::Result<()> { let TestCase { src, - mut db, + db, temp_dir: _temp_dir, } = create_test()?; @@ -132,10 +143,24 @@ mod tests { let ty = db.infer_symbol_type(a_file, d_sym); + let b_file = db + .resolve_module(ModuleName::new("b")) + .expect("module should be found") + .path(&db) + .file(); + let b_syms = db.symbol_table(b_file); + let c_sym = b_syms + .root_symbol_id_by_name("C") + .expect("C symbol should be found"); + let jar = HasJar::::jar(&db); assert!(matches!(ty, Type::Class(_))); assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]"); + assert_eq!( + jar.type_store.get_module(a_file).symbol_dependencies[&d_sym], + [(b_file, c_sym)].iter().copied().collect() + ); Ok(()) } }