From 87a0cd219ff888c8ed0e97f1391b6f42899825eb Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Thu, 21 Sep 2023 00:37:38 -0400 Subject: [PATCH] Detect `asyncio.get_running_loop` calls in RUF006 (#7562) ## Summary We can do a good enough job detecting this with our existing semantic model. Closes https://github.com/astral-sh/ruff/issues/3237. --- .../resources/test/fixtures/ruff/RUF006.py | 21 ++++++++++- .../flake8_bandit/rules/flask_debug_true.rs | 33 +++++------------ .../rules/ssh_no_host_key_verification.rs | 37 +++++-------------- .../rules/ruff/rules/asyncio_dangling_task.rs | 31 ++++++++++++---- ..._rules__ruff__tests__RUF006_RUF006.py.snap | 8 ++++ .../src/analyze/typing.rs | 33 +++++++++++++++++ 6 files changed, 103 insertions(+), 60 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/ruff/RUF006.py b/crates/ruff_linter/resources/test/fixtures/ruff/RUF006.py index 69c878bac3..eedce25631 100644 --- a/crates/ruff_linter/resources/test/fixtures/ruff/RUF006.py +++ b/crates/ruff_linter/resources/test/fixtures/ruff/RUF006.py @@ -63,11 +63,28 @@ def f(): tasks = [asyncio.create_task(task) for task in tasks] -# Ok (false negative) +# OK (false negative) def f(): task = asyncio.create_task(coordinator.ws_connect()) -# Ok (potential false negative) +# OK (potential false negative) def f(): do_nothing_with_the_task(asyncio.create_task(coordinator.ws_connect())) + + +# Error +def f(): + loop = asyncio.get_running_loop() + loop.create_task(coordinator.ws_connect()) # Error + + +# OK +def f(): + loop.create_task(coordinator.ws_connect()) + + +# OK +def f(): + loop = asyncio.get_running_loop() + loop.do_thing(coordinator.ws_connect()) diff --git a/crates/ruff_linter/src/rules/flake8_bandit/rules/flask_debug_true.rs b/crates/ruff_linter/src/rules/flake8_bandit/rules/flask_debug_true.rs index 35fedb70f2..a1517f26c0 100644 --- a/crates/ruff_linter/src/rules/flake8_bandit/rules/flask_debug_true.rs +++ b/crates/ruff_linter/src/rules/flake8_bandit/rules/flask_debug_true.rs @@ -1,7 +1,8 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::is_const_true; -use ruff_python_ast::{Expr, ExprAttribute, ExprCall, Stmt, StmtAssign}; +use ruff_python_ast::{Expr, ExprAttribute, ExprCall}; +use ruff_python_semantic::analyze::typing; use ruff_text_size::Ranged; use crate::checkers::ast::Checker; @@ -63,27 +64,11 @@ pub(crate) fn flask_debug_true(checker: &mut Checker, call: &ExprCall) { return; } - let Expr::Name(name) = value.as_ref() else { - return; - }; - - if let Some(binding_id) = checker.semantic().resolve_name(name) { - if let Some(Stmt::Assign(StmtAssign { value, .. })) = checker - .semantic() - .binding(binding_id) - .statement(checker.semantic()) - { - if let Expr::Call(ExprCall { func, .. }) = value.as_ref() { - if checker - .semantic() - .resolve_call_path(func) - .is_some_and(|call_path| matches!(call_path.as_slice(), ["flask", "Flask"])) - { - checker - .diagnostics - .push(Diagnostic::new(FlaskDebugTrue, debug_argument.range())); - } - } - } - }; + if typing::resolve_assignment(value, checker.semantic()) + .is_some_and(|call_path| matches!(call_path.as_slice(), ["flask", "Flask"])) + { + checker + .diagnostics + .push(Diagnostic::new(FlaskDebugTrue, debug_argument.range())); + } } diff --git a/crates/ruff_linter/src/rules/flake8_bandit/rules/ssh_no_host_key_verification.rs b/crates/ruff_linter/src/rules/flake8_bandit/rules/ssh_no_host_key_verification.rs index fac1b2093c..d653fdf658 100644 --- a/crates/ruff_linter/src/rules/flake8_bandit/rules/ssh_no_host_key_verification.rs +++ b/crates/ruff_linter/src/rules/flake8_bandit/rules/ssh_no_host_key_verification.rs @@ -1,6 +1,7 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::{Expr, ExprAttribute, ExprCall, Stmt, StmtAssign}; +use ruff_python_ast::{Expr, ExprAttribute, ExprCall}; +use ruff_python_semantic::analyze::typing; use ruff_text_size::Ranged; use crate::checkers::ast::Checker; @@ -68,30 +69,12 @@ pub(crate) fn ssh_no_host_key_verification(checker: &mut Checker, call: &ExprCal return; } - let Expr::Name(name) = value.as_ref() else { - return; - }; - - if let Some(binding_id) = checker.semantic().resolve_name(name) { - if let Some(Stmt::Assign(StmtAssign { value, .. })) = checker - .semantic() - .binding(binding_id) - .statement(checker.semantic()) - { - if let Expr::Call(ExprCall { func, .. }) = value.as_ref() { - if checker - .semantic() - .resolve_call_path(func) - .is_some_and(|call_path| { - matches!(call_path.as_slice(), ["paramiko", "client", "SSHClient"]) - }) - { - checker.diagnostics.push(Diagnostic::new( - SSHNoHostKeyVerification, - policy_argument.range(), - )); - } - } - } - }; + if typing::resolve_assignment(value, checker.semantic()).is_some_and(|call_path| { + matches!(call_path.as_slice(), ["paramiko", "client", "SSHClient"]) + }) { + checker.diagnostics.push(Diagnostic::new( + SSHNoHostKeyVerification, + policy_argument.range(), + )); + } } diff --git a/crates/ruff_linter/src/rules/ruff/rules/asyncio_dangling_task.rs b/crates/ruff_linter/src/rules/ruff/rules/asyncio_dangling_task.rs index e693774768..f6f03fc8d9 100644 --- a/crates/ruff_linter/src/rules/ruff/rules/asyncio_dangling_task.rs +++ b/crates/ruff_linter/src/rules/ruff/rules/asyncio_dangling_task.rs @@ -4,6 +4,7 @@ use ruff_python_ast::{self as ast, Expr}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_semantic::analyze::typing; use ruff_text_size::Ranged; use crate::checkers::ast::Checker; @@ -70,7 +71,8 @@ pub(crate) fn asyncio_dangling_task(checker: &mut Checker, expr: &Expr) { return; }; - let Some(method) = checker + // Ex) `asyncio.create_task(...)` + if let Some(method) = checker .semantic() .resolve_call_path(func) .and_then(|call_path| match call_path.as_slice() { @@ -78,14 +80,29 @@ pub(crate) fn asyncio_dangling_task(checker: &mut Checker, expr: &Expr) { ["asyncio", "ensure_future"] => Some(Method::EnsureFuture), _ => None, }) - else { + { + checker.diagnostics.push(Diagnostic::new( + AsyncioDanglingTask { method }, + expr.range(), + )); return; - }; + } - checker.diagnostics.push(Diagnostic::new( - AsyncioDanglingTask { method }, - expr.range(), - )); + // Ex) `loop = asyncio.get_running_loop(); loop.create_task(...)` + if let Expr::Attribute(ast::ExprAttribute { attr, value, .. }) = func.as_ref() { + if attr == "create_task" { + if typing::resolve_assignment(value, checker.semantic()).is_some_and(|call_path| { + matches!(call_path.as_slice(), ["asyncio", "get_running_loop"]) + }) { + checker.diagnostics.push(Diagnostic::new( + AsyncioDanglingTask { + method: Method::CreateTask, + }, + expr.range(), + )); + } + } + } } #[derive(Debug, PartialEq, Eq, Copy, Clone)] diff --git a/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF006_RUF006.py.snap b/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF006_RUF006.py.snap index b91ea21408..11a0bababe 100644 --- a/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF006_RUF006.py.snap +++ b/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF006_RUF006.py.snap @@ -17,4 +17,12 @@ RUF006.py:11:5: RUF006 Store a reference to the return value of `asyncio.ensure_ | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RUF006 | +RUF006.py:79:5: RUF006 Store a reference to the return value of `asyncio.create_task` + | +77 | def f(): +78 | loop = asyncio.get_running_loop() +79 | loop.create_task(coordinator.ws_connect()) # Error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RUF006 + | + diff --git a/crates/ruff_python_semantic/src/analyze/typing.rs b/crates/ruff_python_semantic/src/analyze/typing.rs index b7847f78da..707d55553c 100644 --- a/crates/ruff_python_semantic/src/analyze/typing.rs +++ b/crates/ruff_python_semantic/src/analyze/typing.rs @@ -521,3 +521,36 @@ fn find_parameter<'a>( .chain(parameters.kwonlyargs.iter()) .find(|arg| arg.parameter.name.range() == binding.range()) } + +/// Return the [`CallPath`] of the value to which the given [`Expr`] is assigned, if any. +/// +/// For example, given: +/// ```python +/// import asyncio +/// +/// loop = asyncio.get_running_loop() +/// loop.create_task(...) +/// ``` +/// +/// This function will return `["asyncio", "get_running_loop"]` for the `loop` binding. +pub fn resolve_assignment<'a>( + expr: &'a Expr, + semantic: &'a SemanticModel<'a>, +) -> Option> { + let name = expr.as_name_expr()?; + let binding_id = semantic.resolve_name(name)?; + let statement = semantic.binding(binding_id).statement(semantic)?; + match statement { + Stmt::Assign(ast::StmtAssign { value, .. }) => { + let ast::ExprCall { func, .. } = value.as_call_expr()?; + semantic.resolve_call_path(func) + } + Stmt::AnnAssign(ast::StmtAnnAssign { + value: Some(value), .. + }) => { + let ast::ExprCall { func, .. } = value.as_call_expr()?; + semantic.resolve_call_path(func) + } + _ => None, + } +}