From 9a3b9f9fb5f9e735d043ca18f510146bc182d80b Mon Sep 17 00:00:00 2001 From: plredmond <51248199+plredmond@users.noreply.github.com> Date: Tue, 28 May 2024 13:13:03 -0700 Subject: [PATCH] [redknot] add module type and attribute lookup for some types (#11416) * Add a module type, `ModuleTypeId` * Add an attribute lookup method `get_member` for `Type` * Only implemented for `ModuleTypeId` and `ClassTypeId` * [x] Should this be a trait? *Answer: no* * [x] Uses `unwrap`, but we should remove that. Maybe add a new variant to `QueryError`? *Answer: Return `Option` as is done elsewhere* * Add `infer_definition_type` case for `Import` * Add `infer_expr_type` case for `Attribute` * Add a test to exercise these * [x] remove all NOTE/FIXME/TODO after discussing with reviewers --- crates/red_knot/src/types.rs | 73 +++++++++++++++++++++++++++++- crates/red_knot/src/types/infer.rs | 71 ++++++++++++++++++++++++++--- 2 files changed, 136 insertions(+), 8 deletions(-) diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index cb17803521..478a35f1c1 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -2,7 +2,10 @@ use crate::ast_ids::NodeKey; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::symbols::{symbol_table, GlobalSymbolId, ScopeId, ScopeKind, SymbolId}; +use crate::module::{Module, ModuleName}; +use crate::symbols::{ + resolve_global_symbol, symbol_table, GlobalSymbolId, ScopeId, ScopeKind, SymbolId, +}; use crate::{FxDashMap, FxIndexSet, Name}; use ruff_index::{newtype_index, IndexVec}; use rustc_hash::FxHashMap; @@ -25,6 +28,8 @@ pub enum Type { Unbound, /// a specific function object Function(FunctionTypeId), + /// a specific module object + Module(ModuleTypeId), /// a specific class object Class(ClassTypeId), /// the set of Python objects with the given class in their __class__'s method resolution order @@ -46,6 +51,35 @@ impl Type { pub const fn is_unknown(&self) -> bool { matches!(self, Type::Unknown) } + + pub fn get_member(&self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { + match self { + Type::Any => todo!("attribute lookup on Any type"), + Type::Never => todo!("attribute lookup on Never type"), + Type::Unknown => todo!("attribute lookup on Unknown type"), + Type::Unbound => todo!("attribute lookup on Unbound type"), + Type::Function(_) => todo!("attribute lookup on Function type"), + Type::Module(module_id) => module_id.get_member(db, name), + Type::Class(class_id) => class_id.get_class_member(db, name), + Type::Instance(_) => { + // TODO MRO? get_own_instance_member, get_instance_member + todo!("attribute lookup on Instance type") + } + Type::Union(union_id) => { + let jar: &SemanticJar = db.jar()?; + let _todo_union_ref = jar.type_store.get_union(*union_id); + // TODO perform the get_member on each type in the union + // TODO return the union of those results + // TODO if any of those results is `None` then include Unknown in the result union + todo!("attribute lookup on Union type") + } + Type::Intersection(_) => { + // TODO perform the get_member on each type in the intersection + // TODO return the intersection of those results + todo!("attribute lookup on Intersection type") + } + } + } } impl From for Type { @@ -336,6 +370,31 @@ impl FunctionTypeId { } } +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct ModuleTypeId { + module: Module, + file_id: FileId, +} + +impl ModuleTypeId { + fn module(self, db: &dyn SemanticDb) -> QueryResult { + let jar: &SemanticJar = db.jar()?; + Ok(jar.type_store.add_or_get_module(self.file_id).downgrade()) + } + + pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult { + self.module.name(db) + } + + fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { + if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? { + Ok(Some(infer_symbol_type(db, symbol_id)?)) + } else { + Ok(None) + } + } +} + #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct ClassTypeId { file_id: FileId, @@ -389,7 +448,13 @@ impl ClassTypeId { } } - // TODO: get_own_instance_member, get_class_member, get_instance_member + /// Get own class member or fall back to super-class member. + fn get_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { + self.get_own_class_member(db, name) + .or_else(|_| self.get_super_class_member(db, name)) + } + + // TODO: get_own_instance_member, get_instance_member } #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] @@ -529,6 +594,10 @@ impl std::fmt::Display for DisplayType<'_> { Type::Never => f.write_str("Never"), Type::Unknown => f.write_str("Unknown"), Type::Unbound => f.write_str("Unbound"), + Type::Module(module_id) => { + // NOTE: something like this?: "" + todo!("{module_id:?}") + } // TODO functions and classes should display using a fully qualified name Type::Class(class_id) => { f.write_str("Literal[")?; diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index a34f367fb5..0d6d23b8ce 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -5,13 +5,14 @@ use ruff_python_ast::AstNode; use crate::db::{QueryResult, SemanticDb, SemanticJar}; -use crate::module::ModuleName; +use crate::module::{resolve_module, ModuleName}; use crate::parse::parse; use crate::symbols::{ - resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, ImportFromDefinition, + resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, ImportDefinition, + ImportFromDefinition, }; -use crate::types::Type; -use crate::FileId; +use crate::types::{ModuleTypeId, Type}; +use crate::{FileId, Name}; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. #[tracing::instrument(level = "trace", skip(db))] @@ -46,6 +47,15 @@ pub fn infer_definition_type( let file_id = symbol.file_id; match definition { + Definition::Import(ImportDefinition { + module: module_name, + }) => { + if let Some(module) = resolve_module(db, module_name.clone())? { + Ok(Type::Module(ModuleTypeId { module, file_id })) + } else { + Ok(Type::Unknown) + } + } Definition::ImportFrom(ImportFromDefinition { module, name, @@ -114,10 +124,20 @@ pub fn infer_definition_type( let parsed = parse(db.upcast(), file_id)?; let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - // TODO handle unpacking assignment correctly + // TODO handle unpacking assignment correctly (here and for AnnotatedAssignment case, below) infer_expr_type(db, file_id, &node.value) } - _ => todo!("other kinds of definitions"), + Definition::AnnotatedAssignment(node_key) => { + let parsed = parse(db.upcast(), file_id)?; + let ast = parsed.ast(); + let node = node_key.resolve_unwrap(ast.as_any_node_ref()); + // TODO actually look at the annotation + let Some(value) = &node.value else { + return Ok(Type::Unknown); + }; + // TODO handle unpacking assignment correctly (here and for Assignment case, above) + infer_expr_type(db, file_id, value) + } } } @@ -133,6 +153,13 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu Ok(Type::Unknown) } } + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + let value_type = infer_expr_type(db, file_id, value)?; + let attr_name = &Name::new(&attr.id); + value_type + .get_member(db, attr_name) + .map(|ty| ty.unwrap_or(Type::Unknown)) + } _ => todo!("full expression type resolution"), } } @@ -289,4 +316,36 @@ mod tests { Ok(()) } + + #[test] + fn resolve_module_member() -> anyhow::Result<()> { + let case = create_test()?; + let db = &case.db; + + let a_path = case.src.path().join("a.py"); + let b_path = case.src.path().join("b.py"); + std::fs::write(a_path, "import b; D = b.C")?; + std::fs::write(b_path, "class C: pass")?; + let a_file = resolve_module(db, ModuleName::new("a"))? + .expect("module should be found") + .path(db)? + .file(); + let a_syms = symbol_table(db, a_file)?; + let d_sym = a_syms + .root_symbol_id_by_name("D") + .expect("D symbol should be found"); + + let ty = infer_symbol_type( + db, + GlobalSymbolId { + file_id: a_file, + symbol_id: d_sym, + }, + )?; + + let jar = HasJar::::jar(db)?; + assert!(matches!(ty, Type::Class(_))); + assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]"); + Ok(()) + } }