From c790c1d957039067e48658512284c20388797073 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikko=20Lepp=C3=A4nen?= Date: Fri, 9 Feb 2024 15:11:08 +0200 Subject: [PATCH] Implement missing await for coroutine `RUF065` --- .../resources/test/fixtures/ruff/RUF065.py | 179 ++++++++++++++++++ .../src/checkers/ast/analyze/expression.rs | 3 + crates/ruff_linter/src/codes.rs | 1 + crates/ruff_linter/src/rules/ruff/mod.rs | 1 + .../ruff/rules/missing_await_for_coroutine.rs | 141 ++++++++++++++ .../ruff_linter/src/rules/ruff/rules/mod.rs | 2 + ..._rules__ruff__tests__RUF065_RUF065.py.snap | 136 +++++++++++++ ruff.schema.json | 1 + 8 files changed, 464 insertions(+) create mode 100644 crates/ruff_linter/resources/test/fixtures/ruff/RUF065.py create mode 100644 crates/ruff_linter/src/rules/ruff/rules/missing_await_for_coroutine.rs create mode 100644 crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF065_RUF065.py.snap diff --git a/crates/ruff_linter/resources/test/fixtures/ruff/RUF065.py b/crates/ruff_linter/resources/test/fixtures/ruff/RUF065.py new file mode 100644 index 0000000000..8bfdacb019 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/ruff/RUF065.py @@ -0,0 +1,179 @@ +import asyncio + +# Violation cases: RUF065 + + +async def test_coroutine_without_await(): + async def coro(): + pass + + coro() # RUF065 + + +async def test_coroutine_without_await(): + async def coro(): + pass + + result = coro() # RUF065 + + +async def test_coroutine_without_await(): + def not_coro(): + pass + + async def coro(): + pass + + not_coro() + coro() # RUF065 + + +async def test_coroutine_without_await(): + async def coro(): + another_coro() # RUF065 + + async def another_coro(): + pass + + await coro() + + +async def test_asyncio_api_without_await(): + asyncio.sleep(0.5) # RUF065 + + +async def test_asyncio_api_without_await(): + async def coro(): + asyncio.sleep(0.5) # RUF065 + + await asyncio.wait(coro) + + +async def test_asyncio_api_without_await(): + async def coro(): + await asyncio.sleep(0.5) + + asyncio.wait_for(coro) # RUF065 + + +async def test_asyncio_api_without_await(): + async def coro1(): + await asyncio.sleep(0.5) + + async def coro2(): + await asyncio.sleep(0.5) + + tasks = [coro1(), coro2()] + asyncio.gather(*tasks) # RUF065 + + +# Non-violation cases: RUF065 + + +async def test_coroutine_with_await(): + async def coro(): + pass + + await coro() # OK + + +async def test_coroutine_with_await(): + def not_coro(): + pass + + async def coro(): + pass + + not_coro() + await coro() # OK + + +import asyncio + + +# define an asynchronous context manager +class AsyncContextManager: + # enter the async context manager + async def __aenter__(self): + await asyncio.sleep(0.5) + + async def __aexit__(self, exc_type, exc, tb): + await asyncio.sleep(0.5) + + +# define a simple coroutine +async def custom_coroutine(): + # create and use the asynchronous context manager + async with AsyncContextManager(): # OK + ... + + +async def test_coroutine_in_func_arg(): + async def another_coro(): + pass + + async def coro(cr): + await cr + + await coro(another_coro()) # OK + + +async def test_coroutine_with_yield(): + async def another_coro(): + pass + + async def coro(): + yield another_coro() + + await coro() # OK + + +async def test_coroutine_with_return(): + async def another_coro(): + pass + + async def coro(): + return another_coro() + + await coro() # OK + + +async def test_coroutine_with_async_iterator(): + class Counter: + def __init__(self): + pass + + def __aiter__(self): + return self + + async def __anext__(self): + pass + + async def main(): + async for c in Counter(): # OK + pass + + +async def test_asyncio_api_with_await(): + async def task_coro(value): + await asyncio.sleep(1) + return value * 10 + + # main coroutine + async def main(): + awaitables = [task_coro(i) for i in range(10)] + await asyncio.gather(*awaitables) # OK + + +async def test_coroutine_inside_collections(): + async def coro(): + pass + + [coro(), coro()] # OK + (coro(), coro()) # OK + {coro(), coro()} # OK + {"coro": coro()} # OK + + +async def test_func_used_in_arg_should_not_raise(func): + func() # OK diff --git a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs index 97dc1052db..5d2e3d6bf0 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs @@ -1297,6 +1297,9 @@ pub(crate) fn expression(expr: &Expr, checker: &Checker) { if checker.is_rule_enabled(Rule::NonOctalPermissions) { ruff::rules::non_octal_permissions(checker, call); } + if checker.is_rule_enabled(Rule::MissingAwaitForCoroutine) { + ruff::rules::missing_await_for_coroutine(checker, call); + } if checker.is_rule_enabled(Rule::AssertRaisesException) { flake8_bugbear::rules::assert_raises_exception_call(checker, call); } diff --git a/crates/ruff_linter/src/codes.rs b/crates/ruff_linter/src/codes.rs index eba1404f2b..8d2eb27a6c 100644 --- a/crates/ruff_linter/src/codes.rs +++ b/crates/ruff_linter/src/codes.rs @@ -1051,6 +1051,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> { (Ruff, "061") => (RuleGroup::Preview, rules::ruff::rules::LegacyFormPytestRaises), (Ruff, "063") => (RuleGroup::Preview, rules::ruff::rules::AccessAnnotationsFromClassDict), (Ruff, "064") => (RuleGroup::Preview, rules::ruff::rules::NonOctalPermissions), + (Ruff, "065") => (RuleGroup::Preview, rules::ruff::rules::MissingAwaitForCoroutine), (Ruff, "100") => (RuleGroup::Stable, rules::ruff::rules::UnusedNOQA), (Ruff, "101") => (RuleGroup::Stable, rules::ruff::rules::RedirectedNOQA), (Ruff, "102") => (RuleGroup::Preview, rules::ruff::rules::InvalidRuleCode), diff --git a/crates/ruff_linter/src/rules/ruff/mod.rs b/crates/ruff_linter/src/rules/ruff/mod.rs index 8f0007a344..4b41f67636 100644 --- a/crates/ruff_linter/src/rules/ruff/mod.rs +++ b/crates/ruff_linter/src/rules/ruff/mod.rs @@ -112,6 +112,7 @@ mod tests { #[test_case(Rule::LegacyFormPytestRaises, Path::new("RUF061_warns.py"))] #[test_case(Rule::LegacyFormPytestRaises, Path::new("RUF061_deprecated_call.py"))] #[test_case(Rule::NonOctalPermissions, Path::new("RUF064.py"))] + #[test_case(Rule::MissingAwaitForCoroutine, Path::new("RUF065.py"))] #[test_case(Rule::RedirectedNOQA, Path::new("RUF101_0.py"))] #[test_case(Rule::RedirectedNOQA, Path::new("RUF101_1.py"))] #[test_case(Rule::InvalidRuleCode, Path::new("RUF102.py"))] diff --git a/crates/ruff_linter/src/rules/ruff/rules/missing_await_for_coroutine.rs b/crates/ruff_linter/src/rules/ruff/rules/missing_await_for_coroutine.rs new file mode 100644 index 0000000000..774a974d15 --- /dev/null +++ b/crates/ruff_linter/src/rules/ruff/rules/missing_await_for_coroutine.rs @@ -0,0 +1,141 @@ +use ruff_python_ast::{ + AtomicNodeIndex, Expr, ExprAwait, ExprCall, ExprName, Stmt, StmtAssign, StmtExpr, + StmtFunctionDef, +}; +use ruff_text_size::{Ranged, TextRange}; + +use crate::{Edit, Fix, FixAvailability, Violation}; +use ruff_macros::{ViolationMetadata, derive_message_formats}; +use ruff_python_semantic::SemanticModel; + +use crate::checkers::ast::Checker; + +/// ## What it does +/// Checks for coroutines that are not awaited. This rule is only active in async contexts. +/// +/// ## Why is this bad? +/// Coroutines are not executed until they are awaited. If a coroutine is not awaited, it will +/// not be executed, and the program will not behave as expected. This is a common mistake when +/// using `asyncio.sleep` instead of `await asyncio.sleep`. +/// +/// Python's asyncio runtime will emit a warning when a coroutine is not awaited. +/// +/// ## Examples +/// ```python +/// async def foo(): +/// pass +/// +/// +/// async def bar(): +/// foo() +/// ``` +/// +/// Use instead: +/// ```python +/// async def foo(): +/// pass +/// +/// +/// async def bar(): +/// await foo() +/// ``` +#[derive(ViolationMetadata)] +pub(crate) struct MissingAwaitForCoroutine; + +impl Violation for MissingAwaitForCoroutine { + const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; + + #[derive_message_formats] + fn message(&self) -> String { + "Coroutine is not awaited".to_string() + } + + fn fix_title(&self) -> Option { + Some("Coroutine is not awaited".to_string()) + } +} + +/// RUF065 +pub(crate) fn missing_await_for_coroutine(checker: &Checker, call: &ExprCall) { + // Only check for missing await in async context + if !checker.semantic().in_async_context() { + return; + } + + // Try to detect possible scenarios where await is missing and ignore other cases + // For example, if the call is not a direct child of an statement expression or assignment statement + // then it's not reliable to determine if await is missing. + // User might return coroutine object from a function or pass it as an argument + if !possibly_missing_await(call, checker.semantic()) { + return; + } + + let is_awaitable = is_awaitable_from_asyncio(call.func.as_ref(), checker.semantic()) + || is_awaitable_func(call.func.as_ref(), checker.semantic()); + + // If call does not originate from asyncio or is not an async function, then it's not awaitable + if !is_awaitable { + return; + } + + checker + .report_diagnostic(MissingAwaitForCoroutine, call.range()) + .set_fix(Fix::unsafe_edit(Edit::range_replacement( + checker.generator().expr(&generate_fix(call)), + call.range(), + ))); +} + +fn is_awaitable_from_asyncio(func: &Expr, semantic: &SemanticModel) -> bool { + if let Some(call_path) = semantic.resolve_qualified_name(func) { + return matches!( + call_path.segments(), + ["asyncio", "sleep" | "wait" | "wait_for" | "gather"] + ); + } + false +} + +fn is_awaitable_func(func: &Expr, semantic: &SemanticModel) -> bool { + let Expr::Name(ExprName { id, .. }) = func else { + return false; + }; + let Some(binding_id) = semantic.lookup_symbol(id) else { + return false; + }; + let binding = semantic.binding(binding_id); + if let Some(node_id) = binding.source { + let node = semantic.statement(node_id); + if let Stmt::FunctionDef(StmtFunctionDef { is_async, name, .. }) = node { + return *is_async && name.as_str() == id; + } + } + false +} + +fn possibly_missing_await(call: &ExprCall, semantic: &SemanticModel) -> bool { + if let Stmt::Expr(StmtExpr { value, .. }) = semantic.current_statement() { + if let Expr::Call(expr_call) = value.as_ref() { + return expr_call == call; + } + } + + if let Some(Stmt::Assign(StmtAssign { value, .. })) = semantic.current_statement_parent() { + if let Expr::Call(expr_call) = value.as_ref() { + return expr_call == call; + } + } + false +} + +/// Generate a [`Fix`] to add `await` for coroutine. +/// +/// For example: +/// - Given `asyncio.sleep(1)`, generate `await asyncio.sleep(1)`. +fn generate_fix(call: &ExprCall) -> Expr { + Expr::Await(ExprAwait { + node_index: AtomicNodeIndex::default(), + value: Box::new(Expr::Call(call.clone())), + range: TextRange::default(), + }) +} diff --git a/crates/ruff_linter/src/rules/ruff/rules/mod.rs b/crates/ruff_linter/src/rules/ruff/rules/mod.rs index 420b8c310a..85ed384722 100644 --- a/crates/ruff_linter/src/rules/ruff/rules/mod.rs +++ b/crates/ruff_linter/src/rules/ruff/rules/mod.rs @@ -24,6 +24,7 @@ pub(crate) use invalid_pyproject_toml::*; pub(crate) use invalid_rule_code::*; pub(crate) use legacy_form_pytest_raises::*; pub(crate) use map_int_version_parsing::*; +pub(crate) use missing_await_for_coroutine::*; pub(crate) use missing_fstring_syntax::*; pub(crate) use mutable_class_default::*; pub(crate) use mutable_dataclass_default::*; @@ -87,6 +88,7 @@ mod invalid_pyproject_toml; mod invalid_rule_code; mod legacy_form_pytest_raises; mod map_int_version_parsing; +mod missing_await_for_coroutine; mod missing_fstring_syntax; mod mutable_class_default; mod mutable_dataclass_default; diff --git a/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF065_RUF065.py.snap b/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF065_RUF065.py.snap new file mode 100644 index 0000000000..6d926624c8 --- /dev/null +++ b/crates/ruff_linter/src/rules/ruff/snapshots/ruff_linter__rules__ruff__tests__RUF065_RUF065.py.snap @@ -0,0 +1,136 @@ +--- +source: crates/ruff_linter/src/rules/ruff/mod.rs +--- +RUF065 [*] Coroutine is not awaited + --> RUF065.py:10:5 + | + 8 | pass + 9 | +10 | coro() # RUF065 + | ^^^^^^ + | +help: Coroutine is not awaited +7 | async def coro(): +8 | pass +9 | + - coro() # RUF065 +10 + await coro() # RUF065 +11 | +12 | +13 | async def test_coroutine_without_await(): +note: This is an unsafe fix and may change runtime behavior + +RUF065 [*] Coroutine is not awaited + --> RUF065.py:28:5 + | +27 | not_coro() +28 | coro() # RUF065 + | ^^^^^^ + | +help: Coroutine is not awaited +25 | pass +26 | +27 | not_coro() + - coro() # RUF065 +28 + await coro() # RUF065 +29 | +30 | +31 | async def test_coroutine_without_await(): +note: This is an unsafe fix and may change runtime behavior + +RUF065 [*] Coroutine is not awaited + --> RUF065.py:33:9 + | +31 | async def test_coroutine_without_await(): +32 | async def coro(): +33 | another_coro() # RUF065 + | ^^^^^^^^^^^^^^ +34 | +35 | async def another_coro(): + | +help: Coroutine is not awaited +30 | +31 | async def test_coroutine_without_await(): +32 | async def coro(): + - another_coro() # RUF065 +33 + await another_coro() # RUF065 +34 | +35 | async def another_coro(): +36 | pass +note: This is an unsafe fix and may change runtime behavior + +RUF065 [*] Coroutine is not awaited + --> RUF065.py:42:5 + | +41 | async def test_asyncio_api_without_await(): +42 | asyncio.sleep(0.5) # RUF065 + | ^^^^^^^^^^^^^^^^^^ + | +help: Coroutine is not awaited +39 | +40 | +41 | async def test_asyncio_api_without_await(): + - asyncio.sleep(0.5) # RUF065 +42 + await asyncio.sleep(0.5) # RUF065 +43 | +44 | +45 | async def test_asyncio_api_without_await(): +note: This is an unsafe fix and may change runtime behavior + +RUF065 [*] Coroutine is not awaited + --> RUF065.py:47:9 + | +45 | async def test_asyncio_api_without_await(): +46 | async def coro(): +47 | asyncio.sleep(0.5) # RUF065 + | ^^^^^^^^^^^^^^^^^^ +48 | +49 | await asyncio.wait(coro) + | +help: Coroutine is not awaited +44 | +45 | async def test_asyncio_api_without_await(): +46 | async def coro(): + - asyncio.sleep(0.5) # RUF065 +47 + await asyncio.sleep(0.5) # RUF065 +48 | +49 | await asyncio.wait(coro) +50 | +note: This is an unsafe fix and may change runtime behavior + +RUF065 [*] Coroutine is not awaited + --> RUF065.py:56:5 + | +54 | await asyncio.sleep(0.5) +55 | +56 | asyncio.wait_for(coro) # RUF065 + | ^^^^^^^^^^^^^^^^^^^^^^ + | +help: Coroutine is not awaited +53 | async def coro(): +54 | await asyncio.sleep(0.5) +55 | + - asyncio.wait_for(coro) # RUF065 +56 + await asyncio.wait_for(coro) # RUF065 +57 | +58 | +59 | async def test_asyncio_api_without_await(): +note: This is an unsafe fix and may change runtime behavior + +RUF065 [*] Coroutine is not awaited + --> RUF065.py:67:5 + | +66 | tasks = [coro1(), coro2()] +67 | asyncio.gather(*tasks) # RUF065 + | ^^^^^^^^^^^^^^^^^^^^^^ + | +help: Coroutine is not awaited +64 | await asyncio.sleep(0.5) +65 | +66 | tasks = [coro1(), coro2()] + - asyncio.gather(*tasks) # RUF065 +67 + await asyncio.gather(*tasks) # RUF065 +68 | +69 | +70 | # Non-violation cases: RUF065 +note: This is an unsafe fix and may change runtime behavior diff --git a/ruff.schema.json b/ruff.schema.json index 97fec6ae32..cdad2ce9e7 100644 --- a/ruff.schema.json +++ b/ruff.schema.json @@ -4057,6 +4057,7 @@ "RUF061", "RUF063", "RUF064", + "RUF065", "RUF1", "RUF10", "RUF100",