diff --git a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs index ddb832a4a5..12e686facb 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs @@ -4,9 +4,9 @@ use rustc_hash::FxHashMap; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::statement_visitor::StatementVisitor; -use ruff_python_ast::types::RefEquality; use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{statement_visitor, visitor}; +use ruff_python_semantic::StatementKey; use crate::checkers::ast::Checker; use crate::registry::AsRule; @@ -131,7 +131,7 @@ impl<'a> StatementVisitor<'a> for YieldFromVisitor<'a> { #[derive(Default)] struct ReferenceVisitor<'a> { parent: Option<&'a Stmt>, - references: FxHashMap, Vec<&'a str>>, + references: FxHashMap>, } impl<'a> Visitor<'a> for ReferenceVisitor<'a> { @@ -148,7 +148,7 @@ impl<'a> Visitor<'a> for ReferenceVisitor<'a> { if matches!(ctx, ExprContext::Load | ExprContext::Del) { if let Some(parent) = self.parent { self.references - .entry(RefEquality(parent)) + .entry(StatementKey::from(parent)) .or_default() .push(id); } @@ -177,9 +177,9 @@ pub(crate) fn yield_in_for_loop(checker: &mut Checker, stmt: &Stmt) { for item in yields { // If any of the bound names are used outside of the loop, don't rewrite. - if references.iter().any(|(stmt, names)| { - stmt != &RefEquality(item.stmt) - && stmt != &RefEquality(item.body) + if references.iter().any(|(statement, names)| { + *statement != StatementKey::from(item.stmt) + && *statement != StatementKey::from(item.body) && item.names.iter().any(|name| names.contains(name)) }) { continue; diff --git a/crates/ruff_python_ast/src/types.rs b/crates/ruff_python_ast/src/types.rs index 5243f94b42..32732ed6cb 100644 --- a/crates/ruff_python_ast/src/types.rs +++ b/crates/ruff_python_ast/src/types.rs @@ -1,5 +1,3 @@ -use std::ops::Deref; - use crate::{Expr, Stmt}; #[derive(Clone)] @@ -7,78 +5,3 @@ pub enum Node<'a> { Stmt(&'a Stmt), Expr(&'a Expr), } - -#[derive(Debug)] -pub struct RefEquality<'a, T>(pub &'a T); - -impl<'a, T> RefEquality<'a, T> { - // More specific implementation that keeps the `'a` lifetime. - // It's otherwise the same as [`AsRef::as_ref`] - #[allow(clippy::should_implement_trait)] - pub fn as_ref(&self) -> &'a T { - self.0 - } -} - -impl<'a, T> AsRef for RefEquality<'a, T> { - fn as_ref(&self) -> &T { - self.0 - } -} - -impl<'a, T> Clone for RefEquality<'a, T> { - fn clone(&self) -> Self { - *self - } -} - -impl<'a, T> Copy for RefEquality<'a, T> {} - -impl<'a, T> std::hash::Hash for RefEquality<'a, T> { - fn hash(&self, state: &mut H) - where - H: std::hash::Hasher, - { - (self.0 as *const T).hash(state); - } -} - -impl<'a, 'b, T> PartialEq> for RefEquality<'a, T> { - fn eq(&self, other: &RefEquality<'b, T>) -> bool { - std::ptr::eq(self.0, other.0) - } -} - -impl<'a, T> Eq for RefEquality<'a, T> {} - -impl<'a, T> Deref for RefEquality<'a, T> { - type Target = T; - - fn deref(&self) -> &T { - self.0 - } -} - -impl<'a> From<&RefEquality<'a, Stmt>> for &'a Stmt { - fn from(r: &RefEquality<'a, Stmt>) -> Self { - r.0 - } -} - -impl<'a> From<&RefEquality<'a, Expr>> for &'a Expr { - fn from(r: &RefEquality<'a, Expr>) -> Self { - r.0 - } -} - -impl<'a> From> for &'a Stmt { - fn from(r: RefEquality<'a, Stmt>) -> Self { - r.0 - } -} - -impl<'a> From> for &'a Expr { - fn from(r: RefEquality<'a, Expr>) -> Self { - r.0 - } -} diff --git a/crates/ruff_python_semantic/src/statements.rs b/crates/ruff_python_semantic/src/statements.rs index 8388273bd7..8295a04a7d 100644 --- a/crates/ruff_python_semantic/src/statements.rs +++ b/crates/ruff_python_semantic/src/statements.rs @@ -3,8 +3,8 @@ use std::ops::Index; use rustc_hash::FxHashMap; use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast::types::RefEquality; -use ruff_python_ast::Stmt; +use ruff_python_ast::{Ranged, Stmt}; +use ruff_text_size::TextSize; /// Id uniquely identifying a statement AST node. /// @@ -30,7 +30,19 @@ struct StatementWithParent<'a> { #[derive(Debug, Default)] pub struct Statements<'a> { statements: IndexVec>, - statement_to_id: FxHashMap, StatementId>, + statement_to_id: FxHashMap, +} + +/// A unique key for a statement AST node. No two statements can appear at the same location +/// in the source code, since compound statements must be delimited by _at least_ one character +/// (a colon), so the starting offset is a cheap and sufficient unique identifier. +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct StatementKey(TextSize); + +impl From<&Stmt> for StatementKey { + fn from(statement: &Stmt) -> Self { + Self(statement.start()) + } } impl<'a> Statements<'a> { @@ -43,7 +55,10 @@ impl<'a> Statements<'a> { parent: Option, ) -> StatementId { let next_id = self.statements.next_index(); - if let Some(existing_id) = self.statement_to_id.insert(RefEquality(statement), next_id) { + if let Some(existing_id) = self + .statement_to_id + .insert(StatementKey::from(statement), next_id) + { panic!("Statements already exists with ID: {existing_id:?}"); } self.statements.push(StatementWithParent { @@ -56,7 +71,9 @@ impl<'a> Statements<'a> { /// Returns the [`StatementId`] of the given statement. #[inline] pub fn statement_id(&self, statement: &'a Stmt) -> Option { - self.statement_to_id.get(&RefEquality(statement)).copied() + self.statement_to_id + .get(&StatementKey::from(statement)) + .copied() } /// Return the [`StatementId`] of the parent statement.