diff --git a/Cargo.lock b/Cargo.lock index feb2ae7df2..2e4ebb9a95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1946,6 +1946,7 @@ dependencies = [ "serde", "serde-wasm-bindgen", "shellexpand", + "smallvec", "strum", "strum_macros", "test-case", diff --git a/Cargo.toml b/Cargo.toml index 93c5f9c866..2fd114f543 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ schemars = { version = "0.8.11" } semver = { version = "1.0.16" } serde = { version = "1.0.147", features = ["derive"] } shellexpand = { version = "3.0.0" } +smallvec = { version = "1.10.0" } strum = { version = "0.24.1", features = ["strum_macros"] } strum_macros = { version = "0.24.3" } textwrap = { version = "0.16.0" } diff --git a/src/ast/function_type.rs b/src/ast/function_type.rs index 290f559b76..00472d4998 100644 --- a/src/ast/function_type.rs +++ b/src/ast/function_type.rs @@ -33,7 +33,7 @@ pub fn classify( checker.resolve_call_path(expr).map_or(false, |call_path| { METACLASS_BASES .iter() - .any(|(module, member)| call_path == [*module, *member]) + .any(|(module, member)| call_path.as_slice() == [*module, *member]) }) }) || decorator_list.iter().any(|expr| { diff --git a/src/ast/helpers.rs b/src/ast/helpers.rs index 103fbb5cbb..ac9cf620ab 100644 --- a/src/ast/helpers.rs +++ b/src/ast/helpers.rs @@ -10,8 +10,9 @@ use rustpython_ast::{ use rustpython_parser::lexer; use rustpython_parser::lexer::Tok; use rustpython_parser::token::StringKind; +use smallvec::smallvec; -use crate::ast::types::{Binding, BindingKind, Range}; +use crate::ast::types::{Binding, BindingKind, CallPath, Range}; use crate::checkers::ast::Checker; use crate::source_code::{Generator, Indexer, Locator, Stylist}; @@ -62,6 +63,29 @@ pub fn collect_call_path(expr: &Expr) -> Vec<&str> { segments } +fn collect_small_path_inner<'a>(expr: &'a Expr, parts: &mut CallPath<'a>) { + match &expr.node { + ExprKind::Call { func, .. } => { + collect_small_path_inner(func, parts); + } + ExprKind::Attribute { value, attr, .. } => { + collect_small_path_inner(value, parts); + parts.push(attr); + } + ExprKind::Name { id, .. } => { + parts.push(id); + } + _ => {} + } +} + +/// Convert an `Expr` to its call path segments (like ["typing", "List"]). +pub fn collect_small_path(expr: &Expr) -> CallPath { + let mut segments = smallvec![]; + collect_small_path_inner(expr, &mut segments); + segments +} + /// Convert an `Expr` to its call path (like `List`, or `typing.List`). pub fn compose_call_path(expr: &Expr) -> Option { let call_path = collect_call_path(expr); @@ -90,7 +114,7 @@ pub fn contains_call_path(checker: &Checker, expr: &Expr, target: &[&str]) -> bo any_over_expr(expr, &|expr| { checker .resolve_call_path(expr) - .map_or(false, |call_path| call_path == target) + .map_or(false, |call_path| call_path.as_slice() == target) }) } @@ -415,11 +439,11 @@ pub fn format_import_from_member( } /// Split a target string (like `typing.List`) into (`typing`, `List`). -pub fn to_call_path(target: &str) -> Vec<&str> { +pub fn to_call_path(target: &str) -> CallPath { if target.contains('.') { target.split('.').collect() } else { - vec!["", target] + smallvec!["", target] } } diff --git a/src/ast/types.rs b/src/ast/types.rs index 8b1844949e..19e323d9ca 100644 --- a/src/ast/types.rs +++ b/src/ast/types.rs @@ -242,3 +242,5 @@ impl<'a> From<&RefEquality<'a, Expr>> for &'a Expr { r.0 } } + +pub type CallPath<'a> = smallvec::SmallVec<[&'a str; 8]>; diff --git a/src/checkers/ast.rs b/src/checkers/ast.rs index 572f8b62e9..6b44257303 100644 --- a/src/checkers/ast.rs +++ b/src/checkers/ast.rs @@ -13,12 +13,14 @@ use rustpython_parser::ast::{ KeywordData, Operator, Stmt, StmtKind, Suite, }; use rustpython_parser::parser; +use smallvec::smallvec; -use crate::ast::helpers::{binding_range, collect_call_path, extract_handler_names}; +use crate::ast::helpers::{binding_range, collect_small_path, extract_handler_names}; use crate::ast::operations::extract_all_names; use crate::ast::relocate::relocate_expr; use crate::ast::types::{ - Binding, BindingKind, ClassDef, FunctionDef, Lambda, Node, Range, RefEquality, Scope, ScopeKind, + Binding, BindingKind, CallPath, ClassDef, FunctionDef, Lambda, Node, Range, RefEquality, Scope, + ScopeKind, }; use crate::ast::visitor::{walk_excepthandler, Visitor}; use crate::ast::{branch_detection, cast, helpers, operations, visitor}; @@ -27,7 +29,7 @@ use crate::noqa::Directive; use crate::python::builtins::{BUILTINS, MAGIC_GLOBALS}; use crate::python::future::ALL_FEATURE_NAMES; use crate::python::typing; -use crate::python::typing::SubscriptKind; +use crate::python::typing::{Callable, SubscriptKind}; use crate::registry::{Diagnostic, RuleCode}; use crate::rules::{ flake8_2020, flake8_annotations, flake8_bandit, flake8_blind_except, flake8_boolean_trap, @@ -170,21 +172,21 @@ impl<'a> Checker<'a> { } /// Return `true` if the call path is a reference to `typing.${target}`. - pub fn match_typing_call_path(&self, call_path: &[&str], target: &str) -> bool { - if call_path == ["typing", target] { + pub fn match_typing_call_path(&self, call_path: &CallPath, target: &str) -> bool { + if call_path.as_slice() == ["typing", target] { return true; } if typing::TYPING_EXTENSIONS.contains(target) { - if call_path == ["typing_extensions", target] { + if call_path.as_slice() == ["typing_extensions", target] { return true; } } if self.settings.typing_modules.iter().any(|module| { - let mut module = module.split('.').collect::>(); + let mut module: CallPath = module.split('.').collect(); module.push(target); - call_path == module.as_slice() + *call_path == module }) { return true; } @@ -206,11 +208,11 @@ impl<'a> Checker<'a> { }) } - pub fn resolve_call_path<'b>(&'a self, value: &'b Expr) -> Option> + pub fn resolve_call_path<'b>(&'a self, value: &'b Expr) -> Option> where 'b: 'a, { - let call_path = collect_call_path(value); + let call_path = collect_small_path(value); if let Some(head) = call_path.first() { if let Some(binding) = self.find_binding(head) { match &binding.kind { @@ -219,8 +221,8 @@ impl<'a> Checker<'a> { if name.starts_with('.') { return None; } - let mut source_path: Vec<&str> = name.split('.').collect(); - source_path.extend(call_path.iter().skip(1)); + let mut source_path: CallPath = name.split('.').collect(); + source_path.extend(call_path.into_iter().skip(1)); return Some(source_path); } BindingKind::SubmoduleImportation(name, ..) => { @@ -228,8 +230,8 @@ impl<'a> Checker<'a> { if name.starts_with('.') { return None; } - let mut source_path: Vec<&str> = name.split('.').collect(); - source_path.extend(call_path.iter().skip(1)); + let mut source_path: CallPath = name.split('.').collect(); + source_path.extend(call_path.into_iter().skip(1)); return Some(source_path); } BindingKind::FromImportation(.., name) => { @@ -237,12 +239,12 @@ impl<'a> Checker<'a> { if name.starts_with('.') { return None; } - let mut source_path: Vec<&str> = name.split('.').collect(); - source_path.extend(call_path.iter().skip(1)); + let mut source_path: CallPath = name.split('.').collect(); + source_path.extend(call_path.into_iter().skip(1)); return Some(source_path); } BindingKind::Builtin => { - let mut source_path: Vec<&str> = Vec::with_capacity(call_path.len() + 1); + let mut source_path: CallPath = smallvec![]; source_path.push(""); source_path.extend(call_path); return Some(source_path); @@ -2712,150 +2714,167 @@ where args, keywords, } => { - let call_path = self.resolve_call_path(func); - if call_path.as_ref().map_or(false, |call_path| { - self.match_typing_call_path(call_path, "ForwardRef") - }) { - self.visit_expr(func); - for expr in args { - self.in_type_definition = true; - self.visit_expr(expr); - self.in_type_definition = prev_in_type_definition; + let callable = self.resolve_call_path(func).and_then(|call_path| { + if self.match_typing_call_path(&call_path, "ForwardRef") { + Some(Callable::ForwardRef) + } else if self.match_typing_call_path(&call_path, "cast") { + Some(Callable::Cast) + } else if self.match_typing_call_path(&call_path, "NewType") { + Some(Callable::NewType) + } else if self.match_typing_call_path(&call_path, "TypeVar") { + Some(Callable::TypeVar) + } else if self.match_typing_call_path(&call_path, "NamedTuple") { + Some(Callable::NamedTuple) + } else if self.match_typing_call_path(&call_path, "TypedDict") { + Some(Callable::TypedDict) + } else if ["Arg", "DefaultArg", "NamedArg", "DefaultNamedArg"] + .iter() + .any(|target| call_path.as_slice() == ["mypy_extensions", target]) + { + Some(Callable::MypyExtension) + } else { + None } - } else if call_path.as_ref().map_or(false, |call_path| { - self.match_typing_call_path(call_path, "cast") - }) { - self.visit_expr(func); - if !args.is_empty() { - self.in_type_definition = true; - self.visit_expr(&args[0]); - self.in_type_definition = prev_in_type_definition; - } - for expr in args.iter().skip(1) { - self.visit_expr(expr); - } - } else if call_path.as_ref().map_or(false, |call_path| { - self.match_typing_call_path(call_path, "NewType") - }) { - self.visit_expr(func); - for expr in args.iter().skip(1) { - self.in_type_definition = true; - self.visit_expr(expr); - self.in_type_definition = prev_in_type_definition; - } - } else if call_path.as_ref().map_or(false, |call_path| { - self.match_typing_call_path(call_path, "TypeVar") - }) { - self.visit_expr(func); - for expr in args.iter().skip(1) { - self.in_type_definition = true; - self.visit_expr(expr); - self.in_type_definition = prev_in_type_definition; - } - for keyword in keywords { - let KeywordData { arg, value } = &keyword.node; - if let Some(id) = arg { - if id == "bound" { - self.in_type_definition = true; - self.visit_expr(value); - self.in_type_definition = prev_in_type_definition; - } else { - self.in_type_definition = false; - self.visit_expr(value); - self.in_type_definition = prev_in_type_definition; - } + }); + match callable { + Some(Callable::ForwardRef) => { + self.visit_expr(func); + for expr in args { + self.in_type_definition = true; + self.visit_expr(expr); + self.in_type_definition = prev_in_type_definition; } } - } else if call_path.as_ref().map_or(false, |call_path| { - self.match_typing_call_path(call_path, "NamedTuple") - }) { - self.visit_expr(func); - - // Ex) NamedTuple("a", [("a", int)]) - if args.len() > 1 { - match &args[1].node { - ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { - for elt in elts { - match &elt.node { - ExprKind::List { elts, .. } - | ExprKind::Tuple { elts, .. } => { - if elts.len() == 2 { - self.in_type_definition = false; - self.visit_expr(&elts[0]); - self.in_type_definition = prev_in_type_definition; - - self.in_type_definition = true; - self.visit_expr(&elts[1]); - self.in_type_definition = prev_in_type_definition; - } - } - _ => {} - } + Some(Callable::Cast) => { + self.visit_expr(func); + if !args.is_empty() { + self.in_type_definition = true; + self.visit_expr(&args[0]); + self.in_type_definition = prev_in_type_definition; + } + for expr in args.iter().skip(1) { + self.visit_expr(expr); + } + } + Some(Callable::NewType) => { + self.visit_expr(func); + for expr in args.iter().skip(1) { + self.in_type_definition = true; + self.visit_expr(expr); + self.in_type_definition = prev_in_type_definition; + } + } + Some(Callable::TypeVar) => { + self.visit_expr(func); + for expr in args.iter().skip(1) { + self.in_type_definition = true; + self.visit_expr(expr); + self.in_type_definition = prev_in_type_definition; + } + for keyword in keywords { + let KeywordData { arg, value } = &keyword.node; + if let Some(id) = arg { + if id == "bound" { + self.in_type_definition = true; + self.visit_expr(value); + self.in_type_definition = prev_in_type_definition; + } else { + self.in_type_definition = false; + self.visit_expr(value); + self.in_type_definition = prev_in_type_definition; } } - _ => {} } } + Some(Callable::NamedTuple) => { + self.visit_expr(func); - // Ex) NamedTuple("a", a=int) - for keyword in keywords { - let KeywordData { value, .. } = &keyword.node; - self.in_type_definition = true; - self.visit_expr(value); - self.in_type_definition = prev_in_type_definition; - } - } else if call_path.as_ref().map_or(false, |call_path| { - self.match_typing_call_path(call_path, "TypedDict") - }) { - self.visit_expr(func); + // Ex) NamedTuple("a", [("a", int)]) + if args.len() > 1 { + match &args[1].node { + ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { + for elt in elts { + match &elt.node { + ExprKind::List { elts, .. } + | ExprKind::Tuple { elts, .. } => { + if elts.len() == 2 { + self.in_type_definition = false; - // Ex) TypedDict("a", {"a": int}) - if args.len() > 1 { - if let ExprKind::Dict { keys, values } = &args[1].node { - for key in keys { - self.in_type_definition = false; - self.visit_expr(key); - self.in_type_definition = prev_in_type_definition; - } - for value in values { - self.in_type_definition = true; - self.visit_expr(value); - self.in_type_definition = prev_in_type_definition; + self.visit_expr(&elts[0]); + self.in_type_definition = + prev_in_type_definition; + + self.in_type_definition = true; + self.visit_expr(&elts[1]); + self.in_type_definition = + prev_in_type_definition; + } + } + _ => {} + } + } + } + _ => {} } } - } - // Ex) TypedDict("a", a=int) - for keyword in keywords { - let KeywordData { value, .. } = &keyword.node; - self.in_type_definition = true; - self.visit_expr(value); - self.in_type_definition = prev_in_type_definition; + // Ex) NamedTuple("a", a=int) + for keyword in keywords { + let KeywordData { value, .. } = &keyword.node; + self.in_type_definition = true; + self.visit_expr(value); + self.in_type_definition = prev_in_type_definition; + } } - } else if call_path.as_ref().map_or(false, |call_path| { - ["Arg", "DefaultArg", "NamedArg", "DefaultNamedArg"] - .iter() - .any(|target| *call_path == ["mypy_extensions", target]) - }) { - self.visit_expr(func); + Some(Callable::TypedDict) => { + self.visit_expr(func); - // Ex) DefaultNamedArg(bool | None, name="some_prop_name") - let mut arguments = args.iter().chain(keywords.iter().map(|keyword| { - let KeywordData { value, .. } = &keyword.node; - value - })); - if let Some(expr) = arguments.next() { - self.in_type_definition = true; - self.visit_expr(expr); - self.in_type_definition = prev_in_type_definition; + // Ex) TypedDict("a", {"a": int}) + if args.len() > 1 { + if let ExprKind::Dict { keys, values } = &args[1].node { + for key in keys { + self.in_type_definition = false; + self.visit_expr(key); + self.in_type_definition = prev_in_type_definition; + } + for value in values { + self.in_type_definition = true; + self.visit_expr(value); + self.in_type_definition = prev_in_type_definition; + } + } + } + + // Ex) TypedDict("a", a=int) + for keyword in keywords { + let KeywordData { value, .. } = &keyword.node; + self.in_type_definition = true; + self.visit_expr(value); + self.in_type_definition = prev_in_type_definition; + } } - for expr in arguments { - self.in_type_definition = false; - self.visit_expr(expr); - self.in_type_definition = prev_in_type_definition; + Some(Callable::MypyExtension) => { + self.visit_expr(func); + + // Ex) DefaultNamedArg(bool | None, name="some_prop_name") + let mut arguments = args.iter().chain(keywords.iter().map(|keyword| { + let KeywordData { value, .. } = &keyword.node; + value + })); + if let Some(expr) = arguments.next() { + self.in_type_definition = true; + self.visit_expr(expr); + self.in_type_definition = prev_in_type_definition; + } + for expr in arguments { + self.in_type_definition = false; + self.visit_expr(expr); + self.in_type_definition = prev_in_type_definition; + } + } + None => { + visitor::walk_expr(self, expr); } - } else { - visitor::walk_expr(self, expr); } } ExprKind::Subscript { value, slice, ctx } => { diff --git a/src/python/typing.rs b/src/python/typing.rs index 82407c794a..a96d698f12 100644 --- a/src/python/typing.rs +++ b/src/python/typing.rs @@ -249,3 +249,13 @@ pub fn is_pep585_builtin(checker: &Checker, expr: &Expr) -> bool { PEP_585_BUILTINS_ELIGIBLE.contains(&call_path.as_slice()) }) } + +pub enum Callable { + ForwardRef, + Cast, + NewType, + TypeVar, + NamedTuple, + TypedDict, + MypyExtension, +} diff --git a/src/rules/flake8_2020/rules.rs b/src/rules/flake8_2020/rules.rs index 388dd7b25d..398c3dc0b5 100644 --- a/src/rules/flake8_2020/rules.rs +++ b/src/rules/flake8_2020/rules.rs @@ -9,7 +9,7 @@ use crate::violations; fn is_sys(checker: &Checker, expr: &Expr, target: &str) -> bool { checker .resolve_call_path(expr) - .map_or(false, |path| path == ["sys", target]) + .map_or(false, |call_path| call_path.as_slice() == ["sys", target]) } /// YTT101, YTT102, YTT301, YTT303 @@ -182,7 +182,7 @@ pub fn compare(checker: &mut Checker, left: &Expr, ops: &[Cmpop], comparators: & pub fn name_or_attribute(checker: &mut Checker, expr: &Expr) { if checker .resolve_call_path(expr) - .map_or(false, |path| path == ["six", "PY3"]) + .map_or(false, |call_path| call_path.as_slice() == ["six", "PY3"]) { checker.diagnostics.push(Diagnostic::new( violations::SixPY3Referenced, diff --git a/src/rules/flake8_bandit/rules/bad_file_permissions.rs b/src/rules/flake8_bandit/rules/bad_file_permissions.rs index 8f78589ea8..3557a2c523 100644 --- a/src/rules/flake8_bandit/rules/bad_file_permissions.rs +++ b/src/rules/flake8_bandit/rules/bad_file_permissions.rs @@ -94,7 +94,7 @@ pub fn bad_file_permissions( ) { if checker .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["os", "chmod"]) + .map_or(false, |call_path| call_path.as_slice() == ["os", "chmod"]) { let call_args = SimpleCallArgs::new(args, keywords); if let Some(mode_arg) = call_args.get_argument("mode", Some(1)) { diff --git a/src/rules/flake8_bandit/rules/hashlib_insecure_hash_functions.rs b/src/rules/flake8_bandit/rules/hashlib_insecure_hash_functions.rs index 09263e6192..caf393f38f 100644 --- a/src/rules/flake8_bandit/rules/hashlib_insecure_hash_functions.rs +++ b/src/rules/flake8_bandit/rules/hashlib_insecure_hash_functions.rs @@ -22,6 +22,11 @@ fn is_used_for_security(call_args: &SimpleCallArgs) -> bool { } } +enum HashlibCall { + New, + WeakHash(&'static str), +} + /// S324 pub fn hashlib_insecure_hash_functions( checker: &mut Checker, @@ -29,39 +34,46 @@ pub fn hashlib_insecure_hash_functions( args: &[Expr], keywords: &[Keyword], ) { - if let Some(call_path) = checker.resolve_call_path(func) { - if call_path == ["hashlib", "new"] { - let call_args = SimpleCallArgs::new(args, keywords); - - if !is_used_for_security(&call_args) { - return; - } - - if let Some(name_arg) = call_args.get_argument("name", Some(0)) { - if let Some(hash_func_name) = string_literal(name_arg) { - if WEAK_HASHES.contains(&hash_func_name.to_lowercase().as_str()) { - checker.diagnostics.push(Diagnostic::new( - violations::HashlibInsecureHashFunction(hash_func_name.to_string()), - Range::from_located(name_arg), - )); - } - } - } + if let Some(hashlib_call) = checker.resolve_call_path(func).and_then(|call_path| { + if call_path.as_slice() == ["hashlib", "new"] { + Some(HashlibCall::New) } else { - for func_name in &WEAK_HASHES { - if call_path == ["hashlib", func_name] { - let call_args = SimpleCallArgs::new(args, keywords); + WEAK_HASHES + .iter() + .find(|hash| call_path.as_slice() == ["hashlib", hash]) + .map(|hash| HashlibCall::WeakHash(hash)) + } + }) { + match hashlib_call { + HashlibCall::New => { + let call_args = SimpleCallArgs::new(args, keywords); - if !is_used_for_security(&call_args) { - return; - } - - checker.diagnostics.push(Diagnostic::new( - violations::HashlibInsecureHashFunction((*func_name).to_string()), - Range::from_located(func), - )); + if !is_used_for_security(&call_args) { return; } + + if let Some(name_arg) = call_args.get_argument("name", Some(0)) { + if let Some(hash_func_name) = string_literal(name_arg) { + if WEAK_HASHES.contains(&hash_func_name.to_lowercase().as_str()) { + checker.diagnostics.push(Diagnostic::new( + violations::HashlibInsecureHashFunction(hash_func_name.to_string()), + Range::from_located(name_arg), + )); + } + } + } + } + HashlibCall::WeakHash(func_name) => { + let call_args = SimpleCallArgs::new(args, keywords); + + if !is_used_for_security(&call_args) { + return; + } + + checker.diagnostics.push(Diagnostic::new( + violations::HashlibInsecureHashFunction((*func_name).to_string()), + Range::from_located(func), + )); } } } diff --git a/src/rules/flake8_bandit/rules/jinja2_autoescape_false.rs b/src/rules/flake8_bandit/rules/jinja2_autoescape_false.rs index c0385af332..923477648a 100644 --- a/src/rules/flake8_bandit/rules/jinja2_autoescape_false.rs +++ b/src/rules/flake8_bandit/rules/jinja2_autoescape_false.rs @@ -14,10 +14,9 @@ pub fn jinja2_autoescape_false( args: &[Expr], keywords: &[Keyword], ) { - if checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["jinja2", "Environment"]) - { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["jinja2", "Environment"] + }) { let call_args = SimpleCallArgs::new(args, keywords); if let Some(autoescape_arg) = call_args.get_argument("autoescape", None) { diff --git a/src/rules/flake8_bandit/rules/request_with_no_cert_validation.rs b/src/rules/flake8_bandit/rules/request_with_no_cert_validation.rs index 70b8851128..505d6014be 100644 --- a/src/rules/flake8_bandit/rules/request_with_no_cert_validation.rs +++ b/src/rules/flake8_bandit/rules/request_with_no_cert_validation.rs @@ -29,40 +29,28 @@ pub fn request_with_no_cert_validation( args: &[Expr], keywords: &[Keyword], ) { - if let Some(call_path) = checker.resolve_call_path(func) { - let call_args = SimpleCallArgs::new(args, keywords); - for func_name in &REQUESTS_HTTP_VERBS { - if call_path == ["requests", func_name] { - if let Some(verify_arg) = call_args.get_argument("verify", None) { - if let ExprKind::Constant { - value: Constant::Bool(false), - .. - } = &verify_arg.node - { - checker.diagnostics.push(Diagnostic::new( - violations::RequestWithNoCertValidation("requests".to_string()), - Range::from_located(verify_arg), - )); - } - } - return; + if let Some(target) = checker.resolve_call_path(func).and_then(|call_path| { + if call_path.len() == 2 { + if call_path[0] == "requests" && REQUESTS_HTTP_VERBS.contains(&call_path[1]) { + return Some("requests"); + } + if call_path[0] == "httpx" && HTTPX_METHODS.contains(&call_path[1]) { + return Some("httpx"); } } - for func_name in &HTTPX_METHODS { - if call_path == ["httpx", func_name] { - if let Some(verify_arg) = call_args.get_argument("verify", None) { - if let ExprKind::Constant { - value: Constant::Bool(false), - .. - } = &verify_arg.node - { - checker.diagnostics.push(Diagnostic::new( - violations::RequestWithNoCertValidation("httpx".to_string()), - Range::from_located(verify_arg), - )); - } - } - return; + None + }) { + let call_args = SimpleCallArgs::new(args, keywords); + if let Some(verify_arg) = call_args.get_argument("verify", None) { + if let ExprKind::Constant { + value: Constant::Bool(false), + .. + } = &verify_arg.node + { + checker.diagnostics.push(Diagnostic::new( + violations::RequestWithNoCertValidation(target.to_string()), + Range::from_located(verify_arg), + )); } } } diff --git a/src/rules/flake8_bandit/rules/request_without_timeout.rs b/src/rules/flake8_bandit/rules/request_without_timeout.rs index e304656883..80738ef085 100644 --- a/src/rules/flake8_bandit/rules/request_without_timeout.rs +++ b/src/rules/flake8_bandit/rules/request_without_timeout.rs @@ -19,7 +19,7 @@ pub fn request_without_timeout( if checker.resolve_call_path(func).map_or(false, |call_path| { HTTP_VERBS .iter() - .any(|func_name| call_path == ["requests", func_name]) + .any(|func_name| call_path.as_slice() == ["requests", func_name]) }) { let call_args = SimpleCallArgs::new(args, keywords); if let Some(timeout_arg) = call_args.get_argument("timeout", None) { diff --git a/src/rules/flake8_bandit/rules/snmp_insecure_version.rs b/src/rules/flake8_bandit/rules/snmp_insecure_version.rs index f93525b10f..014d1256f1 100644 --- a/src/rules/flake8_bandit/rules/snmp_insecure_version.rs +++ b/src/rules/flake8_bandit/rules/snmp_insecure_version.rs @@ -16,7 +16,7 @@ pub fn snmp_insecure_version( keywords: &[Keyword], ) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["pysnmp", "hlapi", "CommunityData"] + call_path.as_slice() == ["pysnmp", "hlapi", "CommunityData"] }) { let call_args = SimpleCallArgs::new(args, keywords); if let Some(mp_model_arg) = call_args.get_argument("mpModel", None) { diff --git a/src/rules/flake8_bandit/rules/snmp_weak_cryptography.rs b/src/rules/flake8_bandit/rules/snmp_weak_cryptography.rs index 4fb91f66f7..90c8f7f21d 100644 --- a/src/rules/flake8_bandit/rules/snmp_weak_cryptography.rs +++ b/src/rules/flake8_bandit/rules/snmp_weak_cryptography.rs @@ -14,7 +14,7 @@ pub fn snmp_weak_cryptography( keywords: &[Keyword], ) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["pysnmp", "hlapi", "UsmUserData"] + call_path.as_slice() == ["pysnmp", "hlapi", "UsmUserData"] }) { let call_args = SimpleCallArgs::new(args, keywords); if call_args.len() < 3 { diff --git a/src/rules/flake8_bandit/rules/unsafe_yaml_load.rs b/src/rules/flake8_bandit/rules/unsafe_yaml_load.rs index 7392cb8d6b..fb00b8d739 100644 --- a/src/rules/flake8_bandit/rules/unsafe_yaml_load.rs +++ b/src/rules/flake8_bandit/rules/unsafe_yaml_load.rs @@ -10,14 +10,15 @@ use crate::violations; pub fn unsafe_yaml_load(checker: &mut Checker, func: &Expr, args: &[Expr], keywords: &[Keyword]) { if checker .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["yaml", "load"]) + .map_or(false, |call_path| call_path.as_slice() == ["yaml", "load"]) { let call_args = SimpleCallArgs::new(args, keywords); if let Some(loader_arg) = call_args.get_argument("Loader", Some(1)) { if !checker .resolve_call_path(loader_arg) .map_or(false, |call_path| { - call_path == ["yaml", "SafeLoader"] || call_path == ["yaml", "CSafeLoader"] + call_path.as_slice() == ["yaml", "SafeLoader"] + || call_path.as_slice() == ["yaml", "CSafeLoader"] }) { let loader = match &loader_arg.node { diff --git a/src/rules/flake8_bugbear/rules/abstract_base_class.rs b/src/rules/flake8_bugbear/rules/abstract_base_class.rs index 4ae1dcd3d5..3579c0547e 100644 --- a/src/rules/flake8_bugbear/rules/abstract_base_class.rs +++ b/src/rules/flake8_bugbear/rules/abstract_base_class.rs @@ -15,11 +15,13 @@ fn is_abc_class(checker: &Checker, bases: &[Expr], keywords: &[Keyword]) -> bool .map_or(false, |arg| arg == "metaclass") && checker .resolve_call_path(&keyword.node.value) - .map_or(false, |call_path| call_path == ["abc", "ABCMeta"]) + .map_or(false, |call_path| { + call_path.as_slice() == ["abc", "ABCMeta"] + }) }) || bases.iter().any(|base| { checker .resolve_call_path(base) - .map_or(false, |call_path| call_path == ["abc", "ABC"]) + .map_or(false, |call_path| call_path.as_slice() == ["abc", "ABC"]) }) } diff --git a/src/rules/flake8_bugbear/rules/assert_raises_exception.rs b/src/rules/flake8_bugbear/rules/assert_raises_exception.rs index 537c8fa19c..5a301b54f6 100644 --- a/src/rules/flake8_bugbear/rules/assert_raises_exception.rs +++ b/src/rules/flake8_bugbear/rules/assert_raises_exception.rs @@ -25,7 +25,7 @@ pub fn assert_raises_exception(checker: &mut Checker, stmt: &Stmt, items: &[With } if !checker .resolve_call_path(args.first().unwrap()) - .map_or(false, |call_path| call_path == ["", "Exception"]) + .map_or(false, |call_path| call_path.as_slice() == ["", "Exception"]) { return; } diff --git a/src/rules/flake8_bugbear/rules/cached_instance_method.rs b/src/rules/flake8_bugbear/rules/cached_instance_method.rs index 72723c5fe4..5f638815ed 100644 --- a/src/rules/flake8_bugbear/rules/cached_instance_method.rs +++ b/src/rules/flake8_bugbear/rules/cached_instance_method.rs @@ -7,7 +7,8 @@ use crate::violations; fn is_cache_func(checker: &Checker, expr: &Expr) -> bool { checker.resolve_call_path(expr).map_or(false, |call_path| { - call_path == ["functools", "lru_cache"] || call_path == ["functools", "cache"] + call_path.as_slice() == ["functools", "lru_cache"] + || call_path.as_slice() == ["functools", "cache"] }) } diff --git a/src/rules/flake8_bugbear/rules/function_call_argument_default.rs b/src/rules/flake8_bugbear/rules/function_call_argument_default.rs index 2c3fa53ef9..3e2c72b2da 100644 --- a/src/rules/flake8_bugbear/rules/function_call_argument_default.rs +++ b/src/rules/flake8_bugbear/rules/function_call_argument_default.rs @@ -2,7 +2,7 @@ use rustpython_ast::{Arguments, Constant, Expr, ExprKind}; use super::mutable_argument_default::is_mutable_func; use crate::ast::helpers::{compose_call_path, to_call_path}; -use crate::ast::types::Range; +use crate::ast::types::{CallPath, Range}; use crate::ast::visitor; use crate::ast::visitor::Visitor; use crate::checkers::ast::Checker; @@ -19,9 +19,11 @@ const IMMUTABLE_FUNCS: &[&[&str]] = &[ &["re", "compile"], ]; -fn is_immutable_func(checker: &Checker, expr: &Expr, extend_immutable_calls: &[Vec<&str>]) -> bool { +fn is_immutable_func(checker: &Checker, expr: &Expr, extend_immutable_calls: &[CallPath]) -> bool { checker.resolve_call_path(expr).map_or(false, |call_path| { - IMMUTABLE_FUNCS.iter().any(|target| call_path == *target) + IMMUTABLE_FUNCS + .iter() + .any(|target| call_path.as_slice() == *target) || extend_immutable_calls .iter() .any(|target| call_path == *target) @@ -31,7 +33,7 @@ fn is_immutable_func(checker: &Checker, expr: &Expr, extend_immutable_calls: &[V struct ArgumentDefaultVisitor<'a> { checker: &'a Checker<'a>, diagnostics: Vec<(DiagnosticKind, Range)>, - extend_immutable_calls: Vec>, + extend_immutable_calls: Vec>, } impl<'a, 'b> Visitor<'b> for ArgumentDefaultVisitor<'b> @@ -84,7 +86,7 @@ fn is_nan_or_infinity(expr: &Expr, args: &[Expr]) -> bool { /// B008 pub fn function_call_argument_default(checker: &mut Checker, arguments: &Arguments) { // Map immutable calls to (module, member) format. - let extend_immutable_calls: Vec> = checker + let extend_immutable_calls: Vec = checker .settings .flake8_bugbear .extend_immutable_calls diff --git a/src/rules/flake8_bugbear/rules/mutable_argument_default.rs b/src/rules/flake8_bugbear/rules/mutable_argument_default.rs index 9d25687134..8990c873ea 100644 --- a/src/rules/flake8_bugbear/rules/mutable_argument_default.rs +++ b/src/rules/flake8_bugbear/rules/mutable_argument_default.rs @@ -58,7 +58,9 @@ const IMMUTABLE_GENERIC_TYPES: &[&[&str]] = &[ pub fn is_mutable_func(checker: &Checker, expr: &Expr) -> bool { checker.resolve_call_path(expr).map_or(false, |call_path| { - MUTABLE_FUNCS.iter().any(|target| call_path == *target) + MUTABLE_FUNCS + .iter() + .any(|target| call_path.as_slice() == *target) }) } @@ -82,25 +84,25 @@ fn is_immutable_annotation(checker: &Checker, expr: &Expr) -> bool { IMMUTABLE_TYPES .iter() .chain(IMMUTABLE_GENERIC_TYPES) - .any(|target| call_path == *target) + .any(|target| call_path.as_slice() == *target) }) } ExprKind::Subscript { value, slice, .. } => { checker.resolve_call_path(value).map_or(false, |call_path| { if IMMUTABLE_GENERIC_TYPES .iter() - .any(|target| call_path == *target) + .any(|target| call_path.as_slice() == *target) { true - } else if call_path == ["typing", "Union"] { + } else if call_path.as_slice() == ["typing", "Union"] { if let ExprKind::Tuple { elts, .. } = &slice.node { elts.iter().all(|elt| is_immutable_annotation(checker, elt)) } else { false } - } else if call_path == ["typing", "Optional"] { + } else if call_path.as_slice() == ["typing", "Optional"] { is_immutable_annotation(checker, slice) - } else if call_path == ["typing", "Annotated"] { + } else if call_path.as_slice() == ["typing", "Annotated"] { if let ExprKind::Tuple { elts, .. } = &slice.node { elts.first() .map_or(false, |elt| is_immutable_annotation(checker, elt)) diff --git a/src/rules/flake8_bugbear/rules/useless_contextlib_suppress.rs b/src/rules/flake8_bugbear/rules/useless_contextlib_suppress.rs index cae52017b8..b7e1e6bfa9 100644 --- a/src/rules/flake8_bugbear/rules/useless_contextlib_suppress.rs +++ b/src/rules/flake8_bugbear/rules/useless_contextlib_suppress.rs @@ -8,9 +8,9 @@ use crate::violations; /// B005 pub fn useless_contextlib_suppress(checker: &mut Checker, expr: &Expr, args: &[Expr]) { if args.is_empty() - && checker - .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["contextlib", "suppress"]) + && checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path.as_slice() == ["contextlib", "suppress"] + }) { checker.diagnostics.push(Diagnostic::new( violations::UselessContextlibSuppress, diff --git a/src/rules/flake8_datetimez/rules.rs b/src/rules/flake8_datetimez/rules.rs index 6016cbb664..48d6bde71a 100644 --- a/src/rules/flake8_datetimez/rules.rs +++ b/src/rules/flake8_datetimez/rules.rs @@ -13,10 +13,9 @@ pub fn call_datetime_without_tzinfo( keywords: &[Keyword], location: Range, ) { - if !checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["datetime", "datetime"]) - { + if !checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["datetime", "datetime"] + }) { return; } @@ -41,7 +40,7 @@ pub fn call_datetime_without_tzinfo( /// DTZ002 pub fn call_datetime_today(checker: &mut Checker, func: &Expr, location: Range) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "datetime", "today"] + call_path.as_slice() == ["datetime", "datetime", "today"] }) { checker .diagnostics @@ -52,7 +51,7 @@ pub fn call_datetime_today(checker: &mut Checker, func: &Expr, location: Range) /// DTZ003 pub fn call_datetime_utcnow(checker: &mut Checker, func: &Expr, location: Range) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "datetime", "utcnow"] + call_path.as_slice() == ["datetime", "datetime", "utcnow"] }) { checker .diagnostics @@ -63,7 +62,7 @@ pub fn call_datetime_utcnow(checker: &mut Checker, func: &Expr, location: Range) /// DTZ004 pub fn call_datetime_utcfromtimestamp(checker: &mut Checker, func: &Expr, location: Range) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "datetime", "utcfromtimestamp"] + call_path.as_slice() == ["datetime", "datetime", "utcfromtimestamp"] }) { checker.diagnostics.push(Diagnostic::new( violations::CallDatetimeUtcfromtimestamp, @@ -81,7 +80,7 @@ pub fn call_datetime_now_without_tzinfo( location: Range, ) { if !checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "datetime", "now"] + call_path.as_slice() == ["datetime", "datetime", "now"] }) { return; } @@ -122,7 +121,7 @@ pub fn call_datetime_fromtimestamp( location: Range, ) { if !checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "datetime", "fromtimestamp"] + call_path.as_slice() == ["datetime", "datetime", "fromtimestamp"] }) { return; } @@ -162,7 +161,7 @@ pub fn call_datetime_strptime_without_zone( location: Range, ) { if !checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "datetime", "strptime"] + call_path.as_slice() == ["datetime", "datetime", "strptime"] }) { return; } @@ -211,7 +210,7 @@ pub fn call_datetime_strptime_without_zone( /// DTZ011 pub fn call_date_today(checker: &mut Checker, func: &Expr, location: Range) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "date", "today"] + call_path.as_slice() == ["datetime", "date", "today"] }) { checker .diagnostics @@ -222,7 +221,7 @@ pub fn call_date_today(checker: &mut Checker, func: &Expr, location: Range) { /// DTZ012 pub fn call_date_fromtimestamp(checker: &mut Checker, func: &Expr, location: Range) { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["datetime", "date", "fromtimestamp"] + call_path.as_slice() == ["datetime", "date", "fromtimestamp"] }) { checker .diagnostics diff --git a/src/rules/flake8_debugger/rules.rs b/src/rules/flake8_debugger/rules.rs index dd8f709e96..03ea1e37ea 100644 --- a/src/rules/flake8_debugger/rules.rs +++ b/src/rules/flake8_debugger/rules.rs @@ -27,13 +27,15 @@ const DEBUGGERS: &[&[&str]] = &[ /// Checks for the presence of a debugger call. pub fn debugger_call(checker: &mut Checker, expr: &Expr, func: &Expr) { - if let Some(call_path) = checker.resolve_call_path(func) { - if DEBUGGERS.iter().any(|target| call_path == *target) { - checker.diagnostics.push(Diagnostic::new( - violations::Debugger(DebuggerUsingType::Call(format_call_path(&call_path))), - Range::from_located(expr), - )); - } + if let Some(target) = checker.resolve_call_path(func).and_then(|call_path| { + DEBUGGERS + .iter() + .find(|target| call_path.as_slice() == **target) + }) { + checker.diagnostics.push(Diagnostic::new( + violations::Debugger(DebuggerUsingType::Call(format_call_path(target))), + Range::from_located(expr), + )); } } diff --git a/src/rules/flake8_pie/rules.rs b/src/rules/flake8_pie/rules.rs index 4def9afc38..799b33e3d8 100644 --- a/src/rules/flake8_pie/rules.rs +++ b/src/rules/flake8_pie/rules.rs @@ -120,7 +120,7 @@ where if !bases.iter().any(|expr| { checker .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["enum", "Enum"]) + .map_or(false, |call_path| call_path.as_slice() == ["enum", "Enum"]) }) { return; } @@ -134,7 +134,7 @@ where if let ExprKind::Call { func, .. } = &value.node { if checker .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["enum", "auto"]) + .map_or(false, |call_path| call_path.as_slice() == ["enum", "auto"]) { continue; } diff --git a/src/rules/flake8_print/rules/print_call.rs b/src/rules/flake8_print/rules/print_call.rs index 8cafe2ae0d..b368c868e9 100644 --- a/src/rules/flake8_print/rules/print_call.rs +++ b/src/rules/flake8_print/rules/print_call.rs @@ -14,7 +14,7 @@ pub fn print_call(checker: &mut Checker, func: &Expr, keywords: &[Keyword]) { let call_path = checker.resolve_call_path(func); if call_path .as_ref() - .map_or(false, |call_path| *call_path == ["", "print"]) + .map_or(false, |call_path| *call_path.as_slice() == ["", "print"]) { // If the print call has a `file=` argument (that isn't `None`, `"sys.stdout"`, // or `"sys.stderr"`), don't trigger T201. @@ -26,7 +26,8 @@ pub fn print_call(checker: &mut Checker, func: &Expr, keywords: &[Keyword]) { if checker .resolve_call_path(&keyword.node.value) .map_or(true, |call_path| { - call_path != ["sys", "stdout"] && call_path != ["sys", "stderr"] + call_path.as_slice() != ["sys", "stdout"] + && call_path.as_slice() != ["sys", "stderr"] }) { return; @@ -34,10 +35,9 @@ pub fn print_call(checker: &mut Checker, func: &Expr, keywords: &[Keyword]) { } } Diagnostic::new(violations::PrintFound, Range::from_located(func)) - } else if call_path - .as_ref() - .map_or(false, |call_path| *call_path == ["pprint", "pprint"]) - { + } else if call_path.as_ref().map_or(false, |call_path| { + *call_path.as_slice() == ["pprint", "pprint"] + }) { Diagnostic::new(violations::PPrintFound, Range::from_located(func)) } else { return; diff --git a/src/rules/flake8_pytest_style/rules/helpers.rs b/src/rules/flake8_pytest_style/rules/helpers.rs index 56ca1fcfca..5290d39efe 100644 --- a/src/rules/flake8_pytest_style/rules/helpers.rs +++ b/src/rules/flake8_pytest_style/rules/helpers.rs @@ -18,15 +18,17 @@ pub fn get_mark_name(decorator: &Expr) -> &str { } pub fn is_pytest_fail(call: &Expr, checker: &Checker) -> bool { - checker - .resolve_call_path(call) - .map_or(false, |call_path| call_path == ["pytest", "fail"]) + checker.resolve_call_path(call).map_or(false, |call_path| { + call_path.as_slice() == ["pytest", "fail"] + }) } pub fn is_pytest_fixture(decorator: &Expr, checker: &Checker) -> bool { checker .resolve_call_path(decorator) - .map_or(false, |call_path| call_path == ["pytest", "fixture"]) + .map_or(false, |call_path| { + call_path.as_slice() == ["pytest", "fixture"] + }) } pub fn is_pytest_mark(decorator: &Expr) -> bool { @@ -41,13 +43,17 @@ pub fn is_pytest_mark(decorator: &Expr) -> bool { pub fn is_pytest_yield_fixture(decorator: &Expr, checker: &Checker) -> bool { checker .resolve_call_path(decorator) - .map_or(false, |call_path| call_path == ["pytest", "yield_fixture"]) + .map_or(false, |call_path| { + call_path.as_slice() == ["pytest", "yield_fixture"] + }) } pub fn is_abstractmethod_decorator(decorator: &Expr, checker: &Checker) -> bool { checker .resolve_call_path(decorator) - .map_or(false, |call_path| call_path == ["abc", "abstractmethod"]) + .map_or(false, |call_path| { + call_path.as_slice() == ["abc", "abstractmethod"] + }) } /// Check if the expression is a constant that evaluates to false. @@ -96,7 +102,7 @@ pub fn is_pytest_parametrize(decorator: &Expr, checker: &Checker) -> bool { checker .resolve_call_path(decorator) .map_or(false, |call_path| { - call_path == ["pytest", "mark", "parametrize"] + call_path.as_slice() == ["pytest", "mark", "parametrize"] }) } diff --git a/src/rules/flake8_pytest_style/rules/raises.rs b/src/rules/flake8_pytest_style/rules/raises.rs index 6468a9121d..4e6338863c 100644 --- a/src/rules/flake8_pytest_style/rules/raises.rs +++ b/src/rules/flake8_pytest_style/rules/raises.rs @@ -8,9 +8,9 @@ use crate::registry::{Diagnostic, RuleCode}; use crate::violations; fn is_pytest_raises(checker: &Checker, func: &Expr) -> bool { - checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["pytest", "raises"]) + checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["pytest", "raises"] + }) } fn is_non_trivial_with_body(body: &[Stmt]) -> bool { @@ -93,7 +93,7 @@ pub fn complex_raises(checker: &mut Checker, stmt: &Stmt, items: &[Withitem], bo /// PT011 fn exception_needs_match(checker: &mut Checker, exception: &Expr) { - if let Some(call_path) = checker.resolve_call_path(exception) { + if let Some(call_path) = checker.resolve_call_path(exception).and_then(|call_path| { let is_broad_exception = checker .settings .flake8_pytest_style @@ -107,10 +107,14 @@ fn exception_needs_match(checker: &mut Checker, exception: &Expr) { ) .any(|target| call_path == to_call_path(target)); if is_broad_exception { - checker.diagnostics.push(Diagnostic::new( - violations::RaisesTooBroad(format_call_path(&call_path)), - Range::from_located(exception), - )); + Some(format_call_path(&call_path)) + } else { + None } + }) { + checker.diagnostics.push(Diagnostic::new( + violations::RaisesTooBroad(call_path), + Range::from_located(exception), + )); } } diff --git a/src/rules/flake8_simplify/rules/ast_expr.rs b/src/rules/flake8_simplify/rules/ast_expr.rs index 32508adf49..b6c05942fd 100644 --- a/src/rules/flake8_simplify/rules/ast_expr.rs +++ b/src/rules/flake8_simplify/rules/ast_expr.rs @@ -17,7 +17,7 @@ pub fn use_capital_environment_variables(checker: &mut Checker, expr: &Expr) { // check `os.environ.get('foo')` and `os.getenv('foo')`` if !checker.resolve_call_path(expr).map_or(false, |call_path| { - call_path == ["os", "environ", "get"] || call_path == ["os", "getenv"] + call_path.as_slice() == ["os", "environ", "get"] || call_path.as_slice() == ["os", "getenv"] }) { return; } diff --git a/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs b/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs index 95f3929ddd..00e2e4b749 100644 --- a/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs +++ b/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs @@ -29,7 +29,7 @@ fn match_async_exit_stack(checker: &Checker) -> bool { for item in items { if let ExprKind::Call { func, .. } = &item.context_expr.node { if checker.resolve_call_path(func).map_or(false, |call_path| { - call_path == ["contextlib", "AsyncExitStack"] + call_path.as_slice() == ["contextlib", "AsyncExitStack"] }) { return true; } @@ -59,10 +59,9 @@ fn match_exit_stack(checker: &Checker) -> bool { if let StmtKind::With { items, .. } = &parent.node { for item in items { if let ExprKind::Call { func, .. } = &item.context_expr.node { - if checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["contextlib", "ExitStack"]) - { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["contextlib", "ExitStack"] + }) { return true; } } @@ -76,7 +75,7 @@ fn match_exit_stack(checker: &Checker) -> bool { pub fn open_file_with_context_handler(checker: &mut Checker, func: &Expr) { if checker .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["", "open"]) + .map_or(false, |call_path| call_path.as_slice() == ["", "open"]) { if checker.is_builtin("open") { // Ex) `with open("foo.txt") as f: ...` diff --git a/src/rules/flake8_tidy_imports/banned_api.rs b/src/rules/flake8_tidy_imports/banned_api.rs index e80f66765d..f25aee476d 100644 --- a/src/rules/flake8_tidy_imports/banned_api.rs +++ b/src/rules/flake8_tidy_imports/banned_api.rs @@ -3,7 +3,7 @@ use rustpython_ast::{Alias, Expr, Located}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use crate::ast::types::Range; +use crate::ast::types::{CallPath, Range}; use crate::checkers::ast::Checker; use crate::define_violation; use crate::registry::Diagnostic; @@ -86,19 +86,21 @@ pub fn name_or_parent_is_banned( /// TID251 pub fn banned_attribute_access(checker: &mut Checker, expr: &Expr) { - if let Some(call_path) = checker.resolve_call_path(expr) { - for (banned_path, ban) in checker.settings.flake8_tidy_imports.banned_api.iter() { - if call_path == banned_path.split('.').collect::>() { - checker.diagnostics.push(Diagnostic::new( - BannedApi { - name: banned_path.to_string(), - message: ban.msg.to_string(), - }, - Range::from_located(expr), - )); - return; - } - } + if let Some((banned_path, ban)) = checker.resolve_call_path(expr).and_then(|call_path| { + checker + .settings + .flake8_tidy_imports + .banned_api + .iter() + .find(|(banned_path, ..)| call_path == banned_path.split('.').collect::()) + }) { + checker.diagnostics.push(Diagnostic::new( + BannedApi { + name: banned_path.to_string(), + message: ban.msg.to_string(), + }, + Range::from_located(expr), + )); } } diff --git a/src/rules/pep8_naming/helpers.rs b/src/rules/pep8_naming/helpers.rs index 58a0df0cd6..38641f8503 100644 --- a/src/rules/pep8_naming/helpers.rs +++ b/src/rules/pep8_naming/helpers.rs @@ -27,7 +27,7 @@ pub fn is_namedtuple_assignment(checker: &Checker, stmt: &Stmt) -> bool { return false; }; checker.resolve_call_path(value).map_or(false, |call_path| { - call_path == ["collections", "namedtuple"] + call_path.as_slice() == ["collections", "namedtuple"] }) } diff --git a/src/rules/pygrep_hooks/rules/deprecated_log_warn.rs b/src/rules/pygrep_hooks/rules/deprecated_log_warn.rs index 63d8ee7d3a..cd41024029 100644 --- a/src/rules/pygrep_hooks/rules/deprecated_log_warn.rs +++ b/src/rules/pygrep_hooks/rules/deprecated_log_warn.rs @@ -7,10 +7,9 @@ use crate::violations; /// PGH002 - deprecated use of logging.warn pub fn deprecated_log_warn(checker: &mut Checker, func: &Expr) { - if checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["logging", "warn"]) - { + if checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["logging", "warn"] + }) { checker.diagnostics.push(Diagnostic::new( violations::DeprecatedLogWarn, Range::from_located(func), diff --git a/src/rules/pyupgrade/rules/convert_named_tuple_functional_to_class.rs b/src/rules/pyupgrade/rules/convert_named_tuple_functional_to_class.rs index 04127d8b77..034148cbb6 100644 --- a/src/rules/pyupgrade/rules/convert_named_tuple_functional_to_class.rs +++ b/src/rules/pyupgrade/rules/convert_named_tuple_functional_to_class.rs @@ -29,10 +29,9 @@ fn match_named_tuple_assign<'a>( } = &value.node else { return None; }; - if !checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["typing", "NamedTuple"]) - { + if !checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["typing", "NamedTuple"] + }) { return None; } Some((typename, args, keywords, func)) diff --git a/src/rules/pyupgrade/rules/convert_typed_dict_functional_to_class.rs b/src/rules/pyupgrade/rules/convert_typed_dict_functional_to_class.rs index 4ff67e3237..6aa1458001 100644 --- a/src/rules/pyupgrade/rules/convert_typed_dict_functional_to_class.rs +++ b/src/rules/pyupgrade/rules/convert_typed_dict_functional_to_class.rs @@ -30,10 +30,9 @@ fn match_typed_dict_assign<'a>( } = &value.node else { return None; }; - if !checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["typing", "TypedDict"]) - { + if !checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["typing", "TypedDict"] + }) { return None; } Some((class_name, args, keywords, func)) diff --git a/src/rules/pyupgrade/rules/datetime_utc_alias.rs b/src/rules/pyupgrade/rules/datetime_utc_alias.rs index 11ace9f879..173921863e 100644 --- a/src/rules/pyupgrade/rules/datetime_utc_alias.rs +++ b/src/rules/pyupgrade/rules/datetime_utc_alias.rs @@ -10,7 +10,7 @@ use crate::violations; /// UP017 pub fn datetime_utc_alias(checker: &mut Checker, expr: &Expr) { if checker.resolve_call_path(expr).map_or(false, |call_path| { - call_path == ["datetime", "timezone", "utc"] + call_path.as_slice() == ["datetime", "timezone", "utc"] }) { let straight_import = collect_call_path(expr) == ["datetime", "timezone", "utc"]; let mut diagnostic = Diagnostic::new( diff --git a/src/rules/pyupgrade/rules/functools_cache.rs b/src/rules/pyupgrade/rules/functools_cache.rs index 5c5f290ed8..1ac047f179 100644 --- a/src/rules/pyupgrade/rules/functools_cache.rs +++ b/src/rules/pyupgrade/rules/functools_cache.rs @@ -22,9 +22,9 @@ pub fn functools_cache(checker: &mut Checker, decorator_list: &[Expr]) { // Look for, e.g., `import functools; @functools.lru_cache(maxsize=None)`. if args.is_empty() && keywords.len() == 1 - && checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["functools", "lru_cache"]) + && checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["functools", "lru_cache"] + }) { let KeywordData { arg, value } = &keywords[0].node; if arg.as_ref().map_or(false, |arg| arg == "maxsize") diff --git a/src/rules/pyupgrade/rules/lru_cache_without_parameters.rs b/src/rules/pyupgrade/rules/lru_cache_without_parameters.rs index b967f99d6f..617e5e8060 100644 --- a/src/rules/pyupgrade/rules/lru_cache_without_parameters.rs +++ b/src/rules/pyupgrade/rules/lru_cache_without_parameters.rs @@ -22,9 +22,9 @@ pub fn lru_cache_without_parameters(checker: &mut Checker, decorator_list: &[Exp // Look for, e.g., `import functools; @functools.lru_cache()`. if args.is_empty() && keywords.is_empty() - && checker - .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["functools", "lru_cache"]) + && checker.resolve_call_path(func).map_or(false, |call_path| { + call_path.as_slice() == ["functools", "lru_cache"] + }) { let mut diagnostic = Diagnostic::new( violations::LRUCacheWithoutParameters, diff --git a/src/rules/pyupgrade/rules/open_alias.rs b/src/rules/pyupgrade/rules/open_alias.rs index 0ebc6948f3..f0e2cfe853 100644 --- a/src/rules/pyupgrade/rules/open_alias.rs +++ b/src/rules/pyupgrade/rules/open_alias.rs @@ -10,7 +10,7 @@ use crate::violations; pub fn open_alias(checker: &mut Checker, expr: &Expr, func: &Expr) { if checker .resolve_call_path(func) - .map_or(false, |call_path| call_path == ["io", "open"]) + .map_or(false, |call_path| call_path.as_slice() == ["io", "open"]) { let mut diagnostic = Diagnostic::new(violations::OpenAlias, Range::from_located(expr)); if checker.patch(&RuleCode::UP020) { diff --git a/src/rules/pyupgrade/rules/os_error_alias.rs b/src/rules/pyupgrade/rules/os_error_alias.rs index bc33b89326..42bf035d8e 100644 --- a/src/rules/pyupgrade/rules/os_error_alias.rs +++ b/src/rules/pyupgrade/rules/os_error_alias.rs @@ -39,7 +39,7 @@ fn check_module(checker: &Checker, expr: &Expr) -> (Vec, Vec) { let mut before_replace: Vec = vec![]; if let Some(call_path) = checker.resolve_call_path(expr) { for module in ERROR_MODULES.iter() { - if call_path == [module, "error"] { + if call_path.as_slice() == [module, "error"] { replacements.push("OSError".to_string()); before_replace.push(format!("{module}.error")); break; diff --git a/src/rules/pyupgrade/rules/replace_stdout_stderr.rs b/src/rules/pyupgrade/rules/replace_stdout_stderr.rs index c081b28dab..9e8444ed21 100644 --- a/src/rules/pyupgrade/rules/replace_stdout_stderr.rs +++ b/src/rules/pyupgrade/rules/replace_stdout_stderr.rs @@ -80,10 +80,9 @@ fn generate_fix(locator: &Locator, stdout: &Keyword, stderr: &Keyword) -> Option /// UP022 pub fn replace_stdout_stderr(checker: &mut Checker, expr: &Expr, kwargs: &[Keyword]) { - if checker - .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["subprocess", "run"]) - { + if checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path.as_slice() == ["subprocess", "run"] + }) { // Find `stdout` and `stderr` kwargs. let Some(stdout) = find_keyword(kwargs, "stdout") else { return; @@ -95,10 +94,14 @@ pub fn replace_stdout_stderr(checker: &mut Checker, expr: &Expr, kwargs: &[Keywo // Verify that they're both set to `subprocess.PIPE`. if !checker .resolve_call_path(&stdout.node.value) - .map_or(false, |call_path| call_path == ["subprocess", "PIPE"]) + .map_or(false, |call_path| { + call_path.as_slice() == ["subprocess", "PIPE"] + }) || !checker .resolve_call_path(&stderr.node.value) - .map_or(false, |call_path| call_path == ["subprocess", "PIPE"]) + .map_or(false, |call_path| { + call_path.as_slice() == ["subprocess", "PIPE"] + }) { return; } diff --git a/src/rules/pyupgrade/rules/replace_universal_newlines.rs b/src/rules/pyupgrade/rules/replace_universal_newlines.rs index a65d03fb44..4f99597a6a 100644 --- a/src/rules/pyupgrade/rules/replace_universal_newlines.rs +++ b/src/rules/pyupgrade/rules/replace_universal_newlines.rs @@ -9,10 +9,9 @@ use crate::violations; /// UP021 pub fn replace_universal_newlines(checker: &mut Checker, expr: &Expr, kwargs: &[Keyword]) { - if checker - .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["subprocess", "run"]) - { + if checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path.as_slice() == ["subprocess", "run"] + }) { let Some(kwarg) = find_keyword(kwargs, "universal_newlines") else { return; }; let range = Range::new( kwarg.location, diff --git a/src/rules/pyupgrade/rules/typing_text_str_alias.rs b/src/rules/pyupgrade/rules/typing_text_str_alias.rs index 3e61dfb4ef..b4804d4628 100644 --- a/src/rules/pyupgrade/rules/typing_text_str_alias.rs +++ b/src/rules/pyupgrade/rules/typing_text_str_alias.rs @@ -8,10 +8,9 @@ use crate::violations; /// UP019 pub fn typing_text_str_alias(checker: &mut Checker, expr: &Expr) { - if checker - .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["typing", "Text"]) - { + if checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path.as_slice() == ["typing", "Text"] + }) { let mut diagnostic = Diagnostic::new(violations::TypingTextStrAlias, Range::from_located(expr)); if checker.patch(diagnostic.kind.code()) { diff --git a/src/rules/pyupgrade/rules/use_pep585_annotation.rs b/src/rules/pyupgrade/rules/use_pep585_annotation.rs index 58f05a3402..afb2856519 100644 --- a/src/rules/pyupgrade/rules/use_pep585_annotation.rs +++ b/src/rules/pyupgrade/rules/use_pep585_annotation.rs @@ -8,14 +8,17 @@ use crate::violations; /// UP006 pub fn use_pep585_annotation(checker: &mut Checker, expr: &Expr) { - if let Some(call_path) = checker.resolve_call_path(expr) { + if let Some(binding) = checker + .resolve_call_path(expr) + .and_then(|call_path| call_path.last().copied()) + { let mut diagnostic = Diagnostic::new( - violations::UsePEP585Annotation(call_path[call_path.len() - 1].to_string()), + violations::UsePEP585Annotation(binding.to_string()), Range::from_located(expr), ); if checker.patch(diagnostic.kind.code()) { diagnostic.amend(Fix::replacement( - call_path[call_path.len() - 1].to_lowercase(), + binding.to_lowercase(), expr.location, expr.end_location.unwrap(), )); diff --git a/src/rules/pyupgrade/rules/use_pep604_annotation.rs b/src/rules/pyupgrade/rules/use_pep604_annotation.rs index dc44968761..a572e8f55a 100644 --- a/src/rules/pyupgrade/rules/use_pep604_annotation.rs +++ b/src/rules/pyupgrade/rules/use_pep604_annotation.rs @@ -54,6 +54,11 @@ fn any_arg_is_str(slice: &Expr) -> bool { } } +enum TypingMember { + Union, + Optional, +} + /// UP007 pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, slice: &Expr) { // Avoid rewriting forward annotations. @@ -61,53 +66,63 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s return; } - let call_path = checker.resolve_call_path(value); - if call_path.as_ref().map_or(false, |call_path| { - checker.match_typing_call_path(call_path, "Optional") - }) { - let mut diagnostic = - Diagnostic::new(violations::UsePEP604Annotation, Range::from_located(expr)); - if checker.patch(diagnostic.kind.code()) { - let mut generator: Generator = checker.stylist.into(); - generator.unparse_expr(&optional(slice), 0); - diagnostic.amend(Fix::replacement( - generator.generate(), - expr.location, - expr.end_location.unwrap(), - )); + let Some(typing_member) = checker.resolve_call_path(value).as_ref().and_then(|call_path| { + if checker.match_typing_call_path(call_path, "Optional") { + Some(TypingMember::Optional) + } else if checker.match_typing_call_path(call_path, "Union") { + Some(TypingMember::Union) + } else { + None } - checker.diagnostics.push(diagnostic); - } else if call_path.as_ref().map_or(false, |call_path| { - checker.match_typing_call_path(call_path, "Union") - }) { - let mut diagnostic = - Diagnostic::new(violations::UsePEP604Annotation, Range::from_located(expr)); - if checker.patch(diagnostic.kind.code()) { - match &slice.node { - ExprKind::Slice { .. } => { - // Invalid type annotation. - } - ExprKind::Tuple { elts, .. } => { - let mut generator: Generator = checker.stylist.into(); - generator.unparse_expr(&union(elts), 0); - diagnostic.amend(Fix::replacement( - generator.generate(), - expr.location, - expr.end_location.unwrap(), - )); - } - _ => { - // Single argument. - let mut generator: Generator = checker.stylist.into(); - generator.unparse_expr(slice, 0); - diagnostic.amend(Fix::replacement( - generator.generate(), - expr.location, - expr.end_location.unwrap(), - )); + }) else { + return; + }; + + match typing_member { + TypingMember::Optional => { + let mut diagnostic = + Diagnostic::new(violations::UsePEP604Annotation, Range::from_located(expr)); + if checker.patch(diagnostic.kind.code()) { + let mut generator: Generator = checker.stylist.into(); + generator.unparse_expr(&optional(slice), 0); + diagnostic.amend(Fix::replacement( + generator.generate(), + expr.location, + expr.end_location.unwrap(), + )); + } + checker.diagnostics.push(diagnostic); + } + TypingMember::Union => { + let mut diagnostic = + Diagnostic::new(violations::UsePEP604Annotation, Range::from_located(expr)); + if checker.patch(diagnostic.kind.code()) { + match &slice.node { + ExprKind::Slice { .. } => { + // Invalid type annotation. + } + ExprKind::Tuple { elts, .. } => { + let mut generator: Generator = checker.stylist.into(); + generator.unparse_expr(&union(elts), 0); + diagnostic.amend(Fix::replacement( + generator.generate(), + expr.location, + expr.end_location.unwrap(), + )); + } + _ => { + // Single argument. + let mut generator: Generator = checker.stylist.into(); + generator.unparse_expr(slice, 0); + diagnostic.amend(Fix::replacement( + generator.generate(), + expr.location, + expr.end_location.unwrap(), + )); + } } } + checker.diagnostics.push(diagnostic); } - checker.diagnostics.push(diagnostic); } } diff --git a/src/visibility.rs b/src/visibility.rs index d224e3a809..b266c1ef89 100644 --- a/src/visibility.rs +++ b/src/visibility.rs @@ -31,18 +31,18 @@ pub struct VisibleScope { /// Returns `true` if a function is a "static method". pub fn is_staticmethod(checker: &Checker, decorator_list: &[Expr]) -> bool { decorator_list.iter().any(|expr| { - checker - .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["", "staticmethod"]) + checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path.as_slice() == ["", "staticmethod"] + }) }) } /// Returns `true` if a function is a "class method". pub fn is_classmethod(checker: &Checker, decorator_list: &[Expr]) -> bool { decorator_list.iter().any(|expr| { - checker - .resolve_call_path(expr) - .map_or(false, |call_path| call_path == ["", "classmethod"]) + checker.resolve_call_path(expr).map_or(false, |call_path| { + call_path.as_slice() == ["", "classmethod"] + }) }) } @@ -64,7 +64,8 @@ pub fn is_override(checker: &Checker, decorator_list: &[Expr]) -> bool { pub fn is_abstract(checker: &Checker, decorator_list: &[Expr]) -> bool { decorator_list.iter().any(|expr| { checker.resolve_call_path(expr).map_or(false, |call_path| { - call_path == ["abc", "abstractmethod"] || call_path == ["abc", "abstractproperty"] + call_path.as_slice() == ["abc", "abstractmethod"] + || call_path.as_slice() == ["abc", "abstractproperty"] }) }) }