From e07741e553039c7b4bff35e38626cc48598e93b2 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 28 Mar 2025 20:42:45 +0100 Subject: [PATCH] Add `as_group` methods to `AnyNodeRef` (#17048) ## Summary This PR adds `as_` methods to `AnyNodeRef` to e.g. convert an `AnyNodeRef` to an `ExprRef`. I need this for go to definition where the fallback is to test if `AnyNodeRef` is an expression and then call `inferred_type` (listing this mapping at every call site where we need to convert `AnyNodeRef` to an `ExprRef` is a bit painful ;)) Split out from https://github.com/astral-sh/ruff/pull/16901 ## Test Plan `cargo test` --- crates/ruff_python_ast/generate.py | 17 +++ crates/ruff_python_ast/src/generated.rs | 136 ++++++++++++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/crates/ruff_python_ast/generate.py b/crates/ruff_python_ast/generate.py index 2279b7beb5..0b5e9ac354 100644 --- a/crates/ruff_python_ast/generate.py +++ b/crates/ruff_python_ast/generate.py @@ -525,6 +525,23 @@ def write_anynoderef(out: list[str], ast: Ast) -> None: } """) + # `as_*` methods to convert from `AnyNodeRef` to e.g. `ExprRef` + out.append(f""" + impl<'a> AnyNodeRef<'a> {{ + pub fn as_{to_snake_case(group.ref_enum_ty)}(self) -> Option<{group.ref_enum_ty}<'a>> {{ + match self {{ + """) + for node in group.nodes: + out.append( + f"Self::{node.name}(node) => Some({group.ref_enum_ty}::{node.variant}(node))," + ) + out.append(""" + _ => None, + } + } + } + """) + for node in ast.all_nodes: out.append(f""" impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{ diff --git a/crates/ruff_python_ast/src/generated.rs b/crates/ruff_python_ast/src/generated.rs index 0fe6a84224..691f8bf254 100644 --- a/crates/ruff_python_ast/src/generated.rs +++ b/crates/ruff_python_ast/src/generated.rs @@ -5050,6 +5050,17 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_mod_ref(self) -> Option> { + match self { + Self::ModModule(node) => Some(ModRef::Module(node)), + Self::ModExpression(node) => Some(ModRef::Expression(node)), + + _ => None, + } + } +} + impl<'a> From<&'a Stmt> for AnyNodeRef<'a> { fn from(node: &'a Stmt) -> AnyNodeRef<'a> { match node { @@ -5114,6 +5125,40 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_stmt_ref(self) -> Option> { + match self { + Self::StmtFunctionDef(node) => Some(StmtRef::FunctionDef(node)), + Self::StmtClassDef(node) => Some(StmtRef::ClassDef(node)), + Self::StmtReturn(node) => Some(StmtRef::Return(node)), + Self::StmtDelete(node) => Some(StmtRef::Delete(node)), + Self::StmtTypeAlias(node) => Some(StmtRef::TypeAlias(node)), + Self::StmtAssign(node) => Some(StmtRef::Assign(node)), + Self::StmtAugAssign(node) => Some(StmtRef::AugAssign(node)), + Self::StmtAnnAssign(node) => Some(StmtRef::AnnAssign(node)), + Self::StmtFor(node) => Some(StmtRef::For(node)), + Self::StmtWhile(node) => Some(StmtRef::While(node)), + Self::StmtIf(node) => Some(StmtRef::If(node)), + Self::StmtWith(node) => Some(StmtRef::With(node)), + Self::StmtMatch(node) => Some(StmtRef::Match(node)), + Self::StmtRaise(node) => Some(StmtRef::Raise(node)), + Self::StmtTry(node) => Some(StmtRef::Try(node)), + Self::StmtAssert(node) => Some(StmtRef::Assert(node)), + Self::StmtImport(node) => Some(StmtRef::Import(node)), + Self::StmtImportFrom(node) => Some(StmtRef::ImportFrom(node)), + Self::StmtGlobal(node) => Some(StmtRef::Global(node)), + Self::StmtNonlocal(node) => Some(StmtRef::Nonlocal(node)), + Self::StmtExpr(node) => Some(StmtRef::Expr(node)), + Self::StmtPass(node) => Some(StmtRef::Pass(node)), + Self::StmtBreak(node) => Some(StmtRef::Break(node)), + Self::StmtContinue(node) => Some(StmtRef::Continue(node)), + Self::StmtIpyEscapeCommand(node) => Some(StmtRef::IpyEscapeCommand(node)), + + _ => None, + } + } +} + impl<'a> From<&'a Expr> for AnyNodeRef<'a> { fn from(node: &'a Expr) -> AnyNodeRef<'a> { match node { @@ -5192,6 +5237,47 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_expr_ref(self) -> Option> { + match self { + Self::ExprBoolOp(node) => Some(ExprRef::BoolOp(node)), + Self::ExprNamed(node) => Some(ExprRef::Named(node)), + Self::ExprBinOp(node) => Some(ExprRef::BinOp(node)), + Self::ExprUnaryOp(node) => Some(ExprRef::UnaryOp(node)), + Self::ExprLambda(node) => Some(ExprRef::Lambda(node)), + Self::ExprIf(node) => Some(ExprRef::If(node)), + Self::ExprDict(node) => Some(ExprRef::Dict(node)), + Self::ExprSet(node) => Some(ExprRef::Set(node)), + Self::ExprListComp(node) => Some(ExprRef::ListComp(node)), + Self::ExprSetComp(node) => Some(ExprRef::SetComp(node)), + Self::ExprDictComp(node) => Some(ExprRef::DictComp(node)), + Self::ExprGenerator(node) => Some(ExprRef::Generator(node)), + Self::ExprAwait(node) => Some(ExprRef::Await(node)), + Self::ExprYield(node) => Some(ExprRef::Yield(node)), + Self::ExprYieldFrom(node) => Some(ExprRef::YieldFrom(node)), + Self::ExprCompare(node) => Some(ExprRef::Compare(node)), + Self::ExprCall(node) => Some(ExprRef::Call(node)), + Self::ExprFString(node) => Some(ExprRef::FString(node)), + Self::ExprStringLiteral(node) => Some(ExprRef::StringLiteral(node)), + Self::ExprBytesLiteral(node) => Some(ExprRef::BytesLiteral(node)), + Self::ExprNumberLiteral(node) => Some(ExprRef::NumberLiteral(node)), + Self::ExprBooleanLiteral(node) => Some(ExprRef::BooleanLiteral(node)), + Self::ExprNoneLiteral(node) => Some(ExprRef::NoneLiteral(node)), + Self::ExprEllipsisLiteral(node) => Some(ExprRef::EllipsisLiteral(node)), + Self::ExprAttribute(node) => Some(ExprRef::Attribute(node)), + Self::ExprSubscript(node) => Some(ExprRef::Subscript(node)), + Self::ExprStarred(node) => Some(ExprRef::Starred(node)), + Self::ExprName(node) => Some(ExprRef::Name(node)), + Self::ExprList(node) => Some(ExprRef::List(node)), + Self::ExprTuple(node) => Some(ExprRef::Tuple(node)), + Self::ExprSlice(node) => Some(ExprRef::Slice(node)), + Self::ExprIpyEscapeCommand(node) => Some(ExprRef::IpyEscapeCommand(node)), + + _ => None, + } + } +} + impl<'a> From<&'a ExceptHandler> for AnyNodeRef<'a> { fn from(node: &'a ExceptHandler) -> AnyNodeRef<'a> { match node { @@ -5208,6 +5294,16 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_except_handler_ref(self) -> Option> { + match self { + Self::ExceptHandlerExceptHandler(node) => Some(ExceptHandlerRef::ExceptHandler(node)), + + _ => None, + } + } +} + impl<'a> From<&'a FStringElement> for AnyNodeRef<'a> { fn from(node: &'a FStringElement) -> AnyNodeRef<'a> { match node { @@ -5226,6 +5322,17 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_f_string_element_ref(self) -> Option> { + match self { + Self::FStringExpressionElement(node) => Some(FStringElementRef::Expression(node)), + Self::FStringLiteralElement(node) => Some(FStringElementRef::Literal(node)), + + _ => None, + } + } +} + impl<'a> From<&'a Pattern> for AnyNodeRef<'a> { fn from(node: &'a Pattern) -> AnyNodeRef<'a> { match node { @@ -5256,6 +5363,23 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_pattern_ref(self) -> Option> { + match self { + Self::PatternMatchValue(node) => Some(PatternRef::MatchValue(node)), + Self::PatternMatchSingleton(node) => Some(PatternRef::MatchSingleton(node)), + Self::PatternMatchSequence(node) => Some(PatternRef::MatchSequence(node)), + Self::PatternMatchMapping(node) => Some(PatternRef::MatchMapping(node)), + Self::PatternMatchClass(node) => Some(PatternRef::MatchClass(node)), + Self::PatternMatchStar(node) => Some(PatternRef::MatchStar(node)), + Self::PatternMatchAs(node) => Some(PatternRef::MatchAs(node)), + Self::PatternMatchOr(node) => Some(PatternRef::MatchOr(node)), + + _ => None, + } + } +} + impl<'a> From<&'a TypeParam> for AnyNodeRef<'a> { fn from(node: &'a TypeParam) -> AnyNodeRef<'a> { match node { @@ -5276,6 +5400,18 @@ impl<'a> From> for AnyNodeRef<'a> { } } +impl<'a> AnyNodeRef<'a> { + pub fn as_type_param_ref(self) -> Option> { + match self { + Self::TypeParamTypeVar(node) => Some(TypeParamRef::TypeVar(node)), + Self::TypeParamTypeVarTuple(node) => Some(TypeParamRef::TypeVarTuple(node)), + Self::TypeParamParamSpec(node) => Some(TypeParamRef::ParamSpec(node)), + + _ => None, + } + } +} + impl<'a> From<&'a crate::ModModule> for AnyNodeRef<'a> { fn from(node: &'a crate::ModModule) -> AnyNodeRef<'a> { AnyNodeRef::ModModule(node)