[red-knot] intern types using Salsa (#12061)

Intern types using Salsa interning instead of in the `TypeInference`
result.

This eliminates the need for `TypingContext`, and also paves the way for
finer-grained type inference queries.
This commit is contained in:
Carl Meyer 2024-07-05 12:16:37 -07:00 committed by GitHub
parent 7b50061b43
commit 0e44235981
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 190 additions and 534 deletions

11
Cargo.lock generated
View File

@ -1532,6 +1532,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordermap"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5a8e22be64dfa1123429350872e7be33594dbf5ae5212c90c5890e71966d1d"
dependencies = [
"indexmap",
]
[[package]] [[package]]
name = "os_str_bytes" name = "os_str_bytes"
version = "6.6.1" version = "6.6.1"
@ -1902,7 +1911,7 @@ dependencies = [
"anyhow", "anyhow",
"bitflags 2.6.0", "bitflags 2.6.0",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"indexmap", "ordermap",
"red_knot_module_resolver", "red_knot_module_resolver",
"ruff_db", "ruff_db",
"ruff_index", "ruff_index",

View File

@ -72,7 +72,6 @@ hashbrown = "0.14.3"
ignore = { version = "0.4.22" } ignore = { version = "0.4.22" }
imara-diff = { version = "0.1.5" } imara-diff = { version = "0.1.5" }
imperative = { version = "1.0.4" } imperative = { version = "1.0.4" }
indexmap = { version = "2.2.6" }
indicatif = { version = "0.17.8" } indicatif = { version = "0.17.8" }
indoc = { version = "2.0.4" } indoc = { version = "2.0.4" }
insta = { version = "1.35.1" } insta = { version = "1.35.1" }
@ -95,6 +94,7 @@ mimalloc = { version = "0.1.39" }
natord = { version = "1.0.9" } natord = { version = "1.0.9" }
notify = { version = "6.1.1" } notify = { version = "6.1.1" }
once_cell = { version = "1.19.0" } once_cell = { version = "1.19.0" }
ordermap = { version = "0.5.0" }
path-absolutize = { version = "3.1.1" } path-absolutize = { version = "3.1.1" }
path-slash = { version = "0.2.1" } path-slash = { version = "0.2.1" }
pathdiff = { version = "0.2.1" } pathdiff = { version = "0.2.1" }

View File

@ -122,7 +122,6 @@ fn lint_unresolved_imports(context: &SemanticLintContext, import: AnyImportRef)
fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) { fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) {
let semantic = &context.semantic; let semantic = &context.semantic;
let typing_context = semantic.typing_context();
// TODO we should have a special marker on the real typing module (from typeshed) so if you // TODO we should have a special marker on the real typing module (from typeshed) so if you
// have your own "typing" module in your project, we don't consider it THE typing module (and // have your own "typing" module in your project, we don't consider it THE typing module (and
@ -150,17 +149,17 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) {
return; return;
}; };
if ty.has_decorator(&typing_context, override_ty) { // TODO this shouldn't make direct use of the Db; see comment on SemanticModel::db
let method_name = ty.name(&typing_context); let db = semantic.db();
if class_ty
.inherited_class_member(&typing_context, method_name) if ty.has_decorator(db, override_ty) {
.is_none() let method_name = ty.name(db);
{ if class_ty.inherited_class_member(db, &method_name).is_none() {
// TODO should have a qualname() method to support nested classes // TODO should have a qualname() method to support nested classes
context.push_diagnostic( context.push_diagnostic(
format!( format!(
"Method {}.{} is decorated with `typing.override` but does not override any base class method", "Method {}.{} is decorated with `typing.override` but does not override any base class method",
class_ty.name(&typing_context), class_ty.name(db),
method_name, method_name,
)); ));
} }

View File

@ -18,7 +18,7 @@ ruff_python_ast = { workspace = true }
ruff_text_size = { workspace = true } ruff_text_size = { workspace = true }
bitflags = { workspace = true } bitflags = { workspace = true }
indexmap = { workspace = true } ordermap = { workspace = true }
salsa = { workspace = true } salsa = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
rustc-hash = { workspace = true } rustc-hash = { workspace = true }

View File

@ -7,13 +7,19 @@ use red_knot_module_resolver::Db as ResolverDb;
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId}; use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId};
use crate::semantic_index::{root_scope, semantic_index, symbol_table}; use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::types::{infer_types, public_symbol_ty}; use crate::types::{
infer_types, public_symbol_ty, ClassType, FunctionType, IntersectionType, UnionType,
};
#[salsa::jar(db=Db)] #[salsa::jar(db=Db)]
pub struct Jar( pub struct Jar(
ScopeId<'_>, ScopeId<'_>,
PublicSymbolId<'_>, PublicSymbolId<'_>,
Definition<'_>, Definition<'_>,
FunctionType<'_>,
ClassType<'_>,
UnionType<'_>,
IntersectionType<'_>,
symbol_table, symbol_table,
root_scope, root_scope,
semantic_index, semantic_index,

View File

@ -12,4 +12,4 @@ pub mod semantic_index;
mod semantic_model; mod semantic_model;
pub mod types; pub mod types;
type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>; type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;

View File

@ -1,10 +0,0 @@
use std::hash::BuildHasherDefault;
use rustc_hash::FxHasher;
pub mod ast_node_ref;
mod node_key;
pub mod semantic_index;
pub mod types;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;

View File

@ -155,7 +155,6 @@ impl<'db> PublicSymbolsMap<'db> {
/// A cross-module identifier of a scope that can be used as a salsa query parameter. /// A cross-module identifier of a scope that can be used as a salsa query parameter.
#[salsa::tracked] #[salsa::tracked]
pub struct ScopeId<'db> { pub struct ScopeId<'db> {
#[allow(clippy::used_underscore_binding)]
#[id] #[id]
pub file: VfsFile, pub file: VfsFile,
#[id] #[id]

View File

@ -6,7 +6,7 @@ use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef};
use crate::semantic_index::ast_ids::HasScopedAstId; use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::symbol::PublicSymbolId; use crate::semantic_index::symbol::PublicSymbolId;
use crate::semantic_index::{public_symbol, semantic_index}; use crate::semantic_index::{public_symbol, semantic_index};
use crate::types::{infer_types, public_symbol_ty, Type, TypingContext}; use crate::types::{infer_types, public_symbol_ty, Type};
use crate::Db; use crate::Db;
pub struct SemanticModel<'db> { pub struct SemanticModel<'db> {
@ -19,6 +19,12 @@ impl<'db> SemanticModel<'db> {
Self { db, file } Self { db, file }
} }
// TODO we don't actually want to expose the Db directly to lint rules, but we need to find a
// solution for exposing information from types
pub fn db(&self) -> &dyn Db {
self.db
}
pub fn resolve_module(&self, module_name: ModuleName) -> Option<Module> { pub fn resolve_module(&self, module_name: ModuleName) -> Option<Module> {
resolve_module(self.db.upcast(), module_name) resolve_module(self.db.upcast(), module_name)
} }
@ -27,13 +33,9 @@ impl<'db> SemanticModel<'db> {
public_symbol(self.db, module.file(), symbol_name) public_symbol(self.db, module.file(), symbol_name)
} }
pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type<'db> { pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type {
public_symbol_ty(self.db, symbol) public_symbol_ty(self.db, symbol)
} }
pub fn typing_context(&self) -> TypingContext<'db, '_> {
TypingContext::global(self.db)
}
} }
pub trait HasTy { pub trait HasTy {

View File

@ -1,13 +1,11 @@
use ruff_db::parsed::parsed_module; use ruff_db::parsed::parsed_module;
use ruff_db::vfs::VfsFile; use ruff_db::vfs::VfsFile;
use ruff_index::newtype_index;
use ruff_python_ast::name::Name; use ruff_python_ast::name::Name;
use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, PublicSymbolId, ScopeId}; use crate::semantic_index::symbol::{NodeWithScopeKind, PublicSymbolId, ScopeId};
use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table}; use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table};
use crate::types::infer::{TypeInference, TypeInferenceBuilder}; use crate::types::infer::{TypeInference, TypeInferenceBuilder};
use crate::Db; use crate::{Db, FxOrderSet};
use crate::FxIndexSet;
mod display; mod display;
mod infer; mod infer;
@ -43,12 +41,12 @@ pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db>
let file = symbol.file(db); let file = symbol.file(db);
let scope = root_scope(db, file); let scope = root_scope(db, file);
// TODO switch to inferring just the definition(s), not the whole scope
let inference = infer_types(db, scope); let inference = infer_types(db, scope);
inference.symbol_ty(symbol.scoped_symbol_id(db)) inference.symbol_ty(symbol.scoped_symbol_id(db))
} }
/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`]. /// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`].
#[allow(unused)]
pub(crate) fn public_symbol_ty_by_name<'db>( pub(crate) fn public_symbol_ty_by_name<'db>(
db: &'db dyn Db, db: &'db dyn Db,
file: VfsFile, file: VfsFile,
@ -91,7 +89,7 @@ pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfe
} }
/// unique ID for a type /// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
pub enum Type<'db> { pub enum Type<'db> {
/// the dynamic type: a statically-unknown set of values /// the dynamic type: a statically-unknown set of values
Any, Any,
@ -105,15 +103,15 @@ pub enum Type<'db> {
/// the None object (TODO remove this in favor of Instance(types.NoneType) /// the None object (TODO remove this in favor of Instance(types.NoneType)
None, None,
/// a specific function object /// a specific function object
Function(TypeId<'db, ScopedFunctionTypeId>), Function(FunctionType<'db>),
/// a specific module object /// a specific module object
Module(TypeId<'db, ScopedModuleTypeId>), Module(VfsFile),
/// a specific class object /// a specific class object
Class(TypeId<'db, ScopedClassTypeId>), Class(ClassType<'db>),
/// the set of Python objects with the given class in their __class__'s method resolution order /// the set of Python objects with the given class in their __class__'s method resolution order
Instance(TypeId<'db, ScopedClassTypeId>), Instance(ClassType<'db>),
Union(TypeId<'db, ScopedUnionTypeId>), Union(UnionType<'db>),
Intersection(TypeId<'db, ScopedIntersectionTypeId>), Intersection(IntersectionType<'db>),
IntLiteral(i64), IntLiteral(i64),
// TODO protocols, callable types, overloads, generics, type vars // TODO protocols, callable types, overloads, generics, type vars
} }
@ -127,7 +125,7 @@ impl<'db> Type<'db> {
matches!(self, Type::Unknown) matches!(self, Type::Unknown)
} }
pub fn member(&self, context: &TypingContext<'db, '_>, name: &Name) -> Option<Type<'db>> { pub fn member(&self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
match self { match self {
Type::Any => Some(Type::Any), Type::Any => Some(Type::Any),
Type::Never => todo!("attribute lookup on Never type"), Type::Never => todo!("attribute lookup on Never type"),
@ -135,14 +133,13 @@ impl<'db> Type<'db> {
Type::Unbound => todo!("attribute lookup on Unbound type"), Type::Unbound => todo!("attribute lookup on Unbound type"),
Type::None => todo!("attribute lookup on None type"), Type::None => todo!("attribute lookup on None type"),
Type::Function(_) => todo!("attribute lookup on Function type"), Type::Function(_) => todo!("attribute lookup on Function type"),
Type::Module(module) => module.member(context, name), Type::Module(file) => public_symbol_ty_by_name(db, *file, name),
Type::Class(class) => class.class_member(context, name), Type::Class(class) => class.class_member(db, name),
Type::Instance(_) => { Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member // TODO MRO? get_own_instance_member, get_instance_member
todo!("attribute lookup on Instance type") todo!("attribute lookup on Instance type")
} }
Type::Union(union_id) => { Type::Union(_) => {
let _union = union_id.lookup(context);
// TODO perform the get_member on each type in the union // TODO perform the get_member on each type in the union
// TODO return the union of those results // TODO return the union of those results
// TODO if any of those results is `None` then include Unknown in the result union // TODO if any of those results is `None` then include Unknown in the result union
@ -161,155 +158,25 @@ impl<'db> Type<'db> {
} }
} }
/// ID that uniquely identifies a type in a program. #[salsa::interned]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct FunctionType<'db> {
pub struct TypeId<'db, L> {
/// The scope in which this type is defined or was created.
scope: ScopeId<'db>,
/// The type's local ID in its scope.
scoped: L,
}
impl<'db, Id> TypeId<'db, Id>
where
Id: Copy,
{
pub fn scope(&self) -> ScopeId<'db> {
self.scope
}
pub fn scoped_id(&self) -> Id {
self.scoped
}
/// Resolves the type ID to the actual type.
pub(crate) fn lookup<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Id::Ty<'db>
where
Id: ScopedTypeId,
{
let types = context.types(self.scope);
self.scoped.lookup_scoped(types)
}
}
/// ID that uniquely identifies a type in a scope.
pub(crate) trait ScopedTypeId {
/// The type that this ID points to.
type Ty<'db>;
/// Looks up the type in `index`.
///
/// ## Panics
/// May panic if this type is from another scope than `index`, or might just return an invalid type.
fn lookup_scoped<'a, 'db>(self, index: &'a TypeInference<'db>) -> &'a Self::Ty<'db>;
}
/// ID uniquely identifying a function type in a `scope`.
#[newtype_index]
pub struct ScopedFunctionTypeId;
impl ScopedTypeId for ScopedFunctionTypeId {
type Ty<'db> = FunctionType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.function_ty(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct FunctionType<'a> {
/// name of the function at definition /// name of the function at definition
name: Name, pub name: Name,
/// types of all decorators on this function /// types of all decorators on this function
decorators: Vec<Type<'a>>, decorators: Vec<Type<'db>>,
} }
impl<'a> FunctionType<'a> { impl<'db> FunctionType<'db> {
fn name(&self) -> &str { pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.name.as_str() self.decorators(db).contains(&decorator)
}
#[allow(unused)]
pub(crate) fn decorators(&self) -> &[Type<'a>] {
self.decorators.as_slice()
} }
} }
impl<'db> TypeId<'db, ScopedFunctionTypeId> { #[salsa::interned]
pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name {
let function_ty = self.lookup(context);
&function_ty.name
}
pub fn has_decorator(self, context: &TypingContext, decorator: Type<'db>) -> bool {
let function_ty = self.lookup(context);
function_ty.decorators.contains(&decorator)
}
}
#[newtype_index]
pub struct ScopedClassTypeId;
impl ScopedTypeId for ScopedClassTypeId {
type Ty<'db> = ClassType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.class_ty(self)
}
}
impl<'db> TypeId<'db, ScopedClassTypeId> {
pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name {
let class_ty = self.lookup(context);
&class_ty.name
}
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
pub fn class_member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option<Type<'db>> {
if let Some(member) = self.own_class_member(context, name) {
return Some(member);
}
self.inherited_class_member(context, name)
}
/// Returns the inferred type of the class member named `name`.
pub fn own_class_member(
self,
context: &TypingContext<'db, '_>,
name: &Name,
) -> Option<Type<'db>> {
let class = self.lookup(context);
let symbols = symbol_table(context.db, class.body_scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = context.types(class.body_scope);
Some(types.symbol_ty(symbol))
}
pub fn inherited_class_member(
self,
context: &TypingContext<'db, '_>,
name: &Name,
) -> Option<Type<'db>> {
let class = self.lookup(context);
for base in &class.bases {
if let Some(member) = base.member(context, name) {
return Some(member);
}
}
None
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ClassType<'db> { pub struct ClassType<'db> {
/// Name of the class at definition /// Name of the class at definition
name: Name, pub name: Name,
/// Types of all class bases /// Types of all class bases
bases: Vec<Type<'db>>, bases: Vec<Type<'db>>,
@ -318,52 +185,62 @@ pub struct ClassType<'db> {
} }
impl<'db> ClassType<'db> { impl<'db> ClassType<'db> {
fn name(&self) -> &str { /// Returns the class member of this class named `name`.
self.name.as_str() ///
/// The member resolves to a member of the class itself or any of its bases.
pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
if let Some(member) = self.own_class_member(db, name) {
return Some(member);
} }
#[allow(unused)] self.inherited_class_member(db, name)
pub(super) fn bases(&self) -> &'db [Type] { }
self.bases.as_slice()
/// Returns the inferred type of the class member named `name`.
pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
let scope = self.body_scope(db);
let symbols = symbol_table(db, scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = infer_types(db, scope);
Some(types.symbol_ty(symbol))
}
pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
for base in self.bases(db) {
if let Some(member) = base.member(db, name) {
return Some(member);
} }
} }
#[newtype_index] None
pub struct ScopedUnionTypeId;
impl ScopedTypeId for ScopedUnionTypeId {
type Ty<'db> = UnionType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.union_ty(self)
} }
} }
#[derive(Debug, Eq, PartialEq, Clone)] #[salsa::interned]
pub struct UnionType<'db> { pub struct UnionType<'db> {
// the union type includes values in any of these types /// the union type includes values in any of these types
elements: FxIndexSet<Type<'db>>, elements: FxOrderSet<Type<'db>>,
} }
struct UnionTypeBuilder<'db, 'a> { struct UnionTypeBuilder<'db> {
elements: FxIndexSet<Type<'db>>, elements: FxOrderSet<Type<'db>>,
context: &'a TypingContext<'db, 'a>, db: &'db dyn Db,
} }
impl<'db, 'a> UnionTypeBuilder<'db, 'a> { impl<'db> UnionTypeBuilder<'db> {
fn new(context: &'a TypingContext<'db, 'a>) -> Self { fn new(db: &'db dyn Db) -> Self {
Self { Self {
context, db,
elements: FxIndexSet::default(), elements: FxOrderSet::default(),
} }
} }
/// Adds a type to this union. /// Adds a type to this union.
fn add(mut self, ty: Type<'db>) -> Self { fn add(mut self, ty: Type<'db>) -> Self {
match ty { match ty {
Type::Union(union_id) => { Type::Union(union) => {
let union = union_id.lookup(self.context); self.elements.extend(&union.elements(self.db));
self.elements.extend(&union.elements);
} }
_ => { _ => {
self.elements.insert(ty); self.elements.insert(ty);
@ -374,20 +251,7 @@ impl<'db, 'a> UnionTypeBuilder<'db, 'a> {
} }
fn build(self) -> UnionType<'db> { fn build(self) -> UnionType<'db> {
UnionType { UnionType::new(self.db, self.elements)
elements: self.elements,
}
}
}
#[newtype_index]
pub struct ScopedIntersectionTypeId;
impl ScopedTypeId for ScopedIntersectionTypeId {
type Ty<'db> = IntersectionType<'db>;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.intersection_ty(self)
} }
} }
@ -397,104 +261,12 @@ impl ScopedTypeId for ScopedIntersectionTypeId {
// case where a Not appears outside an intersection (unclear when that could even happen, but we'd // case where a Not appears outside an intersection (unclear when that could even happen, but we'd
// have to represent it as a single-element intersection if it did) in exchange for better // have to represent it as a single-element intersection if it did) in exchange for better
// efficiency in the within-intersection case. // efficiency in the within-intersection case.
#[derive(Debug, PartialEq, Eq, Clone)] #[salsa::interned]
pub struct IntersectionType<'db> { pub struct IntersectionType<'db> {
// the intersection type includes only values in all of these types // the intersection type includes only values in all of these types
positive: FxIndexSet<Type<'db>>, positive: FxOrderSet<Type<'db>>,
// the intersection type does not include any value in any of these types // the intersection type does not include any value in any of these types
negative: FxIndexSet<Type<'db>>, negative: FxOrderSet<Type<'db>>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ScopedModuleTypeId;
impl ScopedTypeId for ScopedModuleTypeId {
type Ty<'db> = ModuleType;
fn lookup_scoped<'a, 'db>(self, types: &'a TypeInference<'db>) -> &'a Self::Ty<'db> {
types.module_ty()
}
}
impl<'db> TypeId<'db, ScopedModuleTypeId> {
fn member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option<Type<'db>> {
context.public_symbol_ty(self.scope.file(context.db), name)
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct ModuleType {
file: VfsFile,
}
/// Context in which to resolve types.
///
/// This abstraction is necessary to support a uniform API that can be used
/// while in the process of building the type inference structure for a scope
/// but also when all types should be resolved by querying the db.
pub struct TypingContext<'db, 'inference> {
db: &'db dyn Db,
/// The Local type inference scope that is in the process of being built.
///
/// Bypass the `db` when resolving the types for this scope.
local: Option<(ScopeId<'db>, &'inference TypeInference<'db>)>,
}
impl<'db, 'inference> TypingContext<'db, 'inference> {
/// Creates a context that resolves all types by querying the db.
#[allow(unused)]
pub(super) fn global(db: &'db dyn Db) -> Self {
Self { db, local: None }
}
/// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`.
fn scoped(
db: &'db dyn Db,
scope_id: ScopeId<'db>,
types: &'inference TypeInference<'db>,
) -> Self {
Self {
db,
local: Some((scope_id, types)),
}
}
/// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`.
fn types(&self, scope_id: ScopeId<'db>) -> &'inference TypeInference<'db> {
if let Some((scope, local_types)) = self.local {
if scope == scope_id {
return local_types;
}
}
infer_types(self.db, scope_id)
}
fn module_ty(&self, file: VfsFile) -> Type<'db> {
let scope = root_scope(self.db, file);
Type::Module(TypeId {
scope,
scoped: ScopedModuleTypeId,
})
}
/// Resolves the public type of a symbol named `name` defined in `file`.
///
/// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`.
/// It otherwise tries to resolve the symbol type locally.
fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option<Type<'db>> {
let symbol = public_symbol(self.db, file, name)?;
if let Some((scope, local_types)) = self.local {
if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file {
return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db)));
}
}
Some(public_symbol_ty(self.db, symbol))
}
} }
#[cfg(test)] #[cfg(test)]
@ -508,7 +280,7 @@ mod tests {
assert_will_not_run_function_query, assert_will_run_function_query, TestDb, assert_will_not_run_function_query, assert_will_run_function_query, TestDb,
}; };
use crate::semantic_index::root_scope; use crate::semantic_index::root_scope;
use crate::types::{infer_types, public_symbol_ty_by_name, TypingContext}; use crate::types::{infer_types, public_symbol_ty_by_name};
use crate::{HasTy, SemanticModel}; use crate::{HasTy, SemanticModel};
fn setup_db() -> TestDb { fn setup_db() -> TestDb {
@ -540,10 +312,7 @@ mod tests {
let literal_ty = statement.value.ty(&model); let literal_ty = statement.value.ty(&model);
assert_eq!( assert_eq!(format!("{}", literal_ty.display(&db)), "Literal[10]");
format!("{}", literal_ty.display(&TypingContext::global(&db))),
"Literal[10]"
);
Ok(()) Ok(())
} }
@ -560,10 +329,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap(); let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!( assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
// Change `x` to a different value // Change `x` to a different value
db.memory_file_system() db.memory_file_system()
@ -577,10 +343,7 @@ mod tests {
db.clear_salsa_events(); db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!( assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]");
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[20]"
);
let events = db.take_salsa_events(); let events = db.take_salsa_events();
@ -607,10 +370,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap(); let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!( assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
db.memory_file_system() db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?; .write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
@ -624,10 +384,7 @@ mod tests {
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!( assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
let events = db.take_salsa_events(); let events = db.take_salsa_events();
@ -655,10 +412,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap(); let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!( assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
x_ty.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
db.memory_file_system() db.memory_file_system()
.write_file("/src/foo.py", "x = 10\ny = 30")?; .write_file("/src/foo.py", "x = 10\ny = 30")?;
@ -672,10 +426,7 @@ mod tests {
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!( assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
x_ty_2.display(&TypingContext::global(&db)).to_string(),
"Literal[10]"
);
let events = db.take_salsa_events(); let events = db.take_salsa_events();

View File

@ -2,18 +2,19 @@
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use crate::types::{IntersectionType, Type, TypingContext, UnionType}; use crate::types::{IntersectionType, Type, UnionType};
use crate::Db;
impl Type<'_> { impl<'db> Type<'db> {
pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> { pub fn display(&'db self, db: &'db dyn Db) -> DisplayType<'db> {
DisplayType { ty: self, context } DisplayType { ty: self, db }
} }
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct DisplayType<'a> { pub struct DisplayType<'db> {
ty: &'a Type<'a>, ty: &'db Type<'db>,
context: &'a TypingContext<'a, 'a>, db: &'db dyn Db,
} }
impl Display for DisplayType<'_> { impl Display for DisplayType<'_> {
@ -24,42 +25,19 @@ impl Display for DisplayType<'_> {
Type::Unknown => f.write_str("Unknown"), Type::Unknown => f.write_str("Unknown"),
Type::Unbound => f.write_str("Unbound"), Type::Unbound => f.write_str("Unbound"),
Type::None => f.write_str("None"), Type::None => f.write_str("None"),
Type::Module(module_id) => { Type::Module(file) => {
write!( write!(f, "<module '{:?}'>", file.path(self.db.upcast()))
f,
"<module '{:?}'>",
module_id
.scope
.file(self.context.db)
.path(self.context.db.upcast())
)
} }
// TODO functions and classes should display using a fully qualified name // TODO functions and classes should display using a fully qualified name
Type::Class(class_id) => { Type::Class(class) => {
let class = class_id.lookup(self.context);
f.write_str("Literal[")?; f.write_str("Literal[")?;
f.write_str(class.name())?; f.write_str(&class.name(self.db))?;
f.write_str("]") f.write_str("]")
} }
Type::Instance(class_id) => { Type::Instance(class) => f.write_str(&class.name(self.db)),
let class = class_id.lookup(self.context); Type::Function(function) => f.write_str(&function.name(self.db)),
f.write_str(class.name()) Type::Union(union) => union.display(self.db).fmt(f),
} Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
Type::Function(function_id) => {
let function = function_id.lookup(self.context);
f.write_str(function.name())
}
Type::Union(union_id) => {
let union = union_id.lookup(self.context);
union.display(self.context).fmt(f)
}
Type::Intersection(intersection_id) => {
let intersection = intersection_id.lookup(self.context);
intersection.display(self.context).fmt(f)
}
Type::IntLiteral(n) => write!(f, "Literal[{n}]"), Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
} }
} }
@ -71,15 +49,15 @@ impl std::fmt::Debug for DisplayType<'_> {
} }
} }
impl UnionType<'_> { impl<'db> UnionType<'db> {
fn display<'a>(&'a self, context: &'a TypingContext<'a, 'a>) -> DisplayUnionType<'a> { fn display(&'db self, db: &'db dyn Db) -> DisplayUnionType<'db> {
DisplayUnionType { context, ty: self } DisplayUnionType { db, ty: self }
} }
} }
struct DisplayUnionType<'a> { struct DisplayUnionType<'db> {
ty: &'a UnionType<'a>, ty: &'db UnionType<'db>,
context: &'a TypingContext<'a, 'a>, db: &'db dyn Db,
} }
impl Display for DisplayUnionType<'_> { impl Display for DisplayUnionType<'_> {
@ -87,7 +65,7 @@ impl Display for DisplayUnionType<'_> {
let union = self.ty; let union = self.ty;
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
.elements .elements(self.db)
.iter() .iter()
.copied() .copied()
.partition(|ty| matches!(ty, Type::IntLiteral(_))); .partition(|ty| matches!(ty, Type::IntLiteral(_)));
@ -121,7 +99,7 @@ impl Display for DisplayUnionType<'_> {
f.write_str(" | ")?; f.write_str(" | ")?;
}; };
first = false; first = false;
write!(f, "{}", ty.display(self.context))?; write!(f, "{}", ty.display(self.db))?;
} }
Ok(()) Ok(())
@ -134,15 +112,15 @@ impl std::fmt::Debug for DisplayUnionType<'_> {
} }
} }
impl IntersectionType<'_> { impl<'db> IntersectionType<'db> {
fn display<'a>(&'a self, context: &'a TypingContext<'a, 'a>) -> DisplayIntersectionType<'a> { fn display(&'db self, db: &'db dyn Db) -> DisplayIntersectionType<'db> {
DisplayIntersectionType { ty: self, context } DisplayIntersectionType { db, ty: self }
} }
} }
struct DisplayIntersectionType<'a> { struct DisplayIntersectionType<'db> {
ty: &'a IntersectionType<'a>, ty: &'db IntersectionType<'db>,
context: &'a TypingContext<'a, 'a>, db: &'db dyn Db,
} }
impl Display for DisplayIntersectionType<'_> { impl Display for DisplayIntersectionType<'_> {
@ -150,10 +128,10 @@ impl Display for DisplayIntersectionType<'_> {
let mut first = true; let mut first = true;
for (neg, ty) in self for (neg, ty) in self
.ty .ty
.positive .positive(self.db)
.iter() .iter()
.map(|ty| (false, ty)) .map(|ty| (false, ty))
.chain(self.ty.negative.iter().map(|ty| (true, ty))) .chain(self.ty.negative(self.db).iter().map(|ty| (true, ty)))
{ {
if !first { if !first {
f.write_str(" & ")?; f.write_str(" & ")?;
@ -162,7 +140,7 @@ impl Display for DisplayIntersectionType<'_> {
if neg { if neg {
f.write_str("~")?; f.write_str("~")?;
}; };
write!(f, "{}", ty.display(self.context))?; write!(f, "{}", ty.display(self.db))?;
} }
Ok(()) Ok(())
} }

View File

@ -2,8 +2,7 @@ use rustc_hash::FxHashMap;
use std::borrow::Cow; use std::borrow::Cow;
use std::sync::Arc; use std::sync::Arc;
use red_knot_module_resolver::resolve_module; use red_knot_module_resolver::{resolve_module, ModuleName};
use red_knot_module_resolver::ModuleName;
use ruff_db::vfs::VfsFile; use ruff_db::vfs::VfsFile;
use ruff_index::IndexVec; use ruff_index::IndexVec;
use ruff_python_ast as ast; use ruff_python_ast as ast;
@ -15,81 +14,40 @@ use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeRef, ScopeId, ScopedSymbolId, SymbolTable, FileScopeId, NodeWithScopeRef, ScopeId, ScopedSymbolId, SymbolTable,
}; };
use crate::semantic_index::{symbol_table, SemanticIndex}; use crate::semantic_index::{symbol_table, SemanticIndex};
use crate::types::{ use crate::types::{infer_types, ClassType, FunctionType, Name, Type, UnionTypeBuilder};
infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId,
ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext,
UnionType, UnionTypeBuilder,
};
use crate::Db; use crate::Db;
/// The inferred types for a single scope. /// The inferred types for a single scope.
#[derive(Debug, Eq, PartialEq, Default, Clone)] #[derive(Debug, Eq, PartialEq, Default, Clone)]
pub(crate) struct TypeInference<'db> { pub(crate) struct TypeInference<'db> {
/// The type of the module if the scope is a module scope.
module_type: Option<ModuleType>,
/// The types of the defined classes in this scope.
class_types: IndexVec<ScopedClassTypeId, ClassType<'db>>,
/// The types of the defined functions in this scope.
function_types: IndexVec<ScopedFunctionTypeId, FunctionType<'db>>,
union_types: IndexVec<ScopedUnionTypeId, UnionType<'db>>,
intersection_types: IndexVec<ScopedIntersectionTypeId, IntersectionType<'db>>,
/// The types of every expression in this scope. /// The types of every expression in this scope.
expression_tys: IndexVec<ScopedExpressionId, Type<'db>>, expressions: IndexVec<ScopedExpressionId, Type<'db>>,
/// The public types of every symbol in this scope. /// The public types of every symbol in this scope.
symbol_tys: IndexVec<ScopedSymbolId, Type<'db>>, symbols: IndexVec<ScopedSymbolId, Type<'db>>,
/// The type of a definition. /// The type of a definition.
definition_tys: FxHashMap<Definition<'db>, Type<'db>>, definitions: FxHashMap<Definition<'db>, Type<'db>>,
} }
impl<'db> TypeInference<'db> { impl<'db> TypeInference<'db> {
#[allow(unused)] #[allow(unused)]
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
self.expression_tys[expression] self.expressions[expression]
} }
pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type<'db> { pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type<'db> {
self.symbol_tys[symbol] self.symbols[symbol]
} }
pub(super) fn module_ty(&self) -> &ModuleType { pub(crate) fn definition_ty(&self, definition: Definition<'db>) -> Type<'db> {
self.module_type.as_ref().unwrap() self.definitions[&definition]
}
pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType<'db> {
&self.class_types[id]
}
pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType<'db> {
&self.function_types[id]
}
pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType<'db> {
&self.union_types[id]
}
pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType<'db> {
&self.intersection_types[id]
}
pub(crate) fn definition_ty(&self, definition: Definition) -> Type<'db> {
self.definition_tys[&definition]
} }
fn shrink_to_fit(&mut self) { fn shrink_to_fit(&mut self) {
self.class_types.shrink_to_fit(); self.expressions.shrink_to_fit();
self.function_types.shrink_to_fit(); self.symbols.shrink_to_fit();
self.union_types.shrink_to_fit(); self.definitions.shrink_to_fit();
self.intersection_types.shrink_to_fit();
self.expression_tys.shrink_to_fit();
self.symbol_tys.shrink_to_fit();
self.definition_tys.shrink_to_fit();
} }
} }
@ -99,7 +57,6 @@ pub(super) struct TypeInferenceBuilder<'db> {
// Cached lookups // Cached lookups
index: &'db SemanticIndex<'db>, index: &'db SemanticIndex<'db>,
scope: ScopeId<'db>,
file_scope_id: FileScopeId, file_scope_id: FileScopeId,
file_id: VfsFile, file_id: VfsFile,
symbol_table: Arc<SymbolTable<'db>>, symbol_table: Arc<SymbolTable<'db>>,
@ -123,7 +80,6 @@ impl<'db> TypeInferenceBuilder<'db> {
index, index,
file_scope_id, file_scope_id,
file_id: file, file_id: file,
scope,
symbol_table, symbol_table,
db, db,
@ -205,13 +161,11 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(return_ty); self.infer_expression(return_ty);
} }
let function_ty = self.function_ty(FunctionType { let function_ty =
name: name.id.clone(), Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys));
decorators: decorator_tys,
});
let definition = self.index.definition(function); let definition = self.index.definition(function);
self.types.definition_tys.insert(definition, function_ty); self.types.definitions.insert(definition, function_ty);
} }
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) { fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
@ -233,16 +187,15 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|arguments| self.infer_arguments(arguments)) .map(|arguments| self.infer_arguments(arguments))
.unwrap_or(Vec::new()); .unwrap_or(Vec::new());
let body_scope = self.index.node_scope(NodeWithScopeRef::Class(class)); let body_scope = self
.index
.node_scope(NodeWithScopeRef::Class(class))
.to_scope_id(self.db, self.file_id);
let class_ty = self.class_ty(ClassType { let class_ty = Type::Class(ClassType::new(self.db, name.id.clone(), bases, body_scope));
name: name.id.clone(),
bases,
body_scope: body_scope.to_scope_id(self.db, self.file_id),
});
let definition = self.index.definition(class); let definition = self.index.definition(class);
self.types.definition_tys.insert(definition, class_ty); self.types.definitions.insert(definition, class_ty);
} }
fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) { fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) {
@ -283,7 +236,7 @@ impl<'db> TypeInferenceBuilder<'db> {
for target in targets { for target in targets {
self.infer_expression(target); self.infer_expression(target);
self.types.definition_tys.insert( self.types.definitions.insert(
self.index.definition(DefinitionNodeRef::Target(target)), self.index.definition(DefinitionNodeRef::Target(target)),
value_ty, value_ty,
); );
@ -306,7 +259,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let annotation_ty = self.infer_expression(annotation); let annotation_ty = self.infer_expression(annotation);
self.infer_expression(target); self.infer_expression(target);
self.types.definition_tys.insert( self.types.definitions.insert(
self.index.definition(DefinitionNodeRef::Target(target)), self.index.definition(DefinitionNodeRef::Target(target)),
annotation_ty, annotation_ty,
); );
@ -341,12 +294,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let module_name = ModuleName::new(&name.id); let module_name = ModuleName::new(&name.id);
let module = module_name.and_then(|name| resolve_module(self.db.upcast(), name)); let module = module_name.and_then(|name| resolve_module(self.db.upcast(), name));
let module_ty = module let module_ty = module
.map(|module| self.typing_context().module_ty(module.file())) .map(|module| Type::Module(module.file()))
.unwrap_or(Type::Unknown); .unwrap_or(Type::Unknown);
let definition = self.index.definition(alias); let definition = self.index.definition(alias);
self.types.definition_tys.insert(definition, module_ty); self.types.definitions.insert(definition, module_ty);
} }
} }
@ -363,7 +316,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let module = let module =
module_name.and_then(|module_name| resolve_module(self.db.upcast(), module_name)); module_name.and_then(|module_name| resolve_module(self.db.upcast(), module_name));
let module_ty = module let module_ty = module
.map(|module| self.typing_context().module_ty(module.file())) .map(|module| Type::Module(module.file()))
.unwrap_or(Type::Unknown); .unwrap_or(Type::Unknown);
for alias in names { for alias in names {
@ -374,11 +327,11 @@ impl<'db> TypeInferenceBuilder<'db> {
} = alias; } = alias;
let ty = module_ty let ty = module_ty
.member(&self.typing_context(), &name.id) .member(self.db, &Name::new(&name.id))
.unwrap_or(Type::Unknown); .unwrap_or(Type::Unknown);
let definition = self.index.definition(alias); let definition = self.index.definition(alias);
self.types.definition_tys.insert(definition, ty); self.types.definitions.insert(definition, ty);
} }
} }
@ -425,7 +378,7 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => todo!("expression type resolution for {:?}", expression), _ => todo!("expression type resolution for {:?}", expression),
}; };
self.types.expression_tys.push(ty); self.types.expressions.push(ty);
ty ty
} }
@ -455,7 +408,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_expression(target); self.infer_expression(target);
self.types self.types
.definition_tys .definitions
.insert(self.index.definition(named), value_ty); .insert(self.index.definition(named), value_ty);
value_ty value_ty
@ -475,12 +428,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let body_ty = self.infer_expression(body); let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse); let orelse_ty = self.infer_expression(orelse);
let union = UnionTypeBuilder::new(&self.typing_context()) let union = UnionTypeBuilder::new(self.db)
.add(body_ty) .add(body_ty)
.add(orelse_ty) .add(orelse_ty)
.build(); .build();
self.union_ty(union) Type::Union(union)
} }
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> {
@ -537,7 +490,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value);
let member_ty = value_ty let member_ty = value_ty
.member(&self.typing_context(), &attr.id) .member(self.db, &Name::new(&attr.id))
.unwrap_or(Type::Unknown); .unwrap_or(Type::Unknown);
match ctx { match ctx {
@ -612,57 +565,31 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|symbol| self.local_definition_ty(symbol)) .map(|symbol| self.local_definition_ty(symbol))
.collect(); .collect();
self.types.symbol_tys = symbol_tys; self.types.symbols = symbol_tys;
self.types.shrink_to_fit(); self.types.shrink_to_fit();
self.types self.types
} }
fn union_ty(&mut self, ty: UnionType<'db>) -> Type<'db> {
Type::Union(TypeId {
scope: self.scope,
scoped: self.types.union_types.push(ty),
})
}
fn function_ty(&mut self, ty: FunctionType<'db>) -> Type<'db> {
Type::Function(TypeId {
scope: self.scope,
scoped: self.types.function_types.push(ty),
})
}
fn class_ty(&mut self, ty: ClassType<'db>) -> Type<'db> {
Type::Class(TypeId {
scope: self.scope,
scoped: self.types.class_types.push(ty),
})
}
fn typing_context(&self) -> TypingContext<'db, '_> {
TypingContext::scoped(self.db, self.scope, &self.types)
}
fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type<'db> { fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type<'db> {
let symbol = self.symbol_table.symbol(symbol); let symbol = self.symbol_table.symbol(symbol);
let mut definitions = symbol let mut definitions = symbol
.definitions() .definitions()
.iter() .iter()
.filter_map(|definition| self.types.definition_tys.get(definition).copied()); .filter_map(|definition| self.types.definitions.get(definition).copied());
let Some(first) = definitions.next() else { let Some(first) = definitions.next() else {
return Type::Unbound; return Type::Unbound;
}; };
if let Some(second) = definitions.next() { if let Some(second) = definitions.next() {
let context = self.typing_context(); let mut builder = UnionTypeBuilder::new(self.db);
let mut builder = UnionTypeBuilder::new(&context);
builder = builder.add(first).add(second); builder = builder.add(first).add(second);
for variant in definitions { for variant in definitions {
builder = builder.add(variant); builder = builder.add(variant);
} }
self.union_ty(builder.build()) Type::Union(builder.build())
} else { } else {
first first
} }
@ -677,7 +604,7 @@ mod tests {
use ruff_python_ast::name::Name; use ruff_python_ast::name::Name;
use crate::db::tests::TestDb; use crate::db::tests::TestDb;
use crate::types::{public_symbol_ty_by_name, Type, TypingContext}; use crate::types::{public_symbol_ty_by_name, Type};
fn setup_db() -> TestDb { fn setup_db() -> TestDb {
let mut db = TestDb::new(); let mut db = TestDb::new();
@ -699,7 +626,7 @@ mod tests {
let file = system_path_to_file(db, file_name).expect("Expected file to exist."); let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown); let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown);
assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected); assert_eq!(ty.display(db).to_string(), expected);
} }
#[test] #[test]
@ -733,17 +660,14 @@ class Sub(Base):
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist."); let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist"); let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist");
let Type::Class(class_id) = ty else { let Type::Class(class) = ty else {
panic!("Sub is not a Class") panic!("Sub is not a Class")
}; };
let context = TypingContext::global(&db); let base_names: Vec<_> = class
.bases(&db)
let base_names: Vec<_> = class_id
.lookup(&context)
.bases()
.iter() .iter()
.map(|base_ty| format!("{}", base_ty.display(&context))) .map(|base_ty| format!("{}", base_ty.display(&db)))
.collect(); .collect();
assert_eq!(base_names, vec!["Literal[Base]"]); assert_eq!(base_names, vec!["Literal[Base]"]);
@ -770,15 +694,13 @@ class C:
panic!("C is not a Class"); panic!("C is not a Class");
}; };
let context = TypingContext::global(&db); let member_ty = class_id.class_member(&db, &Name::new_static("f"));
let member_ty = class_id.class_member(&context, &Name::new_static("f"));
let Some(Type::Function(func_id)) = member_ty else { let Some(Type::Function(func)) = member_ty else {
panic!("C.f is not a Function"); panic!("C.f is not a Function");
}; };
let function_ty = func_id.lookup(&context); assert_eq!(func.name(&db), "f");
assert_eq!(function_ty.name(), "f");
Ok(()) Ok(())
} }