diff --git a/crates/red_knot_python_semantic/src/node_key.rs b/crates/red_knot_python_semantic/src/node_key.rs index 9683b0e7fa..0935a1f839 100644 --- a/crates/red_knot_python_semantic/src/node_key.rs +++ b/crates/red_knot_python_semantic/src/node_key.rs @@ -1,18 +1,12 @@ -use ruff_python_ast::{AnyNodeRef, Identifier, NodeKind}; +use ruff_python_ast::{AnyNodeRef, NodeKind}; use ruff_text_size::{Ranged, TextRange}; -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub(super) enum Kind { - Node(NodeKind), - Identifier, -} - /// Compact key for a node for use in a hash map. /// /// Compares two nodes by their kind and text range. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub(super) struct NodeKey { - kind: Kind, + kind: NodeKind, range: TextRange, } @@ -23,15 +17,8 @@ impl NodeKey { { let node = node.into(); NodeKey { - kind: Kind::Node(node.kind()), + kind: node.kind(), range: node.range(), } } - - pub(super) fn from_identifier(identifier: &Identifier) -> Self { - NodeKey { - kind: Kind::Identifier, - range: identifier.range(), - } - } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 75c95a4bd5..537a17c8c1 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -462,6 +462,6 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey { impl From<&ast::Identifier> for DefinitionNodeKey { fn from(identifier: &ast::Identifier) -> Self { - Self(NodeKey::from_identifier(identifier)) + Self(NodeKey::from_node(identifier)) } } diff --git a/crates/ruff_linter/src/rules/ruff/rules/invalid_formatter_suppression_comment.rs b/crates/ruff_linter/src/rules/ruff/rules/invalid_formatter_suppression_comment.rs index 08e88e67de..8fcbb285b9 100644 --- a/crates/ruff_linter/src/rules/ruff/rules/invalid_formatter_suppression_comment.rs +++ b/crates/ruff_linter/src/rules/ruff/rules/invalid_formatter_suppression_comment.rs @@ -336,6 +336,7 @@ const fn is_valid_enclosing_node(node: AnyNodeRef) -> bool { | AnyNodeRef::TypeParamParamSpec(_) | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) - | AnyNodeRef::BytesLiteral(_) => false, + | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) => false, } } diff --git a/crates/ruff_python_ast/src/node.rs b/crates/ruff_python_ast/src/node.rs index d1ba8de8e0..ca3bbf27f3 100644 --- a/crates/ruff_python_ast/src/node.rs +++ b/crates/ruff_python_ast/src/node.rs @@ -126,6 +126,7 @@ pub enum AnyNode { FString(ast::FString), StringLiteral(ast::StringLiteral), BytesLiteral(ast::BytesLiteral), + Identifier(ast::Identifier), } impl AnyNode { @@ -226,6 +227,7 @@ impl AnyNode { | AnyNode::FString(_) | AnyNode::StringLiteral(_) | AnyNode::BytesLiteral(_) + | AnyNode::Identifier(_) | AnyNode::ElifElseClause(_) => None, } } @@ -323,6 +325,7 @@ impl AnyNode { | AnyNode::FString(_) | AnyNode::StringLiteral(_) | AnyNode::BytesLiteral(_) + | AnyNode::Identifier(_) | AnyNode::ElifElseClause(_) => None, } } @@ -420,6 +423,7 @@ impl AnyNode { | AnyNode::FString(_) | AnyNode::StringLiteral(_) | AnyNode::BytesLiteral(_) + | AnyNode::Identifier(_) | AnyNode::ElifElseClause(_) => None, } } @@ -517,6 +521,7 @@ impl AnyNode { | AnyNode::FString(_) | AnyNode::StringLiteral(_) | AnyNode::BytesLiteral(_) + | AnyNode::Identifier(_) | AnyNode::ElifElseClause(_) => None, } } @@ -634,6 +639,7 @@ impl AnyNode { Self::StringLiteral(node) => AnyNodeRef::StringLiteral(node), Self::BytesLiteral(node) => AnyNodeRef::BytesLiteral(node), Self::ElifElseClause(node) => AnyNodeRef::ElifElseClause(node), + Self::Identifier(node) => AnyNodeRef::Identifier(node), } } @@ -4884,6 +4890,47 @@ impl AstNode for ast::BytesLiteral { } } +impl AstNode for ast::Identifier { + type Ref<'a> = &'a Self; + + fn cast(kind: AnyNode) -> Option + where + Self: Sized, + { + if let AnyNode::Identifier(node) = kind { + Some(node) + } else { + None + } + } + + fn cast_ref(kind: AnyNodeRef<'_>) -> Option> { + if let AnyNodeRef::Identifier(node) = kind { + Some(node) + } else { + None + } + } + + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Identifier) + } + + fn as_any_node_ref(&self) -> AnyNodeRef { + AnyNodeRef::from(self) + } + + fn into_any_node(self) -> AnyNode { + AnyNode::from(self) + } + + fn visit_source_order<'a, V>(&'a self, _visitor: &mut V) + where + V: SourceOrderVisitor<'a> + ?Sized, + { + } +} + impl AstNode for Stmt { type Ref<'a> = StatementRef<'a>; @@ -4980,6 +5027,7 @@ impl AstNode for Stmt { | AnyNode::FString(_) | AnyNode::StringLiteral(_) | AnyNode::BytesLiteral(_) + | AnyNode::Identifier(_) | AnyNode::ElifElseClause(_) => None, } } @@ -5078,6 +5126,7 @@ impl AstNode for Stmt { | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::ElifElseClause(_) => None, } } @@ -5177,6 +5226,7 @@ impl AstNode for Stmt { | NodeKind::TypeParamParamSpec | NodeKind::FString | NodeKind::StringLiteral + | NodeKind::Identifier | NodeKind::BytesLiteral => false, } } @@ -5983,6 +6033,12 @@ impl From for AnyNode { } } +impl From for AnyNode { + fn from(node: ast::Identifier) -> Self { + AnyNode::Identifier(node) + } +} + impl Ranged for AnyNode { fn range(&self) -> TextRange { match self { @@ -6077,6 +6133,7 @@ impl Ranged for AnyNode { AnyNode::StringLiteral(node) => node.range(), AnyNode::BytesLiteral(node) => node.range(), AnyNode::ElifElseClause(node) => node.range(), + AnyNode::Identifier(node) => node.range(), } } } @@ -6174,6 +6231,7 @@ pub enum AnyNodeRef<'a> { StringLiteral(&'a ast::StringLiteral), BytesLiteral(&'a ast::BytesLiteral), ElifElseClause(&'a ast::ElifElseClause), + Identifier(&'a ast::Identifier), } impl<'a> AnyNodeRef<'a> { @@ -6270,6 +6328,7 @@ impl<'a> AnyNodeRef<'a> { AnyNodeRef::StringLiteral(node) => NonNull::from(*node).cast(), AnyNodeRef::BytesLiteral(node) => NonNull::from(*node).cast(), AnyNodeRef::ElifElseClause(node) => NonNull::from(*node).cast(), + AnyNodeRef::Identifier(node) => NonNull::from(*node).cast(), } } @@ -6372,6 +6431,7 @@ impl<'a> AnyNodeRef<'a> { AnyNodeRef::StringLiteral(_) => NodeKind::StringLiteral, AnyNodeRef::BytesLiteral(_) => NodeKind::BytesLiteral, AnyNodeRef::ElifElseClause(_) => NodeKind::ElifElseClause, + AnyNodeRef::Identifier(_) => NodeKind::Identifier, } } @@ -6468,6 +6528,7 @@ impl<'a> AnyNodeRef<'a> { | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::ElifElseClause(_) => false, } } @@ -6565,6 +6626,7 @@ impl<'a> AnyNodeRef<'a> { | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::ElifElseClause(_) => false, } } @@ -6661,6 +6723,7 @@ impl<'a> AnyNodeRef<'a> { | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::ElifElseClause(_) => false, } } @@ -6758,6 +6821,7 @@ impl<'a> AnyNodeRef<'a> { | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::ElifElseClause(_) => false, } } @@ -6855,6 +6919,7 @@ impl<'a> AnyNodeRef<'a> { | AnyNodeRef::FString(_) | AnyNodeRef::StringLiteral(_) | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::ElifElseClause(_) => false, } } @@ -6966,6 +7031,7 @@ impl<'a> AnyNodeRef<'a> { AnyNodeRef::StringLiteral(node) => node.visit_source_order(visitor), AnyNodeRef::BytesLiteral(node) => node.visit_source_order(visitor), AnyNodeRef::ElifElseClause(node) => node.visit_source_order(visitor), + AnyNodeRef::Identifier(node) => node.visit_source_order(visitor), } } @@ -7804,6 +7870,11 @@ impl<'a> From<&'a MatchCase> for AnyNodeRef<'a> { AnyNodeRef::MatchCase(node) } } +impl<'a> From<&'a ast::Identifier> for AnyNodeRef<'a> { + fn from(node: &'a ast::Identifier) -> Self { + AnyNodeRef::Identifier(node) + } +} impl Ranged for AnyNodeRef<'_> { fn range(&self) -> TextRange { @@ -7899,6 +7970,7 @@ impl Ranged for AnyNodeRef<'_> { AnyNodeRef::FString(node) => node.range(), AnyNodeRef::StringLiteral(node) => node.range(), AnyNodeRef::BytesLiteral(node) => node.range(), + AnyNodeRef::Identifier(node) => node.range(), } } } @@ -7999,6 +8071,7 @@ pub enum NodeKind { FString, StringLiteral, BytesLiteral, + Identifier, } // FIXME: The `StatementRef` here allows us to implement `AstNode` for `Stmt` which otherwise wouldn't be possible diff --git a/crates/ruff_python_formatter/src/range.rs b/crates/ruff_python_formatter/src/range.rs index 72035410e7..3541f46e4c 100644 --- a/crates/ruff_python_formatter/src/range.rs +++ b/crates/ruff_python_formatter/src/range.rs @@ -703,6 +703,7 @@ impl Format> for FormatEnclosingNode<'_> { | AnyNodeRef::TypeParamTypeVar(_) | AnyNodeRef::TypeParamTypeVarTuple(_) | AnyNodeRef::TypeParamParamSpec(_) + | AnyNodeRef::Identifier(_) | AnyNodeRef::BytesLiteral(_) => { panic!("Range formatting only supports formatting logical lines") }