diff --git a/ast/Cargo.toml b/ast/Cargo.toml index a7fb168cfe..1c98911d6a 100644 --- a/ast/Cargo.toml +++ b/ast/Cargo.toml @@ -13,6 +13,7 @@ constant-optimization = ["fold"] source-code = ["fold"] fold = [] unparse = ["rustpython-literal"] +visitor = [] [dependencies] rustpython-parser-core = { workspace = true } diff --git a/ast/asdl_rs.py b/ast/asdl_rs.py index b8f5c9a127..e6cbe1ea20 100755 --- a/ast/asdl_rs.py +++ b/ast/asdl_rs.py @@ -543,6 +543,153 @@ class FoldModuleVisitor(EmitVisitor): self.emit("}", depth) +class VisitorStructsDefVisitor(StructVisitor): + def visitModule(self, mod, depth): + for dfn in mod.dfns: + self.visit(dfn, depth) + + def visitProduct(self, product, name, depth): + pass + + def visitSum(self, sum, name, depth): + if not is_simple(sum): + typeinfo = self.typeinfo[name] + if not sum.attributes: + return + for t in sum.types: + typename = t.name + "Node" + + has_userdata = any( + getattr(self.typeinfo.get(f.type), "has_userdata", False) + for f in t.fields + ) + self.emit( + f"pub struct {typename}Data<{'U=()' if has_userdata else ''}> {{", + depth, + ) + for f in t.fields: + self.visit(f, typeinfo, "pub ", depth + 1, t.name) + self.emit("}", depth) + self.emit( + f"pub type {typename} = Located<{typename}Data<{'U' if has_userdata else ''}>, U>;", + depth, + ) + self.emit("", depth) + + +class VisitorTraitDefVisitor(StructVisitor): + def visitModule(self, mod, depth): + self.emit("pub trait Visitor {", depth) + for dfn in mod.dfns: + self.visit(dfn, depth + 1) + self.emit("}", depth) + + def visitType(self, type, depth=0): + self.visit(type.value, type.name, depth) + + def emit_visitor(self, nodename, rusttype, depth): + self.emit(f"fn visit_{nodename}(&mut self, node: {rusttype}) {{", depth) + self.emit(f"self.generic_visit_{nodename}(node);", depth + 1) + self.emit("}", depth) + + def emit_generic_visitor_signature(self, nodename, rusttype, depth): + self.emit(f"fn generic_visit_{nodename}(&mut self, node: {rusttype}) {{", depth) + + def emit_empty_generic_visitor(self, nodename, rusttype, depth): + self.emit_generic_visitor_signature(nodename, rusttype, depth) + self.emit("}", depth) + + def simple_sum(self, sum, name, depth): + rustname = get_rust_type(name) + self.emit_visitor(name, rustname, depth) + self.emit_empty_generic_visitor(name, rustname, depth) + + def visit_match_for_type(self, enumname, type_, depth): + self.emit(f"{enumname}::{type_.name} {{", depth) + for field in type_.fields: + self.emit(f"{rust_field(field.name)},", depth + 1) + self.emit(f"}} => self.visit_{type_.name}(", depth) + self.emit(f"{type_.name}Node {{", depth + 2) + self.emit("location: node.location,", depth + 2) + self.emit("end_location: node.end_location,", depth + 2) + self.emit("custom: node.custom,", depth + 2) + self.emit(f"node: {type_.name}NodeData {{", depth + 2) + for field in type_.fields: + self.emit(f"{rust_field(field.name)},", depth + 3) + self.emit("},", depth + 2) + self.emit("}", depth + 1) + self.emit("),", depth) + + def visit_sumtype(self, type_, depth): + rustname = get_rust_type(type_.name) + "Node" + self.emit_visitor(type_.name, rustname, depth) + self.emit_generic_visitor_signature(type_.name, rustname, depth) + for f in type_.fields: + fieldname = rust_field(f.name) + fieldtype = self.typeinfo.get(f.type) + if not (fieldtype and fieldtype.has_userdata): + continue + + if f.opt: + self.emit(f"if let Some(value) = node.node.{fieldname} {{", depth + 1) + elif f.seq: + iterable = f"node.node.{fieldname}" + if type_.name == "Dict" and f.name == "keys": + iterable = f"{iterable}.into_iter().flatten()" + self.emit(f"for value in {iterable} {{", depth + 1) + else: + self.emit("{", depth + 1) + self.emit(f"let value = node.node.{fieldname};", depth + 2) + + variable = "value" + if fieldtype.boxed and (not f.seq or f.opt): + variable = "*" + variable + self.emit(f"self.visit_{fieldtype.name}({variable});", depth + 2) + + self.emit("}", depth + 1) + + self.emit("}", depth) + + def sum_with_constructors(self, sum, name, depth): + if not sum.attributes: + return + + rustname = enumname = get_rust_type(name) + if sum.attributes: + enumname += "Kind" + self.emit_visitor(name, rustname, depth) + self.emit_generic_visitor_signature(name, rustname, depth) + depth += 1 + self.emit("match node.node {", depth) + for t in sum.types: + self.visit_match_for_type(enumname, t, depth + 1) + self.emit("}", depth) + depth -= 1 + self.emit("}", depth) + + # Now for the visitors for the types + for t in sum.types: + self.visit_sumtype(t, depth) + + def visitProduct(self, product, name, depth): + rusttype = get_rust_type(name) + self.emit_visitor(name, rusttype, depth) + self.emit_empty_generic_visitor(name, rusttype, depth) + + +class VisitorModuleVisitor(EmitVisitor): + def visitModule(self, mod): + depth = 0 + self.emit('#[cfg(feature = "visitor")]', depth) + self.emit("#[allow(unused_variables, non_snake_case)]", depth) + self.emit("pub mod visitor {", depth) + self.emit("use super::*;", depth + 1) + VisitorStructsDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1) + VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1) + self.emit("}", depth) + self.emit("", depth) + + class ClassDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: @@ -811,7 +958,11 @@ def write_generic_def(mod, typeinfo, f): ) ) - c = ChainOfVisitors(StructVisitor(f, typeinfo), FoldModuleVisitor(f, typeinfo)) + c = ChainOfVisitors( + StructVisitor(f, typeinfo), + FoldModuleVisitor(f, typeinfo), + VisitorModuleVisitor(f, typeinfo), + ) c.visit(mod)