From e000b1c3042195f7d5994e95f5bcb26c27f87fe0 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 7 May 2023 23:19:39 +0900 Subject: [PATCH] Remove redundant types --- ast/asdl_rs.py | 139 ++++++++++++++++++++++--------------------------- 1 file changed, 61 insertions(+), 78 deletions(-) diff --git a/ast/asdl_rs.py b/ast/asdl_rs.py index 2750ea163b..d9367b2870 100755 --- a/ast/asdl_rs.py +++ b/ast/asdl_rs.py @@ -68,6 +68,7 @@ class TypeInfo: enum_name: Optional[str] has_userdata: Optional[bool] has_attributes: bool + empty_field: bool children: set boxed: bool product: bool @@ -78,6 +79,7 @@ class TypeInfo: self.enum_name = None self.has_userdata = None self.has_attributes = False + self.empty_field = False self.children = set() self.boxed = False self.product = False @@ -192,10 +194,9 @@ class FindUserdataTypesVisitor(asdl.VisitorBase): info.has_userdata = False else: for t in sum.types: - if not t.fields: - continue t_info = TypeInfo(t.name) t_info.enum_name = name + t_info.empty_field = not t.fields self.typeinfo[t.name] = t_info self.add_children(t.name, t.fields) if len(sum.types) > 1: @@ -543,43 +544,24 @@ class FoldModuleVisitor(EmitVisitor): self.emit("}", depth) -class VisitorStructsDefVisitor(StructVisitor): - def visitModule(self, mod, depth): - for dfn in mod.dfns: - self.visit(dfn, depth) - - def visitProduct(self, product, name, depth): - pass - - def visitSum(self, sum, name, depth): - if not is_simple(sum): - typeinfo = self.typeinfo[name] - if not sum.attributes: - return - for t in sum.types: - typename = t.name + "Node" - - has_userdata = any( - getattr(self.typeinfo.get(f.type), "has_userdata", False) - for f in t.fields - ) - self.emit( - f"pub struct {typename}Data<{'U=()' if has_userdata else ''}> {{", - depth, - ) - for f in t.fields: - self.visit(f, typeinfo, "pub ", depth + 1, t.name) - self.emit("}", depth) - self.emit( - f"pub type {typename} = Located<{typename}Data<{'U' if has_userdata else ''}>, U>;", - depth, - ) - self.emit("", depth) - - class VisitorTraitDefVisitor(StructVisitor): + def full_name(self, name): + typeinfo = self.typeinfo[name] + if typeinfo.enum_name: + return f"{typeinfo.enum_name}_{name}" + else: + return name + + def node_type_name(self, name): + typeinfo = self.typeinfo[name] + if typeinfo.enum_name: + return f"{get_rust_type(typeinfo.enum_name)}{get_rust_type(name)}" + else: + return get_rust_type(name) + def visitModule(self, mod, depth): self.emit("pub trait Visitor {", depth) + for dfn in mod.dfns: self.visit(dfn, depth + 1) self.emit("}", depth) @@ -587,46 +569,46 @@ class VisitorTraitDefVisitor(StructVisitor): def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) - def emit_visitor(self, nodename, rusttype, depth): - self.emit(f"fn visit_{nodename}(&mut self, node: {rusttype}) {{", depth) - self.emit(f"self.generic_visit_{nodename}(node);", depth + 1) + def emit_visitor(self, nodename, depth, has_node=True): + typeinfo = self.typeinfo[nodename] + if has_node: + node_type = typeinfo.rust_sum_name + node_value = "node" + else: + node_type = "()" + node_value = "()" + self.emit(f"fn visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{", depth) + self.emit(f"self.generic_visit_{typeinfo.sum_name}({node_value})", depth + 1) self.emit("}", depth) - def emit_generic_visitor_signature(self, nodename, rusttype, depth): - self.emit(f"fn generic_visit_{nodename}(&mut self, node: {rusttype}) {{", depth) + def emit_generic_visitor_signature(self, nodename, depth, has_node=True): + typeinfo = self.typeinfo[nodename] + if has_node: + node_type = typeinfo.rust_sum_name + else: + node_type = "()" + self.emit(f"fn generic_visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{", depth) - def emit_empty_generic_visitor(self, nodename, rusttype, depth): - self.emit_generic_visitor_signature(nodename, rusttype, depth) + def emit_empty_generic_visitor(self, nodename, depth): + self.emit_generic_visitor_signature(nodename, depth) self.emit("}", depth) def simple_sum(self, sum, name, depth): - enumname = get_rust_type(name) - self.emit_visitor(name, enumname, depth) - self.emit_empty_generic_visitor(name, enumname, depth) + self.emit_visitor(name, depth) + self.emit_empty_generic_visitor(name, depth) - def visit_match_for_type(self, enumname, rustname, type_, depth): + def visit_match_for_type(self, nodename, rustname, type_, depth): self.emit(f"{rustname}::{type_.name}", depth) if type_.fields: - self.emit(f"({enumname}{type_.name} {{", depth) - for field in type_.fields: - self.emit(f"{rust_field(field.name)},", depth + 1) - self.emit("})", depth) - self.emit(f"=> self.visit_{type_.name}(", depth) - self.emit(f"{type_.name}Node {{", depth + 2) - self.emit("location: node.location,", depth + 2) - self.emit("end_location: node.end_location,", depth + 2) - self.emit("custom: node.custom,", depth + 2) - self.emit(f"node: {type_.name}NodeData {{", depth + 2) - for field in type_.fields: - self.emit(f"{rust_field(field.name)},", depth + 3) - self.emit("},", depth + 2) - self.emit("}", depth + 1) - self.emit("),", depth) + self.emit("(data)", depth) + data = "data" + else: + data = "()" + self.emit(f"=> self.visit_{nodename}_{type_.name}({data}),", depth) - def visit_sumtype(self, type_, depth): - rustname = get_rust_type(type_.name) + "Node" - self.emit_visitor(type_.name , rustname, depth) - self.emit_generic_visitor_signature(type_.name, rustname, depth) + def visit_sumtype(self, name, type_, depth): + self.emit_visitor(type_.name, depth, has_node=type_.fields) + self.emit_generic_visitor_signature(type_.name, depth, has_node=type_.fields) for f in type_.fields: fieldname = rust_field(f.name) fieldtype = self.typeinfo.get(f.type) @@ -634,20 +616,21 @@ class VisitorTraitDefVisitor(StructVisitor): continue if f.opt: - self.emit(f"if let Some(value) = node.node.{fieldname} {{", depth + 1) + self.emit(f"if let Some(value) = node.{fieldname} {{", depth + 1) elif f.seq: - iterable = f"node.node.{fieldname}" + iterable = f"node.{fieldname}" if type_.name == "Dict" and f.name == "keys": iterable = f"{iterable}.into_iter().flatten()" self.emit(f"for value in {iterable} {{", depth + 1) else: self.emit("{", depth + 1) - self.emit(f"let value = node.node.{fieldname};", depth + 2) + self.emit(f"let value = node.{fieldname};", depth + 2) variable = "value" if fieldtype.boxed and (not f.seq or f.opt): variable = "*" + variable - self.emit(f"self.visit_{fieldtype.name}({variable});", depth + 2) + typeinfo = self.typeinfo[fieldtype.name] + self.emit(f"self.visit_{typeinfo.sum_name}({variable});", depth + 2) self.emit("}", depth + 1) @@ -660,24 +643,23 @@ class VisitorTraitDefVisitor(StructVisitor): rustname = enumname = get_rust_type(name) if sum.attributes: rustname = enumname + "Kind" - self.emit_visitor(name, enumname, depth) - self.emit_generic_visitor_signature(name, enumname, depth) + self.emit_visitor(name, depth) + self.emit_generic_visitor_signature(name, depth) depth += 1 self.emit("match node.node {", depth) for t in sum.types: - self.visit_match_for_type(enumname, rustname, t, depth + 1) + self.visit_match_for_type(name, rustname, t, depth + 1) self.emit("}", depth) depth -= 1 self.emit("}", depth) # Now for the visitors for the types for t in sum.types: - self.visit_sumtype(t, depth) + self.visit_sumtype(name, t, depth) def visitProduct(self, product, name, depth): - rusttype = get_rust_type(name) - self.emit_visitor(name, rusttype, depth) - self.emit_empty_generic_visitor(name, rusttype, depth) + self.emit_visitor(name, depth) + self.emit_empty_generic_visitor(name, depth) class VisitorModuleVisitor(EmitVisitor): @@ -687,7 +669,6 @@ class VisitorModuleVisitor(EmitVisitor): self.emit("#[allow(unused_variables, non_snake_case)]", depth) self.emit("pub mod visitor {", depth) self.emit("use super::*;", depth + 1) - VisitorStructsDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1) VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1) self.emit("}", depth) self.emit("", depth) @@ -980,6 +961,8 @@ def write_located_def(typeinfo, f): ) ) for info in typeinfo.values(): + if info.empty_field: + continue if info.has_userdata: generics = "::" else: