[red-knot] Fix absolute imports in `module.resolve_name` (#11180)

This commit is contained in:
Micha Reiser 2024-04-27 20:07:07 +02:00 committed by GitHub
parent 983a06cec3
commit 00d7c01cfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 98 additions and 54 deletions

View File

@ -1,4 +1,5 @@
use std::fmt::Formatter;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
@ -8,6 +9,7 @@ use smol_str::SmolStr;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::symbols::Dependency;
use crate::FxDashMap;
/// ID uniquely identifying a module.
@ -42,27 +44,30 @@ impl Module {
modules.modules.get(self).unwrap().kind
}
pub fn relative_name<Db>(&self, db: &Db, level: u32, module: Option<&str>) -> Option<ModuleName>
pub fn resolve_dependency<Db>(&self, db: &Db, dependency: &Dependency) -> Option<ModuleName>
where
Db: HasJar<SemanticJar>,
{
let (level, module) = match dependency {
Dependency::Module(module) => return Some(ModuleName::new(module)),
Dependency::Relative { level, module } => (*level, module.as_deref()),
};
let name = self.name(db);
let kind = self.kind(db);
let mut components = name.components().peekable();
if level > 0 {
let start = match kind {
// `.` resolves to the enclosing package
ModuleKind::Module => 0,
// `.` resolves to the current package
ModuleKind::Package => 1,
};
let start = match kind {
// `.` resolves to the enclosing package
ModuleKind::Module => 0,
// `.` resolves to the current package
ModuleKind::Package => 1,
};
// Skip over the relative parts.
for _ in start..level {
components.next_back()?;
}
// Skip over the relative parts.
for _ in start..level.get() {
components.next_back()?;
}
let mut name = String::new();
@ -141,6 +146,14 @@ impl ModuleName {
}
}
impl Deref for ModuleName {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl std::fmt::Display for ModuleName {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
@ -638,6 +651,8 @@ mod tests {
use crate::db::tests::TestDb;
use crate::db::{SemanticDb, SourceDb};
use crate::module::{ModuleKind, ModuleName, ModuleSearchPath, ModuleSearchPathKind};
use crate::symbols::Dependency;
use std::num::NonZeroU32;
struct TestCase {
temp_dir: tempfile::TempDir,
@ -971,7 +986,7 @@ mod tests {
}
#[test]
fn relative_name() -> std::io::Result<()> {
fn resolve_dependency() -> std::io::Result<()> {
let TestCase {
src,
db,
@ -993,40 +1008,73 @@ mod tests {
// `from . import bar` in `foo/__init__.py` resolves to `foo`
assert_eq!(
Some(ModuleName::new("foo")),
foo_module.relative_name(&db, 1, None)
foo_module.resolve_dependency(
&db,
&Dependency::Relative {
level: NonZeroU32::new(1).unwrap(),
module: None
}
)
);
// `from baz import bar` in `foo/__init__.py` should resolve to `foo/baz.py`
// `from baz import bar` in `foo/__init__.py` should resolve to `baz.py`
assert_eq!(
Some(ModuleName::new("foo.baz")),
foo_module.relative_name(&db, 0, Some("baz"))
Some(ModuleName::new("baz")),
foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))
);
// from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py`
assert_eq!(
Some(ModuleName::new("foo.bar")),
foo_module.relative_name(&db, 1, Some("bar"))
foo_module.resolve_dependency(
&db,
&Dependency::Relative {
level: NonZeroU32::new(1).unwrap(),
module: Some(ModuleName::new("bar"))
}
)
);
// from .. import test in `foo/__init__.py` resolves to `` which is not a module
assert_eq!(None, foo_module.relative_name(&db, 2, None));
assert_eq!(
None,
foo_module.resolve_dependency(
&db,
&Dependency::Relative {
level: NonZeroU32::new(2).unwrap(),
module: None
}
)
);
// `from . import test` in `foo/bar.py` resolves to `foo`
assert_eq!(
Some(ModuleName::new("foo")),
bar_module.relative_name(&db, 1, None)
bar_module.resolve_dependency(
&db,
&Dependency::Relative {
level: NonZeroU32::new(1).unwrap(),
module: None
}
)
);
// `from baz import test` in `foo/bar.py` resolves to `foo.bar.baz`
// `from baz import test` in `foo/bar.py` resolves to `baz`
assert_eq!(
Some(ModuleName::new("foo.bar.baz")),
bar_module.relative_name(&db, 0, Some("baz"))
Some(ModuleName::new("baz")),
bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))
);
// `from .baz import test` in `foo/bar.py` resolves to `foo.baz`.
assert_eq!(
Some(ModuleName::new("foo.baz")),
bar_module.relative_name(&db, 1, Some("baz"))
bar_module.resolve_dependency(
&db,
&Dependency::Relative {
level: NonZeroU32::new(1).unwrap(),
module: Some(ModuleName::new("baz"))
}
)
);
Ok(())

View File

@ -3,6 +3,7 @@ use crate::db::{SemanticDb, SourceDb};
use crate::files::FileId;
use crate::lint::Diagnostics;
use crate::program::Program;
use crate::symbols::Dependency;
use rayon::max_num_threads;
use rustc_hash::FxHashSet;
use std::num::NonZeroUsize;
@ -52,7 +53,12 @@ impl Program {
// Anyway, we need to figure out a way to retrieve the dependencies of a module
// from the persistent cache. So maybe it should be a separate query after all.
for dependency in dependencies {
let dependency_name = dependency.module_name(self, module);
let dependency_name = match dependency {
Dependency::Module(name) => Some(name.clone()),
Dependency::Relative { .. } => module
.as_ref()
.and_then(|module| module.resolve_dependency(self, dependency)),
};
if let Some(dependency_name) = dependency_name {
// TODO We may want to have a different check functions for non-first-party

View File

@ -2,6 +2,7 @@
use std::hash::{Hash, Hasher};
use std::iter::{Copied, DoubleEndedIterator, FusedIterator};
use std::num::NonZeroU32;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
@ -16,7 +17,7 @@ use crate::ast_ids::TypedNodeKey;
use crate::cache::KeyValueCache;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::{Module, ModuleName};
use crate::module::ModuleName;
use crate::Name;
#[allow(unreachable_pub)]
@ -111,34 +112,23 @@ pub(crate) enum Definition {
#[derive(Debug)]
pub(crate) struct ImportDefinition {
pub(crate) module: Name,
pub(crate) module: ModuleName,
}
#[derive(Debug)]
pub(crate) struct ImportFromDefinition {
pub(crate) module: Option<Name>,
pub(crate) module: Option<ModuleName>,
pub(crate) name: Name,
pub(crate) level: u32,
}
#[derive(Debug, Clone)]
pub(crate) enum Dependency {
Module(Name),
Relative { level: u32, module: Option<Name> },
}
impl Dependency {
pub(crate) fn module_name<Db>(&self, db: &Db, relative_to: Option<Module>) -> Option<ModuleName>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
match self {
Dependency::Module(name) => Some(ModuleName::new(name.as_str())),
Dependency::Relative { level, module } => {
relative_to?.relative_name(db, *level, module.as_deref())
}
}
}
pub enum Dependency {
Module(ModuleName),
Relative {
level: NonZeroU32,
module: Option<ModuleName>,
},
}
/// Table of all symbols in all scopes for a module.
@ -473,7 +463,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
alias.name.id.split('.').next().unwrap()
};
let module = Name::new(&alias.name.id);
let module = ModuleName::new(&alias.name.id);
let def = Definition::Import(ImportDefinition {
module: module.clone(),
@ -488,7 +478,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
level,
..
}) => {
let module = module.as_ref().map(|m| Name::new(&m.id));
let module = module.as_ref().map(|m| ModuleName::new(&m.id));
for alias in names {
let symbol_name = if let Some(asname) = &alias.asname {
@ -505,17 +495,17 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
}
let dependency = if let Some(module) = module {
if *level == 0 {
Dependency::Module(module)
} else {
Dependency::Relative {
level: *level,
match NonZeroU32::new(*level) {
Some(level) => Dependency::Relative {
level,
module: Some(module),
}
},
None => Dependency::Module(module),
}
} else {
Dependency::Relative {
level: *level,
level: NonZeroU32::new(*level)
.expect("Import without a module to have a level > 0"),
module,
}
};