diff --git a/src/ast/operations.rs b/src/ast/operations.rs index 2a819a1c9a..b4e16aac65 100644 --- a/src/ast/operations.rs +++ b/src/ast/operations.rs @@ -85,6 +85,21 @@ pub fn on_conditional_branch(parent_stack: &[usize], parents: &[&Stmt]) -> bool false } +/// Check if a node is in a nested block. +pub fn in_nested_block(parent_stack: &[usize], parents: &[&Stmt]) -> bool { + for index in parent_stack.iter().rev() { + let parent = parents[*index]; + if matches!(parent.node, StmtKind::Try { .. }) + || matches!(parent.node, StmtKind::If { .. }) + || matches!(parent.node, StmtKind::With { .. }) + { + return true; + } + } + + false +} + /// Struct used to efficiently slice source code at (row, column) Locations. pub struct SourceCodeLocator<'a> { content: &'a str, diff --git a/src/check_ast.rs b/src/check_ast.rs index 7817575592..63ae997aec 100644 --- a/src/check_ast.rs +++ b/src/check_ast.rs @@ -98,6 +98,41 @@ where fn visit_stmt(&mut self, stmt: &'b Stmt) { self.push_parent(stmt); + // Track whether we've seen docstrings, non-imports, etc. + match &stmt.node { + StmtKind::Import { .. } => {} + StmtKind::ImportFrom { .. } => {} + StmtKind::Expr { value } => { + if !self.seen_docstring + && stmt.location.column() == 1 + && !operations::in_nested_block(&self.parent_stack, &self.parents) + { + if let ExprKind::Constant { + value: Constant::Str(_), + .. + } = &value.node + { + self.seen_docstring = true; + } + } + + if !self.seen_non_import + && stmt.location.column() == 1 + && !operations::in_nested_block(&self.parent_stack, &self.parents) + { + self.seen_non_import = true; + } + } + _ => { + if !self.seen_non_import + && stmt.location.column() == 1 + && !operations::in_nested_block(&self.parent_stack, &self.parents) + { + self.seen_non_import = true; + } + } + } + // Pre-visit. match &stmt.node { StmtKind::Global { names } | StmtKind::Nonlocal { names } => { @@ -357,7 +392,6 @@ where } } StmtKind::AugAssign { target, .. } => { - self.seen_non_import = true; self.handle_node_load(target); } StmtKind::If { test, .. } => { @@ -368,7 +402,6 @@ where } } StmtKind::Assert { test, .. } => { - self.seen_non_import = true; if self.settings.select.contains(CheckKind::AssertTuple.code()) { if let Some(check) = checks::check_assert_tuple(test, stmt.location) { self.checks.push(check); @@ -382,21 +415,7 @@ where } } } - StmtKind::Expr { value } => { - if !self.seen_docstring { - if let ExprKind::Constant { - value: Constant::Str(_), - .. - } = &value.node - { - self.seen_docstring = true; - } - } else { - self.seen_non_import = true; - } - } StmtKind::Assign { value, .. } => { - self.seen_non_import = true; if self.settings.select.contains(&CheckCode::E731) { if let Some(check) = checks::check_do_not_assign_lambda(value, stmt.location) { self.checks.push(check); @@ -404,7 +423,6 @@ where } } StmtKind::AnnAssign { value, .. } => { - self.seen_non_import = true; if self.settings.select.contains(&CheckCode::E731) { if let Some(value) = value { if let Some(check) = @@ -415,9 +433,7 @@ where } } } - StmtKind::Delete { .. } => { - self.seen_non_import = true; - } + StmtKind::Delete { .. } => {} _ => {} }