diff --git a/src/ast/types.rs b/src/ast/types.rs index 59d89f4ec7..85ebf61329 100644 --- a/src/ast/types.rs +++ b/src/ast/types.rs @@ -103,6 +103,23 @@ impl<'a> Scope<'a> { } } +// Pyflakes defines the following binding hierarchy (via inheritance): +// Binding +// ExportBinding +// Annotation +// Argument +// Assignment +// NamedExprAssignment +// Definition +// FunctionDefinition +// ClassDefinition +// Builtin +// Importation +// SubmoduleImportation +// ImportationFrom +// StarImportation +// FutureImportation + #[derive(Clone, Debug)] pub enum BindingKind<'a> { Annotation, @@ -123,10 +140,12 @@ pub enum BindingKind<'a> { SubmoduleImportation(&'a str, &'a str), } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Binding<'a> { pub kind: BindingKind<'a>, pub range: Range, + /// The context in which the binding was created. + pub context: ExecutionContext, /// The statement in which the [`Binding`] was defined. pub source: Option>, /// Tuple of (scope index, range) indicating the scope and range at which @@ -143,33 +162,16 @@ pub struct Binding<'a> { } #[derive(Copy, Clone)] -pub enum UsageContext { +pub enum ExecutionContext { Runtime, Typing, } -// Pyflakes defines the following binding hierarchy (via inheritance): -// Binding -// ExportBinding -// Annotation -// Argument -// Assignment -// NamedExprAssignment -// Definition -// FunctionDefinition -// ClassDefinition -// Builtin -// Importation -// SubmoduleImportation -// ImportationFrom -// StarImportation -// FutureImportation - impl<'a> Binding<'a> { - pub fn mark_used(&mut self, scope: usize, range: Range, context: UsageContext) { + pub fn mark_used(&mut self, scope: usize, range: Range, context: ExecutionContext) { match context { - UsageContext::Runtime => self.runtime_usage = Some((scope, range)), - UsageContext::Typing => self.typing_usage = Some((scope, range)), + ExecutionContext::Runtime => self.runtime_usage = Some((scope, range)), + ExecutionContext::Typing => self.typing_usage = Some((scope, range)), } } diff --git a/src/checkers/ast.rs b/src/checkers/ast.rs index e7892345d0..a7044d3672 100644 --- a/src/checkers/ast.rs +++ b/src/checkers/ast.rs @@ -20,8 +20,8 @@ use crate::ast::helpers::{binding_range, collect_call_path, extract_handler_name use crate::ast::operations::extract_all_names; use crate::ast::relocate::relocate_expr; use crate::ast::types::{ - Binding, BindingKind, CallPath, ClassDef, FunctionDef, Lambda, Node, Range, RefEquality, Scope, - ScopeKind, UsageContext, + Binding, BindingKind, CallPath, ClassDef, ExecutionContext, FunctionDef, Lambda, Node, Range, + RefEquality, Scope, ScopeKind, }; use crate::ast::visitor::{walk_excepthandler, Visitor}; use crate::ast::{branch_detection, cast, helpers, operations, visitor}; @@ -102,7 +102,6 @@ pub struct Checker<'a> { except_handlers: Vec>>, // Check-specific state. pub(crate) flake8_bugbear_seen: Vec<&'a Expr>, - pub(crate) type_checking_blocks: Vec<&'a Stmt>, } impl<'a> Checker<'a> { @@ -163,7 +162,6 @@ impl<'a> Checker<'a> { except_handlers: vec![], // Check-specific state. flake8_bugbear_seen: vec![], - type_checking_blocks: vec![], } } @@ -332,6 +330,7 @@ where let ranges = helpers::find_names(stmt, self.locator); if scope_index != GLOBAL_SCOPE_INDEX { // Add the binding to the current scope. + let context = self.execution_context(); let scope = &mut self.scopes[scope_index]; let usage = Some((scope.id, Range::from_located(stmt))); for (name, range) in names.iter().zip(ranges.iter()) { @@ -343,6 +342,7 @@ where typing_usage: None, range: *range, source: Some(RefEquality(stmt)), + context, }); scope.values.insert(name, index); } @@ -359,6 +359,7 @@ where let scope_index = *self.scope_stack.last().expect("No current scope found"); let ranges = helpers::find_names(stmt, self.locator); if scope_index != GLOBAL_SCOPE_INDEX { + let context = self.execution_context(); let scope = &mut self.scopes[scope_index]; let usage = Some((scope.id, Range::from_located(stmt))); for (name, range) in names.iter().zip(ranges.iter()) { @@ -371,6 +372,7 @@ where typing_usage: None, range: *range, source: Some(RefEquality(stmt)), + context, }); scope.values.insert(name, index); } @@ -676,6 +678,8 @@ where for expr in &args.defaults { self.visit_expr(expr); } + + let context = self.execution_context(); self.add_binding( name, Binding { @@ -685,6 +689,7 @@ where typing_usage: None, range: Range::from_located(stmt), source: Some(self.current_stmt().clone()), + context, }, ); } @@ -839,6 +844,7 @@ where typing_usage: None, range: Range::from_located(alias), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); } else { @@ -878,6 +884,7 @@ where typing_usage: None, range: Range::from_located(alias), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); } @@ -1120,6 +1127,7 @@ where typing_usage: None, range: Range::from_located(alias), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); @@ -1156,6 +1164,7 @@ where typing_usage: None, range: Range::from_located(stmt), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); @@ -1230,6 +1239,7 @@ where typing_usage: None, range, source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); } @@ -1751,6 +1761,7 @@ where typing_usage: None, range: Range::from_located(stmt), source: Some(RefEquality(stmt)), + context: self.execution_context(), }); self.scopes[GLOBAL_SCOPE_INDEX].values.insert(name, index); } @@ -1814,6 +1825,7 @@ where typing_usage: None, range: Range::from_located(stmt), source: Some(RefEquality(stmt)), + context: self.execution_context(), }); self.scopes[GLOBAL_SCOPE_INDEX].values.insert(name, index); } @@ -1877,7 +1889,6 @@ where if self.settings.rules.enabled(&Rule::EmptyTypeCheckingBlock) { flake8_type_checking::rules::empty_type_checking_block(self, test, body); } - self.type_checking_blocks.push(stmt); let prev_in_type_checking_block = self.in_type_checking_block; self.in_type_checking_block = true; @@ -1909,6 +1920,7 @@ where typing_usage: None, range: Range::from_located(stmt), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); } @@ -3612,6 +3624,7 @@ where typing_usage: None, range: Range::from_located(arg), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); @@ -3713,6 +3726,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, source: None, + context: ExecutionContext::Runtime, }); scope.values.insert(builtin, index); } @@ -3749,6 +3763,18 @@ impl<'a> Checker<'a> { .map(|index| &self.scopes[*index]) } + pub fn execution_context(&self) -> ExecutionContext { + if self.in_type_checking_block + || self.in_annotation + || self.in_deferred_string_type_definition + || self.in_deferred_type_definition + { + ExecutionContext::Typing + } else { + ExecutionContext::Runtime + } + } + fn add_binding<'b>(&mut self, name: &'b str, binding: Binding<'a>) where 'b: 'a, @@ -3866,17 +3892,8 @@ impl<'a> Checker<'a> { } if let Some(index) = scope.values.get(&id.as_str()) { - let context = if self.in_type_checking_block - || self.in_annotation - || self.in_deferred_string_type_definition - || self.in_deferred_type_definition - { - UsageContext::Typing - } else { - UsageContext::Runtime - }; - // Mark the binding as used. + let context = self.execution_context(); self.bindings[*index].mark_used(scope_id, Range::from_located(expr), context); if matches!(self.bindings[*index].kind, BindingKind::Annotation) @@ -4059,6 +4076,7 @@ impl<'a> Checker<'a> { typing_usage: None, range: Range::from_located(expr), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); return; @@ -4078,6 +4096,7 @@ impl<'a> Checker<'a> { typing_usage: None, range: Range::from_located(expr), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); return; @@ -4093,6 +4112,7 @@ impl<'a> Checker<'a> { typing_usage: None, range: Range::from_located(expr), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); return; @@ -4145,6 +4165,7 @@ impl<'a> Checker<'a> { typing_usage: None, range: Range::from_located(expr), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); return; @@ -4160,6 +4181,7 @@ impl<'a> Checker<'a> { typing_usage: None, range: Range::from_located(expr), source: Some(self.current_stmt().clone()), + context: self.execution_context(), }, ); } @@ -4394,10 +4416,7 @@ impl<'a> Checker<'a> { .values() .map(|index| &self.bindings[*index]) .filter(|binding| { - flake8_type_checking::helpers::is_valid_runtime_import( - binding, - &self.type_checking_blocks, - ) + flake8_type_checking::helpers::is_valid_runtime_import(binding) }) .collect::>() }) @@ -4562,17 +4581,13 @@ impl<'a> Checker<'a> { let binding = &self.bindings[*index]; if let Some(diagnostic) = - flake8_type_checking::rules::runtime_import_in_type_checking_block( - binding, - &self.type_checking_blocks, - ) + flake8_type_checking::rules::runtime_import_in_type_checking_block(binding) { diagnostics.push(diagnostic); } if let Some(diagnostic) = flake8_type_checking::rules::typing_only_runtime_import( binding, - &self.type_checking_blocks, &runtime_imports, self.package, self.settings, diff --git a/src/rules/flake8_type_checking/helpers.rs b/src/rules/flake8_type_checking/helpers.rs index 36babeacd2..bd112cbef6 100644 --- a/src/rules/flake8_type_checking/helpers.rs +++ b/src/rules/flake8_type_checking/helpers.rs @@ -1,6 +1,6 @@ -use rustpython_ast::{Expr, Stmt}; +use rustpython_ast::Expr; -use crate::ast::types::{Binding, BindingKind, Range}; +use crate::ast::types::{Binding, BindingKind, ExecutionContext}; use crate::checkers::ast::Checker; pub fn is_type_checking_block(checker: &Checker, test: &Expr) -> bool { @@ -9,18 +9,15 @@ pub fn is_type_checking_block(checker: &Checker, test: &Expr) -> bool { }) } -pub fn is_valid_runtime_import(binding: &Binding, blocks: &[&Stmt]) -> bool { +pub fn is_valid_runtime_import(binding: &Binding) -> bool { if matches!( binding.kind, BindingKind::Importation(..) | BindingKind::FromImportation(..) | BindingKind::SubmoduleImportation(..) ) { - if binding.runtime_usage.is_some() { - return !blocks - .iter() - .any(|block| Range::from_located(block).contains(&binding.range)); - } + binding.runtime_usage.is_some() && matches!(binding.context, ExecutionContext::Runtime) + } else { + false } - false } diff --git a/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs b/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs index 552b4b2ec9..1f9b3718c8 100644 --- a/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs +++ b/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs @@ -1,7 +1,6 @@ use ruff_macros::derive_message_formats; -use rustpython_ast::Stmt; -use crate::ast::types::{Binding, BindingKind, Range}; +use crate::ast::types::{Binding, BindingKind, ExecutionContext}; use crate::define_violation; use crate::registry::Diagnostic; use crate::violation::Violation; @@ -23,10 +22,7 @@ impl Violation for RuntimeImportInTypeCheckingBlock { } /// TCH004 -pub fn runtime_import_in_type_checking_block( - binding: &Binding, - blocks: &[&Stmt], -) -> Option { +pub fn runtime_import_in_type_checking_block(binding: &Binding) -> Option { let full_name = match &binding.kind { BindingKind::Importation(.., full_name) => full_name, BindingKind::FromImportation(.., full_name) => full_name.as_str(), @@ -34,19 +30,14 @@ pub fn runtime_import_in_type_checking_block( _ => return None, }; - let defined_in_type_checking = blocks - .iter() - .any(|block| Range::from_located(block).contains(&binding.range)); - if defined_in_type_checking { - if binding.runtime_usage.is_some() { - return Some(Diagnostic::new( - RuntimeImportInTypeCheckingBlock { - full_name: full_name.to_string(), - }, - binding.range, - )); - } + if matches!(binding.context, ExecutionContext::Typing) && binding.runtime_usage.is_some() { + Some(Diagnostic::new( + RuntimeImportInTypeCheckingBlock { + full_name: full_name.to_string(), + }, + binding.range, + )) + } else { + None } - - None } diff --git a/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs b/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs index 808d4b02ae..d6a81d4409 100644 --- a/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs +++ b/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs @@ -1,9 +1,8 @@ use std::path::Path; use ruff_macros::derive_message_formats; -use rustpython_ast::Stmt; -use crate::ast::types::{Binding, BindingKind, Range}; +use crate::ast::types::{Binding, BindingKind, ExecutionContext}; use crate::define_violation; use crate::registry::Diagnostic; use crate::rules::isort::{categorize, ImportType}; @@ -114,7 +113,6 @@ fn is_exempt(name: &str, exempt_modules: &[&str]) -> bool { /// TCH001 pub fn typing_only_runtime_import( binding: &Binding, - blocks: &[&Stmt], runtime_imports: &[&Binding], package: Option<&Path>, settings: &Settings, @@ -148,58 +146,48 @@ pub fn typing_only_runtime_import( return None; } - let defined_in_type_checking = blocks - .iter() - .any(|block| Range::from_located(block).contains(&binding.range)); - if !defined_in_type_checking { - if binding.typing_usage.is_some() - && binding.runtime_usage.is_none() - && binding.synthetic_usage.is_none() - { - // Extract the module base and level from the full name. - // Ex) `foo.bar.baz` -> `foo`, `0` - // Ex) `.foo.bar.baz` -> `foo`, `1` - let module_base = full_name.split('.').next().unwrap(); - let level = full_name.chars().take_while(|c| *c == '.').count(); + if matches!(binding.context, ExecutionContext::Runtime) + && binding.typing_usage.is_some() + && binding.runtime_usage.is_none() + && binding.synthetic_usage.is_none() + { + // Extract the module base and level from the full name. + // Ex) `foo.bar.baz` -> `foo`, `0` + // Ex) `.foo.bar.baz` -> `foo`, `1` + let module_base = full_name.split('.').next().unwrap(); + let level = full_name.chars().take_while(|c| *c == '.').count(); - // Categorize the import. - match categorize( - module_base, - Some(&level), - &settings.src, - package, - &settings.isort.known_first_party, - &settings.isort.known_third_party, - &settings.isort.extra_standard_library, - ) { - ImportType::LocalFolder | ImportType::FirstParty => { - return Some(Diagnostic::new( - TypingOnlyFirstPartyImport { - full_name: full_name.to_string(), - }, - binding.range, - )); - } - ImportType::ThirdParty => { - return Some(Diagnostic::new( - TypingOnlyThirdPartyImport { - full_name: full_name.to_string(), - }, - binding.range, - )); - } - ImportType::StandardLibrary => { - return Some(Diagnostic::new( - TypingOnlyStandardLibraryImport { - full_name: full_name.to_string(), - }, - binding.range, - )); - } - ImportType::Future => unreachable!("`__future__` imports should be marked as used"), - } + // Categorize the import. + match categorize( + module_base, + Some(&level), + &settings.src, + package, + &settings.isort.known_first_party, + &settings.isort.known_third_party, + &settings.isort.extra_standard_library, + ) { + ImportType::LocalFolder | ImportType::FirstParty => Some(Diagnostic::new( + TypingOnlyFirstPartyImport { + full_name: full_name.to_string(), + }, + binding.range, + )), + ImportType::ThirdParty => Some(Diagnostic::new( + TypingOnlyThirdPartyImport { + full_name: full_name.to_string(), + }, + binding.range, + )), + ImportType::StandardLibrary => Some(Diagnostic::new( + TypingOnlyStandardLibraryImport { + full_name: full_name.to_string(), + }, + binding.range, + )), + ImportType::Future => unreachable!("`__future__` imports should be marked as used"), } + } else { + None } - - None }