From 0e44235981c4998b0979c3cd464b0f92fd19e8e3 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 5 Jul 2024 12:16:37 -0700 Subject: [PATCH] [red-knot] intern types using Salsa (#12061) Intern types using Salsa interning instead of in the `TypeInference` result. This eliminates the need for `TypingContext`, and also paves the way for finer-grained type inference queries. --- Cargo.lock | 11 +- Cargo.toml | 2 +- crates/red_knot/src/lint.rs | 15 +- crates/red_knot_python_semantic/Cargo.toml | 2 +- crates/red_knot_python_semantic/src/db.rs | 8 +- crates/red_knot_python_semantic/src/lib.rs | 2 +- crates/red_knot_python_semantic/src/mod.rs | 10 - .../src/semantic_index/symbol.rs | 1 - .../src/semantic_model.rs | 14 +- crates/red_knot_python_semantic/src/types.rs | 401 ++++-------------- .../src/types/display.rs | 88 ++-- .../src/types/infer.rs | 170 ++------ 12 files changed, 190 insertions(+), 534 deletions(-) delete mode 100644 crates/red_knot_python_semantic/src/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 07e945a516..ca7a28371e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1532,6 +1532,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordermap" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5a8e22be64dfa1123429350872e7be33594dbf5ae5212c90c5890e71966d1d" +dependencies = [ + "indexmap", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -1902,7 +1911,7 @@ dependencies = [ "anyhow", "bitflags 2.6.0", "hashbrown 0.14.5", - "indexmap", + "ordermap", "red_knot_module_resolver", "ruff_db", "ruff_index", diff --git a/Cargo.toml b/Cargo.toml index bfc8d351dc..0cb4f2e88e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,7 +72,6 @@ hashbrown = "0.14.3" ignore = { version = "0.4.22" } imara-diff = { version = "0.1.5" } imperative = { version = "1.0.4" } -indexmap = { version = "2.2.6" } indicatif = { version = "0.17.8" } indoc = { version = "2.0.4" } insta = { version = "1.35.1" } @@ -95,6 +94,7 @@ mimalloc = { version = "0.1.39" } natord = { version = "1.0.9" } notify = { version = "6.1.1" } once_cell = { version = "1.19.0" } +ordermap = { version = "0.5.0" } path-absolutize = { version = "3.1.1" } path-slash = { version = "0.2.1" } pathdiff = { version = "0.2.1" } diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index e32a70424e..edef30d563 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -122,7 +122,6 @@ fn lint_unresolved_imports(context: &SemanticLintContext, import: AnyImportRef) fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) { let semantic = &context.semantic; - let typing_context = semantic.typing_context(); // TODO we should have a special marker on the real typing module (from typeshed) so if you // have your own "typing" module in your project, we don't consider it THE typing module (and @@ -150,17 +149,17 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) { return; }; - if ty.has_decorator(&typing_context, override_ty) { - let method_name = ty.name(&typing_context); - if class_ty - .inherited_class_member(&typing_context, method_name) - .is_none() - { + // TODO this shouldn't make direct use of the Db; see comment on SemanticModel::db + let db = semantic.db(); + + if ty.has_decorator(db, override_ty) { + let method_name = ty.name(db); + if class_ty.inherited_class_member(db, &method_name).is_none() { // TODO should have a qualname() method to support nested classes context.push_diagnostic( format!( "Method {}.{} is decorated with `typing.override` but does not override any base class method", - class_ty.name(&typing_context), + class_ty.name(db), method_name, )); } diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index eb66270ff2..b314905d7a 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -18,7 +18,7 @@ ruff_python_ast = { workspace = true } ruff_text_size = { workspace = true } bitflags = { workspace = true } -indexmap = { workspace = true } +ordermap = { workspace = true } salsa = { workspace = true } tracing = { workspace = true } rustc-hash = { workspace = true } diff --git a/crates/red_knot_python_semantic/src/db.rs b/crates/red_knot_python_semantic/src/db.rs index a40dcf7a3b..2ac63f2b45 100644 --- a/crates/red_knot_python_semantic/src/db.rs +++ b/crates/red_knot_python_semantic/src/db.rs @@ -7,13 +7,19 @@ use red_knot_module_resolver::Db as ResolverDb; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId}; use crate::semantic_index::{root_scope, semantic_index, symbol_table}; -use crate::types::{infer_types, public_symbol_ty}; +use crate::types::{ + infer_types, public_symbol_ty, ClassType, FunctionType, IntersectionType, UnionType, +}; #[salsa::jar(db=Db)] pub struct Jar( ScopeId<'_>, PublicSymbolId<'_>, Definition<'_>, + FunctionType<'_>, + ClassType<'_>, + UnionType<'_>, + IntersectionType<'_>, symbol_table, root_scope, semantic_index, diff --git a/crates/red_knot_python_semantic/src/lib.rs b/crates/red_knot_python_semantic/src/lib.rs index 86c195b567..6d0de8fb83 100644 --- a/crates/red_knot_python_semantic/src/lib.rs +++ b/crates/red_knot_python_semantic/src/lib.rs @@ -12,4 +12,4 @@ pub mod semantic_index; mod semantic_model; pub mod types; -type FxIndexSet = indexmap::set::IndexSet>; +type FxOrderSet = ordermap::set::OrderSet>; diff --git a/crates/red_knot_python_semantic/src/mod.rs b/crates/red_knot_python_semantic/src/mod.rs deleted file mode 100644 index cb43a1513f..0000000000 --- a/crates/red_knot_python_semantic/src/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -use std::hash::BuildHasherDefault; - -use rustc_hash::FxHasher; - -pub mod ast_node_ref; -mod node_key; -pub mod semantic_index; -pub mod types; - -pub(crate) type FxIndexSet = indexmap::set::IndexSet>; diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index dc746081fa..00e73788dd 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -155,7 +155,6 @@ impl<'db> PublicSymbolsMap<'db> { /// A cross-module identifier of a scope that can be used as a salsa query parameter. #[salsa::tracked] pub struct ScopeId<'db> { - #[allow(clippy::used_underscore_binding)] #[id] pub file: VfsFile, #[id] diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 2348ac7150..290285cde8 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -6,7 +6,7 @@ use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef}; use crate::semantic_index::ast_ids::HasScopedAstId; use crate::semantic_index::symbol::PublicSymbolId; use crate::semantic_index::{public_symbol, semantic_index}; -use crate::types::{infer_types, public_symbol_ty, Type, TypingContext}; +use crate::types::{infer_types, public_symbol_ty, Type}; use crate::Db; pub struct SemanticModel<'db> { @@ -19,6 +19,12 @@ impl<'db> SemanticModel<'db> { Self { db, file } } + // TODO we don't actually want to expose the Db directly to lint rules, but we need to find a + // solution for exposing information from types + pub fn db(&self) -> &dyn Db { + self.db + } + pub fn resolve_module(&self, module_name: ModuleName) -> Option { resolve_module(self.db.upcast(), module_name) } @@ -27,13 +33,9 @@ impl<'db> SemanticModel<'db> { public_symbol(self.db, module.file(), symbol_name) } - pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type<'db> { + pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type { public_symbol_ty(self.db, symbol) } - - pub fn typing_context(&self) -> TypingContext<'db, '_> { - TypingContext::global(self.db) - } } pub trait HasTy { diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index c6016f5933..5e82c0c712 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,13 +1,11 @@ use ruff_db::parsed::parsed_module; use ruff_db::vfs::VfsFile; -use ruff_index::newtype_index; use ruff_python_ast::name::Name; -use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, PublicSymbolId, ScopeId}; +use crate::semantic_index::symbol::{NodeWithScopeKind, PublicSymbolId, ScopeId}; use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table}; use crate::types::infer::{TypeInference, TypeInferenceBuilder}; -use crate::Db; -use crate::FxIndexSet; +use crate::{Db, FxOrderSet}; mod display; mod infer; @@ -43,12 +41,12 @@ pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db> let file = symbol.file(db); let scope = root_scope(db, file); + // TODO switch to inferring just the definition(s), not the whole scope let inference = infer_types(db, scope); inference.symbol_ty(symbol.scoped_symbol_id(db)) } -/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`]. -#[allow(unused)] +/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`]. pub(crate) fn public_symbol_ty_by_name<'db>( db: &'db dyn Db, file: VfsFile, @@ -91,7 +89,7 @@ pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfe } /// unique ID for a type -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] pub enum Type<'db> { /// the dynamic type: a statically-unknown set of values Any, @@ -105,15 +103,15 @@ pub enum Type<'db> { /// the None object (TODO remove this in favor of Instance(types.NoneType) None, /// a specific function object - Function(TypeId<'db, ScopedFunctionTypeId>), + Function(FunctionType<'db>), /// a specific module object - Module(TypeId<'db, ScopedModuleTypeId>), + Module(VfsFile), /// a specific class object - Class(TypeId<'db, ScopedClassTypeId>), + Class(ClassType<'db>), /// the set of Python objects with the given class in their __class__'s method resolution order - Instance(TypeId<'db, ScopedClassTypeId>), - Union(TypeId<'db, ScopedUnionTypeId>), - Intersection(TypeId<'db, ScopedIntersectionTypeId>), + Instance(ClassType<'db>), + Union(UnionType<'db>), + Intersection(IntersectionType<'db>), IntLiteral(i64), // TODO protocols, callable types, overloads, generics, type vars } @@ -127,7 +125,7 @@ impl<'db> Type<'db> { matches!(self, Type::Unknown) } - pub fn member(&self, context: &TypingContext<'db, '_>, name: &Name) -> Option> { + pub fn member(&self, db: &'db dyn Db, name: &Name) -> Option> { match self { Type::Any => Some(Type::Any), Type::Never => todo!("attribute lookup on Never type"), @@ -135,14 +133,13 @@ impl<'db> Type<'db> { Type::Unbound => todo!("attribute lookup on Unbound type"), Type::None => todo!("attribute lookup on None type"), Type::Function(_) => todo!("attribute lookup on Function type"), - Type::Module(module) => module.member(context, name), - Type::Class(class) => class.class_member(context, name), + Type::Module(file) => public_symbol_ty_by_name(db, *file, name), + Type::Class(class) => class.class_member(db, name), Type::Instance(_) => { // TODO MRO? get_own_instance_member, get_instance_member todo!("attribute lookup on Instance type") } - Type::Union(union_id) => { - let _union = union_id.lookup(context); + Type::Union(_) => { // TODO perform the get_member on each type in the union // TODO return the union of those results // TODO if any of those results is `None` then include Unknown in the result union @@ -161,155 +158,25 @@ impl<'db> Type<'db> { } } -/// ID that uniquely identifies a type in a program. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct TypeId<'db, L> { - /// The scope in which this type is defined or was created. - scope: ScopeId<'db>, - /// The type's local ID in its scope. - scoped: L, -} - -impl<'db, Id> TypeId<'db, Id> -where - Id: Copy, -{ - pub fn scope(&self) -> ScopeId<'db> { - self.scope - } - - pub fn scoped_id(&self) -> Id { - self.scoped - } - - /// Resolves the type ID to the actual type. - pub(crate) fn lookup<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Id::Ty<'db> - where - Id: ScopedTypeId, - { - let types = context.types(self.scope); - self.scoped.lookup_scoped(types) - } -} - -/// ID that uniquely identifies a type in a scope. -pub(crate) trait ScopedTypeId { - /// The type that this ID points to. - type Ty<'db>; - - /// Looks up the type in `index`. - /// - /// ## Panics - /// May panic if this type is from another scope than `index`, or might just return an invalid type. - fn lookup_scoped<'a, 'db>(self, index: &'a TypeInference<'db>) -> &'a Self::Ty<'db>; -} - -/// ID uniquely identifying a function type in a `scope`. -#[newtype_index] -pub struct ScopedFunctionTypeId; - -impl ScopedTypeId for ScopedFunctionTypeId { - type Ty<'db> = FunctionType<'db>; - - fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> { - types.function_ty(self) - } -} - -#[derive(Debug, Eq, PartialEq, Clone)] -pub struct FunctionType<'a> { +#[salsa::interned] +pub struct FunctionType<'db> { /// name of the function at definition - name: Name, + pub name: Name, + /// types of all decorators on this function - decorators: Vec>, + decorators: Vec>, } -impl<'a> FunctionType<'a> { - fn name(&self) -> &str { - self.name.as_str() - } - - #[allow(unused)] - pub(crate) fn decorators(&self) -> &[Type<'a>] { - self.decorators.as_slice() +impl<'db> FunctionType<'db> { + pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool { + self.decorators(db).contains(&decorator) } } -impl<'db> TypeId<'db, ScopedFunctionTypeId> { - pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name { - let function_ty = self.lookup(context); - &function_ty.name - } - - pub fn has_decorator(self, context: &TypingContext, decorator: Type<'db>) -> bool { - let function_ty = self.lookup(context); - function_ty.decorators.contains(&decorator) - } -} - -#[newtype_index] -pub struct ScopedClassTypeId; - -impl ScopedTypeId for ScopedClassTypeId { - type Ty<'db> = ClassType<'db>; - - fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> { - types.class_ty(self) - } -} - -impl<'db> TypeId<'db, ScopedClassTypeId> { - pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name { - let class_ty = self.lookup(context); - &class_ty.name - } - - /// Returns the class member of this class named `name`. - /// - /// The member resolves to a member of the class itself or any of its bases. - pub fn class_member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option> { - if let Some(member) = self.own_class_member(context, name) { - return Some(member); - } - - self.inherited_class_member(context, name) - } - - /// Returns the inferred type of the class member named `name`. - pub fn own_class_member( - self, - context: &TypingContext<'db, '_>, - name: &Name, - ) -> Option> { - let class = self.lookup(context); - - let symbols = symbol_table(context.db, class.body_scope); - let symbol = symbols.symbol_id_by_name(name)?; - let types = context.types(class.body_scope); - - Some(types.symbol_ty(symbol)) - } - - pub fn inherited_class_member( - self, - context: &TypingContext<'db, '_>, - name: &Name, - ) -> Option> { - let class = self.lookup(context); - for base in &class.bases { - if let Some(member) = base.member(context, name) { - return Some(member); - } - } - - None - } -} - -#[derive(Debug, Eq, PartialEq, Clone)] +#[salsa::interned] pub struct ClassType<'db> { /// Name of the class at definition - name: Name, + pub name: Name, /// Types of all class bases bases: Vec>, @@ -318,52 +185,62 @@ pub struct ClassType<'db> { } impl<'db> ClassType<'db> { - fn name(&self) -> &str { - self.name.as_str() + /// Returns the class member of this class named `name`. + /// + /// The member resolves to a member of the class itself or any of its bases. + pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Option> { + if let Some(member) = self.own_class_member(db, name) { + return Some(member); + } + + self.inherited_class_member(db, name) } - #[allow(unused)] - pub(super) fn bases(&self) -> &'db [Type] { - self.bases.as_slice() + /// Returns the inferred type of the class member named `name`. + pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Option> { + let scope = self.body_scope(db); + let symbols = symbol_table(db, scope); + let symbol = symbols.symbol_id_by_name(name)?; + let types = infer_types(db, scope); + + Some(types.symbol_ty(symbol)) + } + + pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Option> { + for base in self.bases(db) { + if let Some(member) = base.member(db, name) { + return Some(member); + } + } + + None } } -#[newtype_index] -pub struct ScopedUnionTypeId; - -impl ScopedTypeId for ScopedUnionTypeId { - type Ty<'db> = UnionType<'db>; - - fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> { - types.union_ty(self) - } -} - -#[derive(Debug, Eq, PartialEq, Clone)] +#[salsa::interned] pub struct UnionType<'db> { - // the union type includes values in any of these types - elements: FxIndexSet>, + /// the union type includes values in any of these types + elements: FxOrderSet>, } -struct UnionTypeBuilder<'db, 'a> { - elements: FxIndexSet>, - context: &'a TypingContext<'db, 'a>, +struct UnionTypeBuilder<'db> { + elements: FxOrderSet>, + db: &'db dyn Db, } -impl<'db, 'a> UnionTypeBuilder<'db, 'a> { - fn new(context: &'a TypingContext<'db, 'a>) -> Self { +impl<'db> UnionTypeBuilder<'db> { + fn new(db: &'db dyn Db) -> Self { Self { - context, - elements: FxIndexSet::default(), + db, + elements: FxOrderSet::default(), } } /// Adds a type to this union. fn add(mut self, ty: Type<'db>) -> Self { match ty { - Type::Union(union_id) => { - let union = union_id.lookup(self.context); - self.elements.extend(&union.elements); + Type::Union(union) => { + self.elements.extend(&union.elements(self.db)); } _ => { self.elements.insert(ty); @@ -374,20 +251,7 @@ impl<'db, 'a> UnionTypeBuilder<'db, 'a> { } fn build(self) -> UnionType<'db> { - UnionType { - elements: self.elements, - } - } -} - -#[newtype_index] -pub struct ScopedIntersectionTypeId; - -impl ScopedTypeId for ScopedIntersectionTypeId { - type Ty<'db> = IntersectionType<'db>; - - fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> { - types.intersection_ty(self) + UnionType::new(self.db, self.elements) } } @@ -397,104 +261,12 @@ impl ScopedTypeId for ScopedIntersectionTypeId { // case where a Not appears outside an intersection (unclear when that could even happen, but we'd // have to represent it as a single-element intersection if it did) in exchange for better // efficiency in the within-intersection case. -#[derive(Debug, PartialEq, Eq, Clone)] +#[salsa::interned] pub struct IntersectionType<'db> { // the intersection type includes only values in all of these types - positive: FxIndexSet>, + positive: FxOrderSet>, // the intersection type does not include any value in any of these types - negative: FxIndexSet>, -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct ScopedModuleTypeId; - -impl ScopedTypeId for ScopedModuleTypeId { - type Ty<'db> = ModuleType; - - fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> { - types.module_ty() - } -} - -impl<'db> TypeId<'db, ScopedModuleTypeId> { - fn member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option> { - context.public_symbol_ty(self.scope.file(context.db), name) - } -} - -#[derive(Debug, Eq, PartialEq, Clone)] -pub struct ModuleType { - file: VfsFile, -} - -/// Context in which to resolve types. -/// -/// This abstraction is necessary to support a uniform API that can be used -/// while in the process of building the type inference structure for a scope -/// but also when all types should be resolved by querying the db. -pub struct TypingContext<'db, 'inference> { - db: &'db dyn Db, - - /// The Local type inference scope that is in the process of being built. - /// - /// Bypass the `db` when resolving the types for this scope. - local: Option<(ScopeId<'db>, &'inference TypeInference<'db>)>, -} - -impl<'db, 'inference> TypingContext<'db, 'inference> { - /// Creates a context that resolves all types by querying the db. - #[allow(unused)] - pub(super) fn global(db: &'db dyn Db) -> Self { - Self { db, local: None } - } - - /// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`. - fn scoped( - db: &'db dyn Db, - scope_id: ScopeId<'db>, - types: &'inference TypeInference<'db>, - ) -> Self { - Self { - db, - local: Some((scope_id, types)), - } - } - - /// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`. - fn types(&self, scope_id: ScopeId<'db>) -> &'inference TypeInference<'db> { - if let Some((scope, local_types)) = self.local { - if scope == scope_id { - return local_types; - } - } - - infer_types(self.db, scope_id) - } - - fn module_ty(&self, file: VfsFile) -> Type<'db> { - let scope = root_scope(self.db, file); - - Type::Module(TypeId { - scope, - scoped: ScopedModuleTypeId, - }) - } - - /// Resolves the public type of a symbol named `name` defined in `file`. - /// - /// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`. - /// It otherwise tries to resolve the symbol type locally. - fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option> { - let symbol = public_symbol(self.db, file, name)?; - - if let Some((scope, local_types)) = self.local { - if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file { - return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db))); - } - } - - Some(public_symbol_ty(self.db, symbol)) - } + negative: FxOrderSet>, } #[cfg(test)] @@ -508,7 +280,7 @@ mod tests { assert_will_not_run_function_query, assert_will_run_function_query, TestDb, }; use crate::semantic_index::root_scope; - use crate::types::{infer_types, public_symbol_ty_by_name, TypingContext}; + use crate::types::{infer_types, public_symbol_ty_by_name}; use crate::{HasTy, SemanticModel}; fn setup_db() -> TestDb { @@ -540,10 +312,7 @@ mod tests { let literal_ty = statement.value.ty(&model); - assert_eq!( - format!("{}", literal_ty.display(&TypingContext::global(&db))), - "Literal[10]" - ); + assert_eq!(format!("{}", literal_ty.display(&db)), "Literal[10]"); Ok(()) } @@ -560,10 +329,7 @@ mod tests { let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); - assert_eq!( - x_ty.display(&TypingContext::global(&db)).to_string(), - "Literal[10]" - ); + assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); // Change `x` to a different value db.memory_file_system() @@ -577,10 +343,7 @@ mod tests { db.clear_salsa_events(); let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); - assert_eq!( - x_ty_2.display(&TypingContext::global(&db)).to_string(), - "Literal[20]" - ); + assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]"); let events = db.take_salsa_events(); @@ -607,10 +370,7 @@ mod tests { let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); - assert_eq!( - x_ty.display(&TypingContext::global(&db)).to_string(), - "Literal[10]" - ); + assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); db.memory_file_system() .write_file("/src/foo.py", "x = 10\ndef foo(): pass")?; @@ -624,10 +384,7 @@ mod tests { let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); - assert_eq!( - x_ty_2.display(&TypingContext::global(&db)).to_string(), - "Literal[10]" - ); + assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); let events = db.take_salsa_events(); @@ -655,10 +412,7 @@ mod tests { let a = system_path_to_file(&db, "/src/a.py").unwrap(); let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); - assert_eq!( - x_ty.display(&TypingContext::global(&db)).to_string(), - "Literal[10]" - ); + assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); db.memory_file_system() .write_file("/src/foo.py", "x = 10\ny = 30")?; @@ -672,10 +426,7 @@ mod tests { let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); - assert_eq!( - x_ty_2.display(&TypingContext::global(&db)).to_string(), - "Literal[10]" - ); + assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); let events = db.take_salsa_events(); diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index d038512cd8..d42119e4b7 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -2,18 +2,19 @@ use std::fmt::{Display, Formatter}; -use crate::types::{IntersectionType, Type, TypingContext, UnionType}; +use crate::types::{IntersectionType, Type, UnionType}; +use crate::Db; -impl Type<'_> { - pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> { - DisplayType { ty: self, context } +impl<'db> Type<'db> { + pub fn display(&'db self, db: &'db dyn Db) -> DisplayType<'db> { + DisplayType { ty: self, db } } } #[derive(Copy, Clone)] -pub struct DisplayType<'a> { - ty: &'a Type<'a>, - context: &'a TypingContext<'a, 'a>, +pub struct DisplayType<'db> { + ty: &'db Type<'db>, + db: &'db dyn Db, } impl Display for DisplayType<'_> { @@ -24,42 +25,19 @@ impl Display for DisplayType<'_> { Type::Unknown => f.write_str("Unknown"), Type::Unbound => f.write_str("Unbound"), Type::None => f.write_str("None"), - Type::Module(module_id) => { - write!( - f, - "", - module_id - .scope - .file(self.context.db) - .path(self.context.db.upcast()) - ) + Type::Module(file) => { + write!(f, "", file.path(self.db.upcast())) } // TODO functions and classes should display using a fully qualified name - Type::Class(class_id) => { - let class = class_id.lookup(self.context); - + Type::Class(class) => { f.write_str("Literal[")?; - f.write_str(class.name())?; + f.write_str(&class.name(self.db))?; f.write_str("]") } - Type::Instance(class_id) => { - let class = class_id.lookup(self.context); - f.write_str(class.name()) - } - Type::Function(function_id) => { - let function = function_id.lookup(self.context); - f.write_str(function.name()) - } - Type::Union(union_id) => { - let union = union_id.lookup(self.context); - - union.display(self.context).fmt(f) - } - Type::Intersection(intersection_id) => { - let intersection = intersection_id.lookup(self.context); - - intersection.display(self.context).fmt(f) - } + Type::Instance(class) => f.write_str(&class.name(self.db)), + Type::Function(function) => f.write_str(&function.name(self.db)), + Type::Union(union) => union.display(self.db).fmt(f), + Type::Intersection(intersection) => intersection.display(self.db).fmt(f), Type::IntLiteral(n) => write!(f, "Literal[{n}]"), } } @@ -71,15 +49,15 @@ impl std::fmt::Debug for DisplayType<'_> { } } -impl UnionType<'_> { - fn display<'a>(&'a self, context: &'a TypingContext<'a, 'a>) -> DisplayUnionType<'a> { - DisplayUnionType { context, ty: self } +impl<'db> UnionType<'db> { + fn display(&'db self, db: &'db dyn Db) -> DisplayUnionType<'db> { + DisplayUnionType { db, ty: self } } } -struct DisplayUnionType<'a> { - ty: &'a UnionType<'a>, - context: &'a TypingContext<'a, 'a>, +struct DisplayUnionType<'db> { + ty: &'db UnionType<'db>, + db: &'db dyn Db, } impl Display for DisplayUnionType<'_> { @@ -87,7 +65,7 @@ impl Display for DisplayUnionType<'_> { let union = self.ty; let (int_literals, other_types): (Vec, Vec) = union - .elements + .elements(self.db) .iter() .copied() .partition(|ty| matches!(ty, Type::IntLiteral(_))); @@ -121,7 +99,7 @@ impl Display for DisplayUnionType<'_> { f.write_str(" | ")?; }; first = false; - write!(f, "{}", ty.display(self.context))?; + write!(f, "{}", ty.display(self.db))?; } Ok(()) @@ -134,15 +112,15 @@ impl std::fmt::Debug for DisplayUnionType<'_> { } } -impl IntersectionType<'_> { - fn display<'a>(&'a self, context: &'a TypingContext<'a, 'a>) -> DisplayIntersectionType<'a> { - DisplayIntersectionType { ty: self, context } +impl<'db> IntersectionType<'db> { + fn display(&'db self, db: &'db dyn Db) -> DisplayIntersectionType<'db> { + DisplayIntersectionType { db, ty: self } } } -struct DisplayIntersectionType<'a> { - ty: &'a IntersectionType<'a>, - context: &'a TypingContext<'a, 'a>, +struct DisplayIntersectionType<'db> { + ty: &'db IntersectionType<'db>, + db: &'db dyn Db, } impl Display for DisplayIntersectionType<'_> { @@ -150,10 +128,10 @@ impl Display for DisplayIntersectionType<'_> { let mut first = true; for (neg, ty) in self .ty - .positive + .positive(self.db) .iter() .map(|ty| (false, ty)) - .chain(self.ty.negative.iter().map(|ty| (true, ty))) + .chain(self.ty.negative(self.db).iter().map(|ty| (true, ty))) { if !first { f.write_str(" & ")?; @@ -162,7 +140,7 @@ impl Display for DisplayIntersectionType<'_> { if neg { f.write_str("~")?; }; - write!(f, "{}", ty.display(self.context))?; + write!(f, "{}", ty.display(self.db))?; } Ok(()) } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index fb7a39c4bd..59811fc9ae 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2,8 +2,7 @@ use rustc_hash::FxHashMap; use std::borrow::Cow; use std::sync::Arc; -use red_knot_module_resolver::resolve_module; -use red_knot_module_resolver::ModuleName; +use red_knot_module_resolver::{resolve_module, ModuleName}; use ruff_db::vfs::VfsFile; use ruff_index::IndexVec; use ruff_python_ast as ast; @@ -15,81 +14,40 @@ use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeRef, ScopeId, ScopedSymbolId, SymbolTable, }; use crate::semantic_index::{symbol_table, SemanticIndex}; -use crate::types::{ - infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, - ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, - UnionType, UnionTypeBuilder, -}; +use crate::types::{infer_types, ClassType, FunctionType, Name, Type, UnionTypeBuilder}; use crate::Db; /// The inferred types for a single scope. #[derive(Debug, Eq, PartialEq, Default, Clone)] pub(crate) struct TypeInference<'db> { - /// The type of the module if the scope is a module scope. - module_type: Option, - - /// The types of the defined classes in this scope. - class_types: IndexVec>, - - /// The types of the defined functions in this scope. - function_types: IndexVec>, - - union_types: IndexVec>, - intersection_types: IndexVec>, - /// The types of every expression in this scope. - expression_tys: IndexVec>, + expressions: IndexVec>, /// The public types of every symbol in this scope. - symbol_tys: IndexVec>, + symbols: IndexVec>, /// The type of a definition. - definition_tys: FxHashMap, Type<'db>>, + definitions: FxHashMap, Type<'db>>, } impl<'db> TypeInference<'db> { #[allow(unused)] pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { - self.expression_tys[expression] + self.expressions[expression] } pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type<'db> { - self.symbol_tys[symbol] + self.symbols[symbol] } - pub(super) fn module_ty(&self) -> &ModuleType { - self.module_type.as_ref().unwrap() - } - - pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType<'db> { - &self.class_types[id] - } - - pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType<'db> { - &self.function_types[id] - } - - pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType<'db> { - &self.union_types[id] - } - - pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType<'db> { - &self.intersection_types[id] - } - - pub(crate) fn definition_ty(&self, definition: Definition) -> Type<'db> { - self.definition_tys[&definition] + pub(crate) fn definition_ty(&self, definition: Definition<'db>) -> Type<'db> { + self.definitions[&definition] } fn shrink_to_fit(&mut self) { - self.class_types.shrink_to_fit(); - self.function_types.shrink_to_fit(); - self.union_types.shrink_to_fit(); - self.intersection_types.shrink_to_fit(); - - self.expression_tys.shrink_to_fit(); - self.symbol_tys.shrink_to_fit(); - self.definition_tys.shrink_to_fit(); + self.expressions.shrink_to_fit(); + self.symbols.shrink_to_fit(); + self.definitions.shrink_to_fit(); } } @@ -99,7 +57,6 @@ pub(super) struct TypeInferenceBuilder<'db> { // Cached lookups index: &'db SemanticIndex<'db>, - scope: ScopeId<'db>, file_scope_id: FileScopeId, file_id: VfsFile, symbol_table: Arc>, @@ -123,7 +80,6 @@ impl<'db> TypeInferenceBuilder<'db> { index, file_scope_id, file_id: file, - scope, symbol_table, db, @@ -205,13 +161,11 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(return_ty); } - let function_ty = self.function_ty(FunctionType { - name: name.id.clone(), - decorators: decorator_tys, - }); + let function_ty = + Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys)); let definition = self.index.definition(function); - self.types.definition_tys.insert(definition, function_ty); + self.types.definitions.insert(definition, function_ty); } fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) { @@ -233,16 +187,15 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|arguments| self.infer_arguments(arguments)) .unwrap_or(Vec::new()); - let body_scope = self.index.node_scope(NodeWithScopeRef::Class(class)); + let body_scope = self + .index + .node_scope(NodeWithScopeRef::Class(class)) + .to_scope_id(self.db, self.file_id); - let class_ty = self.class_ty(ClassType { - name: name.id.clone(), - bases, - body_scope: body_scope.to_scope_id(self.db, self.file_id), - }); + let class_ty = Type::Class(ClassType::new(self.db, name.id.clone(), bases, body_scope)); let definition = self.index.definition(class); - self.types.definition_tys.insert(definition, class_ty); + self.types.definitions.insert(definition, class_ty); } fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) { @@ -283,7 +236,7 @@ impl<'db> TypeInferenceBuilder<'db> { for target in targets { self.infer_expression(target); - self.types.definition_tys.insert( + self.types.definitions.insert( self.index.definition(DefinitionNodeRef::Target(target)), value_ty, ); @@ -306,7 +259,7 @@ impl<'db> TypeInferenceBuilder<'db> { let annotation_ty = self.infer_expression(annotation); self.infer_expression(target); - self.types.definition_tys.insert( + self.types.definitions.insert( self.index.definition(DefinitionNodeRef::Target(target)), annotation_ty, ); @@ -341,12 +294,12 @@ impl<'db> TypeInferenceBuilder<'db> { let module_name = ModuleName::new(&name.id); let module = module_name.and_then(|name| resolve_module(self.db.upcast(), name)); let module_ty = module - .map(|module| self.typing_context().module_ty(module.file())) + .map(|module| Type::Module(module.file())) .unwrap_or(Type::Unknown); let definition = self.index.definition(alias); - self.types.definition_tys.insert(definition, module_ty); + self.types.definitions.insert(definition, module_ty); } } @@ -363,7 +316,7 @@ impl<'db> TypeInferenceBuilder<'db> { let module = module_name.and_then(|module_name| resolve_module(self.db.upcast(), module_name)); let module_ty = module - .map(|module| self.typing_context().module_ty(module.file())) + .map(|module| Type::Module(module.file())) .unwrap_or(Type::Unknown); for alias in names { @@ -374,11 +327,11 @@ impl<'db> TypeInferenceBuilder<'db> { } = alias; let ty = module_ty - .member(&self.typing_context(), &name.id) + .member(self.db, &Name::new(&name.id)) .unwrap_or(Type::Unknown); let definition = self.index.definition(alias); - self.types.definition_tys.insert(definition, ty); + self.types.definitions.insert(definition, ty); } } @@ -425,7 +378,7 @@ impl<'db> TypeInferenceBuilder<'db> { _ => todo!("expression type resolution for {:?}", expression), }; - self.types.expression_tys.push(ty); + self.types.expressions.push(ty); ty } @@ -455,7 +408,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(target); self.types - .definition_tys + .definitions .insert(self.index.definition(named), value_ty); value_ty @@ -475,12 +428,12 @@ impl<'db> TypeInferenceBuilder<'db> { let body_ty = self.infer_expression(body); let orelse_ty = self.infer_expression(orelse); - let union = UnionTypeBuilder::new(&self.typing_context()) + let union = UnionTypeBuilder::new(self.db) .add(body_ty) .add(orelse_ty) .build(); - self.union_ty(union) + Type::Union(union) } fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { @@ -537,7 +490,7 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); let member_ty = value_ty - .member(&self.typing_context(), &attr.id) + .member(self.db, &Name::new(&attr.id)) .unwrap_or(Type::Unknown); match ctx { @@ -612,57 +565,31 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|symbol| self.local_definition_ty(symbol)) .collect(); - self.types.symbol_tys = symbol_tys; + self.types.symbols = symbol_tys; self.types.shrink_to_fit(); self.types } - fn union_ty(&mut self, ty: UnionType<'db>) -> Type<'db> { - Type::Union(TypeId { - scope: self.scope, - scoped: self.types.union_types.push(ty), - }) - } - - fn function_ty(&mut self, ty: FunctionType<'db>) -> Type<'db> { - Type::Function(TypeId { - scope: self.scope, - scoped: self.types.function_types.push(ty), - }) - } - - fn class_ty(&mut self, ty: ClassType<'db>) -> Type<'db> { - Type::Class(TypeId { - scope: self.scope, - scoped: self.types.class_types.push(ty), - }) - } - - fn typing_context(&self) -> TypingContext<'db, '_> { - TypingContext::scoped(self.db, self.scope, &self.types) - } - fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type<'db> { let symbol = self.symbol_table.symbol(symbol); let mut definitions = symbol .definitions() .iter() - .filter_map(|definition| self.types.definition_tys.get(definition).copied()); + .filter_map(|definition| self.types.definitions.get(definition).copied()); let Some(first) = definitions.next() else { return Type::Unbound; }; if let Some(second) = definitions.next() { - let context = self.typing_context(); - let mut builder = UnionTypeBuilder::new(&context); + let mut builder = UnionTypeBuilder::new(self.db); builder = builder.add(first).add(second); for variant in definitions { builder = builder.add(variant); } - self.union_ty(builder.build()) + Type::Union(builder.build()) } else { first } @@ -677,7 +604,7 @@ mod tests { use ruff_python_ast::name::Name; use crate::db::tests::TestDb; - use crate::types::{public_symbol_ty_by_name, Type, TypingContext}; + use crate::types::{public_symbol_ty_by_name, Type}; fn setup_db() -> TestDb { let mut db = TestDb::new(); @@ -699,7 +626,7 @@ mod tests { let file = system_path_to_file(db, file_name).expect("Expected file to exist."); let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown); - assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected); + assert_eq!(ty.display(db).to_string(), expected); } #[test] @@ -733,17 +660,14 @@ class Sub(Base): let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist."); let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist"); - let Type::Class(class_id) = ty else { + let Type::Class(class) = ty else { panic!("Sub is not a Class") }; - let context = TypingContext::global(&db); - - let base_names: Vec<_> = class_id - .lookup(&context) - .bases() + let base_names: Vec<_> = class + .bases(&db) .iter() - .map(|base_ty| format!("{}", base_ty.display(&context))) + .map(|base_ty| format!("{}", base_ty.display(&db))) .collect(); assert_eq!(base_names, vec!["Literal[Base]"]); @@ -770,15 +694,13 @@ class C: panic!("C is not a Class"); }; - let context = TypingContext::global(&db); - let member_ty = class_id.class_member(&context, &Name::new_static("f")); + let member_ty = class_id.class_member(&db, &Name::new_static("f")); - let Some(Type::Function(func_id)) = member_ty else { + let Some(Type::Function(func)) = member_ty else { panic!("C.f is not a Function"); }; - let function_ty = func_id.lookup(&context); - assert_eq!(function_ty.name(), "f"); + assert_eq!(func.name(&db), "f"); Ok(()) }