From 4cb6a09fc0bebc9e116c2a834ecb1cd3a9aa813c Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 1 Jul 2024 10:22:34 +0200 Subject: [PATCH] Use `CompactString` for `ModuleName` (#12131) --- Cargo.lock | 12 +- Cargo.toml | 1 - crates/red_knot/Cargo.toml | 1 - crates/red_knot/src/lint.rs | 11 +- crates/red_knot/src/module.rs | 207 +++++++----------- crates/red_knot/src/program/check.rs | 2 +- crates/red_knot/src/semantic.rs | 6 +- crates/red_knot/src/semantic/definitions.rs | 2 +- crates/red_knot/src/semantic/symbol_table.rs | 2 +- crates/red_knot/src/semantic/types.rs | 3 +- crates/red_knot/src/semantic/types/infer.rs | 18 +- crates/red_knot_module_resolver/Cargo.toml | 2 +- crates/red_knot_module_resolver/src/module.rs | 27 +-- 13 files changed, 121 insertions(+), 173 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf476acf41..e1da51d675 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1881,7 +1881,6 @@ dependencies = [ "ruff_python_parser", "ruff_text_size", "rustc-hash 2.0.0", - "smol_str", "tempfile", "tracing", "tracing-subscriber", @@ -1893,13 +1892,13 @@ name = "red_knot_module_resolver" version = "0.0.0" dependencies = [ "anyhow", + "compact_str", "insta", "path-slash", "ruff_db", "ruff_python_stdlib", "rustc-hash 2.0.0", "salsa", - "smol_str", "tempfile", "tracing", "walkdir", @@ -2860,15 +2859,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "smol_str" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd538fb6910ac1099850255cf94a94df6551fbdd602454387d0adb2d1ca6dead" -dependencies = [ - "serde", -] - [[package]] name = "spin" version = "0.9.8" diff --git a/Cargo.toml b/Cargo.toml index 42c8f16000..3eaa8702e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -120,7 +120,6 @@ serde_with = { version = "3.6.0", default-features = false, features = [ shellexpand = { version = "3.0.0" } similar = { version = "2.4.0", features = ["inline"] } smallvec = { version = "1.13.2" } -smol_str = { version = "0.2.2" } static_assertions = "1.1.0" strum = { version = "0.26.0", features = ["strum_macros"] } strum_macros = { version = "0.26.0" } diff --git a/crates/red_knot/Cargo.toml b/crates/red_knot/Cargo.toml index 1fc5534a84..6ac07c1777 100644 --- a/crates/red_knot/Cargo.toml +++ b/crates/red_knot/Cargo.toml @@ -32,7 +32,6 @@ notify = { workspace = true } parking_lot = { workspace = true } rayon = { workspace = true } rustc-hash = { workspace = true } -smol_str = { version = "0.2.1" } tracing = { workspace = true } tracing-subscriber = { workspace = true } tracing-tree = { workspace = true } diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index a0e5a9cf0b..a801bf9196 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -1,3 +1,4 @@ +use red_knot_module_resolver::ModuleName; use std::cell::RefCell; use std::ops::{Deref, DerefMut}; use std::sync::Arc; @@ -10,7 +11,7 @@ use ruff_python_parser::Parsed; use crate::cache::KeyValueCache; use crate::db::{LintDb, LintJar, QueryResult}; use crate::files::FileId; -use crate::module::{resolve_module, ModuleName}; +use crate::module::resolve_module; use crate::parse::parse; use crate::semantic::{infer_definition_type, infer_symbol_public_type, Type}; use crate::semantic::{ @@ -145,7 +146,9 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> { // TODO we should have a special marker on the real typing module (from typeshed) so if you // have your own "typing" module in your project, we don't consider it THE typing module (and // same for other stdlib modules that our lint rules care about) - let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else { + let Some(typing_override) = + context.resolve_global_symbol(&ModuleName::new_static("typing").unwrap(), "override")? + else { // TODO once we bundle typeshed, this should be unreachable!() return Ok(()); }; @@ -236,10 +239,10 @@ impl<'a> SemanticLintContext<'a> { pub fn resolve_global_symbol( &self, - module: &str, + module: &ModuleName, symbol_name: &str, ) -> QueryResult> { - let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else { + let Some(module) = resolve_module(self.db.upcast(), module)? else { return Ok(None); }; diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs index 4dfb9e74b9..3e7672b899 100644 --- a/crates/red_knot/src/module.rs +++ b/crates/red_knot/src/module.rs @@ -5,9 +5,8 @@ use std::sync::atomic::AtomicU32; use std::sync::Arc; use dashmap::mapref::entry::Entry; -use smol_str::SmolStr; -use red_knot_module_resolver::ModuleKind; +use red_knot_module_resolver::{ModuleKind, ModuleName}; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; @@ -95,87 +94,7 @@ impl Module { name.push_str(part); } - Ok(if name.is_empty() { - None - } else { - Some(ModuleName(SmolStr::new(name))) - }) - } -} - -/// A module name, e.g. `foo.bar`. -/// -/// Always normalized to the absolute form -/// (never a relative module name, i.e., never `.foo`). -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct ModuleName(smol_str::SmolStr); - -impl ModuleName { - pub fn new(name: &str) -> Self { - debug_assert!(!name.is_empty()); - - Self(smol_str::SmolStr::new(name)) - } - - fn from_relative_path(path: &Path) -> Option { - let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") { - path.parent()? - } else { - path - }; - - let name = if let Some(parent) = path.parent() { - let mut name = String::with_capacity(path.as_os_str().len()); - - for component in parent.components() { - name.push_str(component.as_os_str().to_str()?); - name.push('.'); - } - - // SAFETY: Unwrap is safe here or `parent` would have returned `None`. - name.push_str(path.file_stem().unwrap().to_str()?); - - smol_str::SmolStr::from(name) - } else { - smol_str::SmolStr::new(path.file_stem()?.to_str()?) - }; - - Some(Self(name)) - } - - /// An iterator over the components of the module name: - /// `foo.bar.baz` -> `foo`, `bar`, `baz` - pub fn components(&self) -> impl DoubleEndedIterator { - self.0.split('.') - } - - /// The name of this module's immediate parent, if it has a parent - pub fn parent(&self) -> Option { - let (_, parent) = self.0.rsplit_once('.')?; - - Some(Self(smol_str::SmolStr::new(parent))) - } - - pub fn starts_with(&self, other: &ModuleName) -> bool { - self.0.starts_with(other.0.as_str()) - } - - pub fn as_str(&self) -> &str { - &self.0 - } -} - -impl Deref for ModuleName { - type Target = str; - - fn deref(&self) -> &Self::Target { - self.as_str() - } -} - -impl std::fmt::Display for ModuleName { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.0) + Ok(ModuleName::new(&name)) } } @@ -262,7 +181,7 @@ pub struct ModuleData { /// and, therefore, cannot be used as part of a query. /// For this to work with salsa, it would be necessary to intern all `ModuleName`s. #[tracing::instrument(level = "debug", skip(db))] -pub fn resolve_module(db: &dyn SemanticDb, name: ModuleName) -> QueryResult> { +pub fn resolve_module(db: &dyn SemanticDb, name: &ModuleName) -> QueryResult> { let jar: &SemanticJar = db.jar()?; let modules = &jar.module_resolver; @@ -271,7 +190,7 @@ pub fn resolve_module(db: &dyn SemanticDb, name: ModuleName) -> QueryResult Ok(Some(*entry.get())), Entry::Vacant(entry) => { - let Some((root_path, absolute_path, kind)) = resolve_name(&name, &modules.search_paths) + let Some((root_path, absolute_path, kind)) = resolve_name(name, &modules.search_paths) else { return Ok(None); }; @@ -288,9 +207,14 @@ pub fn resolve_module(db: &dyn SemanticDb, name: ModuleName) -> QueryResult QueryResult QueryResult QueryResult Option { + let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") { + path.parent()? + } else { + path + }; + + let name = if let Some(parent) = path.parent() { + let mut name = String::with_capacity(path.to_str().unwrap().len()); + + for component in parent.components() { + name.push_str(component.as_os_str().to_str()?); + name.push('.'); + } + + // SAFETY: Unwrap is safe here or `parent` would have returned `None`. + name.push_str(path.file_stem().unwrap().to_str().unwrap()); + + name + } else { + path.file_stem()?.to_str().unwrap().to_string() + }; + + ModuleName::new(&name) +} + ////////////////////////////////////////////////////// // Mutations ////////////////////////////////////////////////////// @@ -763,13 +713,14 @@ impl PackageKind { #[cfg(test)] mod tests { + use red_knot_module_resolver::ModuleName; use std::num::NonZeroU32; use std::path::PathBuf; use crate::db::tests::TestDb; use crate::db::SourceDb; use crate::module::{ - path_to_module, resolve_module, set_module_search_paths, ModuleKind, ModuleName, + path_to_module, resolve_module, set_module_search_paths, ModuleKind, ModuleResolutionInputs, TYPESHED_STDLIB_DIRECTORY, }; use crate::semantic::Dependency; @@ -829,14 +780,12 @@ mod tests { let foo_path = src.join("foo.py"); std::fs::write(&foo_path, "print('Hello, world!')")?; - let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let foo_name = ModuleName::new_static("foo").unwrap(); + let foo_module = resolve_module(&db, &foo_name)?.unwrap(); - assert_eq!( - Some(foo_module), - resolve_module(&db, ModuleName::new("foo"))? - ); + assert_eq!(Some(foo_module), resolve_module(&db, &foo_name)?); - assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?); + assert_eq!(foo_name, foo_module.name(&db)?); assert_eq!(&src, foo_module.path(&db)?.root().path()); assert_eq!(ModuleKind::Module, foo_module.kind(&db)?); assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); @@ -855,13 +804,14 @@ mod tests { } = create_resolver()?; let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY); std::fs::create_dir_all(&stdlib_dir).unwrap(); + let functools_name = ModuleName::new_static("functools").unwrap(); let functools_path = stdlib_dir.join("functools.py"); std::fs::write(&functools_path, "def update_wrapper(): ...").unwrap(); - let functools_module = resolve_module(&db, ModuleName::new("functools"))?.unwrap(); + let functools_module = resolve_module(&db, &functools_name)?.unwrap(); assert_eq!( Some(functools_module), - resolve_module(&db, ModuleName::new("functools"))? + resolve_module(&db, &functools_name)? ); assert_eq!(&stdlib_dir, functools_module.path(&db)?.root().path()); assert_eq!(ModuleKind::Module, functools_module.kind(&db)?); @@ -895,11 +845,12 @@ mod tests { let first_party_functools_path = src.join("functools.py"); std::fs::write(stdlib_functools_path, "def update_wrapper(): ...").unwrap(); std::fs::write(&first_party_functools_path, "def update_wrapper(): ...").unwrap(); - let functools_module = resolve_module(&db, ModuleName::new("functools"))?.unwrap(); + let functools_name = ModuleName::new_static("functools").unwrap(); + let functools_module = resolve_module(&db, &functools_name)?.unwrap(); assert_eq!( Some(functools_module), - resolve_module(&db, ModuleName::new("functools"))? + resolve_module(&db, &functools_name)? ); assert_eq!(&src, functools_module.path(&db).unwrap().root().path()); assert_eq!(ModuleKind::Module, functools_module.kind(&db)?); @@ -925,14 +876,15 @@ mod tests { .. } = create_resolver()?; + let foo_name = ModuleName::new("foo").unwrap(); let foo_dir = src.join("foo"); let foo_path = foo_dir.join("__init__.py"); std::fs::create_dir(&foo_dir)?; std::fs::write(&foo_path, "print('Hello, world!')")?; - let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, &foo_name)?.unwrap(); - assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?); + assert_eq!(foo_name, foo_module.name(&db)?); assert_eq!(&src, foo_module.path(&db)?.root().path()); assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); @@ -961,7 +913,7 @@ mod tests { let foo_py = src.join("foo.py"); std::fs::write(&foo_py, "print('Hello, world!')")?; - let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); assert_eq!(&src, foo_module.path(&db)?.root().path()); assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file())); @@ -987,7 +939,7 @@ mod tests { std::fs::write(&foo_stub, "x: int")?; std::fs::write(&foo_py, "print('Hello, world!')")?; - let foo = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let foo = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); assert_eq!(&src, foo.path(&db)?.root().path()); assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file())); @@ -1016,7 +968,7 @@ mod tests { std::fs::write(bar.join("__init__.py"), "")?; std::fs::write(&baz, "print('Hello, world!')")?; - let baz_module = resolve_module(&db, ModuleName::new("foo.bar.baz"))?.unwrap(); + let baz_module = resolve_module(&db, &ModuleName::new("foo.bar.baz").unwrap())?.unwrap(); assert_eq!(&src, baz_module.path(&db)?.root().path()); assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file())); @@ -1063,11 +1015,13 @@ mod tests { std::fs::create_dir_all(&child2)?; std::fs::write(&two, "print('Hello, world!')")?; - let one_module = resolve_module(&db, ModuleName::new("parent.child.one"))?.unwrap(); + let one_module = + resolve_module(&db, &ModuleName::new("parent.child.one").unwrap())?.unwrap(); assert_eq!(Some(one_module), path_to_module(&db, &one)?); - let two_module = resolve_module(&db, ModuleName::new("parent.child.two"))?.unwrap(); + let two_module = + resolve_module(&db, &ModuleName::new("parent.child.two").unwrap())?.unwrap(); assert_eq!(Some(two_module), path_to_module(&db, &two)?); Ok(()) @@ -1111,13 +1065,14 @@ mod tests { std::fs::create_dir_all(&child2)?; std::fs::write(two, "print('Hello, world!')")?; - let one_module = resolve_module(&db, ModuleName::new("parent.child.one"))?.unwrap(); + let one_module = + resolve_module(&db, &ModuleName::new("parent.child.one").unwrap())?.unwrap(); assert_eq!(Some(one_module), path_to_module(&db, &one)?); assert_eq!( None, - resolve_module(&db, ModuleName::new("parent.child.two"))? + resolve_module(&db, &ModuleName::new("parent.child.two").unwrap())? ); Ok(()) } @@ -1138,7 +1093,7 @@ mod tests { std::fs::write(&foo_src, "")?; std::fs::write(&foo_site_packages, "")?; - let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); + let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); assert_eq!(&src, foo_module.path(&db)?.root().path()); assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file())); @@ -1165,8 +1120,8 @@ mod tests { std::fs::write(&foo, "")?; std::os::unix::fs::symlink(&foo, &bar)?; - let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); - let bar_module = resolve_module(&db, ModuleName::new("bar"))?.unwrap(); + let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); + let bar_module = resolve_module(&db, &ModuleName::new("bar").unwrap())?.unwrap(); assert_ne!(foo_module, bar_module); @@ -1202,12 +1157,12 @@ mod tests { std::fs::write(foo_path, "from .bar import test")?; std::fs::write(bar_path, "test = 'Hello world'")?; - let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap(); - let bar_module = resolve_module(&db, ModuleName::new("foo.bar"))?.unwrap(); + let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); + let bar_module = resolve_module(&db, &ModuleName::new("foo.bar").unwrap())?.unwrap(); // `from . import bar` in `foo/__init__.py` resolves to `foo` assert_eq!( - Some(ModuleName::new("foo")), + ModuleName::new("foo"), foo_module.resolve_dependency( &db, &Dependency::Relative { @@ -1219,18 +1174,19 @@ mod tests { // `from baz import bar` in `foo/__init__.py` should resolve to `baz.py` assert_eq!( - Some(ModuleName::new("baz")), - foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))? + ModuleName::new("baz"), + foo_module + .resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz").unwrap()))? ); // from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py` assert_eq!( - Some(ModuleName::new("foo.bar")), + ModuleName::new("foo.bar"), foo_module.resolve_dependency( &db, &Dependency::Relative { level: NonZeroU32::new(1).unwrap(), - module: Some(ModuleName::new("bar")) + module: ModuleName::new("bar") } )? ); @@ -1249,7 +1205,7 @@ mod tests { // `from . import test` in `foo/bar.py` resolves to `foo` assert_eq!( - Some(ModuleName::new("foo")), + ModuleName::new("foo"), bar_module.resolve_dependency( &db, &Dependency::Relative { @@ -1261,18 +1217,19 @@ mod tests { // `from baz import test` in `foo/bar.py` resolves to `baz` assert_eq!( - Some(ModuleName::new("baz")), - bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))? + ModuleName::new("baz"), + bar_module + .resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz").unwrap()))? ); // `from .baz import test` in `foo/bar.py` resolves to `foo.baz`. assert_eq!( - Some(ModuleName::new("foo.baz")), + ModuleName::new("foo.baz"), bar_module.resolve_dependency( &db, &Dependency::Relative { level: NonZeroU32::new(1).unwrap(), - module: Some(ModuleName::new("baz")) + module: ModuleName::new("baz") } )? ); diff --git a/crates/red_knot/src/program/check.rs b/crates/red_knot/src/program/check.rs index bf2bfa71af..872b52e9f7 100644 --- a/crates/red_knot/src/program/check.rs +++ b/crates/red_knot/src/program/check.rs @@ -51,7 +51,7 @@ impl Program { // TODO We may want to have a different check functions for non-first-party // files because we only need to index them and not check them. // Supporting non-first-party code also requires supporting typing stubs. - if let Some(dependency) = resolve_module(self, dependency_name)? { + if let Some(dependency) = resolve_module(self, &dependency_name)? { if dependency.path(self)?.root().kind().is_first_party() { context.schedule_dependency(dependency.path(self)?.file()); } diff --git a/crates/red_knot/src/semantic.rs b/crates/red_knot/src/semantic.rs index 73d57c8e33..be4753be96 100644 --- a/crates/red_knot/src/semantic.rs +++ b/crates/red_knot/src/semantic.rs @@ -9,12 +9,12 @@ use crate::cache::KeyValueCache; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::module::Module; -use crate::module::ModuleName; use crate::parse::parse; pub(crate) use definitions::Definition; use definitions::{ImportDefinition, ImportFromDefinition}; pub(crate) use flow_graph::ConstrainedDefinition; use flow_graph::{FlowGraph, FlowGraphBuilder, FlowNodeId, ReachableDefinitionsIterator}; +use red_knot_module_resolver::ModuleName; use ruff_index::{newtype_index, IndexVec}; use rustc_hash::FxHashMap; use std::ops::{Deref, DerefMut}; @@ -410,7 +410,7 @@ impl SourceOrderVisitor<'_> for SemanticIndexer { alias.name.id.split('.').next().unwrap() }; - let module = ModuleName::new(&alias.name.id); + let module = ModuleName::new(&alias.name.id).unwrap(); let def = Definition::Import(ImportDefinition { module: module.clone(), @@ -426,7 +426,7 @@ impl SourceOrderVisitor<'_> for SemanticIndexer { level, .. }) => { - let module = module.as_ref().map(|m| ModuleName::new(&m.id)); + let module = module.as_ref().and_then(|m| ModuleName::new(&m.id)); for alias in names { let symbol_name = if let Some(asname) = &alias.asname { diff --git a/crates/red_knot/src/semantic/definitions.rs b/crates/red_knot/src/semantic/definitions.rs index 149fcb4bf2..112e9d03b9 100644 --- a/crates/red_knot/src/semantic/definitions.rs +++ b/crates/red_knot/src/semantic/definitions.rs @@ -1,5 +1,5 @@ use crate::ast_ids::TypedNodeKey; -use crate::semantic::ModuleName; +use red_knot_module_resolver::ModuleName; use ruff_python_ast as ast; use ruff_python_ast::name::Name; diff --git a/crates/red_knot/src/semantic/symbol_table.rs b/crates/red_knot/src/semantic/symbol_table.rs index a272a6ae4e..9bca6ce0b8 100644 --- a/crates/red_knot/src/semantic/symbol_table.rs +++ b/crates/red_knot/src/semantic/symbol_table.rs @@ -6,13 +6,13 @@ use std::num::NonZeroU32; use bitflags::bitflags; use hashbrown::hash_map::{Keys, RawEntryMut}; +use red_knot_module_resolver::ModuleName; use rustc_hash::{FxHashMap, FxHasher}; use ruff_index::{newtype_index, IndexVec}; use ruff_python_ast::name::Name; use crate::ast_ids::NodeKey; -use crate::module::ModuleName; use crate::semantic::{Definition, ExpressionId}; type Map = hashbrown::HashMap; diff --git a/crates/red_knot/src/semantic/types.rs b/crates/red_knot/src/semantic/types.rs index a9bf11241b..1d0d8a798e 100644 --- a/crates/red_knot/src/semantic/types.rs +++ b/crates/red_knot/src/semantic/types.rs @@ -2,7 +2,7 @@ use crate::ast_ids::NodeKey; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::module::{Module, ModuleName}; +use crate::module::Module; use crate::semantic::{ resolve_global_symbol, semantic_index, GlobalSymbolId, ScopeId, ScopeKind, SymbolId, }; @@ -14,6 +14,7 @@ use rustc_hash::FxHashMap; pub(crate) mod infer; pub(crate) use infer::{infer_definition_type, infer_symbol_public_type}; +use red_knot_module_resolver::ModuleName; use ruff_python_ast::name::Name; /// unique ID for a type diff --git a/crates/red_knot/src/semantic/types/infer.rs b/crates/red_knot/src/semantic/types/infer.rs index 1aa8ac8808..af68e00a6e 100644 --- a/crates/red_knot/src/semantic/types/infer.rs +++ b/crates/red_knot/src/semantic/types/infer.rs @@ -1,12 +1,13 @@ #![allow(dead_code)] +use red_knot_module_resolver::ModuleName; use ruff_python_ast as ast; use ruff_python_ast::AstNode; use std::fmt::Debug; use crate::db::{QueryResult, SemanticDb, SemanticJar}; -use crate::module::{resolve_module, ModuleName}; +use crate::module::resolve_module; use crate::parse::parse; use crate::semantic::types::{ModuleTypeId, Type}; use crate::semantic::{ @@ -136,7 +137,7 @@ pub fn infer_definition_type( Definition::Import(ImportDefinition { module: module_name, }) => { - if let Some(module) = resolve_module(db, module_name.clone())? { + if let Some(module) = resolve_module(db, &module_name)? { Ok(Type::Module(ModuleTypeId { module, file_id })) } else { Ok(Type::Unknown) @@ -149,8 +150,9 @@ pub fn infer_definition_type( }) => { // TODO relative imports assert!(matches!(level, 0)); - let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - let Some(module) = resolve_module(db, module_name.clone())? else { + let module_name = + ModuleName::new(module.as_ref().expect("TODO relative imports")).unwrap(); + let Some(module) = resolve_module(db, &module_name)? else { return Ok(Type::Unknown); }; @@ -343,14 +345,13 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu #[cfg(test)] mod tests { + use red_knot_module_resolver::ModuleName; use ruff_python_ast::name::Name; use std::path::PathBuf; use crate::db::tests::TestDb; use crate::db::{HasJar, SemanticJar}; - use crate::module::{ - resolve_module, set_module_search_paths, ModuleName, ModuleResolutionInputs, - }; + use crate::module::{resolve_module, set_module_search_paths, ModuleResolutionInputs}; use crate::semantic::{infer_symbol_public_type, resolve_global_symbol, Type}; // TODO with virtual filesystem we shouldn't have to write files to disk for these @@ -395,7 +396,8 @@ mod tests { variable_name: &str, ) -> anyhow::Result { let db = &case.db; - let module = resolve_module(db, ModuleName::new(module_name))?.expect("Module to exist"); + let module = + resolve_module(db, &ModuleName::new(module_name).unwrap())?.expect("Module to exist"); let symbol = resolve_global_symbol(db, module, variable_name)?.expect("symbol to exist"); Ok(infer_symbol_public_type(db, symbol)?) diff --git a/crates/red_knot_module_resolver/Cargo.toml b/crates/red_knot_module_resolver/Cargo.toml index c409abb0f7..ec05ec525b 100644 --- a/crates/red_knot_module_resolver/Cargo.toml +++ b/crates/red_knot_module_resolver/Cargo.toml @@ -14,9 +14,9 @@ license = { workspace = true } ruff_db = { workspace = true } ruff_python_stdlib = { workspace = true } +compact_str = { workspace = true } rustc-hash = { workspace = true } salsa = { workspace = true } -smol_str = { workspace = true } tracing = { workspace = true } zip = { workspace = true } diff --git a/crates/red_knot_module_resolver/src/module.rs b/crates/red_knot_module_resolver/src/module.rs index 45ad78145c..8657c4a196 100644 --- a/crates/red_knot_module_resolver/src/module.rs +++ b/crates/red_knot_module_resolver/src/module.rs @@ -1,3 +1,4 @@ +use compact_str::ToCompactString; use std::fmt::Formatter; use std::ops::Deref; use std::sync::Arc; @@ -12,7 +13,7 @@ use crate::Db; /// /// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`). #[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] -pub struct ModuleName(smol_str::SmolStr); +pub struct ModuleName(compact_str::CompactString); impl ModuleName { /// Creates a new module name for `name`. Returns `Some` if `name` is a valid, absolute @@ -27,7 +28,7 @@ impl ModuleName { /// * A component of a name (the part between two dots) isn't a valid python identifier. #[inline] pub fn new(name: &str) -> Option { - Self::new_from_smol(smol_str::SmolStr::new(name)) + Self::is_valid_name(name).then(|| Self(compact_str::CompactString::from(name))) } /// Creates a new module name for `name` where `name` is a static string. @@ -56,19 +57,16 @@ impl ModuleName { /// ``` #[inline] pub fn new_static(name: &'static str) -> Option { - Self::new_from_smol(smol_str::SmolStr::new_static(name)) + // TODO(Micha): Use CompactString::const_new once we upgrade to 0.8 https://github.com/ParkMyCar/compact_str/pull/336 + Self::is_valid_name(name).then(|| Self(compact_str::CompactString::from(name))) } - fn new_from_smol(name: smol_str::SmolStr) -> Option { + fn is_valid_name(name: &str) -> bool { if name.is_empty() { - return None; + return false; } - if name.split('.').all(is_identifier) { - Some(Self(name)) - } else { - None - } + name.split('.').all(is_identifier) } /// An iterator over the components of the module name: @@ -97,8 +95,7 @@ impl ModuleName { /// ``` pub fn parent(&self) -> Option { let (parent, _) = self.0.rsplit_once('.')?; - - Some(Self(smol_str::SmolStr::new(parent))) + Some(Self(parent.to_compact_string())) } /// Returns `true` if the name starts with `other`. @@ -141,7 +138,7 @@ impl ModuleName { }; let name = if let Some(parent) = path.parent() { - let mut name = String::with_capacity(path.as_str().len()); + let mut name = compact_str::CompactString::with_capacity(path.as_str().len()); for component in parent.components() { name.push_str(component.as_os_str().to_str()?); @@ -151,9 +148,9 @@ impl ModuleName { // SAFETY: Unwrap is safe here or `parent` would have returned `None`. name.push_str(path.file_stem().unwrap()); - smol_str::SmolStr::from(name) + name } else { - smol_str::SmolStr::new(path.file_stem()?) + path.file_stem()?.to_compact_string() }; Some(Self(name))