diff --git a/resources/test/fixtures/flake8_annotations/allow_nested_overload.py b/resources/test/fixtures/flake8_annotations/allow_nested_overload.py new file mode 100644 index 0000000000..19cefc8eaa --- /dev/null +++ b/resources/test/fixtures/flake8_annotations/allow_nested_overload.py @@ -0,0 +1,9 @@ +class C: + from typing import overload + + @overload + def f(self, x: int, y: int) -> None: + ... + + def f(self, x, y): + pass diff --git a/resources/test/fixtures/pyflakes/F401_6.py b/resources/test/fixtures/pyflakes/F401_6.py index 16826a4f14..a44ecef8ea 100644 --- a/resources/test/fixtures/pyflakes/F401_6.py +++ b/resources/test/fixtures/pyflakes/F401_6.py @@ -9,7 +9,6 @@ from .background import BackgroundTasks # F401 `datastructures.UploadFile` imported but unused from .datastructures import UploadFile as FileUpload - # OK import applications as applications diff --git a/resources/test/fixtures/pygrep-hooks/PGH002_0.py b/resources/test/fixtures/pygrep-hooks/PGH002_0.py index e6aa467b13..523b65bf41 100644 --- a/resources/test/fixtures/pygrep-hooks/PGH002_0.py +++ b/resources/test/fixtures/pygrep-hooks/PGH002_0.py @@ -5,4 +5,3 @@ from warnings import warn warnings.warn("this is ok") warn("by itself is also ok") logging.warning("this is fine") -log.warning("this is ok") diff --git a/resources/test/fixtures/pygrep-hooks/PGH002_1.py b/resources/test/fixtures/pygrep-hooks/PGH002_1.py index 861e6ed99c..c2866d0969 100644 --- a/resources/test/fixtures/pygrep-hooks/PGH002_1.py +++ b/resources/test/fixtures/pygrep-hooks/PGH002_1.py @@ -2,14 +2,4 @@ import logging from logging import warn logging.warn("this is not ok") -log.warn("this is also not ok") warn("not ok") - - -def foo(): - from logging import warn - - def warn(): - pass - - warn("has been redefined, but we will still report it") diff --git a/src/ast/function_type.rs b/src/ast/function_type.rs index 07cb9d5f5b..290f559b76 100644 --- a/src/ast/function_type.rs +++ b/src/ast/function_type.rs @@ -1,10 +1,8 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::Expr; -use crate::ast::helpers::{ - collect_call_paths, dealias_call_path, match_call_path, to_module_and_member, -}; +use crate::ast::helpers::to_call_path; use crate::ast::types::{Scope, ScopeKind}; +use crate::checkers::ast::Checker; const CLASS_METHODS: [&str; 3] = ["__new__", "__init_subclass__", "__class_getitem__"]; const METACLASS_BASES: [(&str, &str); 2] = [("", "type"), ("abc", "ABCMeta")]; @@ -18,11 +16,10 @@ pub enum FunctionType { /// Classify a function based on its scope, name, and decorators. pub fn classify( + checker: &Checker, scope: &Scope, name: &str, decorator_list: &[Expr], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, classmethod_decorators: &[String], staticmethod_decorators: &[String], ) -> FunctionType { @@ -33,17 +30,18 @@ pub fn classify( if CLASS_METHODS.contains(&name) || scope.bases.iter().any(|expr| { // The class itself extends a known metaclass, so all methods are class methods. - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - METACLASS_BASES - .iter() - .any(|(module, member)| match_call_path(&call_path, module, member, from_imports)) + checker.resolve_call_path(expr).map_or(false, |call_path| { + METACLASS_BASES + .iter() + .any(|(module, member)| call_path == [*module, *member]) + }) }) || decorator_list.iter().any(|expr| { // The method is decorated with a class method decorator (like `@classmethod`). - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - classmethod_decorators.iter().any(|decorator| { - let (module, member) = to_module_and_member(decorator); - match_call_path(&call_path, module, member, from_imports) + checker.resolve_call_path(expr).map_or(false, |call_path| { + classmethod_decorators + .iter() + .any(|decorator| call_path == to_call_path(decorator)) }) }) { @@ -51,10 +49,10 @@ pub fn classify( } else if decorator_list.iter().any(|expr| { // The method is decorated with a static method decorator (like // `@staticmethod`). - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - staticmethod_decorators.iter().any(|decorator| { - let (module, member) = to_module_and_member(decorator); - match_call_path(&call_path, module, member, from_imports) + checker.resolve_call_path(expr).map_or(false, |call_path| { + staticmethod_decorators + .iter() + .any(|decorator| call_path == to_call_path(decorator)) }) }) { FunctionType::StaticMethod diff --git a/src/ast/helpers.rs b/src/ast/helpers.rs index 41164a09f8..4234c4d723 100644 --- a/src/ast/helpers.rs +++ b/src/ast/helpers.rs @@ -12,6 +12,7 @@ use rustpython_parser::lexer::Tok; use rustpython_parser::token::StringKind; use crate::ast::types::{Binding, BindingKind, Range}; +use crate::checkers::ast::Checker; use crate::source_code::{Generator, Locator, Stylist}; /// Create an `Expr` with default location from an `ExprKind`. @@ -54,150 +55,42 @@ fn collect_call_path_inner<'a>(expr: &'a Expr, parts: &mut Vec<&'a str>) { } } -/// Convert an `Expr` to its call path (like `List`, or `typing.List`). -pub fn compose_call_path(expr: &Expr) -> Option { - let segments = collect_call_paths(expr); - if segments.is_empty() { - None - } else { - Some(segments.join(".")) - } -} - /// Convert an `Expr` to its call path segments (like ["typing", "List"]). -pub fn collect_call_paths(expr: &Expr) -> Vec<&str> { +pub fn collect_call_path(expr: &Expr) -> Vec<&str> { let mut segments = vec![]; collect_call_path_inner(expr, &mut segments); segments } -/// Rewrite any import aliases on a call path. -pub fn dealias_call_path<'a>( - call_path: Vec<&'a str>, - import_aliases: &FxHashMap<&str, &'a str>, -) -> Vec<&'a str> { - if let Some(head) = call_path.first() { - if let Some(origin) = import_aliases.get(head) { - let tail = &call_path[1..]; - let mut call_path: Vec<&str> = vec![]; - call_path.extend(origin.split('.')); - call_path.extend(tail); - call_path - } else { - call_path - } +/// Convert an `Expr` to its call path (like `List`, or `typing.List`). +pub fn compose_call_path(expr: &Expr) -> Option { + let call_path = collect_call_path(expr); + if call_path.is_empty() { + None } else { - call_path + Some(format_call_path(&call_path)) } } -/// Return `true` if the `Expr` is a reference to `${module}.${target}`. -/// -/// Useful for, e.g., ensuring that a `Union` reference represents -/// `typing.Union`. -pub fn match_module_member( - expr: &Expr, - module: &str, - member: &str, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - match_call_path( - &dealias_call_path(collect_call_paths(expr), import_aliases), - module, - member, - from_imports, - ) -} - -/// Return `true` if the `call_path` is a reference to `${module}.${target}`. -/// -/// Optimized version of `match_module_member` for pre-computed call paths. -pub fn match_call_path( - call_path: &[&str], - module: &str, - member: &str, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, -) -> bool { - // If we have no segments, we can't ever match. - let num_segments = call_path.len(); - if num_segments == 0 { - return false; - } - - // If the last segment doesn't match the member, we can't ever match. - if call_path[num_segments - 1] != member { - return false; - } - - // We now only need the module path, so throw out the member name. - let call_path = &call_path[..num_segments - 1]; - let num_segments = call_path.len(); - - // Case (1): It's a builtin (like `list`). - // Case (2a): We imported from the parent (`from typing.re import Match`, - // `Match`). - // Case (2b): We imported star from the parent (`from typing.re import *`, - // `Match`). - if num_segments == 0 { - module.is_empty() - || from_imports.get(module).map_or(false, |imports| { - imports.contains(member) || imports.contains("*") - }) +/// Format a call path for display. +pub fn format_call_path(call_path: &[&str]) -> String { + if call_path + .first() + .expect("Unable to format empty call path") + .is_empty() + { + call_path[1..].join(".") } else { - let components: Vec<&str> = module.split('.').collect(); - - // Case (3a): it's a fully qualified call path (`import typing`, - // `typing.re.Match`). Case (3b): it's a fully qualified call path (`import - // typing.re`, `typing.re.Match`). - if components == call_path { - return true; - } - - // Case (4): We imported from the grandparent (`from typing import re`, - // `re.Match`) - let num_matches = (0..components.len()) - .take(num_segments) - .take_while(|i| components[components.len() - 1 - i] == call_path[num_segments - 1 - i]) - .count(); - if num_matches > 0 { - let cut = components.len() - num_matches; - // TODO(charlie): Rewrite to avoid this allocation. - let module = components[..cut].join("."); - let member = components[cut]; - if from_imports - .get(&module.as_str()) - .map_or(false, |imports| imports.contains(member)) - { - return true; - } - } - - false + call_path.join(".") } } /// Return `true` if the `Expr` contains a reference to `${module}.${target}`. -pub fn contains_call_path( - expr: &Expr, - module: &str, - member: &str, - import_aliases: &FxHashMap<&str, &str>, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, -) -> bool { +pub fn contains_call_path(checker: &Checker, expr: &Expr, target: &[&str]) -> bool { any_over_expr(expr, &|expr| { - let call_path = collect_call_paths(expr); - if !call_path.is_empty() { - if match_call_path( - &dealias_call_path(call_path, import_aliases), - module, - member, - from_imports, - ) { - return true; - } - } - false + checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == target) }) } @@ -389,13 +282,13 @@ pub fn extract_handler_names(handlers: &[Excepthandler]) -> Vec> { if let Some(type_) = type_ { if let ExprKind::Tuple { elts, .. } = &type_.node { for type_ in elts { - let call_path = collect_call_paths(type_); + let call_path = collect_call_path(type_); if !call_path.is_empty() { handler_names.push(call_path); } } } else { - let call_path = collect_call_paths(type_); + let call_path = collect_call_path(type_); if !call_path.is_empty() { handler_names.push(call_path); } @@ -458,12 +351,37 @@ pub fn format_import_from(level: Option<&usize>, module: Option<&str>) -> String module_name } +/// Format the member reference name for a relative import. +pub fn format_import_from_member( + level: Option<&usize>, + module: Option<&str>, + member: &str, +) -> String { + let mut full_name = String::with_capacity( + level.map_or(0, |level| *level) + + module.as_ref().map_or(0, |module| module.len()) + + 1 + + member.len(), + ); + if let Some(level) = level { + for _ in 0..*level { + full_name.push('.'); + } + } + if let Some(module) = module { + full_name.push_str(module); + full_name.push('.'); + } + full_name.push_str(member); + full_name +} + /// Split a target string (like `typing.List`) into (`typing`, `List`). -pub fn to_module_and_member(target: &str) -> (&str, &str) { - if let Some(index) = target.rfind('.') { - (&target[..index], &target[index + 1..]) +pub fn to_call_path(target: &str) -> Vec<&str> { + if target.contains('.') { + target.split('.').collect() } else { - ("", target) + vec!["", target] } } @@ -786,159 +704,13 @@ impl<'a> SimpleCallArgs<'a> { #[cfg(test)] mod tests { use anyhow::Result; - use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::Location; use rustpython_parser::parser; - use crate::ast::helpers::{ - else_range, identifier_range, match_module_member, match_trailing_content, - }; + use crate::ast::helpers::{else_range, identifier_range, match_trailing_content}; use crate::ast::types::Range; use crate::source_code::Locator; - #[test] - fn builtin() -> Result<()> { - let expr = parser::parse_expression("list", "")?; - assert!(match_module_member( - &expr, - "", - "list", - &FxHashMap::default(), - &FxHashMap::default(), - )); - Ok(()) - } - - #[test] - fn fully_qualified() -> Result<()> { - let expr = parser::parse_expression("typing.re.Match", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::default(), - &FxHashMap::default(), - )); - Ok(()) - } - - #[test] - fn unimported() -> Result<()> { - let expr = parser::parse_expression("Match", "")?; - assert!(!match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::default(), - &FxHashMap::default(), - )); - let expr = parser::parse_expression("re.Match", "")?; - assert!(!match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::default(), - &FxHashMap::default(), - )); - Ok(()) - } - - #[test] - fn from_star() -> Result<()> { - let expr = parser::parse_expression("Match", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::from_iter([("typing.re", FxHashSet::from_iter(["*"]))]), - &FxHashMap::default() - )); - Ok(()) - } - - #[test] - fn from_parent() -> Result<()> { - let expr = parser::parse_expression("Match", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::from_iter([("typing.re", FxHashSet::from_iter(["Match"]))]), - &FxHashMap::default() - )); - Ok(()) - } - - #[test] - fn from_grandparent() -> Result<()> { - let expr = parser::parse_expression("re.Match", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::from_iter([("typing", FxHashSet::from_iter(["re"]))]), - &FxHashMap::default() - )); - - let expr = parser::parse_expression("match.Match", "")?; - assert!(match_module_member( - &expr, - "typing.re.match", - "Match", - &FxHashMap::from_iter([("typing.re", FxHashSet::from_iter(["match"]))]), - &FxHashMap::default() - )); - - let expr = parser::parse_expression("re.match.Match", "")?; - assert!(match_module_member( - &expr, - "typing.re.match", - "Match", - &FxHashMap::from_iter([("typing", FxHashSet::from_iter(["re"]))]), - &FxHashMap::default() - )); - Ok(()) - } - - #[test] - fn from_alias() -> Result<()> { - let expr = parser::parse_expression("IMatch", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::from_iter([("typing.re", FxHashSet::from_iter(["Match"]))]), - &FxHashMap::from_iter([("IMatch", "Match")]), - )); - Ok(()) - } - - #[test] - fn from_aliased_parent() -> Result<()> { - let expr = parser::parse_expression("t.Match", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::default(), - &FxHashMap::from_iter([("t", "typing.re")]), - )); - Ok(()) - } - - #[test] - fn from_aliased_grandparent() -> Result<()> { - let expr = parser::parse_expression("t.re.Match", "")?; - assert!(match_module_member( - &expr, - "typing.re", - "Match", - &FxHashMap::default(), - &FxHashMap::from_iter([("t", "typing")]), - )); - Ok(()) - } - #[test] fn trailing_content() -> Result<()> { let contents = "x = 1"; diff --git a/src/autofix/helpers.rs b/src/autofix/helpers.rs index 2f3724cc3c..4b864c4d9b 100644 --- a/src/autofix/helpers.rs +++ b/src/autofix/helpers.rs @@ -210,14 +210,17 @@ pub fn remove_unused_imports<'a>( Some(SmallStatement::Import(import_body)) => (&mut import_body.names, None), Some(SmallStatement::ImportFrom(import_body)) => { if let ImportNames::Aliases(names) = &mut import_body.names { - (names, import_body.module.as_ref()) + ( + names, + Some((&import_body.relative, import_body.module.as_ref())), + ) } else if let ImportNames::Star(..) = &import_body.names { // Special-case: if the import is a `from ... import *`, then we delete the // entire statement. let mut found_star = false; for unused_import in unused_imports { let full_name = match import_body.module.as_ref() { - Some(module_name) => format!("{}.*", compose_module_path(module_name),), + Some(module_name) => format!("{}.*", compose_module_path(module_name)), None => "*".to_string(), }; if unused_import == full_name { @@ -246,11 +249,25 @@ pub fn remove_unused_imports<'a>( for unused_import in unused_imports { let alias_index = aliases.iter().position(|alias| { let full_name = match import_module { - Some(module_name) => format!( - "{}.{}", - compose_module_path(module_name), - compose_module_path(&alias.name) - ), + Some((relative, module)) => { + let module = module.map(compose_module_path); + let member = compose_module_path(&alias.name); + let mut full_name = String::with_capacity( + relative.len() + + module.as_ref().map_or(0, std::string::String::len) + + member.len() + + 1, + ); + for _ in 0..relative.len() { + full_name.push('.'); + } + if let Some(module) = module { + full_name.push_str(&module); + full_name.push('.'); + } + full_name.push_str(&member); + full_name + } None => compose_module_path(&alias.name), }; full_name == unused_import diff --git a/src/checkers/ast.rs b/src/checkers/ast.rs index 917c2f8679..77297add5b 100644 --- a/src/checkers/ast.rs +++ b/src/checkers/ast.rs @@ -14,9 +14,7 @@ use rustpython_parser::ast::{ }; use rustpython_parser::parser; -use crate::ast::helpers::{ - binding_range, collect_call_paths, dealias_call_path, extract_handler_names, match_call_path, -}; +use crate::ast::helpers::{binding_range, collect_call_path, extract_handler_names}; use crate::ast::operations::extract_all_names; use crate::ast::relocate::relocate_expr; use crate::ast::types::{ @@ -63,13 +61,10 @@ pub struct Checker<'a> { // Computed diagnostics. pub(crate) diagnostics: Vec, // Function and class definition tracking (e.g., for docstring enforcement). - definitions: Vec<(Definition<'a>, Visibility)>, + definitions: Vec<(Definition<'a>, Visibility, DeferralContext<'a>)>, // Edit tracking. // TODO(charlie): Instead of exposing deletions, wrap in a public API. pub(crate) deletions: FxHashSet>, - // Import tracking. - pub(crate) from_imports: FxHashMap<&'a str, FxHashSet<&'a str>>, - pub(crate) import_aliases: FxHashMap<&'a str, &'a str>, // Retain all scopes and parent nodes, along with a stack of indexes to track which are active // at various points in time. pub(crate) parents: Vec>, @@ -123,8 +118,6 @@ impl<'a> Checker<'a> { diagnostics: vec![], definitions: vec![], deletions: FxHashSet::default(), - from_imports: FxHashMap::default(), - import_aliases: FxHashMap::default(), parents: vec![], depths: FxHashMap::default(), child_to_parent: FxHashMap::default(), @@ -167,28 +160,28 @@ impl<'a> Checker<'a> { /// Return `true` if the `Expr` is a reference to `typing.${target}`. pub fn match_typing_expr(&self, expr: &Expr, target: &str) -> bool { - let call_path = dealias_call_path(collect_call_paths(expr), &self.import_aliases); - self.match_typing_call_path(&call_path, target) + self.resolve_call_path(expr).map_or(false, |call_path| { + self.match_typing_call_path(&call_path, target) + }) } /// Return `true` if the call path is a reference to `typing.${target}`. pub fn match_typing_call_path(&self, call_path: &[&str], target: &str) -> bool { - if match_call_path(call_path, "typing", target, &self.from_imports) { + if call_path == ["typing", target] { return true; } if typing::TYPING_EXTENSIONS.contains(target) { - if match_call_path(call_path, "typing_extensions", target, &self.from_imports) { + if call_path == ["typing_extensions", target] { return true; } } - if self - .settings - .typing_modules - .iter() - .any(|module| match_call_path(call_path, module, target, &self.from_imports)) - { + if self.settings.typing_modules.iter().any(|module| { + let mut module = module.split('.').collect::>(); + module.push(target); + call_path == module.as_slice() + }) { return true; } @@ -209,6 +202,35 @@ impl<'a> Checker<'a> { }) } + pub fn resolve_call_path<'b>(&'a self, value: &'b Expr) -> Option> + where + 'b: 'a, + { + let call_path = collect_call_path(value); + if let Some(head) = call_path.first() { + if let Some(binding) = self.find_binding(head) { + if let BindingKind::Importation(.., name) + | BindingKind::SubmoduleImportation(name, ..) + | BindingKind::FromImportation(.., name) = &binding.kind + { + // Ignore relative imports. + if name.starts_with('.') { + return None; + } + let mut source_path: Vec<&str> = name.split('.').collect(); + source_path.extend(call_path.iter().skip(1)); + return Some(source_path); + } else if let BindingKind::Builtin = &binding.kind { + let mut source_path: Vec<&str> = Vec::with_capacity(call_path.len() + 1); + source_path.push(""); + source_path.extend(call_path); + return Some(source_path); + } + } + } + None + } + /// Return `true` if a `RuleCode` is disabled by a `noqa` directive. pub fn is_ignored(&self, code: &RuleCode, lineno: usize) -> bool { // TODO(charlie): `noqa` directives are mostly enforced in `check_lines.rs`. @@ -415,13 +437,11 @@ where if self.settings.enabled.contains(&RuleCode::N804) { if let Some(diagnostic) = pep8_naming::rules::invalid_first_argument_name_for_class_method( + self, self.current_scope(), name, decorator_list, args, - &self.from_imports, - &self.import_aliases, - &self.settings.pep8_naming, ) { self.diagnostics.push(diagnostic); @@ -431,13 +451,11 @@ where if self.settings.enabled.contains(&RuleCode::N805) { if let Some(diagnostic) = pep8_naming::rules::invalid_first_argument_name_for_method( + self, self.current_scope(), name, decorator_list, args, - &self.from_imports, - &self.import_aliases, - &self.settings.pep8_naming, ) { self.diagnostics.push(diagnostic); @@ -788,12 +806,6 @@ where } if let Some(asname) = &alias.node.asname { - for alias in names { - if let Some(asname) = &alias.node.asname { - self.import_aliases.insert(asname, &alias.node.name); - } - } - let name = alias.node.name.split('.').last().unwrap(); if self.settings.enabled.contains(&RuleCode::N811) { if let Some(diagnostic) = @@ -890,25 +902,6 @@ where module, level, } => { - // Track `import from` statements, to ensure that we can correctly attribute - // references like `from typing import Union`. - if self.settings.enabled.contains(&RuleCode::UP023) { - pyupgrade::rules::replace_c_element_tree(self, stmt); - } - if level.map(|level| level == 0).unwrap_or(true) { - if let Some(module) = module { - self.from_imports - .entry(module) - .or_insert_with(FxHashSet::default) - .extend(names.iter().map(|alias| alias.node.name.as_str())); - } - for alias in names { - if let Some(asname) = &alias.node.asname { - self.import_aliases.insert(asname, &alias.node.name); - } - } - } - if self.settings.enabled.contains(&RuleCode::E402) { if self.seen_import_boundary && stmt.location.column() == 0 { self.diagnostics.push(Diagnostic::new( @@ -926,6 +919,9 @@ where if self.settings.enabled.contains(&RuleCode::UP026) { pyupgrade::rules::rewrite_mock_import(self, stmt); } + if self.settings.enabled.contains(&RuleCode::UP023) { + pyupgrade::rules::replace_c_element_tree(self, stmt); + } if self.settings.enabled.contains(&RuleCode::UP029) { if let Some(module) = module.as_deref() { pyupgrade::rules::unnecessary_builtin_import(self, stmt, module, names); @@ -1057,10 +1053,11 @@ where // be "foo.bar". Given `from foo import bar as baz`, `name` would be "baz" // and `full_name` would be "foo.bar". let name = alias.node.asname.as_ref().unwrap_or(&alias.node.name); - let full_name = match module { - None => alias.node.name.to_string(), - Some(parent) => format!("{parent}.{}", alias.node.name), - }; + let full_name = helpers::format_import_from_member( + level.as_ref(), + module.as_deref(), + &alias.node.name, + ); let range = Range::from_located(alias); self.add_binding( name, @@ -1453,8 +1450,11 @@ where pyupgrade::rules::rewrite_yield_from(self, stmt); } let scope = transition_scope(&self.visible_scope, stmt, &Documentable::Function); - self.definitions - .push((definition, scope.visibility.clone())); + self.definitions.push(( + definition, + scope.visibility.clone(), + (self.scope_stack.clone(), self.parents.clone()), + )); self.visible_scope = scope; // If any global bindings don't already exist in the global scope, add it. @@ -1511,8 +1511,11 @@ where &Documentable::Class, ); let scope = transition_scope(&self.visible_scope, stmt, &Documentable::Class); - self.definitions - .push((definition, scope.visibility.clone())); + self.definitions.push(( + definition, + scope.visibility.clone(), + (self.scope_stack.clone(), self.parents.clone()), + )); self.visible_scope = scope; // If any global bindings don't already exist in the global scope, add it. @@ -1708,13 +1711,9 @@ where && !self.settings.pyupgrade.keep_runtime_typing && self.annotations_future_enabled && self.in_annotation)) - && typing::is_pep585_builtin( - expr, - &self.from_imports, - &self.import_aliases, - ) + && typing::is_pep585_builtin(self, expr) { - pyupgrade::rules::use_pep585_annotation(self, expr, id); + pyupgrade::rules::use_pep585_annotation(self, expr); } self.handle_node_load(expr); @@ -1752,9 +1751,9 @@ where || (self.settings.target_version >= PythonVersion::Py37 && self.annotations_future_enabled && self.in_annotation)) - && typing::is_pep585_builtin(expr, &self.from_imports, &self.import_aliases) + && typing::is_pep585_builtin(self, expr) { - pyupgrade::rules::use_pep585_annotation(self, expr, attr); + pyupgrade::rules::use_pep585_annotation(self, expr); } if self.settings.enabled.contains(&RuleCode::UP016) { @@ -1822,12 +1821,7 @@ where } if self.settings.enabled.contains(&RuleCode::TID251) { - flake8_tidy_imports::rules::banned_attribute_access( - self, - &dealias_call_path(collect_call_paths(expr), &self.import_aliases), - expr, - &self.settings.flake8_tidy_imports.banned_api, - ); + flake8_tidy_imports::rules::banned_attribute_access(self, expr); } } ExprKind::Call { @@ -1976,96 +1970,36 @@ where } } if self.settings.enabled.contains(&RuleCode::S103) { - if let Some(diagnostic) = flake8_bandit::rules::bad_file_permissions( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::bad_file_permissions(self, func, args, keywords); } if self.settings.enabled.contains(&RuleCode::S501) { - if let Some(diagnostic) = flake8_bandit::rules::request_with_no_cert_validation( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::request_with_no_cert_validation( + self, func, args, keywords, + ); } if self.settings.enabled.contains(&RuleCode::S506) { - if let Some(diagnostic) = flake8_bandit::rules::unsafe_yaml_load( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::unsafe_yaml_load(self, func, args, keywords); } if self.settings.enabled.contains(&RuleCode::S508) { - if let Some(diagnostic) = flake8_bandit::rules::snmp_insecure_version( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::snmp_insecure_version(self, func, args, keywords); } if self.settings.enabled.contains(&RuleCode::S509) { - if let Some(diagnostic) = flake8_bandit::rules::snmp_weak_cryptography( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::snmp_weak_cryptography(self, func, args, keywords); } if self.settings.enabled.contains(&RuleCode::S701) { - if let Some(diagnostic) = flake8_bandit::rules::jinja2_autoescape_false( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::jinja2_autoescape_false(self, func, args, keywords); } if self.settings.enabled.contains(&RuleCode::S106) { self.diagnostics .extend(flake8_bandit::rules::hardcoded_password_func_arg(keywords)); } if self.settings.enabled.contains(&RuleCode::S324) { - if let Some(diagnostic) = flake8_bandit::rules::hashlib_insecure_hash_functions( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::hashlib_insecure_hash_functions( + self, func, args, keywords, + ); } if self.settings.enabled.contains(&RuleCode::S113) { - if let Some(diagnostic) = flake8_bandit::rules::request_without_timeout( - func, - args, - keywords, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_bandit::rules::request_without_timeout(self, func, args, keywords); } // flake8-comprehensions @@ -2157,14 +2091,7 @@ where // flake8-debugger if self.settings.enabled.contains(&RuleCode::T100) { - if let Some(diagnostic) = flake8_debugger::rules::debugger_call( - expr, - func, - &self.from_imports, - &self.import_aliases, - ) { - self.diagnostics.push(diagnostic); - } + flake8_debugger::rules::debugger_call(self, expr, func); } // pandas-vet @@ -2747,15 +2674,19 @@ where args, keywords, } => { - let call_path = dealias_call_path(collect_call_paths(func), &self.import_aliases); - if self.match_typing_call_path(&call_path, "ForwardRef") { + let call_path = self.resolve_call_path(func); + if call_path.as_ref().map_or(false, |call_path| { + self.match_typing_call_path(call_path, "ForwardRef") + }) { self.visit_expr(func); for expr in args { self.in_type_definition = true; self.visit_expr(expr); self.in_type_definition = prev_in_type_definition; } - } else if self.match_typing_call_path(&call_path, "cast") { + } else if call_path.as_ref().map_or(false, |call_path| { + self.match_typing_call_path(call_path, "cast") + }) { self.visit_expr(func); if !args.is_empty() { self.in_type_definition = true; @@ -2765,14 +2696,18 @@ where for expr in args.iter().skip(1) { self.visit_expr(expr); } - } else if self.match_typing_call_path(&call_path, "NewType") { + } else if call_path.as_ref().map_or(false, |call_path| { + self.match_typing_call_path(call_path, "NewType") + }) { self.visit_expr(func); for expr in args.iter().skip(1) { self.in_type_definition = true; self.visit_expr(expr); self.in_type_definition = prev_in_type_definition; } - } else if self.match_typing_call_path(&call_path, "TypeVar") { + } else if call_path.as_ref().map_or(false, |call_path| { + self.match_typing_call_path(call_path, "TypeVar") + }) { self.visit_expr(func); for expr in args.iter().skip(1) { self.in_type_definition = true; @@ -2793,7 +2728,9 @@ where } } } - } else if self.match_typing_call_path(&call_path, "NamedTuple") { + } else if call_path.as_ref().map_or(false, |call_path| { + self.match_typing_call_path(call_path, "NamedTuple") + }) { self.visit_expr(func); // Ex) NamedTuple("a", [("a", int)]) @@ -2829,7 +2766,9 @@ where self.visit_expr(value); self.in_type_definition = prev_in_type_definition; } - } else if self.match_typing_call_path(&call_path, "TypedDict") { + } else if call_path.as_ref().map_or(false, |call_path| { + self.match_typing_call_path(call_path, "TypedDict") + }) { self.visit_expr(func); // Ex) TypedDict("a", {"a": int}) @@ -2855,12 +2794,11 @@ where self.visit_expr(value); self.in_type_definition = prev_in_type_definition; } - } else if ["Arg", "DefaultArg", "NamedArg", "DefaultNamedArg"] - .iter() - .any(|target| { - match_call_path(&call_path, "mypy_extensions", target, &self.from_imports) - }) - { + } else if call_path.as_ref().map_or(false, |call_path| { + ["Arg", "DefaultArg", "NamedArg", "DefaultNamedArg"] + .iter() + .any(|target| *call_path == ["mypy_extensions", target]) + }) { self.visit_expr(func); // Ex) DefaultNamedArg(bool | None, name="some_prop_name") @@ -2894,13 +2832,7 @@ where self.in_subscript = true; visitor::walk_expr(self, expr); } else { - match typing::match_annotated_subscript( - value, - &self.from_imports, - &self.import_aliases, - self.settings.typing_modules.iter().map(String::as_str), - |member| self.is_builtin(member), - ) { + match typing::match_annotated_subscript(self, value) { Some(subscript) => { match subscript { // Ex) Optional[int] @@ -3697,6 +3629,7 @@ impl<'a> Checker<'a> { docstring, }, self.visible_scope.visibility.clone(), + (self.scope_stack.clone(), self.parents.clone()), )); docstring.is_some() } @@ -4150,7 +4083,10 @@ impl<'a> Checker<'a> { let mut overloaded_name: Option = None; self.definitions.reverse(); - while let Some((definition, visibility)) = self.definitions.pop() { + while let Some((definition, visibility, (scopes, parents))) = self.definitions.pop() { + self.scope_stack = scopes.clone(); + self.parents = parents.clone(); + // flake8-annotations if enforce_annotations { // TODO(charlie): This should be even stricter, in that an overload @@ -4363,13 +4299,13 @@ pub fn check_ast( let mut allocator = vec![]; checker.check_deferred_string_type_definitions(&mut allocator); + // Check docstrings. + checker.check_definitions(); + // Reset the scope to module-level, and check all consumed scopes. checker.scope_stack = vec![GLOBAL_SCOPE_INDEX]; checker.pop_scope(); checker.check_dead_scopes(); - // Check docstrings. - checker.check_definitions(); - checker.diagnostics } diff --git a/src/flake8_2020/rules.rs b/src/flake8_2020/rules.rs index e9c898a4e6..c1e477a963 100644 --- a/src/flake8_2020/rules.rs +++ b/src/flake8_2020/rules.rs @@ -1,20 +1,15 @@ use num_bigint::BigInt; use rustpython_ast::{Cmpop, Constant, Expr, ExprKind, Located}; -use crate::ast::helpers::match_module_member; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::{Diagnostic, RuleCode}; use crate::violations; fn is_sys(checker: &Checker, expr: &Expr, target: &str) -> bool { - match_module_member( - expr, - "sys", - target, - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(expr) + .map_or(false, |path| path == ["sys", target]) } /// YTT101, YTT102, YTT301, YTT303 @@ -187,13 +182,10 @@ pub fn compare(checker: &mut Checker, left: &Expr, ops: &[Cmpop], comparators: & /// YTT202 pub fn name_or_attribute(checker: &mut Checker, expr: &Expr) { - if match_module_member( - expr, - "six", - "PY3", - &checker.from_imports, - &checker.import_aliases, - ) { + if checker + .resolve_call_path(expr) + .map_or(false, |path| path == ["six", "PY3"]) + { checker.diagnostics.push(Diagnostic::new( violations::SixPY3Referenced, Range::from_located(expr), diff --git a/src/flake8_annotations/mod.rs b/src/flake8_annotations/mod.rs index 0395dc1048..affb89750a 100644 --- a/src/flake8_annotations/mod.rs +++ b/src/flake8_annotations/mod.rs @@ -145,4 +145,22 @@ mod tests { insta::assert_yaml_snapshot!(diagnostics); Ok(()) } + + #[test] + fn allow_nested_overload() -> Result<()> { + let diagnostics = test_path( + Path::new("./resources/test/fixtures/flake8_annotations/allow_nested_overload.py"), + &Settings { + ..Settings::for_rules(vec![ + RuleCode::ANN201, + RuleCode::ANN202, + RuleCode::ANN204, + RuleCode::ANN205, + RuleCode::ANN206, + ]) + }, + )?; + insta::assert_yaml_snapshot!(diagnostics); + Ok(()) + } } diff --git a/src/flake8_annotations/snapshots/ruff__flake8_annotations__tests__allow_nested_overload.snap b/src/flake8_annotations/snapshots/ruff__flake8_annotations__tests__allow_nested_overload.snap new file mode 100644 index 0000000000..d52cf629b5 --- /dev/null +++ b/src/flake8_annotations/snapshots/ruff__flake8_annotations__tests__allow_nested_overload.snap @@ -0,0 +1,6 @@ +--- +source: src/flake8_annotations/mod.rs +expression: diagnostics +--- +[] + diff --git a/src/flake8_bandit/rules/bad_file_permissions.rs b/src/flake8_bandit/rules/bad_file_permissions.rs index 3606e17482..8f78589ea8 100644 --- a/src/flake8_bandit/rules/bad_file_permissions.rs +++ b/src/flake8_bandit/rules/bad_file_permissions.rs @@ -1,10 +1,11 @@ use num_traits::ToPrimitive; use once_cell::sync::Lazy; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::FxHashMap; use rustpython_ast::{Constant, Expr, ExprKind, Keyword, Operator}; -use crate::ast::helpers::{compose_call_path, match_module_member, SimpleCallArgs}; +use crate::ast::helpers::{compose_call_path, SimpleCallArgs}; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; @@ -86,18 +87,20 @@ fn get_int_value(expr: &Expr) -> Option { /// S103 pub fn bad_file_permissions( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - if match_module_member(func, "os", "chmod", from_imports, import_aliases) { +) { + if checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["os", "chmod"]) + { let call_args = SimpleCallArgs::new(args, keywords); if let Some(mode_arg) = call_args.get_argument("mode", Some(1)) { if let Some(int_value) = get_int_value(mode_arg) { if (int_value & WRITE_WORLD > 0) || (int_value & EXECUTE_GROUP > 0) { - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::BadFilePermissions(int_value), Range::from_located(mode_arg), )); @@ -105,5 +108,4 @@ pub fn bad_file_permissions( } } } - None } diff --git a/src/flake8_bandit/rules/hashlib_insecure_hash_functions.rs b/src/flake8_bandit/rules/hashlib_insecure_hash_functions.rs index b0d79d28d0..3f13d881c6 100644 --- a/src/flake8_bandit/rules/hashlib_insecure_hash_functions.rs +++ b/src/flake8_bandit/rules/hashlib_insecure_hash_functions.rs @@ -1,8 +1,8 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Constant, Expr, ExprKind, Keyword}; -use crate::ast::helpers::{match_module_member, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::flake8_bandit::helpers::string_literal; use crate::registry::Diagnostic; use crate::violations; @@ -24,44 +24,45 @@ fn is_used_for_security(call_args: &SimpleCallArgs) -> bool { /// S324 pub fn hashlib_insecure_hash_functions( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - if match_module_member(func, "hashlib", "new", from_imports, import_aliases) { - let call_args = SimpleCallArgs::new(args, keywords); +) { + if let Some(call_path) = checker.resolve_call_path(func) { + if call_path == ["hashlib", "new"] { + let call_args = SimpleCallArgs::new(args, keywords); - if !is_used_for_security(&call_args) { - return None; - } - - if let Some(name_arg) = call_args.get_argument("name", Some(0)) { - let hash_func_name = string_literal(name_arg)?; - - if WEAK_HASHES.contains(&hash_func_name.to_lowercase().as_str()) { - return Some(Diagnostic::new( - violations::HashlibInsecureHashFunction(hash_func_name.to_string()), - Range::from_located(name_arg), - )); + if !is_used_for_security(&call_args) { + return; } - } - } else { - for func_name in &WEAK_HASHES { - if match_module_member(func, "hashlib", func_name, from_imports, import_aliases) { - let call_args = SimpleCallArgs::new(args, keywords); - if !is_used_for_security(&call_args) { - return None; + if let Some(name_arg) = call_args.get_argument("name", Some(0)) { + if let Some(hash_func_name) = string_literal(name_arg) { + if WEAK_HASHES.contains(&hash_func_name.to_lowercase().as_str()) { + checker.diagnostics.push(Diagnostic::new( + violations::HashlibInsecureHashFunction(hash_func_name.to_string()), + Range::from_located(name_arg), + )); + } } + } + } else { + for func_name in &WEAK_HASHES { + if call_path == ["hashlib", func_name] { + let call_args = SimpleCallArgs::new(args, keywords); - return Some(Diagnostic::new( - violations::HashlibInsecureHashFunction((*func_name).to_string()), - Range::from_located(func), - )); + if !is_used_for_security(&call_args) { + return; + } + + checker.diagnostics.push(Diagnostic::new( + violations::HashlibInsecureHashFunction((*func_name).to_string()), + Range::from_located(func), + )); + return; + } } } } - None } diff --git a/src/flake8_bandit/rules/jinja2_autoescape_false.rs b/src/flake8_bandit/rules/jinja2_autoescape_false.rs index 400acd6862..c0385af332 100644 --- a/src/flake8_bandit/rules/jinja2_autoescape_false.rs +++ b/src/flake8_bandit/rules/jinja2_autoescape_false.rs @@ -1,26 +1,23 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, ExprKind, Keyword}; use rustpython_parser::ast::Constant; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; /// S701 pub fn jinja2_autoescape_false( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - if match_call_path( - &dealias_call_path(collect_call_paths(func), import_aliases), - "jinja2", - "Environment", - from_imports, - ) { +) { + if checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["jinja2", "Environment"]) + { let call_args = SimpleCallArgs::new(args, keywords); if let Some(autoescape_arg) = call_args.get_argument("autoescape", None) { @@ -32,26 +29,23 @@ pub fn jinja2_autoescape_false( ExprKind::Call { func, .. } => { if let ExprKind::Name { id, .. } = &func.node { if id.as_str() != "select_autoescape" { - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::Jinja2AutoescapeFalse(true), Range::from_located(autoescape_arg), )); } } } - _ => { - return Some(Diagnostic::new( - violations::Jinja2AutoescapeFalse(true), - Range::from_located(autoescape_arg), - )) - } + _ => checker.diagnostics.push(Diagnostic::new( + violations::Jinja2AutoescapeFalse(true), + Range::from_located(autoescape_arg), + )), } } else { - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::Jinja2AutoescapeFalse(false), Range::from_located(func), )); } } - None } diff --git a/src/flake8_bandit/rules/request_with_no_cert_validation.rs b/src/flake8_bandit/rules/request_with_no_cert_validation.rs index b20f583d5a..70b8851128 100644 --- a/src/flake8_bandit/rules/request_with_no_cert_validation.rs +++ b/src/flake8_bandit/rules/request_with_no_cert_validation.rs @@ -1,9 +1,9 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, ExprKind, Keyword}; use rustpython_parser::ast::Constant; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; @@ -24,47 +24,46 @@ const HTTPX_METHODS: [&str; 11] = [ /// S501 pub fn request_with_no_cert_validation( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - let call_path = dealias_call_path(collect_call_paths(func), import_aliases); - let call_args = SimpleCallArgs::new(args, keywords); - - for func_name in &REQUESTS_HTTP_VERBS { - if match_call_path(&call_path, "requests", func_name, from_imports) { - if let Some(verify_arg) = call_args.get_argument("verify", None) { - if let ExprKind::Constant { - value: Constant::Bool(false), - .. - } = &verify_arg.node - { - return Some(Diagnostic::new( - violations::RequestWithNoCertValidation("requests".to_string()), - Range::from_located(verify_arg), - )); +) { + if let Some(call_path) = checker.resolve_call_path(func) { + let call_args = SimpleCallArgs::new(args, keywords); + for func_name in &REQUESTS_HTTP_VERBS { + if call_path == ["requests", func_name] { + if let Some(verify_arg) = call_args.get_argument("verify", None) { + if let ExprKind::Constant { + value: Constant::Bool(false), + .. + } = &verify_arg.node + { + checker.diagnostics.push(Diagnostic::new( + violations::RequestWithNoCertValidation("requests".to_string()), + Range::from_located(verify_arg), + )); + } } + return; + } + } + for func_name in &HTTPX_METHODS { + if call_path == ["httpx", func_name] { + if let Some(verify_arg) = call_args.get_argument("verify", None) { + if let ExprKind::Constant { + value: Constant::Bool(false), + .. + } = &verify_arg.node + { + checker.diagnostics.push(Diagnostic::new( + violations::RequestWithNoCertValidation("httpx".to_string()), + Range::from_located(verify_arg), + )); + } + } + return; } } } - - for func_name in &HTTPX_METHODS { - if match_call_path(&call_path, "httpx", func_name, from_imports) { - if let Some(verify_arg) = call_args.get_argument("verify", None) { - if let ExprKind::Constant { - value: Constant::Bool(false), - .. - } = &verify_arg.node - { - return Some(Diagnostic::new( - violations::RequestWithNoCertValidation("httpx".to_string()), - Range::from_located(verify_arg), - )); - } - } - } - } - None } diff --git a/src/flake8_bandit/rules/request_without_timeout.rs b/src/flake8_bandit/rules/request_without_timeout.rs index 0d5cabc228..e304656883 100644 --- a/src/flake8_bandit/rules/request_without_timeout.rs +++ b/src/flake8_bandit/rules/request_without_timeout.rs @@ -1,9 +1,9 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, ExprKind, Keyword}; use rustpython_parser::ast::Constant; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; @@ -11,36 +11,35 @@ const HTTP_VERBS: [&str; 7] = ["get", "options", "head", "post", "put", "patch", /// S113 pub fn request_without_timeout( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - let call_path = dealias_call_path(collect_call_paths(func), import_aliases); - for func_name in &HTTP_VERBS { - if match_call_path(&call_path, "requests", func_name, from_imports) { - let call_args = SimpleCallArgs::new(args, keywords); - if let Some(timeout_arg) = call_args.get_argument("timeout", None) { - if let Some(timeout) = match &timeout_arg.node { - ExprKind::Constant { - value: value @ Constant::None, - .. - } => Some(value.to_string()), - _ => None, - } { - return Some(Diagnostic::new( - violations::RequestWithoutTimeout(Some(timeout)), - Range::from_located(timeout_arg), - )); - } - } else { - return Some(Diagnostic::new( - violations::RequestWithoutTimeout(None), - Range::from_located(func), +) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + HTTP_VERBS + .iter() + .any(|func_name| call_path == ["requests", func_name]) + }) { + let call_args = SimpleCallArgs::new(args, keywords); + if let Some(timeout_arg) = call_args.get_argument("timeout", None) { + if let Some(timeout) = match &timeout_arg.node { + ExprKind::Constant { + value: value @ Constant::None, + .. + } => Some(value.to_string()), + _ => None, + } { + checker.diagnostics.push(Diagnostic::new( + violations::RequestWithoutTimeout(Some(timeout)), + Range::from_located(timeout_arg), )); } + } else { + checker.diagnostics.push(Diagnostic::new( + violations::RequestWithoutTimeout(None), + Range::from_located(func), + )); } } - None } diff --git a/src/flake8_bandit/rules/snmp_insecure_version.rs b/src/flake8_bandit/rules/snmp_insecure_version.rs index 272fd3b8f2..f93525b10f 100644 --- a/src/flake8_bandit/rules/snmp_insecure_version.rs +++ b/src/flake8_bandit/rules/snmp_insecure_version.rs @@ -1,26 +1,24 @@ use num_traits::{One, Zero}; -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, ExprKind, Keyword}; use rustpython_parser::ast::Constant; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; /// S508 pub fn snmp_insecure_version( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - let call_path = dealias_call_path(collect_call_paths(func), import_aliases); - - if match_call_path(&call_path, "pysnmp.hlapi", "CommunityData", from_imports) { +) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["pysnmp", "hlapi", "CommunityData"] + }) { let call_args = SimpleCallArgs::new(args, keywords); - if let Some(mp_model_arg) = call_args.get_argument("mpModel", None) { if let ExprKind::Constant { value: Constant::Int(value), @@ -28,7 +26,7 @@ pub fn snmp_insecure_version( } = &mp_model_arg.node { if value.is_zero() || value.is_one() { - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::SnmpInsecureVersion, Range::from_located(mp_model_arg), )); @@ -36,5 +34,4 @@ pub fn snmp_insecure_version( } } } - None } diff --git a/src/flake8_bandit/rules/snmp_weak_cryptography.rs b/src/flake8_bandit/rules/snmp_weak_cryptography.rs index 1111898b10..4fb91f66f7 100644 --- a/src/flake8_bandit/rules/snmp_weak_cryptography.rs +++ b/src/flake8_bandit/rules/snmp_weak_cryptography.rs @@ -1,30 +1,27 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, Keyword}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; /// S509 pub fn snmp_weak_cryptography( + checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - let call_path = dealias_call_path(collect_call_paths(func), import_aliases); - - if match_call_path(&call_path, "pysnmp.hlapi", "UsmUserData", from_imports) { +) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["pysnmp", "hlapi", "UsmUserData"] + }) { let call_args = SimpleCallArgs::new(args, keywords); - if call_args.len() < 3 { - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::SnmpWeakCryptography, Range::from_located(func), )); } } - None } diff --git a/src/flake8_bandit/rules/unsafe_yaml_load.rs b/src/flake8_bandit/rules/unsafe_yaml_load.rs index 0b9827f1c6..7392cb8d6b 100644 --- a/src/flake8_bandit/rules/unsafe_yaml_load.rs +++ b/src/flake8_bandit/rules/unsafe_yaml_load.rs @@ -1,51 +1,40 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, ExprKind, Keyword}; -use crate::ast::helpers::{match_module_member, SimpleCallArgs}; +use crate::ast::helpers::SimpleCallArgs; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; /// S506 -pub fn unsafe_yaml_load( - func: &Expr, - args: &[Expr], - keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - if match_module_member(func, "yaml", "load", from_imports, import_aliases) { +pub fn unsafe_yaml_load(checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword]) { + if checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["yaml", "load"]) + { let call_args = SimpleCallArgs::new(args, keywords); if let Some(loader_arg) = call_args.get_argument("Loader", Some(1)) { - if !match_module_member( - loader_arg, - "yaml", - "SafeLoader", - from_imports, - import_aliases, - ) && !match_module_member( - loader_arg, - "yaml", - "CSafeLoader", - from_imports, - import_aliases, - ) { + if !checker + .resolve_call_path(loader_arg) + .map_or(false, |call_path| { + call_path == ["yaml", "SafeLoader"] || call_path == ["yaml", "CSafeLoader"] + }) + { let loader = match &loader_arg.node { ExprKind::Attribute { attr, .. } => Some(attr.to_string()), ExprKind::Name { id, .. } => Some(id.to_string()), _ => None, }; - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::UnsafeYAMLLoad(loader), Range::from_located(loader_arg), )); } } else { - return Some(Diagnostic::new( + checker.diagnostics.push(Diagnostic::new( violations::UnsafeYAMLLoad(None), Range::from_located(func), )); } } - None } diff --git a/src/flake8_bugbear/rules/abstract_base_class.rs b/src/flake8_bugbear/rules/abstract_base_class.rs index b1023fd079..8fe8b8d267 100644 --- a/src/flake8_bugbear/rules/abstract_base_class.rs +++ b/src/flake8_bugbear/rules/abstract_base_class.rs @@ -1,34 +1,26 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Constant, Expr, ExprKind, Keyword, Stmt, StmtKind}; -use crate::ast::helpers::match_module_member; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::{Diagnostic, RuleCode}; use crate::violations; +use crate::visibility::{is_abstract, is_overload}; -fn is_abc_class( - bases: &[Expr], - keywords: &[Keyword], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { +fn is_abc_class(checker: &Checker, bases: &[Expr], keywords: &[Keyword]) -> bool { keywords.iter().any(|keyword| { keyword .node .arg .as_ref() - .map_or(false, |a| a == "metaclass") - && match_module_member( - &keyword.node.value, - "abc", - "ABCMeta", - from_imports, - import_aliases, - ) - }) || bases - .iter() - .any(|base| match_module_member(base, "abc", "ABC", from_imports, import_aliases)) + .map_or(false, |arg| arg == "metaclass") + && checker + .resolve_call_path(&keyword.node.value) + .map_or(false, |call_path| call_path == ["abc", "ABCMeta"]) + }) || bases.iter().any(|base| { + checker + .resolve_call_path(base) + .map_or(false, |call_path| call_path == ["abc", "ABC"]) + }) } fn is_empty_body(body: &[Stmt]) -> bool { @@ -44,36 +36,6 @@ fn is_empty_body(body: &[Stmt]) -> bool { }) } -fn is_abstractmethod( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - match_module_member(expr, "abc", "abstractmethod", from_imports, import_aliases) -} - -fn is_abstractproperty( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - match_module_member( - expr, - "abc", - "abstractproperty", - from_imports, - import_aliases, - ) -} - -fn is_overload( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - match_module_member(expr, "typing", "overload", from_imports, import_aliases) -} - pub fn abstract_base_class( checker: &mut Checker, stmt: &Stmt, @@ -85,12 +47,7 @@ pub fn abstract_base_class( if bases.len() + keywords.len() != 1 { return; } - if !is_abc_class( - bases, - keywords, - &checker.from_imports, - &checker.import_aliases, - ) { + if !is_abc_class(checker, bases, keywords) { return; } @@ -116,23 +73,14 @@ pub fn abstract_base_class( continue; }; - let has_abstract_decorator = decorator_list.iter().any(|d| { - is_abstractmethod(d, &checker.from_imports, &checker.import_aliases) - || is_abstractproperty(d, &checker.from_imports, &checker.import_aliases) - }); - + let has_abstract_decorator = is_abstract(checker, decorator_list); has_abstract_method |= has_abstract_decorator; if !checker.settings.enabled.contains(&RuleCode::B027) { continue; } - if !has_abstract_decorator - && is_empty_body(body) - && !decorator_list - .iter() - .any(|d| is_overload(d, &checker.from_imports, &checker.import_aliases)) - { + if !has_abstract_decorator && is_empty_body(body) && !is_overload(checker, decorator_list) { checker.diagnostics.push(Diagnostic::new( violations::EmptyMethodWithoutAbstractDecorator(name.to_string()), Range::from_located(stmt), diff --git a/src/flake8_bugbear/rules/assert_raises_exception.rs b/src/flake8_bugbear/rules/assert_raises_exception.rs index 19fced319c..537c8fa19c 100644 --- a/src/flake8_bugbear/rules/assert_raises_exception.rs +++ b/src/flake8_bugbear/rules/assert_raises_exception.rs @@ -1,6 +1,5 @@ use rustpython_ast::{ExprKind, Stmt, Withitem}; -use crate::ast::helpers::match_module_member; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; @@ -24,13 +23,10 @@ pub fn assert_raises_exception(checker: &mut Checker, stmt: &Stmt, items: &[With if !matches!(&func.node, ExprKind::Attribute { attr, .. } if attr == "assertRaises") { return; } - if !match_module_member( - args.first().unwrap(), - "", - "Exception", - &checker.from_imports, - &checker.import_aliases, - ) { + if !checker + .resolve_call_path(args.first().unwrap()) + .map_or(false, |call_path| call_path == ["", "Exception"]) + { return; } diff --git a/src/flake8_bugbear/rules/cached_instance_method.rs b/src/flake8_bugbear/rules/cached_instance_method.rs index d537c32f31..72723c5fe4 100644 --- a/src/flake8_bugbear/rules/cached_instance_method.rs +++ b/src/flake8_bugbear/rules/cached_instance_method.rs @@ -1,15 +1,14 @@ use rustpython_ast::{Expr, ExprKind}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; use crate::ast::types::{Range, ScopeKind}; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; fn is_cache_func(checker: &Checker, expr: &Expr) -> bool { - let call_path = dealias_call_path(collect_call_paths(expr), &checker.import_aliases); - match_call_path(&call_path, "functools", "lru_cache", &checker.from_imports) - || match_call_path(&call_path, "functools", "cache", &checker.from_imports) + checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path == ["functools", "lru_cache"] || call_path == ["functools", "cache"] + }) } /// B019 diff --git a/src/flake8_bugbear/rules/duplicate_exceptions.rs b/src/flake8_bugbear/rules/duplicate_exceptions.rs index 1f0d0e084a..95ea698165 100644 --- a/src/flake8_bugbear/rules/duplicate_exceptions.rs +++ b/src/flake8_bugbear/rules/duplicate_exceptions.rs @@ -30,7 +30,7 @@ fn duplicate_handler_exceptions<'a>( let mut duplicates: FxHashSet> = FxHashSet::default(); let mut unique_elts: Vec<&Expr> = Vec::default(); for type_ in elts { - let call_path = helpers::collect_call_paths(type_); + let call_path = helpers::collect_call_path(type_); if !call_path.is_empty() { if seen.contains_key(&call_path) { duplicates.insert(call_path); @@ -83,7 +83,7 @@ pub fn duplicate_exceptions(checker: &mut Checker, handlers: &[Excepthandler]) { }; match &type_.node { ExprKind::Attribute { .. } | ExprKind::Name { .. } => { - let call_path = helpers::collect_call_paths(type_); + let call_path = helpers::collect_call_path(type_); if !call_path.is_empty() { if seen.contains(&call_path) { duplicates.entry(call_path).or_default().push(type_); diff --git a/src/flake8_bugbear/rules/function_call_argument_default.rs b/src/flake8_bugbear/rules/function_call_argument_default.rs index 20c5d3074f..68781cf6d4 100644 --- a/src/flake8_bugbear/rules/function_call_argument_default.rs +++ b/src/flake8_bugbear/rules/function_call_argument_default.rs @@ -1,9 +1,6 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Arguments, Constant, Expr, ExprKind}; -use crate::ast::helpers::{ - collect_call_paths, compose_call_path, dealias_call_path, match_call_path, to_module_and_member, -}; +use crate::ast::helpers::{compose_call_path, to_call_path}; use crate::ast::types::Range; use crate::ast::visitor; use crate::ast::visitor::Visitor; @@ -12,34 +9,29 @@ use crate::flake8_bugbear::rules::mutable_argument_default::is_mutable_func; use crate::registry::{Diagnostic, DiagnosticKind}; use crate::violations; -const IMMUTABLE_FUNCS: [(&str, &str); 7] = [ - ("", "tuple"), - ("", "frozenset"), - ("operator", "attrgetter"), - ("operator", "itemgetter"), - ("operator", "methodcaller"), - ("types", "MappingProxyType"), - ("re", "compile"), +const IMMUTABLE_FUNCS: &[&[&str]] = &[ + &["", "tuple"], + &["", "frozenset"], + &["operator", "attrgetter"], + &["operator", "itemgetter"], + &["operator", "methodcaller"], + &["types", "MappingProxyType"], + &["re", "compile"], ]; -fn is_immutable_func( - expr: &Expr, - extend_immutable_calls: &[(&str, &str)], - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - IMMUTABLE_FUNCS - .iter() - .chain(extend_immutable_calls) - .any(|(module, member)| match_call_path(&call_path, module, member, from_imports)) +fn is_immutable_func(checker: &Checker, expr: &Expr, extend_immutable_calls: &[Vec<&str>]) -> bool { + checker.resolve_call_path(expr).map_or(false, |call_path| { + IMMUTABLE_FUNCS.iter().any(|target| call_path == *target) + || extend_immutable_calls + .iter() + .any(|target| call_path == *target) + }) } struct ArgumentDefaultVisitor<'a> { + checker: &'a Checker<'a>, diagnostics: Vec<(DiagnosticKind, Range)>, - extend_immutable_calls: &'a [(&'a str, &'a str)], - from_imports: &'a FxHashMap<&'a str, FxHashSet<&'a str>>, - import_aliases: &'a FxHashMap<&'a str, &'a str>, + extend_immutable_calls: Vec>, } impl<'a, 'b> Visitor<'b> for ArgumentDefaultVisitor<'b> @@ -49,13 +41,8 @@ where fn visit_expr(&mut self, expr: &'b Expr) { match &expr.node { ExprKind::Call { func, args, .. } => { - if !is_mutable_func(func, self.from_imports, self.import_aliases) - && !is_immutable_func( - func, - self.extend_immutable_calls, - self.from_imports, - self.import_aliases, - ) + if !is_mutable_func(self.checker, func) + && !is_immutable_func(self.checker, func, &self.extend_immutable_calls) && !is_nan_or_infinity(func, args) { self.diagnostics.push(( @@ -97,27 +84,29 @@ fn is_nan_or_infinity(expr: &Expr, args: &[Expr]) -> bool { /// B008 pub fn function_call_argument_default(checker: &mut Checker, arguments: &Arguments) { // Map immutable calls to (module, member) format. - let extend_immutable_cells: Vec<(&str, &str)> = checker + let extend_immutable_calls: Vec> = checker .settings .flake8_bugbear .extend_immutable_calls .iter() - .map(|target| to_module_and_member(target)) + .map(|target| to_call_path(target)) .collect(); - let mut visitor = ArgumentDefaultVisitor { - diagnostics: vec![], - extend_immutable_calls: &extend_immutable_cells, - from_imports: &checker.from_imports, - import_aliases: &checker.import_aliases, + let diagnostics = { + let mut visitor = ArgumentDefaultVisitor { + checker, + diagnostics: vec![], + extend_immutable_calls, + }; + for expr in arguments + .defaults + .iter() + .chain(arguments.kw_defaults.iter()) + { + visitor.visit_expr(expr); + } + visitor.diagnostics }; - for expr in arguments - .defaults - .iter() - .chain(arguments.kw_defaults.iter()) - { - visitor.visit_expr(expr); - } - for (check, range) in visitor.diagnostics { + for (check, range) in diagnostics { checker.diagnostics.push(Diagnostic::new(check, range)); } } diff --git a/src/flake8_bugbear/rules/mutable_argument_default.rs b/src/flake8_bugbear/rules/mutable_argument_default.rs index 0bdbaa7561..9d25687134 100644 --- a/src/flake8_bugbear/rules/mutable_argument_default.rs +++ b/src/flake8_bugbear/rules/mutable_argument_default.rs @@ -1,79 +1,68 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Arguments, Constant, Expr, ExprKind, Operator}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; -const MUTABLE_FUNCS: &[(&str, &str)] = &[ - ("", "dict"), - ("", "list"), - ("", "set"), - ("collections", "Counter"), - ("collections", "OrderedDict"), - ("collections", "defaultdict"), - ("collections", "deque"), +const MUTABLE_FUNCS: &[&[&str]] = &[ + &["", "dict"], + &["", "list"], + &["", "set"], + &["collections", "Counter"], + &["collections", "OrderedDict"], + &["collections", "defaultdict"], + &["collections", "deque"], ]; -const IMMUTABLE_TYPES: &[(&str, &str)] = &[ - ("", "bool"), - ("", "bytes"), - ("", "complex"), - ("", "float"), - ("", "frozenset"), - ("", "int"), - ("", "object"), - ("", "range"), - ("", "str"), - ("collections.abc", "Sized"), - ("typing", "LiteralString"), - ("typing", "Sized"), +const IMMUTABLE_TYPES: &[&[&str]] = &[ + &["", "bool"], + &["", "bytes"], + &["", "complex"], + &["", "float"], + &["", "frozenset"], + &["", "int"], + &["", "object"], + &["", "range"], + &["", "str"], + &["collections", "abc", "Sized"], + &["typing", "LiteralString"], + &["typing", "Sized"], ]; -const IMMUTABLE_GENERIC_TYPES: &[(&str, &str)] = &[ - ("", "tuple"), - ("collections.abc", "ByteString"), - ("collections.abc", "Collection"), - ("collections.abc", "Container"), - ("collections.abc", "Iterable"), - ("collections.abc", "Mapping"), - ("collections.abc", "Reversible"), - ("collections.abc", "Sequence"), - ("collections.abc", "Set"), - ("typing", "AbstractSet"), - ("typing", "ByteString"), - ("typing", "Callable"), - ("typing", "Collection"), - ("typing", "Container"), - ("typing", "FrozenSet"), - ("typing", "Iterable"), - ("typing", "Literal"), - ("typing", "Mapping"), - ("typing", "Never"), - ("typing", "NoReturn"), - ("typing", "Reversible"), - ("typing", "Sequence"), - ("typing", "Tuple"), +const IMMUTABLE_GENERIC_TYPES: &[&[&str]] = &[ + &["", "tuple"], + &["collections", "abc", "ByteString"], + &["collections", "abc", "Collection"], + &["collections", "abc", "Container"], + &["collections", "abc", "Iterable"], + &["collections", "abc", "Mapping"], + &["collections", "abc", "Reversible"], + &["collections", "abc", "Sequence"], + &["collections", "abc", "Set"], + &["typing", "AbstractSet"], + &["typing", "ByteString"], + &["typing", "Callable"], + &["typing", "Collection"], + &["typing", "Container"], + &["typing", "FrozenSet"], + &["typing", "Iterable"], + &["typing", "Literal"], + &["typing", "Mapping"], + &["typing", "Never"], + &["typing", "NoReturn"], + &["typing", "Reversible"], + &["typing", "Sequence"], + &["typing", "Tuple"], ]; -pub fn is_mutable_func( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - MUTABLE_FUNCS - .iter() - .any(|(module, member)| match_call_path(&call_path, module, member, from_imports)) +pub fn is_mutable_func(checker: &Checker, expr: &Expr) -> bool { + checker.resolve_call_path(expr).map_or(false, |call_path| { + MUTABLE_FUNCS.iter().any(|target| call_path == *target) + }) } -fn is_mutable_expr( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { +fn is_mutable_expr(checker: &Checker, expr: &Expr) -> bool { match &expr.node { ExprKind::List { .. } | ExprKind::Dict { .. } @@ -81,60 +70,53 @@ fn is_mutable_expr( | ExprKind::ListComp { .. } | ExprKind::DictComp { .. } | ExprKind::SetComp { .. } => true, - ExprKind::Call { func, .. } => is_mutable_func(func, from_imports, import_aliases), + ExprKind::Call { func, .. } => is_mutable_func(checker, func), _ => false, } } -fn is_immutable_annotation( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { +fn is_immutable_annotation(checker: &Checker, expr: &Expr) -> bool { match &expr.node { ExprKind::Name { .. } | ExprKind::Attribute { .. } => { - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - IMMUTABLE_TYPES - .iter() - .chain(IMMUTABLE_GENERIC_TYPES) - .any(|(module, member)| match_call_path(&call_path, module, member, from_imports)) + checker.resolve_call_path(expr).map_or(false, |call_path| { + IMMUTABLE_TYPES + .iter() + .chain(IMMUTABLE_GENERIC_TYPES) + .any(|target| call_path == *target) + }) } ExprKind::Subscript { value, slice, .. } => { - let call_path = dealias_call_path(collect_call_paths(value), import_aliases); - if IMMUTABLE_GENERIC_TYPES - .iter() - .any(|(module, member)| match_call_path(&call_path, module, member, from_imports)) - { - true - } else if match_call_path(&call_path, "typing", "Union", from_imports) { - if let ExprKind::Tuple { elts, .. } = &slice.node { - elts.iter() - .all(|elt| is_immutable_annotation(elt, from_imports, import_aliases)) + checker.resolve_call_path(value).map_or(false, |call_path| { + if IMMUTABLE_GENERIC_TYPES + .iter() + .any(|target| call_path == *target) + { + true + } else if call_path == ["typing", "Union"] { + if let ExprKind::Tuple { elts, .. } = &slice.node { + elts.iter().all(|elt| is_immutable_annotation(checker, elt)) + } else { + false + } + } else if call_path == ["typing", "Optional"] { + is_immutable_annotation(checker, slice) + } else if call_path == ["typing", "Annotated"] { + if let ExprKind::Tuple { elts, .. } = &slice.node { + elts.first() + .map_or(false, |elt| is_immutable_annotation(checker, elt)) + } else { + false + } } else { false } - } else if match_call_path(&call_path, "typing", "Optional", from_imports) { - is_immutable_annotation(slice, from_imports, import_aliases) - } else if match_call_path(&call_path, "typing", "Annotated", from_imports) { - if let ExprKind::Tuple { elts, .. } = &slice.node { - elts.first().map_or(false, |elt| { - is_immutable_annotation(elt, from_imports, import_aliases) - }) - } else { - false - } - } else { - false - } + }) } ExprKind::BinOp { left, op: Operator::BitOr, right, - } => { - is_immutable_annotation(left, from_imports, import_aliases) - && is_immutable_annotation(right, from_imports, import_aliases) - } + } => is_immutable_annotation(checker, left) && is_immutable_annotation(checker, right), ExprKind::Constant { value: Constant::None, .. @@ -145,7 +127,7 @@ fn is_immutable_annotation( /// B006 pub fn mutable_argument_default(checker: &mut Checker, arguments: &Arguments) { - // Scan in reverse order to right-align zip() + // Scan in reverse order to right-align zip(). for (arg, default) in arguments .kwonlyargs .iter() @@ -160,10 +142,12 @@ pub fn mutable_argument_default(checker: &mut Checker, arguments: &Arguments) { .zip(arguments.defaults.iter().rev()), ) { - if is_mutable_expr(default, &checker.from_imports, &checker.import_aliases) - && arg.node.annotation.as_ref().map_or(true, |expr| { - !is_immutable_annotation(expr, &checker.from_imports, &checker.import_aliases) - }) + if is_mutable_expr(checker, default) + && !arg + .node + .annotation + .as_ref() + .map_or(false, |expr| is_immutable_annotation(checker, expr)) { checker.diagnostics.push(Diagnostic::new( violations::MutableArgumentDefault, diff --git a/src/flake8_bugbear/rules/useless_contextlib_suppress.rs b/src/flake8_bugbear/rules/useless_contextlib_suppress.rs index 68119353be..cae52017b8 100644 --- a/src/flake8_bugbear/rules/useless_contextlib_suppress.rs +++ b/src/flake8_bugbear/rules/useless_contextlib_suppress.rs @@ -1,6 +1,5 @@ use rustpython_ast::Expr; -use crate::ast::helpers::{collect_call_paths, match_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; @@ -8,12 +7,10 @@ use crate::violations; /// B005 pub fn useless_contextlib_suppress(checker: &mut Checker, expr: &Expr, args: &[Expr]) { - if match_call_path( - &collect_call_paths(expr), - "contextlib", - "suppress", - &checker.from_imports, - ) && args.is_empty() + if args.is_empty() + && checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == ["contextlib", "suppress"]) { checker.diagnostics.push(Diagnostic::new( violations::UselessContextlibSuppress, diff --git a/src/flake8_datetimez/rules.rs b/src/flake8_datetimez/rules.rs index a3ea4e2060..6016cbb664 100644 --- a/src/flake8_datetimez/rules.rs +++ b/src/flake8_datetimez/rules.rs @@ -1,8 +1,6 @@ use rustpython_ast::{Constant, Expr, ExprKind, Keyword}; -use crate::ast::helpers::{ - collect_call_paths, dealias_call_path, has_non_none_keyword, is_const_none, match_call_path, -}; +use crate::ast::helpers::{has_non_none_keyword, is_const_none}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; @@ -15,8 +13,10 @@ pub fn call_datetime_without_tzinfo( keywords: &[Keyword], location: Range, ) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if !match_call_path(&call_path, "datetime", "datetime", &checker.from_imports) { + if !checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["datetime", "datetime"]) + { return; } @@ -40,13 +40,9 @@ pub fn call_datetime_without_tzinfo( /// DTZ002 pub fn call_datetime_today(checker: &mut Checker, func: &Expr, location: Range) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if match_call_path( - &call_path, - "datetime.datetime", - "today", - &checker.from_imports, - ) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "datetime", "today"] + }) { checker .diagnostics .push(Diagnostic::new(violations::CallDatetimeToday, location)); @@ -55,13 +51,9 @@ pub fn call_datetime_today(checker: &mut Checker, func: &Expr, location: Range) /// DTZ003 pub fn call_datetime_utcnow(checker: &mut Checker, func: &Expr, location: Range) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if match_call_path( - &call_path, - "datetime.datetime", - "utcnow", - &checker.from_imports, - ) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "datetime", "utcnow"] + }) { checker .diagnostics .push(Diagnostic::new(violations::CallDatetimeUtcnow, location)); @@ -70,13 +62,9 @@ pub fn call_datetime_utcnow(checker: &mut Checker, func: &Expr, location: Range) /// DTZ004 pub fn call_datetime_utcfromtimestamp(checker: &mut Checker, func: &Expr, location: Range) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if match_call_path( - &call_path, - "datetime.datetime", - "utcfromtimestamp", - &checker.from_imports, - ) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "datetime", "utcfromtimestamp"] + }) { checker.diagnostics.push(Diagnostic::new( violations::CallDatetimeUtcfromtimestamp, location, @@ -92,13 +80,9 @@ pub fn call_datetime_now_without_tzinfo( keywords: &[Keyword], location: Range, ) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if !match_call_path( - &call_path, - "datetime.datetime", - "now", - &checker.from_imports, - ) { + if !checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "datetime", "now"] + }) { return; } @@ -137,13 +121,9 @@ pub fn call_datetime_fromtimestamp( keywords: &[Keyword], location: Range, ) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if !match_call_path( - &call_path, - "datetime.datetime", - "fromtimestamp", - &checker.from_imports, - ) { + if !checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "datetime", "fromtimestamp"] + }) { return; } @@ -181,13 +161,9 @@ pub fn call_datetime_strptime_without_zone( args: &[Expr], location: Range, ) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if !match_call_path( - &call_path, - "datetime.datetime", - "strptime", - &checker.from_imports, - ) { + if !checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "datetime", "strptime"] + }) { return; } @@ -234,8 +210,9 @@ pub fn call_datetime_strptime_without_zone( /// DTZ011 pub fn call_date_today(checker: &mut Checker, func: &Expr, location: Range) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if match_call_path(&call_path, "datetime.date", "today", &checker.from_imports) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "date", "today"] + }) { checker .diagnostics .push(Diagnostic::new(violations::CallDateToday, location)); @@ -244,13 +221,9 @@ pub fn call_date_today(checker: &mut Checker, func: &Expr, location: Range) { /// DTZ012 pub fn call_date_fromtimestamp(checker: &mut Checker, func: &Expr, location: Range) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if match_call_path( - &call_path, - "datetime.date", - "fromtimestamp", - &checker.from_imports, - ) { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path == ["datetime", "date", "fromtimestamp"] + }) { checker .diagnostics .push(Diagnostic::new(violations::CallDateFromtimestamp, location)); diff --git a/src/flake8_debugger/rules.rs b/src/flake8_debugger/rules.rs index d9aef6dd38..29b2c5a31a 100644 --- a/src/flake8_debugger/rules.rs +++ b/src/flake8_debugger/rules.rs @@ -1,42 +1,39 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, Stmt}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; +use crate::ast::helpers::format_call_path; use crate::ast::types::Range; +use crate::checkers::ast::Checker; use crate::flake8_debugger::types::DebuggerUsingType; use crate::registry::Diagnostic; use crate::violations; -const DEBUGGERS: &[(&str, &str)] = &[ - ("pdb", "set_trace"), - ("pudb", "set_trace"), - ("ipdb", "set_trace"), - ("ipdb", "sset_trace"), - ("IPython.terminal.embed", "InteractiveShellEmbed"), - ("IPython.frontend.terminal.embed", "InteractiveShellEmbed"), - ("celery.contrib.rdb", "set_trace"), - ("builtins", "breakpoint"), - ("", "breakpoint"), +const DEBUGGERS: &[&[&str]] = &[ + &["pdb", "set_trace"], + &["pudb", "set_trace"], + &["ipdb", "set_trace"], + &["ipdb", "sset_trace"], + &["IPython", "terminal", "embed", "InteractiveShellEmbed"], + &[ + "IPython", + "frontend", + "terminal", + "embed", + "InteractiveShellEmbed", + ], + &["celery", "contrib", "rdb", "set_trace"], + &["builtins", "breakpoint"], + &["", "breakpoint"], ]; /// Checks for the presence of a debugger call. -pub fn debugger_call( - expr: &Expr, - func: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> Option { - let call_path = dealias_call_path(collect_call_paths(func), import_aliases); - if DEBUGGERS - .iter() - .any(|(module, member)| match_call_path(&call_path, module, member, from_imports)) - { - Some(Diagnostic::new( - violations::Debugger(DebuggerUsingType::Call(call_path.join("."))), - Range::from_located(expr), - )) - } else { - None +pub fn debugger_call(checker: &mut Checker, expr: &Expr, func: &Expr) { + if let Some(call_path) = checker.resolve_call_path(func) { + if DEBUGGERS.iter().any(|target| call_path == *target) { + checker.diagnostics.push(Diagnostic::new( + violations::Debugger(DebuggerUsingType::Call(format_call_path(&call_path))), + Range::from_located(expr), + )); + } } } @@ -49,23 +46,25 @@ pub fn debugger_import(stmt: &Stmt, module: Option<&str>, name: &str) -> Option< } if let Some(module) = module { - if let Some((module_name, member)) = DEBUGGERS - .iter() - .find(|(module_name, member)| module_name == &module && member == &name) - { + let mut call_path = module.split('.').collect::>(); + call_path.push(name); + if DEBUGGERS.iter().any(|target| call_path == **target) { return Some(Diagnostic::new( - violations::Debugger(DebuggerUsingType::Import(format!("{module_name}.{member}"))), + violations::Debugger(DebuggerUsingType::Import(format_call_path(&call_path))), + Range::from_located(stmt), + )); + } + } else { + let parts = name.split('.').collect::>(); + if DEBUGGERS + .iter() + .any(|call_path| call_path[..call_path.len() - 1] == parts) + { + return Some(Diagnostic::new( + violations::Debugger(DebuggerUsingType::Import(name.to_string())), Range::from_located(stmt), )); } - } else if DEBUGGERS - .iter() - .any(|(module_name, ..)| module_name == &name) - { - return Some(Diagnostic::new( - violations::Debugger(DebuggerUsingType::Import(name.to_string())), - Range::from_located(stmt), - )); } None } diff --git a/src/flake8_debugger/snapshots/ruff__flake8_debugger__tests__T100_T100.py.snap b/src/flake8_debugger/snapshots/ruff__flake8_debugger__tests__T100_T100.py.snap index 20ea6a5d5c..1f7b547c10 100644 --- a/src/flake8_debugger/snapshots/ruff__flake8_debugger__tests__T100_T100.py.snap +++ b/src/flake8_debugger/snapshots/ruff__flake8_debugger__tests__T100_T100.py.snap @@ -1,6 +1,6 @@ --- source: src/flake8_debugger/mod.rs -expression: checks +expression: diagnostics --- - kind: Debugger: @@ -70,7 +70,7 @@ expression: checks parent: ~ - kind: Debugger: - Call: breakpoint + Call: builtins.breakpoint location: row: 11 column: 0 @@ -81,7 +81,7 @@ expression: checks parent: ~ - kind: Debugger: - Call: set_trace + Call: pdb.set_trace location: row: 12 column: 0 @@ -92,7 +92,7 @@ expression: checks parent: ~ - kind: Debugger: - Call: set_trace + Call: celery.contrib.rdb.set_trace location: row: 13 column: 0 diff --git a/src/flake8_print/rules/print_call.rs b/src/flake8_print/rules/print_call.rs index eefa7c042c..9e4d2e9810 100644 --- a/src/flake8_print/rules/print_call.rs +++ b/src/flake8_print/rules/print_call.rs @@ -1,7 +1,7 @@ use log::error; use rustpython_ast::{Expr, Keyword, Stmt, StmtKind}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, is_const_none, match_call_path}; +use crate::ast::helpers::is_const_none; use crate::ast::types::Range; use crate::autofix::helpers; use crate::checkers::ast::Checker; @@ -11,8 +11,11 @@ use crate::violations; /// T201, T203 pub fn print_call(checker: &mut Checker, func: &Expr, keywords: &[Keyword]) { let mut diagnostic = { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if match_call_path(&call_path, "", "print", &checker.from_imports) { + let call_path = checker.resolve_call_path(func); + if call_path + .as_ref() + .map_or(false, |call_path| *call_path == ["", "print"]) + { // If the print call has a `file=` argument (that isn't `None`, `"sys.stdout"`, // or `"sys.stderr"`), don't trigger T201. if let Some(keyword) = keywords @@ -20,16 +23,21 @@ pub fn print_call(checker: &mut Checker, func: &Expr, keywords: &[Keyword]) { .find(|keyword| keyword.node.arg.as_ref().map_or(false, |arg| arg == "file")) { if !is_const_none(&keyword.node.value) { - let call_path = collect_call_paths(&keyword.node.value); - if !(match_call_path(&call_path, "sys", "stdout", &checker.from_imports) - || match_call_path(&call_path, "sys", "stderr", &checker.from_imports)) + if checker + .resolve_call_path(&keyword.node.value) + .map_or(true, |call_path| { + call_path != ["sys", "stdout"] && call_path != ["sys", "stderr"] + }) { return; } } } Diagnostic::new(violations::PrintFound, Range::from_located(func)) - } else if match_call_path(&call_path, "pprint", "pprint", &checker.from_imports) { + } else if call_path + .as_ref() + .map_or(false, |call_path| *call_path == ["pprint", "pprint"]) + { Diagnostic::new(violations::PPrintFound, Range::from_located(func)) } else { return; diff --git a/src/flake8_pytest_style/rules/fixture.rs b/src/flake8_pytest_style/rules/fixture.rs index 00f38ed7fe..e49fd306a8 100644 --- a/src/flake8_pytest_style/rules/fixture.rs +++ b/src/flake8_pytest_style/rules/fixture.rs @@ -4,7 +4,7 @@ use super::helpers::{ get_mark_decorators, get_mark_name, is_abstractmethod_decorator, is_pytest_fixture, is_pytest_yield_fixture, keyword_is_literal, }; -use crate::ast::helpers::{collect_arg_names, collect_call_paths}; +use crate::ast::helpers::{collect_arg_names, collect_call_path}; use crate::ast::types::Range; use crate::ast::visitor; use crate::ast::visitor::Visitor; @@ -50,7 +50,7 @@ where } } ExprKind::Call { func, .. } => { - if collect_call_paths(func) == vec!["request", "addfinalizer"] { + if collect_call_path(func) == vec!["request", "addfinalizer"] { self.addfinalizer_call = Some(expr); }; visitor::walk_expr(self, expr); diff --git a/src/flake8_pytest_style/rules/helpers.rs b/src/flake8_pytest_style/rules/helpers.rs index 84c63aa464..56ca1fcfca 100644 --- a/src/flake8_pytest_style/rules/helpers.rs +++ b/src/flake8_pytest_style/rules/helpers.rs @@ -1,7 +1,7 @@ use num_traits::identities::Zero; use rustpython_ast::{Constant, Expr, ExprKind, Keyword}; -use crate::ast::helpers::{collect_call_paths, compose_call_path, match_module_member}; +use crate::ast::helpers::collect_call_path; use crate::checkers::ast::Checker; const ITERABLE_INITIALIZERS: &[&str] = &["dict", "frozenset", "list", "tuple", "set"]; @@ -14,55 +14,40 @@ pub fn get_mark_decorators(decorators: &[Expr]) -> Vec<&Expr> { } pub fn get_mark_name(decorator: &Expr) -> &str { - collect_call_paths(decorator).last().unwrap() + collect_call_path(decorator).last().unwrap() } pub fn is_pytest_fail(call: &Expr, checker: &Checker) -> bool { - match_module_member( - call, - "pytest", - "fail", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(call) + .map_or(false, |call_path| call_path == ["pytest", "fail"]) } pub fn is_pytest_fixture(decorator: &Expr, checker: &Checker) -> bool { - match_module_member( - decorator, - "pytest", - "fixture", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(decorator) + .map_or(false, |call_path| call_path == ["pytest", "fixture"]) } pub fn is_pytest_mark(decorator: &Expr) -> bool { - if let Some(qualname) = compose_call_path(decorator) { - qualname.starts_with("pytest.mark.") + let segments = collect_call_path(decorator); + if segments.len() > 2 { + segments[0] == "pytest" && segments[1] == "mark" } else { false } } pub fn is_pytest_yield_fixture(decorator: &Expr, checker: &Checker) -> bool { - match_module_member( - decorator, - "pytest", - "yield_fixture", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(decorator) + .map_or(false, |call_path| call_path == ["pytest", "yield_fixture"]) } pub fn is_abstractmethod_decorator(decorator: &Expr, checker: &Checker) -> bool { - match_module_member( - decorator, - "abc", - "abstractmethod", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(decorator) + .map_or(false, |call_path| call_path == ["abc", "abstractmethod"]) } /// Check if the expression is a constant that evaluates to false. @@ -108,13 +93,11 @@ pub fn is_falsy_constant(expr: &Expr) -> bool { } pub fn is_pytest_parametrize(decorator: &Expr, checker: &Checker) -> bool { - match_module_member( - decorator, - "pytest.mark", - "parametrize", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(decorator) + .map_or(false, |call_path| { + call_path == ["pytest", "mark", "parametrize"] + }) } pub fn keyword_is_literal(kw: &Keyword, literal: &str) -> bool { diff --git a/src/flake8_pytest_style/rules/raises.rs b/src/flake8_pytest_style/rules/raises.rs index 8690d3d246..ecf05e38f4 100644 --- a/src/flake8_pytest_style/rules/raises.rs +++ b/src/flake8_pytest_style/rules/raises.rs @@ -1,22 +1,16 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Expr, ExprKind, Keyword, Stmt, StmtKind, Withitem}; use super::helpers::is_empty_or_null_string; -use crate::ast::helpers::{ - collect_call_paths, dealias_call_path, match_call_path, match_module_member, - to_module_and_member, -}; +use crate::ast::helpers::{format_call_path, to_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::{Diagnostic, RuleCode}; use crate::violations; -fn is_pytest_raises( - func: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - match_module_member(func, "pytest", "raises", from_imports, import_aliases) +fn is_pytest_raises(checker: &Checker, func: &Expr) -> bool { + checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["pytest", "raises"]) } fn is_non_trivial_with_body(body: &[Stmt]) -> bool { @@ -30,7 +24,7 @@ fn is_non_trivial_with_body(body: &[Stmt]) -> bool { } pub fn raises_call(checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword]) { - if is_pytest_raises(func, &checker.from_imports, &checker.import_aliases) { + if is_pytest_raises(checker, func) { if checker.settings.enabled.contains(&RuleCode::PT010) { if args.is_empty() && keywords.is_empty() { checker.diagnostics.push(Diagnostic::new( @@ -62,9 +56,7 @@ pub fn complex_raises(checker: &mut Checker, stmt: &Stmt, items: &[Withitem], bo let mut is_too_complex = false; let raises_called = items.iter().any(|item| match &item.context_expr.node { - ExprKind::Call { func, .. } => { - is_pytest_raises(func, &checker.from_imports, &checker.import_aliases) - } + ExprKind::Call { func, .. } => is_pytest_raises(checker, func), _ => false, }); @@ -101,26 +93,24 @@ pub fn complex_raises(checker: &mut Checker, stmt: &Stmt, items: &[Withitem], bo /// PT011 fn exception_needs_match(checker: &mut Checker, exception: &Expr) { - let call_path = dealias_call_path(collect_call_paths(exception), &checker.import_aliases); - - let is_broad_exception = checker - .settings - .flake8_pytest_style - .raises_require_match_for - .iter() - .chain( - &checker - .settings - .flake8_pytest_style - .raises_extend_require_match_for, - ) - .map(|target| to_module_and_member(target)) - .any(|(module, member)| match_call_path(&call_path, module, member, &checker.from_imports)); - - if is_broad_exception { - checker.diagnostics.push(Diagnostic::new( - violations::RaisesTooBroad(call_path.join(".")), - Range::from_located(exception), - )); + if let Some(call_path) = checker.resolve_call_path(exception) { + let is_broad_exception = checker + .settings + .flake8_pytest_style + .raises_require_match_for + .iter() + .chain( + &checker + .settings + .flake8_pytest_style + .raises_extend_require_match_for, + ) + .any(|target| call_path == to_call_path(target)); + if is_broad_exception { + checker.diagnostics.push(Diagnostic::new( + violations::RaisesTooBroad(format_call_path(&call_path)), + Range::from_located(exception), + )); + } } } diff --git a/src/flake8_simplify/rules/ast_expr.rs b/src/flake8_simplify/rules/ast_expr.rs index c305dfc2e7..a9db52d4c5 100644 --- a/src/flake8_simplify/rules/ast_expr.rs +++ b/src/flake8_simplify/rules/ast_expr.rs @@ -1,6 +1,6 @@ use rustpython_ast::{Constant, Expr, ExprKind}; -use crate::ast::helpers::{create_expr, match_module_member, unparse_expr}; +use crate::ast::helpers::{create_expr, unparse_expr}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -16,21 +16,9 @@ pub fn use_capital_environment_variables(checker: &mut Checker, expr: &Expr) { } // check `os.environ.get('foo')` and `os.getenv('foo')`` - let is_os_environ_get = match_module_member( - expr, - "os.environ", - "get", - &checker.from_imports, - &checker.import_aliases, - ); - let is_os_getenv = match_module_member( - expr, - "os", - "getenv", - &checker.from_imports, - &checker.import_aliases, - ); - if !(is_os_environ_get || is_os_getenv) { + if !checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path == ["os", "environ", "get"] || call_path == ["os", "getenv"] + }) { return; } diff --git a/src/flake8_simplify/rules/ast_if.rs b/src/flake8_simplify/rules/ast_if.rs index 2e8776c558..d87f4904fa 100644 --- a/src/flake8_simplify/rules/ast_if.rs +++ b/src/flake8_simplify/rules/ast_if.rs @@ -149,25 +149,13 @@ pub fn use_ternary_operator(checker: &mut Checker, stmt: &Stmt, parent: Option<& } // Avoid suggesting ternary for `if sys.version_info >= ...`-style checks. - if contains_call_path( - test, - "sys", - "version_info", - &checker.import_aliases, - &checker.from_imports, - ) { + if contains_call_path(checker, test, &["sys", "version_info"]) { return; } // Avoid suggesting ternary for `if sys.platform.startswith("...")`-style // checks. - if contains_call_path( - test, - "sys", - "platform", - &checker.import_aliases, - &checker.from_imports, - ) { + if contains_call_path(checker, test, &["sys", "platform"]) { return; } diff --git a/src/flake8_simplify/rules/open_file_with_context_handler.rs b/src/flake8_simplify/rules/open_file_with_context_handler.rs index f740ed1c25..b169473529 100644 --- a/src/flake8_simplify/rules/open_file_with_context_handler.rs +++ b/src/flake8_simplify/rules/open_file_with_context_handler.rs @@ -1,7 +1,6 @@ use rustpython_ast::Expr; use rustpython_parser::ast::StmtKind; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; @@ -9,12 +8,10 @@ use crate::violations; /// SIM115 pub fn open_file_with_context_handler(checker: &mut Checker, func: &Expr) { - if match_call_path( - &dealias_call_path(collect_call_paths(func), &checker.import_aliases), - "", - "open", - &checker.from_imports, - ) { + if checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["", "open"]) + { if checker.is_builtin("open") { match checker.current_stmt().node { StmtKind::With { .. } => (), diff --git a/src/flake8_tidy_imports/mod.rs b/src/flake8_tidy_imports/mod.rs index 19f391f27e..2293e91d6b 100644 --- a/src/flake8_tidy_imports/mod.rs +++ b/src/flake8_tidy_imports/mod.rs @@ -74,25 +74,4 @@ mod tests { insta::assert_yaml_snapshot!(diagnostics); Ok(()) } - - #[test] - fn banned_api_false_positives() -> Result<()> { - let diagnostics = test_path( - Path::new("./resources/test/fixtures/flake8_tidy_imports/TID251_false_positives.py"), - &Settings { - flake8_tidy_imports: flake8_tidy_imports::settings::Settings { - banned_api: FxHashMap::from_iter([( - "typing.TypedDict".to_string(), - BannedApi { - msg: "Use typing_extensions.TypedDict instead.".to_string(), - }, - )]), - ..Default::default() - }, - ..Settings::for_rules(vec![RuleCode::TID251]) - }, - )?; - insta::assert_yaml_snapshot!(diagnostics); - Ok(()) - } } diff --git a/src/flake8_tidy_imports/rules.rs b/src/flake8_tidy_imports/rules.rs index b6b60db6d7..a6047133c2 100644 --- a/src/flake8_tidy_imports/rules.rs +++ b/src/flake8_tidy_imports/rules.rs @@ -2,7 +2,6 @@ use rustc_hash::FxHashMap; use rustpython_ast::{Alias, Expr, Located, Stmt}; use super::settings::BannedApi; -use crate::ast::helpers::match_call_path; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::flake8_tidy_imports::settings::Strictness; @@ -75,15 +74,10 @@ pub fn name_or_parent_is_banned( } /// TID251 -pub fn banned_attribute_access( - checker: &mut Checker, - call_path: &[&str], - expr: &Expr, - banned_apis: &FxHashMap, -) { - for (banned_path, ban) in banned_apis { - if let Some((module, member)) = banned_path.rsplit_once('.') { - if match_call_path(call_path, module, member, &checker.from_imports) { +pub fn banned_attribute_access(checker: &mut Checker, expr: &Expr) { + if let Some(call_path) = checker.resolve_call_path(expr) { + for (banned_path, ban) in &checker.settings.flake8_tidy_imports.banned_api { + if call_path == banned_path.split('.').collect::>() { checker.diagnostics.push(Diagnostic::new( violations::BannedApi { name: banned_path.to_string(), diff --git a/src/flake8_tidy_imports/snapshots/ruff__flake8_tidy_imports__tests__banned_api_false_positives.snap b/src/flake8_tidy_imports/snapshots/ruff__flake8_tidy_imports__tests__banned_api_false_positives.snap deleted file mode 100644 index b77f9a16fa..0000000000 --- a/src/flake8_tidy_imports/snapshots/ruff__flake8_tidy_imports__tests__banned_api_false_positives.snap +++ /dev/null @@ -1,41 +0,0 @@ ---- -source: src/flake8_tidy_imports/mod.rs -expression: checks ---- -- kind: - BannedApi: - name: typing.TypedDict - message: Use typing_extensions.TypedDict instead. - location: - row: 2 - column: 7 - end_location: - row: 2 - column: 23 - fix: ~ - parent: ~ -- kind: - BannedApi: - name: typing.TypedDict - message: Use typing_extensions.TypedDict instead. - location: - row: 7 - column: 0 - end_location: - row: 7 - column: 16 - fix: ~ - parent: ~ -- kind: - BannedApi: - name: typing.TypedDict - message: Use typing_extensions.TypedDict instead. - location: - row: 11 - column: 4 - end_location: - row: 11 - column: 20 - fix: ~ - parent: ~ - diff --git a/src/flake8_unused_arguments/rules.rs b/src/flake8_unused_arguments/rules.rs index df5183d130..f3a87ab464 100644 --- a/src/flake8_unused_arguments/rules.rs +++ b/src/flake8_unused_arguments/rules.rs @@ -118,11 +118,10 @@ pub fn unused_arguments( .. }) => { match function_type::classify( + checker, parent, name, decorator_list, - &checker.from_imports, - &checker.import_aliases, &checker.settings.pep8_naming.classmethod_decorators, &checker.settings.pep8_naming.staticmethod_decorators, ) { diff --git a/src/lib.rs b/src/lib.rs index 93319fc0cd..d672c8625e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,8 @@ )] #![forbid(unsafe_code)] +extern crate core; + mod ast; mod autofix; mod cache; diff --git a/src/pep8_naming/helpers.rs b/src/pep8_naming/helpers.rs index ffb23db431..bb49cc4978 100644 --- a/src/pep8_naming/helpers.rs +++ b/src/pep8_naming/helpers.rs @@ -1,8 +1,7 @@ use itertools::Itertools; -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Stmt, StmtKind}; -use crate::ast::helpers::{collect_call_paths, match_call_path}; +use crate::checkers::ast::Checker; use crate::python::string::{is_lower, is_upper}; pub fn is_camelcase(name: &str) -> bool { @@ -23,19 +22,13 @@ pub fn is_acronym(name: &str, asname: &str) -> bool { name.chars().filter(|c| c.is_uppercase()).join("") == asname } -pub fn is_namedtuple_assignment( - stmt: &Stmt, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, -) -> bool { +pub fn is_namedtuple_assignment(checker: &Checker, stmt: &Stmt) -> bool { let StmtKind::Assign { value, .. } = &stmt.node else { return false; }; - match_call_path( - &collect_call_paths(value), - "collections", - "namedtuple", - from_imports, - ) + checker.resolve_call_path(value).map_or(false, |call_path| { + call_path == ["collections", "namedtuple"] + }) } #[cfg(test)] diff --git a/src/pep8_naming/rules.rs b/src/pep8_naming/rules.rs index dafa4bd54a..e2d07bb42d 100644 --- a/src/pep8_naming/rules.rs +++ b/src/pep8_naming/rules.rs @@ -1,4 +1,3 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Arg, Arguments, Expr, ExprKind, Stmt}; use crate::ast::function_type; @@ -6,7 +5,6 @@ use crate::ast::helpers::identifier_range; use crate::ast::types::{Range, Scope, ScopeKind}; use crate::checkers::ast::Checker; use crate::pep8_naming::helpers; -use crate::pep8_naming::settings::Settings; use crate::python::string::{self}; use crate::registry::Diagnostic; use crate::source_code::Locator; @@ -53,23 +51,20 @@ pub fn invalid_argument_name(name: &str, arg: &Arg) -> Option { /// N804 pub fn invalid_first_argument_name_for_class_method( + checker: &Checker, scope: &Scope, name: &str, decorator_list: &[Expr], args: &Arguments, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, - settings: &Settings, ) -> Option { if !matches!( function_type::classify( + checker, scope, name, decorator_list, - from_imports, - import_aliases, - &settings.classmethod_decorators, - &settings.staticmethod_decorators, + &checker.settings.pep8_naming.classmethod_decorators, + &checker.settings.pep8_naming.staticmethod_decorators, ), function_type::FunctionType::ClassMethod ) { @@ -95,23 +90,20 @@ pub fn invalid_first_argument_name_for_class_method( /// N805 pub fn invalid_first_argument_name_for_method( + checker: &Checker, scope: &Scope, name: &str, decorator_list: &[Expr], args: &Arguments, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, - settings: &Settings, ) -> Option { if !matches!( function_type::classify( + checker, scope, name, decorator_list, - from_imports, - import_aliases, - &settings.classmethod_decorators, - &settings.staticmethod_decorators, + &checker.settings.pep8_naming.classmethod_decorators, + &checker.settings.pep8_naming.staticmethod_decorators, ), function_type::FunctionType::Method ) { @@ -134,9 +126,7 @@ pub fn non_lowercase_variable_in_function( stmt: &Stmt, name: &str, ) { - if name.to_lowercase() != name - && !helpers::is_namedtuple_assignment(stmt, &checker.from_imports) - { + if name.to_lowercase() != name && !helpers::is_namedtuple_assignment(checker, stmt) { checker.diagnostics.push(Diagnostic::new( violations::NonLowercaseVariableInFunction(name.to_string()), Range::from_located(expr), @@ -243,9 +233,7 @@ pub fn mixed_case_variable_in_class_scope( stmt: &Stmt, name: &str, ) { - if helpers::is_mixed_case(name) - && !helpers::is_namedtuple_assignment(stmt, &checker.from_imports) - { + if helpers::is_mixed_case(name) && !helpers::is_namedtuple_assignment(checker, stmt) { checker.diagnostics.push(Diagnostic::new( violations::MixedCaseVariableInClassScope(name.to_string()), Range::from_located(expr), @@ -260,9 +248,7 @@ pub fn mixed_case_variable_in_global_scope( stmt: &Stmt, name: &str, ) { - if helpers::is_mixed_case(name) - && !helpers::is_namedtuple_assignment(stmt, &checker.from_imports) - { + if helpers::is_mixed_case(name) && !helpers::is_namedtuple_assignment(checker, stmt) { checker.diagnostics.push(Diagnostic::new( violations::MixedCaseVariableInGlobalScope(name.to_string()), Range::from_located(expr), diff --git a/src/pyflakes/snapshots/ruff__pyflakes__tests__F401_F401_6.py.snap b/src/pyflakes/snapshots/ruff__pyflakes__tests__F401_F401_6.py.snap index d507e44765..56d9566e62 100644 --- a/src/pyflakes/snapshots/ruff__pyflakes__tests__F401_F401_6.py.snap +++ b/src/pyflakes/snapshots/ruff__pyflakes__tests__F401_F401_6.py.snap @@ -1,10 +1,10 @@ --- source: src/pyflakes/mod.rs -expression: checks +expression: diagnostics --- - kind: UnusedImport: - - background.BackgroundTasks + - ".background.BackgroundTasks" - false - false location: @@ -24,7 +24,7 @@ expression: checks parent: ~ - kind: UnusedImport: - - datastructures.UploadFile + - ".datastructures.UploadFile" - false - false location: @@ -48,18 +48,18 @@ expression: checks - false - false location: - row: 17 + row: 16 column: 7 end_location: - row: 17 + row: 16 column: 17 fix: content: "" location: - row: 17 + row: 16 column: 0 end_location: - row: 18 + row: 17 column: 0 parent: ~ - kind: @@ -68,18 +68,18 @@ expression: checks - false - false location: - row: 20 + row: 19 column: 7 end_location: - row: 20 + row: 19 column: 35 fix: content: "" location: - row: 20 + row: 19 column: 0 end_location: - row: 21 + row: 20 column: 0 parent: ~ diff --git a/src/pygrep_hooks/rules/deprecated_log_warn.rs b/src/pygrep_hooks/rules/deprecated_log_warn.rs index 875b347bc1..63d8ee7d3a 100644 --- a/src/pygrep_hooks/rules/deprecated_log_warn.rs +++ b/src/pygrep_hooks/rules/deprecated_log_warn.rs @@ -1,6 +1,5 @@ use rustpython_ast::Expr; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; @@ -8,9 +7,9 @@ use crate::violations; /// PGH002 - deprecated use of logging.warn pub fn deprecated_log_warn(checker: &mut Checker, func: &Expr) { - let call_path = dealias_call_path(collect_call_paths(func), &checker.import_aliases); - if call_path == ["log", "warn"] - || match_call_path(&call_path, "logging", "warn", &checker.from_imports) + if checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["logging", "warn"]) { checker.diagnostics.push(Diagnostic::new( violations::DeprecatedLogWarn, diff --git a/src/pygrep_hooks/snapshots/ruff__pygrep_hooks__tests__PGH002_PGH002_1.py.snap b/src/pygrep_hooks/snapshots/ruff__pygrep_hooks__tests__PGH002_PGH002_1.py.snap index 61c4f1c528..fb3887f292 100644 --- a/src/pygrep_hooks/snapshots/ruff__pygrep_hooks__tests__PGH002_PGH002_1.py.snap +++ b/src/pygrep_hooks/snapshots/ruff__pygrep_hooks__tests__PGH002_PGH002_1.py.snap @@ -1,6 +1,6 @@ --- source: src/pygrep_hooks/mod.rs -expression: checks +expression: diagnostics --- - kind: DeprecatedLogWarn: ~ @@ -19,27 +19,7 @@ expression: checks column: 0 end_location: row: 5 - column: 8 - fix: ~ - parent: ~ -- kind: - DeprecatedLogWarn: ~ - location: - row: 6 - column: 0 - end_location: - row: 6 column: 4 fix: ~ parent: ~ -- kind: - DeprecatedLogWarn: ~ - location: - row: 15 - column: 4 - end_location: - row: 15 - column: 8 - fix: ~ - parent: ~ diff --git a/src/python/typing.rs b/src/python/typing.rs index d372b91205..82407c794a 100644 --- a/src/python/typing.rs +++ b/src/python/typing.rs @@ -1,8 +1,8 @@ use once_cell::sync::Lazy; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::FxHashSet; use rustpython_ast::{Expr, ExprKind}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; +use crate::checkers::ast::Checker; // See: https://pypi.org/project/typing-extensions/ pub static TYPING_EXTENSIONS: Lazy> = Lazy::new(|| { @@ -62,225 +62,190 @@ pub static TYPING_EXTENSIONS: Lazy> = Lazy::new(|| { }); // See: https://docs.python.org/3/library/typing.html -static SUBSCRIPTS: Lazy>> = Lazy::new(|| { - let mut subscripts: FxHashMap<&'static str, Vec<&'static str>> = FxHashMap::default(); - for (module, name) in [ - // builtins - ("", "dict"), - ("", "frozenset"), - ("", "list"), - ("", "set"), - ("", "tuple"), - ("", "type"), - // `collections` - ("collections", "ChainMap"), - ("collections", "Counter"), - ("collections", "OrderedDict"), - ("collections", "defaultdict"), - ("collections", "deque"), - // `collections.abc` - ("collections.abc", "AsyncGenerator"), - ("collections.abc", "AsyncIterable"), - ("collections.abc", "AsyncIterator"), - ("collections.abc", "Awaitable"), - ("collections.abc", "ByteString"), - ("collections.abc", "Callable"), - ("collections.abc", "Collection"), - ("collections.abc", "Container"), - ("collections.abc", "Coroutine"), - ("collections.abc", "Generator"), - ("collections.abc", "ItemsView"), - ("collections.abc", "Iterable"), - ("collections.abc", "Iterator"), - ("collections.abc", "KeysView"), - ("collections.abc", "Mapping"), - ("collections.abc", "MappingView"), - ("collections.abc", "MutableMapping"), - ("collections.abc", "MutableSequence"), - ("collections.abc", "MutableSet"), - ("collections.abc", "Reversible"), - ("collections.abc", "Sequence"), - ("collections.abc", "Set"), - ("collections.abc", "ValuesView"), - // `contextlib` - ("contextlib", "AbstractAsyncContextManager"), - ("contextlib", "AbstractContextManager"), - // `re` - ("re", "Match"), - ("re", "Pattern"), - // `typing` - ("typing", "AbstractSet"), - ("typing", "AsyncContextManager"), - ("typing", "AsyncGenerator"), - ("typing", "AsyncIterator"), - ("typing", "Awaitable"), - ("typing", "BinaryIO"), - ("typing", "ByteString"), - ("typing", "Callable"), - ("typing", "ChainMap"), - ("typing", "ClassVar"), - ("typing", "Collection"), - ("typing", "Concatenate"), - ("typing", "Container"), - ("typing", "ContextManager"), - ("typing", "Coroutine"), - ("typing", "Counter"), - ("typing", "DefaultDict"), - ("typing", "Deque"), - ("typing", "Dict"), - ("typing", "Final"), - ("typing", "FrozenSet"), - ("typing", "Generator"), - ("typing", "Generic"), - ("typing", "IO"), - ("typing", "ItemsView"), - ("typing", "Iterable"), - ("typing", "Iterator"), - ("typing", "KeysView"), - ("typing", "List"), - ("typing", "Mapping"), - ("typing", "Match"), - ("typing", "MutableMapping"), - ("typing", "MutableSequence"), - ("typing", "MutableSet"), - ("typing", "Optional"), - ("typing", "OrderedDict"), - ("typing", "Pattern"), - ("typing", "Reversible"), - ("typing", "Sequence"), - ("typing", "Set"), - ("typing", "TextIO"), - ("typing", "Tuple"), - ("typing", "Type"), - ("typing", "TypeGuard"), - ("typing", "Union"), - ("typing", "Unpack"), - ("typing", "ValuesView"), - // `typing.io` - ("typing.io", "BinaryIO"), - ("typing.io", "IO"), - ("typing.io", "TextIO"), - // `typing.re` - ("typing.re", "Match"), - ("typing.re", "Pattern"), - // `typing_extensions` - ("typing_extensions", "AsyncContextManager"), - ("typing_extensions", "AsyncGenerator"), - ("typing_extensions", "AsyncIterable"), - ("typing_extensions", "AsyncIterator"), - ("typing_extensions", "Awaitable"), - ("typing_extensions", "ChainMap"), - ("typing_extensions", "ClassVar"), - ("typing_extensions", "Concatenate"), - ("typing_extensions", "ContextManager"), - ("typing_extensions", "Coroutine"), - ("typing_extensions", "Counter"), - ("typing_extensions", "DefaultDict"), - ("typing_extensions", "Deque"), - ("typing_extensions", "Type"), - // `weakref` - ("weakref", "WeakKeyDictionary"), - ("weakref", "WeakSet"), - ("weakref", "WeakValueDictionary"), - ] { - subscripts.entry(name).or_default().push(module); - } - subscripts -}); +const SUBSCRIPTS: &[&[&str]] = &[ + // builtins + &["", "dict"], + &["", "frozenset"], + &["", "list"], + &["", "set"], + &["", "tuple"], + &["", "type"], + // `collections` + &["collections", "ChainMap"], + &["collections", "Counter"], + &["collections", "OrderedDict"], + &["collections", "defaultdict"], + &["collections", "deque"], + // `collections.abc` + &["collections", "abc", "AsyncGenerator"], + &["collections", "abc", "AsyncIterable"], + &["collections", "abc", "AsyncIterator"], + &["collections", "abc", "Awaitable"], + &["collections", "abc", "ByteString"], + &["collections", "abc", "Callable"], + &["collections", "abc", "Collection"], + &["collections", "abc", "Container"], + &["collections", "abc", "Coroutine"], + &["collections", "abc", "Generator"], + &["collections", "abc", "ItemsView"], + &["collections", "abc", "Iterable"], + &["collections", "abc", "Iterator"], + &["collections", "abc", "KeysView"], + &["collections", "abc", "Mapping"], + &["collections", "abc", "MappingView"], + &["collections", "abc", "MutableMapping"], + &["collections", "abc", "MutableSequence"], + &["collections", "abc", "MutableSet"], + &["collections", "abc", "Reversible"], + &["collections", "abc", "Sequence"], + &["collections", "abc", "Set"], + &["collections", "abc", "ValuesView"], + // `contextlib` + &["contextlib", "AbstractAsyncContextManager"], + &["contextlib", "AbstractContextManager"], + // `re` + &["re", "Match"], + &["re", "Pattern"], + // `typing` + &["typing", "AbstractSet"], + &["typing", "AsyncContextManager"], + &["typing", "AsyncGenerator"], + &["typing", "AsyncIterator"], + &["typing", "Awaitable"], + &["typing", "BinaryIO"], + &["typing", "ByteString"], + &["typing", "Callable"], + &["typing", "ChainMap"], + &["typing", "ClassVar"], + &["typing", "Collection"], + &["typing", "Concatenate"], + &["typing", "Container"], + &["typing", "ContextManager"], + &["typing", "Coroutine"], + &["typing", "Counter"], + &["typing", "DefaultDict"], + &["typing", "Deque"], + &["typing", "Dict"], + &["typing", "Final"], + &["typing", "FrozenSet"], + &["typing", "Generator"], + &["typing", "Generic"], + &["typing", "IO"], + &["typing", "ItemsView"], + &["typing", "Iterable"], + &["typing", "Iterator"], + &["typing", "KeysView"], + &["typing", "List"], + &["typing", "Mapping"], + &["typing", "Match"], + &["typing", "MutableMapping"], + &["typing", "MutableSequence"], + &["typing", "MutableSet"], + &["typing", "Optional"], + &["typing", "OrderedDict"], + &["typing", "Pattern"], + &["typing", "Reversible"], + &["typing", "Sequence"], + &["typing", "Set"], + &["typing", "TextIO"], + &["typing", "Tuple"], + &["typing", "Type"], + &["typing", "TypeGuard"], + &["typing", "Union"], + &["typing", "Unpack"], + &["typing", "ValuesView"], + // `typing.io` + &["typing", "io", "BinaryIO"], + &["typing", "io", "IO"], + &["typing", "io", "TextIO"], + // `typing.re` + &["typing", "re", "Match"], + &["typing", "re", "Pattern"], + // `typing_extensions` + &["typing_extensions", "AsyncContextManager"], + &["typing_extensions", "AsyncGenerator"], + &["typing_extensions", "AsyncIterable"], + &["typing_extensions", "AsyncIterator"], + &["typing_extensions", "Awaitable"], + &["typing_extensions", "ChainMap"], + &["typing_extensions", "ClassVar"], + &["typing_extensions", "Concatenate"], + &["typing_extensions", "ContextManager"], + &["typing_extensions", "Coroutine"], + &["typing_extensions", "Counter"], + &["typing_extensions", "DefaultDict"], + &["typing_extensions", "Deque"], + &["typing_extensions", "Type"], + // `weakref` + &["weakref", "WeakKeyDictionary"], + &["weakref", "WeakSet"], + &["weakref", "WeakValueDictionary"], +]; // See: https://docs.python.org/3/library/typing.html -static PEP_593_SUBSCRIPTS: Lazy>> = Lazy::new(|| { - let mut subscripts: FxHashMap<&'static str, Vec<&'static str>> = FxHashMap::default(); - for (module, name) in [ - // `typing` - ("typing", "Annotated"), - // `typing_extensions` - ("typing_extensions", "Annotated"), - ] { - subscripts.entry(name).or_default().push(module); - } - subscripts -}); +const PEP_593_SUBSCRIPTS: &[&[&str]] = &[ + // `typing` + &["typing", "Annotated"], + // `typing_extensions` + &["typing_extensions", "Annotated"], +]; pub enum SubscriptKind { AnnotatedSubscript, PEP593AnnotatedSubscript, } -pub fn match_annotated_subscript<'a, F>( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, - typing_modules: impl Iterator, - is_builtin: F, -) -> Option -where - F: Fn(&str) -> bool, -{ +pub fn match_annotated_subscript(checker: &Checker, expr: &Expr) -> Option { if !matches!( expr.node, ExprKind::Name { .. } | ExprKind::Attribute { .. } ) { return None; } - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - if let Some(member) = call_path.last() { - if let Some(modules) = SUBSCRIPTS.get(member) { - for module in modules { - if match_call_path(&call_path, module, member, from_imports) - && (!module.is_empty() || is_builtin(member)) - { - return Some(SubscriptKind::AnnotatedSubscript); + + checker.resolve_call_path(expr).and_then(|call_path| { + if SUBSCRIPTS.contains(&call_path.as_slice()) { + return Some(SubscriptKind::AnnotatedSubscript); + } + if PEP_593_SUBSCRIPTS.contains(&call_path.as_slice()) { + return Some(SubscriptKind::PEP593AnnotatedSubscript); + } + + for module in &checker.settings.typing_modules { + let module_call_path = module.split('.').collect::>(); + if call_path.starts_with(&module_call_path) { + for subscript in SUBSCRIPTS.iter() { + if call_path.last() == subscript.last() { + return Some(SubscriptKind::AnnotatedSubscript); + } } - } - for module in typing_modules { - if match_call_path(&call_path, module, member, from_imports) { - return Some(SubscriptKind::AnnotatedSubscript); - } - } - } else if let Some(modules) = PEP_593_SUBSCRIPTS.get(member) { - for module in modules { - if match_call_path(&call_path, module, member, from_imports) - && (!module.is_empty() || is_builtin(member)) - { - return Some(SubscriptKind::PEP593AnnotatedSubscript); - } - } - for module in typing_modules { - if match_call_path(&call_path, module, member, from_imports) { - return Some(SubscriptKind::PEP593AnnotatedSubscript); + for subscript in PEP_593_SUBSCRIPTS.iter() { + if call_path.last() == subscript.last() { + return Some(SubscriptKind::PEP593AnnotatedSubscript); + } } } } - } - None + + None + }) } // See: https://peps.python.org/pep-0585/ -const PEP_585_BUILTINS_ELIGIBLE: &[(&str, &str)] = &[ - ("typing", "Dict"), - ("typing", "FrozenSet"), - ("typing", "List"), - ("typing", "Set"), - ("typing", "Tuple"), - ("typing", "Type"), - ("typing_extensions", "Type"), +const PEP_585_BUILTINS_ELIGIBLE: &[&[&str]] = &[ + &["typing", "Dict"], + &["typing", "FrozenSet"], + &["typing", "List"], + &["typing", "Set"], + &["typing", "Tuple"], + &["typing", "Type"], + &["typing_extensions", "Type"], ]; /// Returns `true` if `Expr` represents a reference to a typing object with a /// PEP 585 built-in. -pub fn is_pep585_builtin( - expr: &Expr, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, -) -> bool { - let call_path = dealias_call_path(collect_call_paths(expr), import_aliases); - if !call_path.is_empty() { - for (module, member) in PEP_585_BUILTINS_ELIGIBLE { - if match_call_path(&call_path, module, member, from_imports) { - return true; - } - } - } - false +pub fn is_pep585_builtin(checker: &Checker, expr: &Expr) -> bool { + checker.resolve_call_path(expr).map_or(false, |call_path| { + PEP_585_BUILTINS_ELIGIBLE.contains(&call_path.as_slice()) + }) } diff --git a/src/pyupgrade/rules/convert_named_tuple_functional_to_class.rs b/src/pyupgrade/rules/convert_named_tuple_functional_to_class.rs index 2985c61c93..9952d6f52d 100644 --- a/src/pyupgrade/rules/convert_named_tuple_functional_to_class.rs +++ b/src/pyupgrade/rules/convert_named_tuple_functional_to_class.rs @@ -2,7 +2,7 @@ use anyhow::{bail, Result}; use log::debug; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Keyword, Stmt, StmtKind}; -use crate::ast::helpers::{create_expr, create_stmt, match_module_member, unparse_stmt}; +use crate::ast::helpers::{create_expr, create_stmt, unparse_stmt}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -29,13 +29,10 @@ fn match_named_tuple_assign<'a>( } = &value.node else { return None; }; - if !match_module_member( - func, - "typing", - "NamedTuple", - &checker.from_imports, - &checker.import_aliases, - ) { + if !checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["typing", "NamedTuple"]) + { return None; } Some((typename, args, keywords, func)) diff --git a/src/pyupgrade/rules/convert_typed_dict_functional_to_class.rs b/src/pyupgrade/rules/convert_typed_dict_functional_to_class.rs index 156c5ea84f..9e2745a856 100644 --- a/src/pyupgrade/rules/convert_typed_dict_functional_to_class.rs +++ b/src/pyupgrade/rules/convert_typed_dict_functional_to_class.rs @@ -2,7 +2,7 @@ use anyhow::{bail, Result}; use log::debug; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Keyword, Stmt, StmtKind}; -use crate::ast::helpers::{create_expr, create_stmt, match_module_member, unparse_stmt}; +use crate::ast::helpers::{create_expr, create_stmt, unparse_stmt}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -30,13 +30,10 @@ fn match_typed_dict_assign<'a>( } = &value.node else { return None; }; - if !match_module_member( - func, - "typing", - "TypedDict", - &checker.from_imports, - &checker.import_aliases, - ) { + if !checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["typing", "TypedDict"]) + { return None; } Some((class_name, args, keywords, func)) diff --git a/src/pyupgrade/rules/datetime_utc_alias.rs b/src/pyupgrade/rules/datetime_utc_alias.rs index 6e231665ec..11ace9f879 100644 --- a/src/pyupgrade/rules/datetime_utc_alias.rs +++ b/src/pyupgrade/rules/datetime_utc_alias.rs @@ -1,6 +1,6 @@ use rustpython_ast::Expr; -use crate::ast::helpers::{collect_call_paths, compose_call_path, dealias_call_path}; +use crate::ast::helpers::collect_call_path; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -9,18 +9,22 @@ use crate::violations; /// UP017 pub fn datetime_utc_alias(checker: &mut Checker, expr: &Expr) { - let dealiased_call_path = dealias_call_path(collect_call_paths(expr), &checker.import_aliases); - if dealiased_call_path == ["datetime", "timezone", "utc"] { - let mut diagnostic = - Diagnostic::new(violations::DatetimeTimezoneUTC, Range::from_located(expr)); + if checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path == ["datetime", "timezone", "utc"] + }) { + let straight_import = collect_call_path(expr) == ["datetime", "timezone", "utc"]; + let mut diagnostic = Diagnostic::new( + violations::DatetimeTimezoneUTC { straight_import }, + Range::from_located(expr), + ); if checker.patch(&RuleCode::UP017) { - diagnostic.amend(Fix::replacement( - compose_call_path(expr) - .unwrap() - .replace("timezone.utc", "UTC"), - expr.location, - expr.end_location.unwrap(), - )); + if straight_import { + diagnostic.amend(Fix::replacement( + "datetime.UTC".to_string(), + expr.location, + expr.end_location.unwrap(), + )); + } } checker.diagnostics.push(diagnostic); } diff --git a/src/pyupgrade/rules/open_alias.rs b/src/pyupgrade/rules/open_alias.rs index 7a2ae745a5..0ebc6948f3 100644 --- a/src/pyupgrade/rules/open_alias.rs +++ b/src/pyupgrade/rules/open_alias.rs @@ -1,6 +1,5 @@ use rustpython_ast::Expr; -use crate::ast::helpers::{collect_call_paths, dealias_call_path, match_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -9,9 +8,10 @@ use crate::violations; /// UP020 pub fn open_alias(checker: &mut Checker, expr: &Expr, func: &Expr) { - let call_path = dealias_call_path(collect_call_paths(expr), &checker.import_aliases); - - if match_call_path(&call_path, "io", "open", &checker.from_imports) { + if checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["io", "open"]) + { let mut diagnostic = Diagnostic::new(violations::OpenAlias, Range::from_located(expr)); if checker.patch(&RuleCode::UP020) { diagnostic.amend(Fix::replacement( diff --git a/src/pyupgrade/rules/os_error_alias.rs b/src/pyupgrade/rules/os_error_alias.rs index 59f7ee4136..bc33b89326 100644 --- a/src/pyupgrade/rules/os_error_alias.rs +++ b/src/pyupgrade/rules/os_error_alias.rs @@ -1,7 +1,7 @@ use itertools::Itertools; use rustpython_ast::{Excepthandler, ExcepthandlerKind, Expr, ExprKind, Located}; -use crate::ast::helpers::{compose_call_path, match_module_member}; +use crate::ast::helpers::compose_call_path; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -37,17 +37,13 @@ fn get_before_replace(elts: &[Expr]) -> Vec { fn check_module(checker: &Checker, expr: &Expr) -> (Vec, Vec) { let mut replacements: Vec = vec![]; let mut before_replace: Vec = vec![]; - for module in ERROR_MODULES.iter() { - if match_module_member( - expr, - module, - "error", - &checker.from_imports, - &checker.import_aliases, - ) { - replacements.push("OSError".to_string()); - before_replace.push(format!("{module}.error")); - break; + if let Some(call_path) = checker.resolve_call_path(expr) { + for module in ERROR_MODULES.iter() { + if call_path == [module, "error"] { + replacements.push("OSError".to_string()); + before_replace.push(format!("{module}.error")); + break; + } } } (replacements, before_replace) @@ -140,17 +136,6 @@ fn handle_making_changes( replacements: &[String], ) { if before_replace != replacements && !replacements.is_empty() { - let range = Range::new(target.location, target.end_location.unwrap()); - let contents = checker.locator.slice_source_code_range(&range); - // Pyyupgrade does not want imports changed if a module only is - // surrounded by parentheses. For example: `except mmap.error:` - // would be changed, but: `(mmap).error:` would not. One issue with - // this implementation is that any valid changes will also be - // ignored. Let me know if you want me to go with a more - // complicated solution that avoids this. - if contents.contains(").") { - return; - } let mut final_str: String; if replacements.len() == 1 { final_str = replacements.get(0).unwrap().to_string(); @@ -159,13 +144,15 @@ fn handle_making_changes( final_str.insert(0, '('); final_str.push(')'); } - let mut diagnostic = - Diagnostic::new(violations::OSErrorAlias(compose_call_path(target)), range); + let mut diagnostic = Diagnostic::new( + violations::OSErrorAlias(compose_call_path(target)), + Range::from_located(target), + ); if checker.patch(diagnostic.kind.code()) { diagnostic.amend(Fix::replacement( final_str, - range.location, - range.end_location, + target.location, + target.end_location.unwrap(), )); } checker.diagnostics.push(diagnostic); diff --git a/src/pyupgrade/rules/remove_six_compat.rs b/src/pyupgrade/rules/remove_six_compat.rs index bbea35a6c3..a4a85330e5 100644 --- a/src/pyupgrade/rules/remove_six_compat.rs +++ b/src/pyupgrade/rules/remove_six_compat.rs @@ -1,6 +1,6 @@ use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Keyword, StmtKind}; -use crate::ast::helpers::{collect_call_paths, create_expr, create_stmt, dealias_call_path}; +use crate::ast::helpers::{create_expr, create_stmt}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -364,8 +364,10 @@ fn handle_next_on_six_dict(expr: &Expr, patch: bool, checker: &Checker) -> Optio return None; } let [arg] = &args[..] else { return None; }; - let call_path = dealias_call_path(collect_call_paths(arg), &checker.import_aliases); - if !is_module_member(&call_path, "six") { + if !checker + .resolve_call_path(arg) + .map_or(false, |call_path| is_module_member(&call_path, "six")) + { return None; } let ExprKind::Call { func, args, .. } = &arg.node else {return None;}; @@ -409,8 +411,10 @@ pub fn remove_six_compat(checker: &mut Checker, expr: &Expr) { return; } - let call_path = dealias_call_path(collect_call_paths(expr), &checker.import_aliases); - if is_module_member(&call_path, "six") { + if checker + .resolve_call_path(expr) + .map_or(false, |call_path| is_module_member(&call_path, "six")) + { let patch = checker.patch(&RuleCode::UP016); let diagnostic = match &expr.node { ExprKind::Call { diff --git a/src/pyupgrade/rules/replace_stdout_stderr.rs b/src/pyupgrade/rules/replace_stdout_stderr.rs index 33bfe3cf88..a010ef9d3a 100644 --- a/src/pyupgrade/rules/replace_stdout_stderr.rs +++ b/src/pyupgrade/rules/replace_stdout_stderr.rs @@ -1,6 +1,6 @@ use rustpython_ast::{Expr, Keyword}; -use crate::ast::helpers::{find_keyword, match_module_member}; +use crate::ast::helpers::find_keyword; use crate::ast::types::Range; use crate::ast::whitespace::indentation; use crate::checkers::ast::Checker; @@ -43,13 +43,10 @@ fn extract_middle(contents: &str) -> Option { /// UP022 pub fn replace_stdout_stderr(checker: &mut Checker, expr: &Expr, kwargs: &[Keyword]) { - if match_module_member( - expr, - "subprocess", - "run", - &checker.from_imports, - &checker.import_aliases, - ) { + if checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == ["subprocess", "run"]) + { // Find `stdout` and `stderr` kwargs. let Some(stdout) = find_keyword(kwargs, "stdout") else { return; @@ -59,19 +56,13 @@ pub fn replace_stdout_stderr(checker: &mut Checker, expr: &Expr, kwargs: &[Keywo }; // Verify that they're both set to `subprocess.PIPE`. - if !match_module_member( - &stdout.node.value, - "subprocess", - "PIPE", - &checker.from_imports, - &checker.import_aliases, - ) || !match_module_member( - &stderr.node.value, - "subprocess", - "PIPE", - &checker.from_imports, - &checker.import_aliases, - ) { + if !checker + .resolve_call_path(&stdout.node.value) + .map_or(false, |call_path| call_path == ["subprocess", "PIPE"]) + || !checker + .resolve_call_path(&stderr.node.value) + .map_or(false, |call_path| call_path == ["subprocess", "PIPE"]) + { return; } diff --git a/src/pyupgrade/rules/replace_universal_newlines.rs b/src/pyupgrade/rules/replace_universal_newlines.rs index bc797d6f04..a65d03fb44 100644 --- a/src/pyupgrade/rules/replace_universal_newlines.rs +++ b/src/pyupgrade/rules/replace_universal_newlines.rs @@ -1,6 +1,6 @@ use rustpython_ast::{Expr, Keyword, Location}; -use crate::ast::helpers::{find_keyword, match_module_member}; +use crate::ast::helpers::find_keyword; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -9,13 +9,10 @@ use crate::violations; /// UP021 pub fn replace_universal_newlines(checker: &mut Checker, expr: &Expr, kwargs: &[Keyword]) { - if match_module_member( - expr, - "subprocess", - "run", - &checker.from_imports, - &checker.import_aliases, - ) { + if checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == ["subprocess", "run"]) + { let Some(kwarg) = find_keyword(kwargs, "universal_newlines") else { return; }; let range = Range::new( kwarg.location, diff --git a/src/pyupgrade/rules/rewrite_mock_import.rs b/src/pyupgrade/rules/rewrite_mock_import.rs index 5273a6c6f1..0141d53483 100644 --- a/src/pyupgrade/rules/rewrite_mock_import.rs +++ b/src/pyupgrade/rules/rewrite_mock_import.rs @@ -6,7 +6,7 @@ use libcst_native::{ use log::error; use rustpython_ast::{Expr, ExprKind, Stmt, StmtKind}; -use crate::ast::helpers::collect_call_paths; +use crate::ast::helpers::collect_call_path; use crate::ast::types::Range; use crate::ast::whitespace::indentation; use crate::checkers::ast::Checker; @@ -199,7 +199,7 @@ fn format_import_from( /// UP026 pub fn rewrite_mock_attribute(checker: &mut Checker, expr: &Expr) { if let ExprKind::Attribute { value, .. } = &expr.node { - if collect_call_paths(value) == ["mock", "mock"] { + if collect_call_path(value) == ["mock", "mock"] { let mut diagnostic = Diagnostic::new( violations::RewriteMockImport(MockReference::Attribute), Range::from_located(value), diff --git a/src/pyupgrade/rules/typing_text_str_alias.rs b/src/pyupgrade/rules/typing_text_str_alias.rs index 95eb059fab..3e61dfb4ef 100644 --- a/src/pyupgrade/rules/typing_text_str_alias.rs +++ b/src/pyupgrade/rules/typing_text_str_alias.rs @@ -1,6 +1,5 @@ use rustpython_ast::Expr; -use crate::ast::helpers::match_module_member; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -9,13 +8,10 @@ use crate::violations; /// UP019 pub fn typing_text_str_alias(checker: &mut Checker, expr: &Expr) { - if match_module_member( - expr, - "typing", - "Text", - &checker.from_imports, - &checker.import_aliases, - ) { + if checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == ["typing", "Text"]) + { let mut diagnostic = Diagnostic::new(violations::TypingTextStrAlias, Range::from_located(expr)); if checker.patch(diagnostic.kind.code()) { diff --git a/src/pyupgrade/rules/unnecessary_lru_cache_params.rs b/src/pyupgrade/rules/unnecessary_lru_cache_params.rs index 8c91f0a161..37abeae97b 100644 --- a/src/pyupgrade/rules/unnecessary_lru_cache_params.rs +++ b/src/pyupgrade/rules/unnecessary_lru_cache_params.rs @@ -1,8 +1,6 @@ -use rustc_hash::{FxHashMap, FxHashSet}; use rustpython_ast::{Constant, ExprKind, KeywordData}; use rustpython_parser::ast::Expr; -use crate::ast::helpers; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -11,10 +9,9 @@ use crate::settings::types::PythonVersion; use crate::violations; fn rule( + checker: &Checker, decorator_list: &[Expr], target_version: PythonVersion, - from_imports: &FxHashMap<&str, FxHashSet<&str>>, - import_aliases: &FxHashMap<&str, &str>, ) -> Option { for expr in decorator_list.iter() { let ExprKind::Call { @@ -27,13 +24,9 @@ fn rule( }; if !(args.is_empty() - && helpers::match_module_member( - func, - "functools", - "lru_cache", - from_imports, - import_aliases, - )) + && checker + .resolve_call_path(func) + .map_or(false, |call_path| call_path == ["functools", "lru_cache"])) { continue; } @@ -74,10 +67,9 @@ fn rule( /// UP011 pub fn unnecessary_lru_cache_params(checker: &mut Checker, decorator_list: &[Expr]) { let Some(mut diagnostic) = rule( + checker, decorator_list, checker.settings.target_version, - &checker.from_imports, - &checker.import_aliases, ) else { return; }; diff --git a/src/pyupgrade/rules/use_pep585_annotation.rs b/src/pyupgrade/rules/use_pep585_annotation.rs index b26f532e83..58f05a3402 100644 --- a/src/pyupgrade/rules/use_pep585_annotation.rs +++ b/src/pyupgrade/rules/use_pep585_annotation.rs @@ -7,18 +7,19 @@ use crate::registry::Diagnostic; use crate::violations; /// UP006 -pub fn use_pep585_annotation(checker: &mut Checker, expr: &Expr, id: &str) { - let replacement = *checker.import_aliases.get(id).unwrap_or(&id); - let mut diagnostic = Diagnostic::new( - violations::UsePEP585Annotation(replacement.to_string()), - Range::from_located(expr), - ); - if checker.patch(diagnostic.kind.code()) { - diagnostic.amend(Fix::replacement( - replacement.to_lowercase(), - expr.location, - expr.end_location.unwrap(), - )); +pub fn use_pep585_annotation(checker: &mut Checker, expr: &Expr) { + if let Some(call_path) = checker.resolve_call_path(expr) { + let mut diagnostic = Diagnostic::new( + violations::UsePEP585Annotation(call_path[call_path.len() - 1].to_string()), + Range::from_located(expr), + ); + if checker.patch(diagnostic.kind.code()) { + diagnostic.amend(Fix::replacement( + call_path[call_path.len() - 1].to_lowercase(), + expr.location, + expr.end_location.unwrap(), + )); + } + checker.diagnostics.push(diagnostic); } - checker.diagnostics.push(diagnostic); } diff --git a/src/pyupgrade/rules/use_pep604_annotation.rs b/src/pyupgrade/rules/use_pep604_annotation.rs index bd6923b006..c19e6d08d1 100644 --- a/src/pyupgrade/rules/use_pep604_annotation.rs +++ b/src/pyupgrade/rules/use_pep604_annotation.rs @@ -1,6 +1,5 @@ use rustpython_ast::{Constant, Expr, ExprKind, Location, Operator}; -use crate::ast::helpers::{collect_call_paths, dealias_call_path}; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::fix::Fix; @@ -62,8 +61,10 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s return; } - let call_path = dealias_call_path(collect_call_paths(value), &checker.import_aliases); - if checker.match_typing_call_path(&call_path, "Optional") { + let call_path = checker.resolve_call_path(value); + if call_path.as_ref().map_or(false, |call_path| { + checker.match_typing_call_path(call_path, "Optional") + }) { let mut diagnostic = Diagnostic::new(violations::UsePEP604Annotation, Range::from_located(expr)); if checker.patch(diagnostic.kind.code()) { @@ -76,7 +77,9 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s )); } checker.diagnostics.push(diagnostic); - } else if checker.match_typing_call_path(&call_path, "Union") { + } else if call_path.as_ref().map_or(false, |call_path| { + checker.match_typing_call_path(call_path, "Union") + }) { let mut diagnostic = Diagnostic::new(violations::UsePEP604Annotation, Range::from_located(expr)); if checker.patch(diagnostic.kind.code()) { diff --git a/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP024_UP024_0.py.snap b/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP024_UP024_0.py.snap index 6da0bb67df..3b5035aeb8 100644 --- a/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP024_UP024_0.py.snap +++ b/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP024_UP024_0.py.snap @@ -1,6 +1,6 @@ --- source: src/pyupgrade/mod.rs -expression: checks +expression: diagnostics --- - kind: OSErrorAlias: EnvironmentError @@ -206,4 +206,21 @@ expression: checks row: 65 column: 23 parent: ~ +- kind: + OSErrorAlias: mmap.error + location: + row: 87 + column: 7 + end_location: + row: 87 + column: 19 + fix: + content: OSError + location: + row: 87 + column: 7 + end_location: + row: 87 + column: 19 + parent: ~ diff --git a/src/pyupgrade/snapshots/ruff__pyupgrade__tests__datetime_utc_alias_py311.snap b/src/pyupgrade/snapshots/ruff__pyupgrade__tests__datetime_utc_alias_py311.snap index 286db221ee..26e908c8c8 100644 --- a/src/pyupgrade/snapshots/ruff__pyupgrade__tests__datetime_utc_alias_py311.snap +++ b/src/pyupgrade/snapshots/ruff__pyupgrade__tests__datetime_utc_alias_py311.snap @@ -1,9 +1,32 @@ --- source: src/pyupgrade/mod.rs -expression: checks +expression: diagnostics --- - kind: - DatetimeTimezoneUTC: ~ + DatetimeTimezoneUTC: + straight_import: false + location: + row: 7 + column: 6 + end_location: + row: 7 + column: 18 + fix: ~ + parent: ~ +- kind: + DatetimeTimezoneUTC: + straight_import: false + location: + row: 8 + column: 6 + end_location: + row: 8 + column: 12 + fix: ~ + parent: ~ +- kind: + DatetimeTimezoneUTC: + straight_import: true location: row: 10 column: 6 @@ -20,20 +43,14 @@ expression: checks column: 27 parent: ~ - kind: - DatetimeTimezoneUTC: ~ + DatetimeTimezoneUTC: + straight_import: false location: row: 11 column: 6 end_location: row: 11 column: 21 - fix: - content: dt.UTC - location: - row: 11 - column: 6 - end_location: - row: 11 - column: 21 + fix: ~ parent: ~ diff --git a/src/violations.rs b/src/violations.rs index 25742ac599..60e43f1689 100644 --- a/src/violations.rs +++ b/src/violations.rs @@ -3554,19 +3554,27 @@ impl AlwaysAutofixableViolation for RemoveSixCompat { } define_violation!( - pub struct DatetimeTimezoneUTC; + pub struct DatetimeTimezoneUTC { + pub straight_import: bool, + } ); -impl AlwaysAutofixableViolation for DatetimeTimezoneUTC { +impl Violation for DatetimeTimezoneUTC { fn message(&self) -> String { "Use `datetime.UTC` alias".to_string() } - fn autofix_title(&self) -> String { - "Convert to `datetime.UTC` alias".to_string() + fn autofix_title_formatter(&self) -> Option String> { + if self.straight_import { + Some(|_| "Convert to `datetime.UTC` alias".to_string()) + } else { + None + } } fn placeholder() -> Self { - DatetimeTimezoneUTC + DatetimeTimezoneUTC { + straight_import: true, + } } } diff --git a/src/visibility.rs b/src/visibility.rs index 95ac409a9e..2719ae826e 100644 --- a/src/visibility.rs +++ b/src/visibility.rs @@ -5,7 +5,6 @@ use std::path::Path; use rustpython_ast::{Expr, Stmt, StmtKind}; -use crate::ast::helpers::match_module_member; use crate::checkers::ast::Checker; use crate::docstrings::definition::Documentable; @@ -31,26 +30,18 @@ pub struct VisibleScope { /// Returns `true` if a function is a "static method". pub fn is_staticmethod(checker: &Checker, decorator_list: &[Expr]) -> bool { decorator_list.iter().any(|expr| { - match_module_member( - expr, - "", - "staticmethod", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == ["", "staticmethod"]) }) } /// Returns `true` if a function is a "class method". pub fn is_classmethod(checker: &Checker, decorator_list: &[Expr]) -> bool { decorator_list.iter().any(|expr| { - match_module_member( - expr, - "", - "classmethod", - &checker.from_imports, - &checker.import_aliases, - ) + checker + .resolve_call_path(expr) + .map_or(false, |call_path| call_path == ["", "classmethod"]) }) } @@ -71,13 +62,9 @@ pub fn is_override(checker: &Checker, decorator_list: &[Expr]) -> bool { /// Returns `true` if a function definition is an `@abstractmethod`. pub fn is_abstract(checker: &Checker, decorator_list: &[Expr]) -> bool { decorator_list.iter().any(|expr| { - match_module_member( - expr, - "abc", - "abstractmethod", - &checker.from_imports, - &checker.import_aliases, - ) + checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path == ["abc", "abstractmethod"] || call_path == ["abc", "abstractproperty"] + }) }) }