Address review comments

This commit is contained in:
Amethyst Reese 2025-09-22 19:35:46 -07:00
parent c790c1d957
commit 4447bcf468
1 changed files with 22 additions and 33 deletions

View File

@ -1,8 +1,5 @@
use ruff_python_ast::{ use ruff_python_ast::{Expr, ExprCall, ExprName, Stmt, StmtAssign, StmtExpr, StmtFunctionDef};
AtomicNodeIndex, Expr, ExprAwait, ExprCall, ExprName, Stmt, StmtAssign, StmtExpr, use ruff_text_size::Ranged;
StmtFunctionDef,
};
use ruff_text_size::{Ranged, TextRange};
use crate::{Edit, Fix, FixAvailability, Violation}; use crate::{Edit, Fix, FixAvailability, Violation};
use ruff_macros::{ViolationMetadata, derive_message_formats}; use ruff_macros::{ViolationMetadata, derive_message_formats};
@ -38,6 +35,13 @@ use crate::checkers::ast::Checker;
/// ///
/// async def bar(): /// async def bar():
/// await foo() /// await foo()
///
/// ## Limitations
///
/// If the call is not a direct child of an statement expression or assignment statement
/// then this rule may not reliably determine if await is missing. Functions that return
/// coroutine objects or pass them as arguments might not be flagged correctly.
///
/// ``` /// ```
#[derive(ViolationMetadata)] #[derive(ViolationMetadata)]
pub(crate) struct MissingAwaitForCoroutine; pub(crate) struct MissingAwaitForCoroutine;
@ -62,28 +66,21 @@ pub(crate) fn missing_await_for_coroutine(checker: &Checker, call: &ExprCall) {
return; return;
} }
// Try to detect possible scenarios where await is missing and ignore other cases
// For example, if the call is not a direct child of an statement expression or assignment statement
// then it's not reliable to determine if await is missing.
// User might return coroutine object from a function or pass it as an argument
if !possibly_missing_await(call, checker.semantic()) { if !possibly_missing_await(call, checker.semantic()) {
return; return;
} }
let is_awaitable = is_awaitable_from_asyncio(call.func.as_ref(), checker.semantic())
|| is_awaitable_func(call.func.as_ref(), checker.semantic());
// If call does not originate from asyncio or is not an async function, then it's not awaitable // If call does not originate from asyncio or is not an async function, then it's not awaitable
if !is_awaitable { if is_awaitable_from_asyncio(call.func.as_ref(), checker.semantic())
return; || is_awaitable_func(call.func.as_ref(), checker.semantic())
} {
checker checker
.report_diagnostic(MissingAwaitForCoroutine, call.range()) .report_diagnostic(MissingAwaitForCoroutine, call.range())
.set_fix(Fix::unsafe_edit(Edit::range_replacement( .set_fix(Fix::unsafe_edit(Edit::insertion(
checker.generator().expr(&generate_fix(call)), "await ".to_string(),
call.range(), call.start(),
))); )));
}
} }
fn is_awaitable_from_asyncio(func: &Expr, semantic: &SemanticModel) -> bool { fn is_awaitable_from_asyncio(func: &Expr, semantic: &SemanticModel) -> bool {
@ -113,6 +110,10 @@ fn is_awaitable_func(func: &Expr, semantic: &SemanticModel) -> bool {
false false
} }
/// Try to detect possible scenarios where await is missing and ignore other cases
/// If the call is not a direct child of an statement expression or assignment statement
/// then this rule may not reliably determine if await is missing. Functions that return
/// coroutine objects or pass them as arguments might not be flagged correctly.
fn possibly_missing_await(call: &ExprCall, semantic: &SemanticModel) -> bool { fn possibly_missing_await(call: &ExprCall, semantic: &SemanticModel) -> bool {
if let Stmt::Expr(StmtExpr { value, .. }) = semantic.current_statement() { if let Stmt::Expr(StmtExpr { value, .. }) = semantic.current_statement() {
if let Expr::Call(expr_call) = value.as_ref() { if let Expr::Call(expr_call) = value.as_ref() {
@ -127,15 +128,3 @@ fn possibly_missing_await(call: &ExprCall, semantic: &SemanticModel) -> bool {
} }
false false
} }
/// Generate a [`Fix`] to add `await` for coroutine.
///
/// For example:
/// - Given `asyncio.sleep(1)`, generate `await asyncio.sleep(1)`.
fn generate_fix(call: &ExprCall) -> Expr {
Expr::Await(ExprAwait {
node_index: AtomicNodeIndex::default(),
value: Box::new(Expr::Call(call.clone())),
range: TextRange::default(),
})
}