From d0ad1ed0af15bf71664ae88e114c39aeee58767a Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Fri, 16 Jun 2023 13:34:42 -0400 Subject: [PATCH] Replace static `CallPath` vectors with `matches!` macros (#5148) ## Summary After #5140, I audited the codebase for similar patterns (defining a list of `CallPath` entities in a static vector, then looping over them to pattern-match). This PR migrates all other such cases to use `match` and `matches!` where possible. There are a few benefits to this: 1. It more clearly denotes the intended semantics (branches are exclusive). 2. The compiler can help deduplicate the patterns and detect unreachable branches. 3. Performance: in the benchmark below, the all-rules performance is increased by nearly 10%... ## Benchmarks I decided to benchmark against a large file in the Airflow repository with a lot of type annotations ([`views.py`](https://raw.githubusercontent.com/apache/airflow/f03f73100e8a7d6019249889de567cb00e71e457/airflow/www/views.py)): ``` linter/default-rules/airflow/views.py time: [10.871 ms 10.882 ms 10.894 ms] thrpt: [19.739 MiB/s 19.761 MiB/s 19.781 MiB/s] change: time: [-2.7182% -2.5687% -2.4204%] (p = 0.00 < 0.05) thrpt: [+2.4805% +2.6364% +2.7942%] Performance has improved. linter/all-rules/airflow/views.py time: [24.021 ms 24.038 ms 24.062 ms] thrpt: [8.9373 MiB/s 8.9461 MiB/s 8.9527 MiB/s] change: time: [-8.9537% -8.8516% -8.7527%] (p = 0.00 < 0.05) thrpt: [+9.5923% +9.7112% +9.8342%] Performance has improved. Found 12 outliers among 100 measurements (12.00%) 5 (5.00%) high mild 7 (7.00%) high severe ``` The impact is dramatic -- nearly a 10% improvement for `all-rules`. --- .../flake8_annotations/rules/definition.rs | 17 +- .../flake8_async/rules/blocking_http_call.rs | 51 +- .../flake8_async/rules/blocking_os_call.rs | 45 +- .../rules/open_sleep_or_subprocess_call.rs | 48 +- .../rules/function_call_argument_default.rs | 3 +- .../rules/mutable_argument_default.rs | 35 +- .../rules/flake8_debugger/rules/debugger.rs | 89 +-- .../rules/flake8_pyi/rules/simple_defaults.rs | 66 +- .../src/rules/flake8_return/rules/function.rs | 33 +- .../function_call_in_dataclass_default.rs | 4 +- crates/ruff/src/rules/ruff/rules/helpers.rs | 23 +- .../rules/ruff/rules/mutable_class_default.rs | 8 +- .../ruff/rules/mutable_dataclass_default.rs | 6 +- .../src/analyze/typing.rs | 116 ++- crates/ruff_python_semantic/src/model.rs | 4 +- crates/ruff_python_stdlib/src/typing.rs | 677 +++++++++++------- 16 files changed, 641 insertions(+), 584 deletions(-) diff --git a/crates/ruff/src/rules/flake8_annotations/rules/definition.rs b/crates/ruff/src/rules/flake8_annotations/rules/definition.rs index e839a092f7..22d63d89a3 100644 --- a/crates/ruff/src/rules/flake8_annotations/rules/definition.rs +++ b/crates/ruff/src/rules/flake8_annotations/rules/definition.rs @@ -1,6 +1,6 @@ use rustpython_parser::ast::{Expr, Ranged, Stmt}; -use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Violation}; +use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::cast; use ruff_python_ast::helpers::ReturnStatementVisitor; @@ -8,7 +8,7 @@ use ruff_python_ast::identifier::Identifier; use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_semantic::analyze::visibility; use ruff_python_semantic::{Definition, Member, MemberKind, SemanticModel}; -use ruff_python_stdlib::typing::SIMPLE_MAGIC_RETURN_TYPES; +use ruff_python_stdlib::typing::simple_magic_return_type; use crate::checkers::ast::Checker; use crate::registry::{AsRule, Rule}; @@ -667,9 +667,9 @@ pub(crate) fn definition( stmt.identifier(checker.locator), ); if checker.patch(diagnostic.kind.rule()) { - #[allow(deprecated)] - diagnostic.try_set_fix_from_edit(|| { + diagnostic.try_set_fix(|| { fixes::add_return_annotation(checker.locator, stmt, "None") + .map(Fix::suggested) }); } diagnostics.push(diagnostic); @@ -683,12 +683,11 @@ pub(crate) fn definition( }, stmt.identifier(checker.locator), ); - let return_type = SIMPLE_MAGIC_RETURN_TYPES.get(name); - if let Some(return_type) = return_type { - if checker.patch(diagnostic.kind.rule()) { - #[allow(deprecated)] - diagnostic.try_set_fix_from_edit(|| { + if checker.patch(diagnostic.kind.rule()) { + if let Some(return_type) = simple_magic_return_type(name) { + diagnostic.try_set_fix(|| { fixes::add_return_annotation(checker.locator, stmt, return_type) + .map(Fix::suggested) }); } } diff --git a/crates/ruff/src/rules/flake8_async/rules/blocking_http_call.rs b/crates/ruff/src/rules/flake8_async/rules/blocking_http_call.rs index dbd00a3aaf..2ab4e94f4d 100644 --- a/crates/ruff/src/rules/flake8_async/rules/blocking_http_call.rs +++ b/crates/ruff/src/rules/flake8_async/rules/blocking_http_call.rs @@ -3,6 +3,7 @@ use rustpython_parser::ast::{Expr, Ranged}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::call_path::CallPath; use crate::checkers::ast::Checker; @@ -40,37 +41,35 @@ impl Violation for BlockingHttpCallInAsyncFunction { } } -const BLOCKING_HTTP_CALLS: &[&[&str]] = &[ - &["urllib", "request", "urlopen"], - &["httpx", "get"], - &["httpx", "post"], - &["httpx", "delete"], - &["httpx", "patch"], - &["httpx", "put"], - &["httpx", "head"], - &["httpx", "connect"], - &["httpx", "options"], - &["httpx", "trace"], - &["requests", "get"], - &["requests", "post"], - &["requests", "delete"], - &["requests", "patch"], - &["requests", "put"], - &["requests", "head"], - &["requests", "connect"], - &["requests", "options"], - &["requests", "trace"], -]; +fn is_blocking_http_call(call_path: &CallPath) -> bool { + matches!( + call_path.as_slice(), + ["urllib", "request", "urlopen"] + | [ + "httpx" | "requests", + "get" + | "post" + | "delete" + | "patch" + | "put" + | "head" + | "connect" + | "options" + | "trace" + ] + ) +} /// ASYNC100 pub(crate) fn blocking_http_call(checker: &mut Checker, expr: &Expr) { if checker.semantic().in_async_context() { if let Expr::Call(ast::ExprCall { func, .. }) = expr { - let call_path = checker.semantic().resolve_call_path(func); - let is_blocking = - call_path.map_or(false, |path| BLOCKING_HTTP_CALLS.contains(&path.as_slice())); - - if is_blocking { + if checker + .semantic() + .resolve_call_path(func) + .as_ref() + .map_or(false, is_blocking_http_call) + { checker.diagnostics.push(Diagnostic::new( BlockingHttpCallInAsyncFunction, func.range(), diff --git a/crates/ruff/src/rules/flake8_async/rules/blocking_os_call.rs b/crates/ruff/src/rules/flake8_async/rules/blocking_os_call.rs index 861689bb3d..c08dece6f4 100644 --- a/crates/ruff/src/rules/flake8_async/rules/blocking_os_call.rs +++ b/crates/ruff/src/rules/flake8_async/rules/blocking_os_call.rs @@ -3,6 +3,7 @@ use rustpython_parser::ast::{Expr, Ranged}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::call_path::CallPath; use crate::checkers::ast::Checker; @@ -39,31 +40,16 @@ impl Violation for BlockingOsCallInAsyncFunction { } } -const UNSAFE_OS_METHODS: &[&[&str]] = &[ - &["os", "popen"], - &["os", "posix_spawn"], - &["os", "posix_spawnp"], - &["os", "spawnl"], - &["os", "spawnle"], - &["os", "spawnlp"], - &["os", "spawnlpe"], - &["os", "spawnv"], - &["os", "spawnve"], - &["os", "spawnvp"], - &["os", "spawnvpe"], - &["os", "system"], -]; - /// ASYNC102 pub(crate) fn blocking_os_call(checker: &mut Checker, expr: &Expr) { if checker.semantic().in_async_context() { if let Expr::Call(ast::ExprCall { func, .. }) = expr { - let is_unsafe_os_method = checker + if checker .semantic() .resolve_call_path(func) - .map_or(false, |path| UNSAFE_OS_METHODS.contains(&path.as_slice())); - - if is_unsafe_os_method { + .as_ref() + .map_or(false, is_unsafe_os_method) + { checker .diagnostics .push(Diagnostic::new(BlockingOsCallInAsyncFunction, func.range())); @@ -71,3 +57,24 @@ pub(crate) fn blocking_os_call(checker: &mut Checker, expr: &Expr) { } } } + +fn is_unsafe_os_method(call_path: &CallPath) -> bool { + matches!( + call_path.as_slice(), + [ + "os", + "popen" + | "posix_spawn" + | "posix_spawnp" + | "spawnl" + | "spawnle" + | "spawnlp" + | "spawnlpe" + | "spawnv" + | "spawnve" + | "spawnvp" + | "spawnvpe" + | "system" + ] + ) +} diff --git a/crates/ruff/src/rules/flake8_async/rules/open_sleep_or_subprocess_call.rs b/crates/ruff/src/rules/flake8_async/rules/open_sleep_or_subprocess_call.rs index c0d370da82..0d1f813ec6 100644 --- a/crates/ruff/src/rules/flake8_async/rules/open_sleep_or_subprocess_call.rs +++ b/crates/ruff/src/rules/flake8_async/rules/open_sleep_or_subprocess_call.rs @@ -3,6 +3,7 @@ use rustpython_parser::ast::{Expr, Ranged}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::call_path::CallPath; use crate::checkers::ast::Checker; @@ -39,36 +40,16 @@ impl Violation for OpenSleepOrSubprocessInAsyncFunction { } } -const OPEN_SLEEP_OR_SUBPROCESS_CALL: &[&[&str]] = &[ - &["", "open"], - &["time", "sleep"], - &["subprocess", "run"], - &["subprocess", "Popen"], - // Deprecated subprocess calls: - &["subprocess", "call"], - &["subprocess", "check_call"], - &["subprocess", "check_output"], - &["subprocess", "getoutput"], - &["subprocess", "getstatusoutput"], - &["os", "wait"], - &["os", "wait3"], - &["os", "wait4"], - &["os", "waitid"], - &["os", "waitpid"], -]; - /// ASYNC101 pub(crate) fn open_sleep_or_subprocess_call(checker: &mut Checker, expr: &Expr) { if checker.semantic().in_async_context() { if let Expr::Call(ast::ExprCall { func, .. }) = expr { - let is_open_sleep_or_subprocess_call = checker + if checker .semantic() .resolve_call_path(func) - .map_or(false, |path| { - OPEN_SLEEP_OR_SUBPROCESS_CALL.contains(&path.as_slice()) - }); - - if is_open_sleep_or_subprocess_call { + .as_ref() + .map_or(false, is_open_sleep_or_subprocess_call) + { checker.diagnostics.push(Diagnostic::new( OpenSleepOrSubprocessInAsyncFunction, func.range(), @@ -77,3 +58,22 @@ pub(crate) fn open_sleep_or_subprocess_call(checker: &mut Checker, expr: &Expr) } } } + +fn is_open_sleep_or_subprocess_call(call_path: &CallPath) -> bool { + matches!( + call_path.as_slice(), + ["", "open"] + | ["time", "sleep"] + | [ + "subprocess", + "run" + | "Popen" + | "call" + | "check_call" + | "check_output" + | "getoutput" + | "getstatusoutput" + ] + | ["os", "wait" | "wait3" | "wait4" | "waitid" | "waitpid"] + ) +} diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/function_call_argument_default.rs b/crates/ruff/src/rules/flake8_bugbear/rules/function_call_argument_default.rs index 0443f328b8..6d0fcc4e21 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/function_call_argument_default.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/function_call_argument_default.rs @@ -7,11 +7,10 @@ use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::call_path::{compose_call_path, from_qualified_name, CallPath}; use ruff_python_ast::visitor; use ruff_python_ast::visitor::Visitor; -use ruff_python_semantic::analyze::typing::is_immutable_func; +use ruff_python_semantic::analyze::typing::{is_immutable_func, is_mutable_func}; use ruff_python_semantic::SemanticModel; use crate::checkers::ast::Checker; -use crate::rules::flake8_bugbear::rules::mutable_argument_default::is_mutable_func; /// ## What it does /// Checks for function calls in default function arguments. diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/mutable_argument_default.rs b/crates/ruff/src/rules/flake8_bugbear/rules/mutable_argument_default.rs index 064d1516b4..986aca83d4 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/mutable_argument_default.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/mutable_argument_default.rs @@ -1,9 +1,8 @@ -use rustpython_parser::ast::{self, Arguments, Expr, Ranged}; +use rustpython_parser::ast::{Arguments, Ranged}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_semantic::analyze::typing::is_immutable_annotation; -use ruff_python_semantic::SemanticModel; +use ruff_python_semantic::analyze::typing::{is_immutable_annotation, is_mutable_expr}; use crate::checkers::ast::Checker; @@ -16,36 +15,6 @@ impl Violation for MutableArgumentDefault { format!("Do not use mutable data structures for argument defaults") } } -const MUTABLE_FUNCS: &[&[&str]] = &[ - &["", "dict"], - &["", "list"], - &["", "set"], - &["collections", "Counter"], - &["collections", "OrderedDict"], - &["collections", "defaultdict"], - &["collections", "deque"], -]; - -pub(crate) fn is_mutable_func(func: &Expr, semantic: &SemanticModel) -> bool { - semantic.resolve_call_path(func).map_or(false, |call_path| { - MUTABLE_FUNCS - .iter() - .any(|target| call_path.as_slice() == *target) - }) -} - -fn is_mutable_expr(expr: &Expr, semantic: &SemanticModel) -> bool { - match expr { - Expr::List(_) - | Expr::Dict(_) - | Expr::Set(_) - | Expr::ListComp(_) - | Expr::DictComp(_) - | Expr::SetComp(_) => true, - Expr::Call(ast::ExprCall { func, .. }) => is_mutable_func(func, semantic), - _ => false, - } -} /// B006 pub(crate) fn mutable_argument_default(checker: &mut Checker, arguments: &Arguments) { diff --git a/crates/ruff/src/rules/flake8_debugger/rules/debugger.rs b/crates/ruff/src/rules/flake8_debugger/rules/debugger.rs index e0552c1bcd..94c151c95b 100644 --- a/crates/ruff/src/rules/flake8_debugger/rules/debugger.rs +++ b/crates/ruff/src/rules/flake8_debugger/rules/debugger.rs @@ -23,59 +23,32 @@ impl Violation for Debugger { } } -const DEBUGGERS: &[&[&str]] = &[ - &["pdb", "set_trace"], - &["pudb", "set_trace"], - &["ipdb", "set_trace"], - &["ipdb", "sset_trace"], - &["IPython", "terminal", "embed", "InteractiveShellEmbed"], - &[ - "IPython", - "frontend", - "terminal", - "embed", - "InteractiveShellEmbed", - ], - &["celery", "contrib", "rdb", "set_trace"], - &["builtins", "breakpoint"], - &["", "breakpoint"], -]; - /// Checks for the presence of a debugger call. pub(crate) fn debugger_call(checker: &mut Checker, expr: &Expr, func: &Expr) { - if let Some(target) = checker + if let Some(using_type) = checker .semantic() .resolve_call_path(func) .and_then(|call_path| { - DEBUGGERS - .iter() - .find(|target| call_path.as_slice() == **target) + if is_debugger_call(&call_path) { + Some(DebuggerUsingType::Call(format_call_path(&call_path))) + } else { + None + } }) { - checker.diagnostics.push(Diagnostic::new( - Debugger { - using_type: DebuggerUsingType::Call(format_call_path(target)), - }, - expr.range(), - )); + checker + .diagnostics + .push(Diagnostic::new(Debugger { using_type }, expr.range())); } } /// Checks for the presence of a debugger import. pub(crate) fn debugger_import(stmt: &Stmt, module: Option<&str>, name: &str) -> Option { - // Special-case: allow `import builtins`, which is far more general than (e.g.) - // `import celery.contrib.rdb`). - if module.is_none() && name == "builtins" { - return None; - } - if let Some(module) = module { let mut call_path: CallPath = from_unqualified_name(module); call_path.push(name); - if DEBUGGERS - .iter() - .any(|target| call_path.as_slice() == *target) - { + + if is_debugger_call(&call_path) { return Some(Diagnostic::new( Debugger { using_type: DebuggerUsingType::Import(format_call_path(&call_path)), @@ -84,11 +57,9 @@ pub(crate) fn debugger_import(stmt: &Stmt, module: Option<&str>, name: &str) -> )); } } else { - let parts: CallPath = from_unqualified_name(name); - if DEBUGGERS - .iter() - .any(|call_path| &call_path[..call_path.len() - 1] == parts.as_slice()) - { + let call_path: CallPath = from_unqualified_name(name); + + if is_debugger_import(&call_path) { return Some(Diagnostic::new( Debugger { using_type: DebuggerUsingType::Import(name.to_string()), @@ -99,3 +70,35 @@ pub(crate) fn debugger_import(stmt: &Stmt, module: Option<&str>, name: &str) -> } None } + +fn is_debugger_call(call_path: &CallPath) -> bool { + matches!( + call_path.as_slice(), + ["pdb" | "pudb" | "ipdb", "set_trace"] + | ["ipdb", "sset_trace"] + | ["IPython", "terminal", "embed", "InteractiveShellEmbed"] + | [ + "IPython", + "frontend", + "terminal", + "embed", + "InteractiveShellEmbed" + ] + | ["celery", "contrib", "rdb", "set_trace"] + | ["builtins" | "", "breakpoint"] + ) +} + +fn is_debugger_import(call_path: &CallPath) -> bool { + // Constructed by taking every pattern in `is_debugger_call`, removing the last element in + // each pattern, and de-duplicating the values. + // As a special-case, we omit `builtins` to allow `import builtins`, which is far more general + // than (e.g.) `import celery.contrib.rdb`. + matches!( + call_path.as_slice(), + ["pdb" | "pudb" | "ipdb"] + | ["IPython", "terminal", "embed"] + | ["IPython", "frontend", "terminal", "embed",] + | ["celery", "contrib", "rdb"] + ) +} diff --git a/crates/ruff/src/rules/flake8_pyi/rules/simple_defaults.rs b/crates/ruff/src/rules/flake8_pyi/rules/simple_defaults.rs index c7299a7656..14df881ff0 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/simple_defaults.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/simple_defaults.rs @@ -2,6 +2,7 @@ use rustpython_parser::ast::{self, Arguments, Constant, Expr, Operator, Ranged, use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::call_path::CallPath; use ruff_python_ast::source_code::Locator; use ruff_python_semantic::{ScopeKind, SemanticModel}; @@ -94,30 +95,33 @@ impl Violation for UnassignedSpecialVariableInStub { } } -const ALLOWED_MATH_ATTRIBUTES_IN_DEFAULTS: &[&[&str]] = &[ - &["math", "inf"], - &["math", "nan"], - &["math", "e"], - &["math", "pi"], - &["math", "tau"], -]; +fn is_allowed_negated_math_attribute(call_path: &CallPath) -> bool { + matches!(call_path.as_slice(), ["math", "inf" | "e" | "pi" | "tau"]) +} -const ALLOWED_ATTRIBUTES_IN_DEFAULTS: &[&[&str]] = &[ - &["sys", "stdin"], - &["sys", "stdout"], - &["sys", "stderr"], - &["sys", "version"], - &["sys", "version_info"], - &["sys", "platform"], - &["sys", "executable"], - &["sys", "prefix"], - &["sys", "exec_prefix"], - &["sys", "base_prefix"], - &["sys", "byteorder"], - &["sys", "maxsize"], - &["sys", "hexversion"], - &["sys", "winver"], -]; +fn is_allowed_math_attribute(call_path: &CallPath) -> bool { + matches!( + call_path.as_slice(), + ["math", "inf" | "nan" | "e" | "pi" | "tau"] + | [ + "sys", + "stdin" + | "stdout" + | "stderr" + | "version" + | "version_info" + | "platform" + | "executable" + | "prefix" + | "exec_prefix" + | "base_prefix" + | "byteorder" + | "maxsize" + | "hexversion" + | "winver" + ] + ) +} fn is_valid_default_value_with_annotation( default: &Expr, @@ -166,12 +170,8 @@ fn is_valid_default_value_with_annotation( Expr::Attribute(_) => { if semantic .resolve_call_path(operand) - .map_or(false, |call_path| { - ALLOWED_MATH_ATTRIBUTES_IN_DEFAULTS.iter().any(|target| { - // reject `-math.nan` - call_path.as_slice() == *target && *target != ["math", "nan"] - }) - }) + .as_ref() + .map_or(false, is_allowed_negated_math_attribute) { return true; } @@ -219,12 +219,8 @@ fn is_valid_default_value_with_annotation( Expr::Attribute(_) => { if semantic .resolve_call_path(default) - .map_or(false, |call_path| { - ALLOWED_MATH_ATTRIBUTES_IN_DEFAULTS - .iter() - .chain(ALLOWED_ATTRIBUTES_IN_DEFAULTS.iter()) - .any(|target| call_path.as_slice() == *target) - }) + .as_ref() + .map_or(false, is_allowed_math_attribute) { return true; } diff --git a/crates/ruff/src/rules/flake8_return/rules/function.rs b/crates/ruff/src/rules/flake8_return/rules/function.rs index cae9574f5b..1fb986ca2b 100644 --- a/crates/ruff/src/rules/flake8_return/rules/function.rs +++ b/crates/ruff/src/rules/flake8_return/rules/function.rs @@ -370,34 +370,17 @@ fn implicit_return_value(checker: &mut Checker, stack: &Stack) { } } -const NORETURN_FUNCS: &[&[&str]] = &[ - // builtins - &["", "exit"], - &["", "quit"], - // stdlib - &["builtins", "exit"], - &["builtins", "quit"], - &["os", "_exit"], - &["os", "abort"], - &["posix", "_exit"], - &["posix", "abort"], - &["sys", "exit"], - &["_thread", "exit"], - &["_winapi", "ExitProcess"], - // third-party modules - &["pytest", "exit"], - &["pytest", "fail"], - &["pytest", "skip"], - &["pytest", "xfail"], -]; - /// Return `true` if the `func` is a known function that never returns. fn is_noreturn_func(func: &Expr, semantic: &SemanticModel) -> bool { semantic.resolve_call_path(func).map_or(false, |call_path| { - NORETURN_FUNCS - .iter() - .any(|target| call_path.as_slice() == *target) - || semantic.match_typing_call_path(&call_path, "assert_never") + matches!( + call_path.as_slice(), + ["" | "builtins" | "sys" | "_thread" | "pytest", "exit"] + | ["" | "builtins", "quit"] + | ["os" | "posix", "_exit" | "abort"] + | ["_winapi", "ExitProcess"] + | ["pytest", "fail" | "skip" | "xfail"] + ) || semantic.match_typing_call_path(&call_path, "assert_never") }) } diff --git a/crates/ruff/src/rules/ruff/rules/function_call_in_dataclass_default.rs b/crates/ruff/src/rules/ruff/rules/function_call_in_dataclass_default.rs index a41be00b35..f2bea235a7 100644 --- a/crates/ruff/src/rules/ruff/rules/function_call_in_dataclass_default.rs +++ b/crates/ruff/src/rules/ruff/rules/function_call_in_dataclass_default.rs @@ -8,7 +8,7 @@ use ruff_python_semantic::analyze::typing::is_immutable_func; use crate::checkers::ast::Checker; use crate::rules::ruff::rules::helpers::{ - is_allowed_dataclass_function, is_class_var_annotation, is_dataclass, + is_class_var_annotation, is_dataclass, is_dataclass_field, }; /// ## What it does @@ -97,7 +97,7 @@ pub(crate) fn function_call_in_dataclass_default( if let Expr::Call(ast::ExprCall { func, .. }) = expr.as_ref() { if !is_class_var_annotation(annotation, checker.semantic()) && !is_immutable_func(func, checker.semantic(), &extend_immutable_calls) - && !is_allowed_dataclass_function(func, checker.semantic()) + && !is_dataclass_field(func, checker.semantic()) { checker.diagnostics.push(Diagnostic::new( FunctionCallInDataclassDefaultArgument { diff --git a/crates/ruff/src/rules/ruff/rules/helpers.rs b/crates/ruff/src/rules/ruff/rules/helpers.rs index 65763b7bdc..b5d09012e5 100644 --- a/crates/ruff/src/rules/ruff/rules/helpers.rs +++ b/crates/ruff/src/rules/ruff/rules/helpers.rs @@ -1,27 +1,12 @@ -use ruff_python_ast::helpers::map_callable; use rustpython_parser::ast::{self, Expr}; +use ruff_python_ast::helpers::map_callable; use ruff_python_semantic::SemanticModel; -pub(super) fn is_mutable_expr(expr: &Expr) -> bool { - matches!( - expr, - Expr::List(_) - | Expr::Dict(_) - | Expr::Set(_) - | Expr::ListComp(_) - | Expr::DictComp(_) - | Expr::SetComp(_) - ) -} - -const ALLOWED_DATACLASS_SPECIFIC_FUNCTIONS: &[&[&str]] = &[&["dataclasses", "field"]]; - -pub(super) fn is_allowed_dataclass_function(func: &Expr, semantic: &SemanticModel) -> bool { +/// Returns `true` if the given [`Expr`] is a `dataclasses.field` call. +pub(super) fn is_dataclass_field(func: &Expr, semantic: &SemanticModel) -> bool { semantic.resolve_call_path(func).map_or(false, |call_path| { - ALLOWED_DATACLASS_SPECIFIC_FUNCTIONS - .iter() - .any(|target| call_path.as_slice() == *target) + matches!(call_path.as_slice(), ["dataclasses", "field"]) }) } diff --git a/crates/ruff/src/rules/ruff/rules/mutable_class_default.rs b/crates/ruff/src/rules/ruff/rules/mutable_class_default.rs index 113c3afc96..d2c263ec21 100644 --- a/crates/ruff/src/rules/ruff/rules/mutable_class_default.rs +++ b/crates/ruff/src/rules/ruff/rules/mutable_class_default.rs @@ -2,10 +2,10 @@ use rustpython_parser::ast::{self, Ranged, Stmt}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_semantic::analyze::typing::is_immutable_annotation; +use ruff_python_semantic::analyze::typing::{is_immutable_annotation, is_mutable_expr}; use crate::checkers::ast::Checker; -use crate::rules::ruff::rules::helpers::{is_class_var_annotation, is_dataclass, is_mutable_expr}; +use crate::rules::ruff::rules::helpers::{is_class_var_annotation, is_dataclass}; /// ## What it does /// Checks for mutable default values in class attributes. @@ -52,7 +52,7 @@ pub(crate) fn mutable_class_default(checker: &mut Checker, class_def: &ast::Stmt value: Some(value), .. }) => { - if is_mutable_expr(value) + if is_mutable_expr(value, checker.semantic()) && !is_class_var_annotation(annotation, checker.semantic()) && !is_immutable_annotation(annotation, checker.semantic()) && !is_dataclass(class_def, checker.semantic()) @@ -63,7 +63,7 @@ pub(crate) fn mutable_class_default(checker: &mut Checker, class_def: &ast::Stmt } } Stmt::Assign(ast::StmtAssign { value, .. }) => { - if is_mutable_expr(value) { + if is_mutable_expr(value, checker.semantic()) { checker .diagnostics .push(Diagnostic::new(MutableClassDefault, value.range())); diff --git a/crates/ruff/src/rules/ruff/rules/mutable_dataclass_default.rs b/crates/ruff/src/rules/ruff/rules/mutable_dataclass_default.rs index ac30e0e214..2b47c32a46 100644 --- a/crates/ruff/src/rules/ruff/rules/mutable_dataclass_default.rs +++ b/crates/ruff/src/rules/ruff/rules/mutable_dataclass_default.rs @@ -2,10 +2,10 @@ use rustpython_parser::ast::{self, Ranged, Stmt}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_semantic::analyze::typing::is_immutable_annotation; +use ruff_python_semantic::analyze::typing::{is_immutable_annotation, is_mutable_expr}; use crate::checkers::ast::Checker; -use crate::rules::ruff::rules::helpers::{is_class_var_annotation, is_dataclass, is_mutable_expr}; +use crate::rules::ruff::rules::helpers::{is_class_var_annotation, is_dataclass}; /// ## What it does /// Checks for mutable default values in dataclass attributes. @@ -74,7 +74,7 @@ pub(crate) fn mutable_dataclass_default(checker: &mut Checker, class_def: &ast:: .. }) = statement { - if is_mutable_expr(value) + if is_mutable_expr(value, checker.semantic()) && !is_class_var_annotation(annotation, checker.semantic()) && !is_immutable_annotation(annotation, checker.semantic()) { diff --git a/crates/ruff_python_semantic/src/analyze/typing.rs b/crates/ruff_python_semantic/src/analyze/typing.rs index 8dea42177e..4da230cd70 100644 --- a/crates/ruff_python_semantic/src/analyze/typing.rs +++ b/crates/ruff_python_semantic/src/analyze/typing.rs @@ -6,7 +6,10 @@ use rustpython_parser::ast::{self, Constant, Expr, Operator}; use ruff_python_ast::call_path::{from_qualified_name, from_unqualified_name, CallPath}; use ruff_python_ast::helpers::is_const_false; use ruff_python_stdlib::typing::{ - IMMUTABLE_GENERIC_TYPES, IMMUTABLE_TYPES, PEP_585_GENERICS, PEP_593_SUBSCRIPTS, SUBSCRIPTS, + as_pep_585_generic, has_pep_585_generic, is_immutable_generic_type, + is_immutable_non_generic_type, is_immutable_return_type, is_mutable_return_type, + is_pep_593_generic_member, is_pep_593_generic_type, is_standard_library_generic, + is_standard_library_generic_member, }; use crate::model::SemanticModel; @@ -34,12 +37,8 @@ pub fn match_annotated_subscript<'a>( typing_modules: impl Iterator, extend_generics: &[String], ) -> Option { - if !matches!(expr, Expr::Name(_) | Expr::Attribute(_)) { - return None; - } - semantic.resolve_call_path(expr).and_then(|call_path| { - if SUBSCRIPTS.contains(&call_path.as_slice()) + if is_standard_library_generic(call_path.as_slice()) || extend_generics .iter() .map(|target| from_qualified_name(target)) @@ -47,20 +46,19 @@ pub fn match_annotated_subscript<'a>( { return Some(SubscriptKind::AnnotatedSubscript); } - if PEP_593_SUBSCRIPTS.contains(&call_path.as_slice()) { + + if is_pep_593_generic_type(call_path.as_slice()) { return Some(SubscriptKind::PEP593AnnotatedSubscript); } for module in typing_modules { let module_call_path: CallPath = from_unqualified_name(module); if call_path.starts_with(&module_call_path) { - for subscript in SUBSCRIPTS.iter() { - if call_path.last() == subscript.last() { + if let Some(member) = call_path.last() { + if is_standard_library_generic_member(member) { return Some(SubscriptKind::AnnotatedSubscript); } - } - for subscript in PEP_593_SUBSCRIPTS.iter() { - if call_path.last() == subscript.last() { + if is_pep_593_generic_member(member) { return Some(SubscriptKind::PEP593AnnotatedSubscript); } } @@ -92,38 +90,27 @@ impl std::fmt::Display for ModuleMember { /// a variant exists. pub fn to_pep585_generic(expr: &Expr, semantic: &SemanticModel) -> Option { semantic.resolve_call_path(expr).and_then(|call_path| { - let [module, name] = call_path.as_slice() else { + let [module, member] = call_path.as_slice() else { return None; }; - PEP_585_GENERICS - .iter() - .find_map(|((from_module, from_member), (to_module, to_member))| { - if module == from_module && name == from_member { - if to_module.is_empty() { - Some(ModuleMember::BuiltIn(to_member)) - } else { - Some(ModuleMember::Member(to_module, to_member)) - } - } else { - None - } - }) + as_pep_585_generic(module, member).map(|(module, member)| { + if module.is_empty() { + ModuleMember::BuiltIn(member) + } else { + ModuleMember::Member(module, member) + } + }) }) } /// Return whether a given expression uses a PEP 585 standard library generic. pub fn is_pep585_generic(expr: &Expr, semantic: &SemanticModel) -> bool { - if let Some(call_path) = semantic.resolve_call_path(expr) { + semantic.resolve_call_path(expr).map_or(false, |call_path| { let [module, name] = call_path.as_slice() else { return false; }; - for (_, (to_module, to_member)) in PEP_585_GENERICS { - if module == to_module && name == to_member { - return true; - } - } - } - false + has_pep_585_generic(module, name) + }) } #[derive(Debug, Copy, Clone)] @@ -178,19 +165,14 @@ pub fn is_immutable_annotation(expr: &Expr, semantic: &SemanticModel) -> bool { match expr { Expr::Name(_) | Expr::Attribute(_) => { semantic.resolve_call_path(expr).map_or(false, |call_path| { - IMMUTABLE_TYPES - .iter() - .chain(IMMUTABLE_GENERIC_TYPES) - .any(|target| call_path.as_slice() == *target) + is_immutable_non_generic_type(call_path.as_slice()) + || is_immutable_generic_type(call_path.as_slice()) }) } Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => semantic .resolve_call_path(value) .map_or(false, |call_path| { - if IMMUTABLE_GENERIC_TYPES - .iter() - .any(|target| call_path.as_slice() == *target) - { + if is_immutable_generic_type(call_path.as_slice()) { true } else if matches!(call_path.as_slice(), ["typing", "Union"]) { if let Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() { @@ -226,43 +208,43 @@ pub fn is_immutable_annotation(expr: &Expr, semantic: &SemanticModel) -> bool { } } -const IMMUTABLE_FUNCS: &[&[&str]] = &[ - &["", "bool"], - &["", "complex"], - &["", "float"], - &["", "frozenset"], - &["", "int"], - &["", "str"], - &["", "tuple"], - &["datetime", "date"], - &["datetime", "datetime"], - &["datetime", "timedelta"], - &["decimal", "Decimal"], - &["fractions", "Fraction"], - &["operator", "attrgetter"], - &["operator", "itemgetter"], - &["operator", "methodcaller"], - &["pathlib", "Path"], - &["types", "MappingProxyType"], - &["re", "compile"], -]; - -/// Return `true` if `func` is a function that returns an immutable object. +/// Return `true` if `func` is a function that returns an immutable value. pub fn is_immutable_func( func: &Expr, semantic: &SemanticModel, extend_immutable_calls: &[CallPath], ) -> bool { semantic.resolve_call_path(func).map_or(false, |call_path| { - IMMUTABLE_FUNCS - .iter() - .any(|target| call_path.as_slice() == *target) + is_immutable_return_type(call_path.as_slice()) || extend_immutable_calls .iter() .any(|target| call_path == *target) }) } +/// Return `true` if `func` is a function that returns a mutable value. +pub fn is_mutable_func(func: &Expr, semantic: &SemanticModel) -> bool { + semantic + .resolve_call_path(func) + .as_ref() + .map(CallPath::as_slice) + .map_or(false, is_mutable_return_type) +} + +/// Return `true` if `expr` is an expression that resolves to a mutable value. +pub fn is_mutable_expr(expr: &Expr, semantic: &SemanticModel) -> bool { + match expr { + Expr::List(_) + | Expr::Dict(_) + | Expr::Set(_) + | Expr::ListComp(_) + | Expr::DictComp(_) + | Expr::SetComp(_) => true, + Expr::Call(ast::ExprCall { func, .. }) => is_mutable_func(func, semantic), + _ => false, + } +} + /// Return `true` if [`Expr`] is a guard for a type-checking block. pub fn is_type_checking_block(stmt: &ast::StmtIf, semantic: &SemanticModel) -> bool { let ast::StmtIf { test, .. } = stmt; diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index bae2be65ca..56778df02e 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -10,7 +10,7 @@ use smallvec::smallvec; use ruff_python_ast::call_path::{collect_call_path, from_unqualified_name, CallPath}; use ruff_python_ast::helpers::from_relative_import; use ruff_python_stdlib::path::is_python_stub_file; -use ruff_python_stdlib::typing::TYPING_EXTENSIONS; +use ruff_python_stdlib::typing::is_typing_extension; use crate::binding::{ Binding, BindingFlags, BindingId, BindingKind, Bindings, Exceptions, FromImportation, @@ -175,7 +175,7 @@ impl<'a> SemanticModel<'a> { return true; } - if TYPING_EXTENSIONS.contains(target) { + if is_typing_extension(target) { if call_path.as_slice() == ["typing_extensions", target] { return true; } diff --git a/crates/ruff_python_stdlib/src/typing.rs b/crates/ruff_python_stdlib/src/typing.rs index 48895413f6..796f7c3a07 100644 --- a/crates/ruff_python_stdlib/src/typing.rs +++ b/crates/ruff_python_stdlib/src/typing.rs @@ -1,279 +1,414 @@ -use once_cell::sync::Lazy; -use rustc_hash::{FxHashMap, FxHashSet}; +/// Returns `true` if a name is a member of Python's `typing_extensions` module. +/// +/// See: +pub fn is_typing_extension(member: &str) -> bool { + matches!( + member, + "Annotated" + | "Any" + | "AsyncContextManager" + | "AsyncGenerator" + | "AsyncIterable" + | "AsyncIterator" + | "Awaitable" + | "ChainMap" + | "ClassVar" + | "Concatenate" + | "ContextManager" + | "Coroutine" + | "Counter" + | "DefaultDict" + | "Deque" + | "Final" + | "Literal" + | "LiteralString" + | "NamedTuple" + | "Never" + | "NewType" + | "NotRequired" + | "OrderedDict" + | "ParamSpec" + | "ParamSpecArgs" + | "ParamSpecKwargs" + | "Protocol" + | "Required" + | "Self" + | "TYPE_CHECKING" + | "Text" + | "Type" + | "TypeAlias" + | "TypeGuard" + | "TypeVar" + | "TypeVarTuple" + | "TypedDict" + | "Unpack" + | "assert_never" + | "assert_type" + | "clear_overloads" + | "final" + | "get_type_hints" + | "get_args" + | "get_origin" + | "get_overloads" + | "is_typeddict" + | "overload" + | "override" + | "reveal_type" + | "runtime_checkable" + ) +} -// See: https://pypi.org/project/typing-extensions/ -pub static TYPING_EXTENSIONS: Lazy> = Lazy::new(|| { - FxHashSet::from_iter([ - "Annotated", - "Any", - "AsyncContextManager", - "AsyncGenerator", - "AsyncIterable", - "AsyncIterator", - "Awaitable", - "ChainMap", - "ClassVar", - "Concatenate", - "ContextManager", - "Coroutine", - "Counter", - "DefaultDict", - "Deque", - "Final", - "Literal", - "LiteralString", - "NamedTuple", - "Never", - "NewType", - "NotRequired", - "OrderedDict", - "ParamSpec", - "ParamSpecArgs", - "ParamSpecKwargs", - "Protocol", - "Required", - "Self", - "TYPE_CHECKING", - "Text", - "Type", - "TypeAlias", - "TypeGuard", - "TypeVar", - "TypeVarTuple", - "TypedDict", - "Unpack", - "assert_never", - "assert_type", - "clear_overloads", - "final", - "get_type_hints", - "get_args", - "get_origin", - "get_overloads", - "is_typeddict", - "overload", - "override", - "reveal_type", - "runtime_checkable", - ]) -}); +/// Returns `true` if a call path is a generic from the Python standard library (e.g. `list`, which +/// can be used as `list[int]`). +/// +/// See: +pub fn is_standard_library_generic(call_path: &[&str]) -> bool { + matches!( + call_path, + ["", "dict" | "frozenset" | "list" | "set" | "tuple" | "type"] + | [ + "collections" | "typing" | "typing_extensions", + "ChainMap" | "Counter" + ] + | ["collections" | "typing", "OrderedDict"] + | ["collections", "defaultdict" | "deque"] + | [ + "collections", + "abc", + "AsyncGenerator" + | "AsyncIterable" + | "AsyncIterator" + | "Awaitable" + | "ByteString" + | "Callable" + | "Collection" + | "Container" + | "Coroutine" + | "Generator" + | "ItemsView" + | "Iterable" + | "Iterator" + | "KeysView" + | "Mapping" + | "MappingView" + | "MutableMapping" + | "MutableSequence" + | "MutableSet" + | "Reversible" + | "Sequence" + | "Set" + | "ValuesView" + ] + | [ + "contextlib", + "AbstractAsyncContextManager" | "AbstractContextManager" + ] + | ["re" | "typing", "Match" | "Pattern"] + | [ + "typing", + "AbstractSet" + | "AsyncContextManager" + | "AsyncGenerator" + | "AsyncIterator" + | "Awaitable" + | "BinaryIO" + | "ByteString" + | "Callable" + | "ClassVar" + | "Collection" + | "Concatenate" + | "Container" + | "ContextManager" + | "Coroutine" + | "DefaultDict" + | "Deque" + | "Dict" + | "Final" + | "FrozenSet" + | "Generator" + | "Generic" + | "IO" + | "ItemsView" + | "Iterable" + | "Iterator" + | "KeysView" + | "List" + | "Mapping" + | "MutableMapping" + | "MutableSequence" + | "MutableSet" + | "Optional" + | "Reversible" + | "Sequence" + | "Set" + | "TextIO" + | "Tuple" + | "Type" + | "TypeGuard" + | "Union" + | "Unpack" + | "ValuesView" + ] + | ["typing", "io", "BinaryIO" | "IO" | "TextIO"] + | ["typing", "re", "Match" | "Pattern"] + | [ + "typing_extensions", + "AsyncContextManager" + | "AsyncGenerator" + | "AsyncIterable" + | "AsyncIterator" + | "Awaitable" + | "ClassVar" + | "Concatenate" + | "ContextManager" + | "Coroutine" + | "DefaultDict" + | "Deque" + | "Type" + ] + | [ + "weakref", + "WeakKeyDictionary" | "WeakSet" | "WeakValueDictionary" + ] + ) +} -// See: https://docs.python.org/3/library/typing.html -pub const SUBSCRIPTS: &[&[&str]] = &[ - // builtins - &["", "dict"], - &["", "frozenset"], - &["", "list"], - &["", "set"], - &["", "tuple"], - &["", "type"], - // `collections` - &["collections", "ChainMap"], - &["collections", "Counter"], - &["collections", "OrderedDict"], - &["collections", "defaultdict"], - &["collections", "deque"], - // `collections.abc` - &["collections", "abc", "AsyncGenerator"], - &["collections", "abc", "AsyncIterable"], - &["collections", "abc", "AsyncIterator"], - &["collections", "abc", "Awaitable"], - &["collections", "abc", "ByteString"], - &["collections", "abc", "Callable"], - &["collections", "abc", "Collection"], - &["collections", "abc", "Container"], - &["collections", "abc", "Coroutine"], - &["collections", "abc", "Generator"], - &["collections", "abc", "ItemsView"], - &["collections", "abc", "Iterable"], - &["collections", "abc", "Iterator"], - &["collections", "abc", "KeysView"], - &["collections", "abc", "Mapping"], - &["collections", "abc", "MappingView"], - &["collections", "abc", "MutableMapping"], - &["collections", "abc", "MutableSequence"], - &["collections", "abc", "MutableSet"], - &["collections", "abc", "Reversible"], - &["collections", "abc", "Sequence"], - &["collections", "abc", "Set"], - &["collections", "abc", "ValuesView"], - // `contextlib` - &["contextlib", "AbstractAsyncContextManager"], - &["contextlib", "AbstractContextManager"], - // `re` - &["re", "Match"], - &["re", "Pattern"], - // `typing` - &["typing", "AbstractSet"], - &["typing", "AsyncContextManager"], - &["typing", "AsyncGenerator"], - &["typing", "AsyncIterator"], - &["typing", "Awaitable"], - &["typing", "BinaryIO"], - &["typing", "ByteString"], - &["typing", "Callable"], - &["typing", "ChainMap"], - &["typing", "ClassVar"], - &["typing", "Collection"], - &["typing", "Concatenate"], - &["typing", "Container"], - &["typing", "ContextManager"], - &["typing", "Coroutine"], - &["typing", "Counter"], - &["typing", "DefaultDict"], - &["typing", "Deque"], - &["typing", "Dict"], - &["typing", "Final"], - &["typing", "FrozenSet"], - &["typing", "Generator"], - &["typing", "Generic"], - &["typing", "IO"], - &["typing", "ItemsView"], - &["typing", "Iterable"], - &["typing", "Iterator"], - &["typing", "KeysView"], - &["typing", "List"], - &["typing", "Mapping"], - &["typing", "Match"], - &["typing", "MutableMapping"], - &["typing", "MutableSequence"], - &["typing", "MutableSet"], - &["typing", "Optional"], - &["typing", "OrderedDict"], - &["typing", "Pattern"], - &["typing", "Reversible"], - &["typing", "Sequence"], - &["typing", "Set"], - &["typing", "TextIO"], - &["typing", "Tuple"], - &["typing", "Type"], - &["typing", "TypeGuard"], - &["typing", "Union"], - &["typing", "Unpack"], - &["typing", "ValuesView"], - // `typing.io` - &["typing", "io", "BinaryIO"], - &["typing", "io", "IO"], - &["typing", "io", "TextIO"], - // `typing.re` - &["typing", "re", "Match"], - &["typing", "re", "Pattern"], - // `typing_extensions` - &["typing_extensions", "AsyncContextManager"], - &["typing_extensions", "AsyncGenerator"], - &["typing_extensions", "AsyncIterable"], - &["typing_extensions", "AsyncIterator"], - &["typing_extensions", "Awaitable"], - &["typing_extensions", "ChainMap"], - &["typing_extensions", "ClassVar"], - &["typing_extensions", "Concatenate"], - &["typing_extensions", "ContextManager"], - &["typing_extensions", "Coroutine"], - &["typing_extensions", "Counter"], - &["typing_extensions", "DefaultDict"], - &["typing_extensions", "Deque"], - &["typing_extensions", "Type"], - // `weakref` - &["weakref", "WeakKeyDictionary"], - &["weakref", "WeakSet"], - &["weakref", "WeakValueDictionary"], -]; +/// Returns `true` if a call path is a [PEP 593] generic (e.g. `Annotated`). +/// +/// See: +/// +/// [PEP 593]: https://peps.python.org/pep-0593/ +pub fn is_pep_593_generic_type(call_path: &[&str]) -> bool { + matches!(call_path, ["typing" | "typing_extensions", "Annotated"]) +} -// See: https://docs.python.org/3/library/typing.html -pub const PEP_593_SUBSCRIPTS: &[&[&str]] = &[ - // `typing` - &["typing", "Annotated"], - // `typing_extensions` - &["typing_extensions", "Annotated"], -]; +/// Returns `true` if a name matches that of a generic from the Python standard library (e.g. +/// `list` or `Set`). +/// +/// See: +pub fn is_standard_library_generic_member(member: &str) -> bool { + // Constructed by taking every pattern from `is_standard_library_generic`, removing all but + // the last element in each pattern, and de-duplicating the values. + matches!( + member, + "dict" + | "AbstractAsyncContextManager" + | "AbstractContextManager" + | "AbstractSet" + | "AsyncContextManager" + | "AsyncGenerator" + | "AsyncIterable" + | "AsyncIterator" + | "Awaitable" + | "BinaryIO" + | "ByteString" + | "Callable" + | "ChainMap" + | "ClassVar" + | "Collection" + | "Concatenate" + | "Container" + | "ContextManager" + | "Coroutine" + | "Counter" + | "DefaultDict" + | "Deque" + | "Dict" + | "Final" + | "FrozenSet" + | "Generator" + | "Generic" + | "IO" + | "ItemsView" + | "Iterable" + | "Iterator" + | "KeysView" + | "List" + | "Mapping" + | "MappingView" + | "Match" + | "MutableMapping" + | "MutableSequence" + | "MutableSet" + | "Optional" + | "OrderedDict" + | "Pattern" + | "Reversible" + | "Sequence" + | "Set" + | "TextIO" + | "Tuple" + | "Type" + | "TypeGuard" + | "Union" + | "Unpack" + | "ValuesView" + | "WeakKeyDictionary" + | "WeakSet" + | "WeakValueDictionary" + | "defaultdict" + | "deque" + | "frozenset" + | "list" + | "set" + | "tuple" + | "type" + ) +} + +/// Returns `true` if a name matches that of a generic from [PEP 593] (e.g. `Annotated`). +/// +/// See: +/// +/// [PEP 593]: https://peps.python.org/pep-0593/ +pub fn is_pep_593_generic_member(member: &str) -> bool { + // Constructed by taking every pattern from `is_pep_593_generic`, removing all but + // the last element in each pattern, and de-duplicating the values. + matches!(member, "Annotated") +} + +/// Returns `true` if a call path represents that of an immutable, non-generic type from the Python +/// standard library (e.g. `int` or `str`). +pub fn is_immutable_non_generic_type(call_path: &[&str]) -> bool { + matches!( + call_path, + ["collections", "abc", "Sized"] + | ["typing", "LiteralString" | "Sized"] + | [ + "", + "bool" + | "bytes" + | "complex" + | "float" + | "frozenset" + | "int" + | "object" + | "range" + | "str" + ] + ) +} + +/// Returns `true` if a call path represents that of an immutable, generic type from the Python +/// standard library (e.g. `tuple`). +pub fn is_immutable_generic_type(call_path: &[&str]) -> bool { + matches!( + call_path, + ["", "tuple"] + | [ + "collections", + "abc", + "ByteString" + | "Collection" + | "Container" + | "Iterable" + | "Mapping" + | "Reversible" + | "Sequence" + | "Set" + ] + | [ + "typing", + "AbstractSet" + | "ByteString" + | "Callable" + | "Collection" + | "Container" + | "FrozenSet" + | "Iterable" + | "Literal" + | "Mapping" + | "Never" + | "NoReturn" + | "Reversible" + | "Sequence" + | "Tuple" + ] + ) +} + +/// Returns `true` if a call path represents a function from the Python standard library that +/// returns a mutable value (e.g., `dict`). +pub fn is_mutable_return_type(call_path: &[&str]) -> bool { + matches!( + call_path, + ["", "dict" | "list" | "set"] + | [ + "collections", + "Counter" | "OrderedDict" | "defaultdict" | "deque" + ] + ) +} + +/// Returns `true` if a call path represents a function from the Python standard library that +/// returns a immutable value (e.g., `bool`). +pub fn is_immutable_return_type(call_path: &[&str]) -> bool { + matches!( + call_path, + ["datetime", "date" | "datetime" | "timedelta"] + | ["decimal", "Decimal"] + | ["fractions", "Fraction"] + | ["operator", "attrgetter" | "itemgetter" | "methodcaller"] + | ["pathlib", "Path"] + | ["types", "MappingProxyType"] + | ["re", "compile"] + | [ + "", + "bool" | "complex" | "float" | "frozenset" | "int" | "str" | "tuple" + ] + ) +} type ModuleMember = (&'static str, &'static str); -type SymbolReplacement = (ModuleMember, ModuleMember); +/// Given a typing member, returns the module and member name for a generic from the Python standard +/// library (e.g., `list` for `typing.List`), if such a generic was introduced by [PEP 585]. +/// +/// [PEP 585]: https://peps.python.org/pep-0585/ +pub fn as_pep_585_generic(module: &str, member: &str) -> Option { + match (module, member) { + ("typing", "Dict") => Some(("", "dict")), + ("typing", "FrozenSet") => Some(("", "frozenset")), + ("typing", "List") => Some(("", "list")), + ("typing", "Set") => Some(("", "set")), + ("typing", "Tuple") => Some(("", "tuple")), + ("typing", "Type") => Some(("", "type")), + ("typing_extensions", "Type") => Some(("", "type")), + ("typing", "Deque") => Some(("collections", "deque")), + ("typing_extensions", "Deque") => Some(("collections", "deque")), + ("typing", "DefaultDict") => Some(("collections", "defaultdict")), + ("typing_extensions", "DefaultDict") => Some(("collections", "defaultdict")), + _ => None, + } +} -// See: https://peps.python.org/pep-0585/ -pub const PEP_585_GENERICS: &[SymbolReplacement] = &[ - (("typing", "Dict"), ("", "dict")), - (("typing", "FrozenSet"), ("", "frozenset")), - (("typing", "List"), ("", "list")), - (("typing", "Set"), ("", "set")), - (("typing", "Tuple"), ("", "tuple")), - (("typing", "Type"), ("", "type")), - (("typing_extensions", "Type"), ("", "type")), - (("typing", "Deque"), ("collections", "deque")), - (("typing_extensions", "Deque"), ("collections", "deque")), - (("typing", "DefaultDict"), ("collections", "defaultdict")), - ( - ("typing_extensions", "DefaultDict"), - ("collections", "defaultdict"), - ), -]; +/// Given a typing member, returns `true` if a generic equivalent exists in the Python standard +/// library (e.g., `list` for `typing.List`), as introduced by [PEP 585]. +/// +/// [PEP 585]: https://peps.python.org/pep-0585/ +pub fn has_pep_585_generic(module: &str, member: &str) -> bool { + // Constructed by taking every pattern from `as_pep_585_generic`, removing all but + // the last element in each pattern, and de-duplicating the values. + matches!( + (module, member), + ("", "dict" | "frozenset" | "list" | "set" | "tuple" | "type") + | ("collections", "deque" | "defaultdict") + ) +} -// See: https://github.com/JelleZijlstra/autotyping/blob/0adba5ba0eee33c1de4ad9d0c79acfd737321dd9/autotyping/autotyping.py#L69-L91 -pub static SIMPLE_MAGIC_RETURN_TYPES: Lazy> = - Lazy::new(|| { - FxHashMap::from_iter([ - ("__str__", "str"), - ("__repr__", "str"), - ("__len__", "int"), - ("__length_hint__", "int"), - ("__init__", "None"), - ("__del__", "None"), - ("__bool__", "bool"), - ("__bytes__", "bytes"), - ("__format__", "str"), - ("__contains__", "bool"), - ("__complex__", "complex"), - ("__int__", "int"), - ("__float__", "float"), - ("__index__", "int"), - ("__setattr__", "None"), - ("__delattr__", "None"), - ("__setitem__", "None"), - ("__delitem__", "None"), - ("__set__", "None"), - ("__instancecheck__", "bool"), - ("__subclasscheck__", "bool"), - ]) - }); - -pub const IMMUTABLE_TYPES: &[&[&str]] = &[ - &["", "bool"], - &["", "bytes"], - &["", "complex"], - &["", "float"], - &["", "frozenset"], - &["", "int"], - &["", "object"], - &["", "range"], - &["", "str"], - &["collections", "abc", "Sized"], - &["typing", "LiteralString"], - &["typing", "Sized"], -]; - -pub const IMMUTABLE_GENERIC_TYPES: &[&[&str]] = &[ - &["", "tuple"], - &["collections", "abc", "ByteString"], - &["collections", "abc", "Collection"], - &["collections", "abc", "Container"], - &["collections", "abc", "Iterable"], - &["collections", "abc", "Mapping"], - &["collections", "abc", "Reversible"], - &["collections", "abc", "Sequence"], - &["collections", "abc", "Set"], - &["typing", "AbstractSet"], - &["typing", "ByteString"], - &["typing", "Callable"], - &["typing", "Collection"], - &["typing", "Container"], - &["typing", "FrozenSet"], - &["typing", "Iterable"], - &["typing", "Literal"], - &["typing", "Mapping"], - &["typing", "Never"], - &["typing", "NoReturn"], - &["typing", "Reversible"], - &["typing", "Sequence"], - &["typing", "Tuple"], -]; +/// Returns the expected return type for a magic method. +/// +/// See: +pub fn simple_magic_return_type(method: &str) -> Option<&'static str> { + match method { + "__str__" | "__repr__" | "__format__" => Some("str"), + "__bytes__" => Some("bytes"), + "__len__" | "__length_hint__" | "__int__" | "__index__" => Some("int"), + "__float__" => Some("float"), + "__complex__" => Some("complex"), + "__bool__" | "__contains__" | "__instancecheck__" | "__subclasscheck__" => Some("bool"), + "__init__" | "__del__" | "__setattr__" | "__delattr__" | "__setitem__" | "__delitem__" + | "__set__" => Some("None"), + _ => None, + } +}