[red-knot] @override lint rule (#11282)

## Summary

Lots of TODOs and things to clean up here, but it demonstrates the
working lint rule.

## Test Plan

```
➜ cat main.py
from typing import override
from base import B

class C(B):
    @override
    def method(self): pass

➜ cat base.py
class B: pass

➜ cat typing.py
def override(func):
    return func
```

(We provide our own `typing.py` since we don't have typeshed vendored or
type stub support yet.)

```
➜ ./target/debug/red_knot main.py
...
1   0.012086s TRACE red_knot Main Loop: Tick
[crates/red_knot/src/main.rs:157:21] diagnostics = [
    "Method C.method is decorated with `typing.override` but does not override any base class method",
]
```

If we add `def method(self): pass` to class `B` in `base.py` and run
red_knot again, there is no lint error.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
Carl Meyer 2024-05-09 09:25:08 -06:00 committed by GitHub
parent dd42961dd9
commit b6b4ad9949
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 462 additions and 114 deletions

View File

@ -9,10 +9,13 @@ use ruff_python_ast::{ModModule, StringLiteral};
use crate::cache::KeyValueCache;
use crate::db::{LintDb, LintJar, QueryResult};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::parse::{parse, Parsed};
use crate::source::{source_text, Source};
use crate::symbols::{symbol_table, Definition, SymbolId, SymbolTable};
use crate::types::{infer_symbol_type, Type};
use crate::symbols::{
resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable,
};
use crate::types::{infer_definition_type, infer_symbol_type, Type};
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult<Diagnostics> {
@ -90,6 +93,7 @@ pub(crate) fn lint_semantic(db: &dyn LintDb, file_id: FileId) -> QueryResult<Dia
};
lint_unresolved_imports(&context)?;
lint_bad_overrides(&context)?;
Ok(Diagnostics::from(context.diagnostics.take()))
})
@ -136,6 +140,57 @@ fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> {
Ok(())
}
fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> {
// TODO we should have a special marker on the real typing module (from typeshed) so if you
// have your own "typing" module in your project, we don't consider it THE typing module (and
// same for other stdlib modules that our lint rules care about)
let Some(typing_override) =
resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")?
else {
// TODO once we bundle typeshed, this should be unreachable!()
return Ok(());
};
// TODO we should maybe index definitions by type instead of iterating all, or else iterate all
// just once, match, and branch to all lint rules that care about a type of definition
for (symbol, definition) in context.symbols().all_definitions() {
if !matches!(definition, Definition::FunctionDef(_)) {
continue;
}
let ty = infer_definition_type(
context.db.upcast(),
GlobalSymbolId {
file_id: context.file_id,
symbol_id: symbol,
},
definition.clone(),
)?;
let Type::Function(func) = ty else {
unreachable!("type of a FunctionDef should always be a Function");
};
let Some(class) = func.get_containing_class(context.db.upcast())? else {
// not a method of a class
continue;
};
if func.has_decorator(context.db.upcast(), typing_override)? {
let method_name = func.name(context.db.upcast())?;
if class
.get_super_class_member(context.db.upcast(), &method_name)?
.is_none()
{
// TODO should have a qualname() method to support nested classes
context.push_diagnostic(
format!(
"Method {}.{} is decorated with `typing.override` but does not override any base class method",
class.name(context.db.upcast())?,
method_name,
));
}
}
}
Ok(())
}
pub struct SemanticLintContext<'a> {
file_id: FileId,
source: Source,
@ -163,7 +218,13 @@ impl<'a> SemanticLintContext<'a> {
}
pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(self.db.upcast(), self.file_id, symbol_id)
infer_symbol_type(
self.db.upcast(),
GlobalSymbolId {
file_id: self.file_id,
symbol_id,
},
)
}
pub fn push_diagnostic(&self, diagnostic: String) {

View File

@ -14,15 +14,14 @@ use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast as ast;
use ruff_python_ast::visitor::preorder::PreorderVisitor;
use crate::ast_ids::TypedNodeKey;
use crate::ast_ids::{NodeKey, TypedNodeKey};
use crate::cache::KeyValueCache;
use crate::db::{QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::module::{resolve_module, ModuleName};
use crate::parse::parse;
use crate::Name;
#[allow(unreachable_pub)]
#[tracing::instrument(level = "debug", skip(db))]
pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult<Arc<SymbolTable>> {
let jar: &SemanticJar = db.jar()?;
@ -33,6 +32,32 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult<Arc<Sym
})
}
#[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_global_symbol(
db: &dyn SemanticDb,
module: ModuleName,
name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(typing_module) = resolve_module(db, module)? else {
return Ok(None);
};
let typing_file = typing_module.path(db)?.file();
let typing_table = symbol_table(db, typing_file)?;
let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else {
return Ok(None);
};
Ok(Some(GlobalSymbolId {
file_id: typing_file,
symbol_id: typing_override,
}))
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct GlobalSymbolId {
pub(crate) file_id: FileId,
pub(crate) symbol_id: SymbolId,
}
type Map<K, V> = hashbrown::HashMap<K, V, ()>;
#[newtype_index]
@ -65,7 +90,12 @@ pub(crate) enum ScopeKind {
pub(crate) struct Scope {
name: Name,
kind: ScopeKind,
child_scopes: Vec<ScopeId>,
parent: Option<ScopeId>,
children: Vec<ScopeId>,
/// the definition (e.g. class or function) that created this scope
definition: Option<Definition>,
/// the symbol (e.g. class or function) that owns this scope
defining_symbol: Option<SymbolId>,
/// symbol IDs, hashed by symbol name
symbols_by_name: Map<SymbolId, ()>,
}
@ -78,6 +108,14 @@ impl Scope {
pub(crate) fn kind(&self) -> ScopeKind {
self.kind
}
pub(crate) fn definition(&self) -> Option<Definition> {
self.definition.clone()
}
pub(crate) fn defining_symbol(&self) -> Option<SymbolId> {
self.defining_symbol
}
}
#[derive(Debug)]
@ -114,6 +152,10 @@ impl Symbol {
self.name.as_str()
}
pub(crate) fn scope_id(&self) -> ScopeId {
self.scope_id
}
/// Is the symbol used in its containing scope?
pub(crate) fn is_used(&self) -> bool {
self.flags.contains(SymbolFlags::IS_USED)
@ -132,6 +174,7 @@ impl Symbol {
// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST;
// this is at best O(log n). If looking up definitions is a bottleneck we should look for
// alternatives here.
// TODO intern Definitions in SymbolTable and reference using IDs?
#[derive(Clone, Debug)]
pub(crate) enum Definition {
// For the import cases, we don't need reference to any arbitrary AST subtrees (annotations,
@ -140,7 +183,7 @@ pub(crate) enum Definition {
// the small amount of information we need from the AST.
Import(ImportDefinition),
ImportFrom(ImportFromDefinition),
ClassDef(ClassDefinition),
ClassDef(TypedNodeKey<ast::StmtClassDef>),
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
Assignment(TypedNodeKey<ast::StmtAssign>),
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
@ -173,12 +216,6 @@ impl ImportFromDefinition {
}
}
#[derive(Clone, Debug)]
pub(crate) struct ClassDefinition {
pub(crate) node_key: TypedNodeKey<ast::StmtClassDef>,
pub(crate) scope_id: ScopeId,
}
#[derive(Debug, Clone)]
pub enum Dependency {
Module(ModuleName),
@ -193,7 +230,11 @@ pub enum Dependency {
pub struct SymbolTable {
scopes_by_id: IndexVec<ScopeId, Scope>,
symbols_by_id: IndexVec<SymbolId, Symbol>,
/// the definitions for each symbol
defs: FxHashMap<SymbolId, Vec<Definition>>,
/// map of AST node (e.g. class/function def) to sub-scope it creates
scopes_by_node: FxHashMap<NodeKey, ScopeId>,
/// dependencies of this module
dependencies: Vec<Dependency>,
}
@ -214,12 +255,16 @@ impl SymbolTable {
scopes_by_id: IndexVec::new(),
symbols_by_id: IndexVec::new(),
defs: FxHashMap::default(),
scopes_by_node: FxHashMap::default(),
dependencies: Vec::new(),
};
table.scopes_by_id.push(Scope {
name: Name::new("<module>"),
kind: ScopeKind::Module,
child_scopes: Vec::new(),
parent: None,
children: Vec::new(),
definition: None,
defining_symbol: None,
symbols_by_name: Map::default(),
});
table
@ -260,7 +305,7 @@ impl SymbolTable {
}
pub(crate) fn child_scope_ids_of(&self, scope_id: ScopeId) -> &[ScopeId] {
&self.scopes_by_id[scope_id].child_scopes
&self.scopes_by_id[scope_id].children
}
pub(crate) fn child_scopes_of(&self, scope_id: ScopeId) -> ScopeIterator<&[ScopeId]> {
@ -303,6 +348,32 @@ impl SymbolTable {
self.symbol_by_name(SymbolTable::root_scope_id(), name)
}
pub(crate) fn scope_id_of_symbol(&self, symbol_id: SymbolId) -> ScopeId {
self.symbols_by_id[symbol_id].scope_id
}
pub(crate) fn scope_of_symbol(&self, symbol_id: SymbolId) -> &Scope {
&self.scopes_by_id[self.scope_id_of_symbol(symbol_id)]
}
pub(crate) fn parent_scopes(
&self,
scope_id: ScopeId,
) -> ScopeIterator<impl Iterator<Item = ScopeId> + '_> {
ScopeIterator {
table: self,
ids: std::iter::successors(Some(scope_id), |scope| self.scopes_by_id[*scope].parent),
}
}
pub(crate) fn parent_scope(&self, scope_id: ScopeId) -> Option<ScopeId> {
self.scopes_by_id[scope_id].parent
}
pub(crate) fn scope_id_for_node(&self, node_key: &NodeKey) -> ScopeId {
self.scopes_by_node[node_key]
}
pub(crate) fn definitions(&self, symbol_id: SymbolId) -> &[Definition] {
self.defs
.get(&symbol_id)
@ -316,7 +387,7 @@ impl SymbolTable {
.flat_map(|(sym_id, defs)| defs.iter().map(move |def| (*sym_id, def)))
}
fn add_or_update_symbol(
pub(crate) fn add_or_update_symbol(
&mut self,
scope_id: ScopeId,
name: &str,
@ -357,15 +428,20 @@ impl SymbolTable {
parent_scope_id: ScopeId,
name: &str,
kind: ScopeKind,
definition: Option<Definition>,
defining_symbol: Option<SymbolId>,
) -> ScopeId {
let new_scope_id = self.scopes_by_id.push(Scope {
name: Name::new(name),
kind,
child_scopes: Vec::new(),
parent: Some(parent_scope_id),
children: Vec::new(),
definition,
defining_symbol,
symbols_by_name: Map::default(),
});
let parent_scope = &mut self.scopes_by_id[parent_scope_id];
parent_scope.child_scopes.push(new_scope_id);
parent_scope.children.push(new_scope_id);
new_scope_id
}
@ -412,20 +488,22 @@ where
}
}
// TODO maybe get rid of this and just do all data access via methods on ScopeId?
pub(crate) struct ScopeIterator<'a, I> {
table: &'a SymbolTable,
ids: I,
}
/// iterate (`ScopeId`, `Scope`) pairs for given `ScopeId` iterator
impl<'a, I> Iterator for ScopeIterator<'a, I>
where
I: Iterator<Item = ScopeId>,
{
type Item = &'a Scope;
type Item = (ScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
let id = self.ids.next()?;
Some(&self.table.scopes_by_id[id])
Some((id, &self.table.scopes_by_id[id]))
}
fn size_hint(&self) -> (usize, Option<usize>) {
@ -441,7 +519,7 @@ where
{
fn next_back(&mut self) -> Option<Self::Item> {
let id = self.ids.next_back()?;
Some(&self.table.scopes_by_id[id])
Some((id, &self.table.scopes_by_id[id]))
}
}
@ -472,8 +550,16 @@ impl SymbolTableBuilder {
symbol_id
}
fn push_scope(&mut self, name: &str, kind: ScopeKind) -> ScopeId {
let scope_id = self.table.add_child_scope(self.cur_scope(), name, kind);
fn push_scope(
&mut self,
name: &str,
kind: ScopeKind,
definition: Option<Definition>,
defining_symbol: Option<SymbolId>,
) -> ScopeId {
let scope_id =
self.table
.add_child_scope(self.cur_scope(), name, kind, definition, defining_symbol);
self.scopes.push(scope_id);
scope_id
}
@ -491,14 +577,20 @@ impl SymbolTableBuilder {
.expect("Scope stack should never be empty")
}
fn record_scope_for_node(&mut self, node_key: NodeKey, scope_id: ScopeId) {
self.table.scopes_by_node.insert(node_key, scope_id);
}
fn with_type_params(
&mut self,
name: &str,
params: &Option<Box<ast::TypeParams>>,
definition: Option<Definition>,
defining_symbol: Option<SymbolId>,
nested: impl FnOnce(&mut Self) -> ScopeId,
) -> ScopeId {
if let Some(type_params) = params {
self.push_scope(name, ScopeKind::Annotation);
self.push_scope(name, ScopeKind::Annotation, definition, defining_symbol);
for type_param in &type_params.type_params {
let name = match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
@ -539,27 +631,50 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
// TODO need to capture more definition statements here
match stmt {
ast::Stmt::ClassDef(node) => {
let scope_id = self.with_type_params(&node.name, &node.type_params, |builder| {
let scope_id = builder.push_scope(&node.name, ScopeKind::Class);
let node_key = TypedNodeKey::from_node(node);
let def = Definition::ClassDef(node_key.clone());
let symbol_id = self.add_or_update_symbol_with_def(&node.name, def.clone());
let scope_id = self.with_type_params(
&node.name,
&node.type_params,
Some(def.clone()),
Some(symbol_id),
|builder| {
let scope_id = builder.push_scope(
&node.name,
ScopeKind::Class,
Some(def.clone()),
Some(symbol_id),
);
ast::visitor::preorder::walk_stmt(builder, stmt);
builder.pop_scope();
scope_id
});
let def = Definition::ClassDef(ClassDefinition {
node_key: TypedNodeKey::from_node(node),
scope_id,
});
self.add_or_update_symbol_with_def(&node.name, def);
},
);
self.record_scope_for_node(*node_key.erased(), scope_id);
}
ast::Stmt::FunctionDef(node) => {
let def = Definition::FunctionDef(TypedNodeKey::from_node(node));
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
let scope_id = builder.push_scope(&node.name, ScopeKind::Function);
let node_key = TypedNodeKey::from_node(node);
let def = Definition::FunctionDef(node_key.clone());
let symbol_id = self.add_or_update_symbol_with_def(&node.name, def.clone());
let scope_id = self.with_type_params(
&node.name,
&node.type_params,
Some(def.clone()),
Some(symbol_id),
|builder| {
let scope_id = builder.push_scope(
&node.name,
ScopeKind::Function,
Some(def.clone()),
Some(symbol_id),
);
ast::visitor::preorder::walk_stmt(builder, stmt);
builder.pop_scope();
scope_id
});
},
);
self.record_scope_for_node(*node_key.erased(), scope_id);
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
for alias in names {
@ -933,7 +1048,7 @@ mod tests {
let mut table = SymbolTable::new();
let root_scope_id = SymbolTable::root_scope_id();
let foo_symbol_top = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty());
let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class);
let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class, None, None);
let foo_symbol_inner = table.add_or_update_symbol(c_scope, "foo", SymbolFlags::empty());
assert_ne!(foo_symbol_top, foo_symbol_inner);
}

View File

@ -1,18 +1,16 @@
#![allow(dead_code)]
use rustc_hash::FxHashMap;
pub(crate) use infer::infer_symbol_type;
use ruff_index::{newtype_index, IndexVec};
use crate::ast_ids::NodeKey;
use crate::db::{QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::symbols::{symbol_table, ScopeId, SymbolId};
use crate::symbols::{symbol_table, GlobalSymbolId, ScopeId, ScopeKind, SymbolId};
use crate::{FxDashMap, FxIndexSet, Name};
use ruff_index::{newtype_index, IndexVec};
use rustc_hash::FxHashMap;
pub(crate) mod infer;
pub(crate) use infer::{infer_definition_type, infer_symbol_type};
/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Type {
@ -82,10 +80,10 @@ impl TypeStore {
self.modules.remove(&file_id);
}
pub fn cache_symbol_type(&self, file_id: FileId, symbol_id: SymbolId, ty: Type) {
self.add_or_get_module(file_id)
pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) {
self.add_or_get_module(symbol.file_id)
.symbol_types
.insert(symbol_id, ty);
.insert(symbol.symbol_id, ty);
}
pub fn cache_node_type(&self, file_id: FileId, node_key: NodeKey, ty: Type) {
@ -94,10 +92,10 @@ impl TypeStore {
.insert(node_key, ty);
}
pub fn get_cached_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Option<Type> {
self.try_get_module(file_id)?
pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
self.try_get_module(symbol.file_id)?
.symbol_types
.get(&symbol_id)
.get(&symbol.symbol_id)
.copied()
}
@ -122,9 +120,16 @@ impl TypeStore {
self.modules.get(&file_id)
}
fn add_function(&self, file_id: FileId, name: &str, decorators: Vec<Type>) -> FunctionTypeId {
fn add_function(
&self,
file_id: FileId,
name: &str,
symbol_id: SymbolId,
scope_id: ScopeId,
decorators: Vec<Type>,
) -> FunctionTypeId {
self.add_or_get_module(file_id)
.add_function(name, decorators)
.add_function(name, symbol_id, scope_id, decorators)
}
fn add_class(
@ -257,6 +262,80 @@ pub struct FunctionTypeId {
func_id: ModuleFunctionTypeId,
}
impl FunctionTypeId {
fn function(self, db: &dyn SemanticDb) -> QueryResult<FunctionTypeRef> {
let jar: &SemanticJar = db.jar()?;
Ok(jar.type_store.get_function(self))
}
pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult<Name> {
Ok(self.function(db)?.name().into())
}
pub(crate) fn global_symbol(self, db: &dyn SemanticDb) -> QueryResult<GlobalSymbolId> {
Ok(GlobalSymbolId {
file_id: self.file(),
symbol_id: self.symbol(db)?,
})
}
pub(crate) fn file(self) -> FileId {
self.file_id
}
pub(crate) fn symbol(self, db: &dyn SemanticDb) -> QueryResult<SymbolId> {
let FunctionType { symbol_id, .. } = *self.function(db)?;
Ok(symbol_id)
}
pub(crate) fn get_containing_class(
self,
db: &dyn SemanticDb,
) -> QueryResult<Option<ClassTypeId>> {
let table = symbol_table(db, self.file_id)?;
let FunctionType { symbol_id, .. } = *self.function(db)?;
let scope_id = symbol_id.symbol(&table).scope_id();
let scope = scope_id.scope(&table);
if !matches!(scope.kind(), ScopeKind::Class) {
return Ok(None);
};
let Some(def) = scope.definition() else {
return Ok(None);
};
let Some(symbol_id) = scope.defining_symbol() else {
return Ok(None);
};
let Type::Class(class) = infer_definition_type(
db,
GlobalSymbolId {
file_id: self.file_id,
symbol_id,
},
def,
)?
else {
return Ok(None);
};
Ok(Some(class))
}
pub(crate) fn has_decorator(
self,
db: &dyn SemanticDb,
decorator_symbol: GlobalSymbolId,
) -> QueryResult<bool> {
for deco_ty in self.function(db)?.decorators() {
let Type::Function(deco_func) = deco_ty else {
continue;
};
if deco_func.global_symbol(db)? == decorator_symbol {
return Ok(true);
}
}
Ok(false)
}
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct ClassTypeId {
file_id: FileId,
@ -264,14 +343,47 @@ pub struct ClassTypeId {
}
impl ClassTypeId {
fn get_own_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
fn class(self, db: &dyn SemanticDb) -> QueryResult<ClassTypeRef> {
let jar: &SemanticJar = db.jar()?;
Ok(jar.type_store.get_class(self))
}
pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult<Name> {
Ok(self.class(db)?.name().into())
}
pub(crate) fn get_super_class_member(
self,
db: &dyn SemanticDb,
name: &Name,
) -> QueryResult<Option<Type>> {
// TODO we should linearize the MRO instead of doing this recursively
let class = self.class(db)?;
for base in class.bases() {
if let Type::Class(base) = base {
if let Some(own_member) = base.get_own_class_member(db, name)? {
return Ok(Some(own_member));
}
if let Some(base_member) = base.get_super_class_member(db, name)? {
return Ok(Some(base_member));
}
}
}
Ok(None)
}
fn get_own_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
// TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them
let ClassType { scope_id, .. } = *jar.type_store.get_class(self);
let ClassType { scope_id, .. } = *self.class(db)?;
let table = symbol_table(db, self.file_id)?;
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
Ok(Some(infer_symbol_type(db, self.file_id, symbol_id)?))
Ok(Some(infer_symbol_type(
db,
GlobalSymbolId {
file_id: self.file_id,
symbol_id,
},
)?))
} else {
Ok(None)
}
@ -334,9 +446,17 @@ impl ModuleTypeStore {
}
}
fn add_function(&mut self, name: &str, decorators: Vec<Type>) -> FunctionTypeId {
fn add_function(
&mut self,
name: &str,
symbol_id: SymbolId,
scope_id: ScopeId,
decorators: Vec<Type>,
) -> FunctionTypeId {
let func_id = self.functions.push(FunctionType {
name: Name::new(name),
symbol_id,
scope_id,
decorators,
});
FunctionTypeId {
@ -436,7 +556,7 @@ pub(crate) struct ClassType {
/// Name of the class at definition
name: Name,
/// `ScopeId` of the class body
pub(crate) scope_id: ScopeId,
scope_id: ScopeId,
/// Types of all class bases
bases: Vec<Type>,
}
@ -453,7 +573,13 @@ impl ClassType {
#[derive(Debug)]
pub(crate) struct FunctionType {
/// name of the function at definition
name: Name,
/// symbol which this function is a definition of
symbol_id: SymbolId,
/// scope of this function's body
scope_id: ScopeId,
/// types of all decorators on this function
decorators: Vec<Type>,
}
@ -462,7 +588,11 @@ impl FunctionType {
self.name.as_str()
}
fn decorators(&self) -> &[Type] {
fn scope_id(&self) -> ScopeId {
self.scope_id
}
pub(crate) fn decorators(&self) -> &[Type] {
self.decorators.as_slice()
}
}
@ -493,12 +623,12 @@ impl UnionType {
// directly in intersections rather than as a separate type. This sacrifices some efficiency in the
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd
// have to represent it as a single-element intersection if it did) in exchange for better
// efficiency in the not-within-intersection case.
// efficiency in the within-intersection case.
#[derive(Debug)]
pub(crate) struct IntersectionType {
// the intersection type includes only values in all of these types
positive: FxIndexSet<Type>,
// negated elements of the intersection, e.g.
// the intersection type does not include any value in any of these types
negative: FxIndexSet<Type>,
}
@ -530,7 +660,7 @@ mod tests {
use std::path::Path;
use crate::files::Files;
use crate::symbols::SymbolTable;
use crate::symbols::{SymbolFlags, SymbolTable};
use crate::types::{Type, TypeStore};
use crate::FxIndexSet;
@ -550,7 +680,20 @@ mod tests {
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_function(file_id, "func", vec![Type::Unknown]);
let mut table = SymbolTable::new();
let func_symbol = table.add_or_update_symbol(
SymbolTable::root_scope_id(),
"func",
SymbolFlags::IS_DEFINED,
);
let id = store.add_function(
file_id,
"func",
func_symbol,
SymbolTable::root_scope_id(),
vec![Type::Unknown],
);
assert_eq!(store.get_function(id).name(), "func");
assert_eq!(store.get_function(id).decorators(), vec![Type::Unknown]);
let func = Type::Function(id);

View File

@ -5,33 +5,47 @@ use ruff_python_ast::AstNode;
use crate::db::{QueryResult, SemanticDb, SemanticJar};
use crate::module::{resolve_module, ModuleName};
use crate::module::ModuleName;
use crate::parse::parse;
use crate::symbols::{symbol_table, ClassDefinition, Definition, ImportFromDefinition, SymbolId};
use crate::symbols::{
resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, ImportFromDefinition,
};
use crate::types::Type;
use crate::FileId;
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type(
db: &dyn SemanticDb,
file_id: FileId,
symbol_id: SymbolId,
) -> QueryResult<Type> {
let symbols = symbol_table(db, file_id)?;
let defs = symbols.definitions(symbol_id);
pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
let symbols = symbol_table(db, symbol.file_id)?;
let defs = symbols.definitions(symbol.symbol_id);
let jar: &SemanticJar = db.jar()?;
let type_store = &jar.type_store;
if let Some(ty) = type_store.get_cached_symbol_type(file_id, symbol_id) {
if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) {
return Ok(ty);
}
// TODO handle multiple defs, conditional defs...
assert_eq!(defs.len(), 1);
let ty = match &defs[0] {
let ty = infer_definition_type(db, symbol, defs[0].clone())?;
jar.type_store.cache_symbol_type(symbol, ty);
// TODO record dependencies
Ok(ty)
}
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_definition_type(
db: &dyn SemanticDb,
symbol: GlobalSymbolId,
definition: Definition,
) -> QueryResult<Type> {
let jar: &SemanticJar = db.jar()?;
let type_store = &jar.type_store;
let file_id = symbol.file_id;
match definition {
Definition::ImportFrom(ImportFromDefinition {
module,
name,
@ -40,24 +54,19 @@ pub fn infer_symbol_type(
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(module) = resolve_module(db, module_name)? {
let remote_file_id = module.path(db)?.file();
let remote_symbols = symbol_table(db, remote_file_id)?;
if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) {
infer_symbol_type(db, remote_file_id, remote_symbol_id)?
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
infer_symbol_type(db, remote_symbol)
} else {
Type::Unknown
}
} else {
Type::Unknown
Ok(Type::Unknown)
}
}
Definition::ClassDef(ClassDefinition { node_key, scope_id }) => {
Definition::ClassDef(node_key) => {
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
ty
Ok(ty)
} else {
let parsed = parse(db.upcast(), file_id)?;
let ast = parsed.ast();
let table = symbol_table(db, file_id)?;
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
let mut bases = Vec::with_capacity(node.bases().len());
@ -65,19 +74,19 @@ pub fn infer_symbol_type(
for base in node.bases() {
bases.push(infer_expr_type(db, file_id, base)?);
}
let ty =
Type::Class(type_store.add_class(file_id, &node.name.id, *scope_id, bases));
let scope_id = table.scope_id_for_node(node_key.erased());
let ty = Type::Class(type_store.add_class(file_id, &node.name.id, scope_id, bases));
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty
Ok(ty)
}
}
Definition::FunctionDef(node_key) => {
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
ty
Ok(ty)
} else {
let parsed = parse(db.upcast(), file_id)?;
let ast = parsed.ast();
let table = symbol_table(db, file_id)?;
let node = node_key
.resolve(ast.as_any_node_ref())
.expect("node key should resolve");
@ -87,12 +96,18 @@ pub fn infer_symbol_type(
.iter()
.map(|decorator| infer_expr_type(db, file_id, &decorator.expression))
.collect::<QueryResult<_>>()?;
let scope_id = table.scope_id_for_node(node_key.erased());
let ty = type_store
.add_function(file_id, &node.name.id, decorator_tys)
.add_function(
file_id,
&node.name.id,
symbol.symbol_id,
scope_id,
decorator_tys,
)
.into();
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty
Ok(ty)
}
}
Definition::Assignment(node_key) => {
@ -100,15 +115,10 @@ pub fn infer_symbol_type(
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
// TODO handle unpacking assignment correctly
infer_expr_type(db, file_id, &node.value)?
infer_expr_type(db, file_id, &node.value)
}
_ => todo!("other kinds of definitions"),
};
type_store.cache_symbol_type(file_id, symbol_id, ty);
// TODO record dependencies
Ok(ty)
}
}
fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> QueryResult<Type> {
@ -116,8 +126,9 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu
let symbols = symbol_table(db, file_id)?;
match expr {
ast::Expr::Name(name) => {
// TODO look up in the correct scope, don't assume global
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
infer_symbol_type(db, file_id, symbol_id)
infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id })
} else {
Ok(Type::Unknown)
}
@ -133,7 +144,7 @@ mod tests {
use crate::module::{
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::symbols::symbol_table;
use crate::symbols::{symbol_table, GlobalSymbolId};
use crate::types::{infer_symbol_type, Type};
use crate::Name;
@ -180,7 +191,13 @@ mod tests {
.root_symbol_id_by_name("E")
.expect("E symbol should be found");
let ty = infer_symbol_type(db, a_file, e_sym)?;
let ty = infer_symbol_type(
db,
GlobalSymbolId {
file_id: a_file,
symbol_id: e_sym,
},
)?;
let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::Class(_)));
@ -205,7 +222,13 @@ mod tests {
.root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found");
let ty = infer_symbol_type(db, file, sym)?;
let ty = infer_symbol_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: sym,
},
)?;
let Type::Class(class_id) = ty else {
panic!("Sub is not a Class")
@ -240,7 +263,13 @@ mod tests {
.root_symbol_id_by_name("C")
.expect("C symbol should be found");
let ty = infer_symbol_type(db, file, sym)?;
let ty = infer_symbol_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: sym,
},
)?;
let Type::Class(class_id) = ty else {
panic!("C is not a Class");