diff --git a/ast/asdl_rs.py b/ast/asdl_rs.py index e6cbe1ea20..2750ea163b 100755 --- a/ast/asdl_rs.py +++ b/ast/asdl_rs.py @@ -600,15 +600,18 @@ class VisitorTraitDefVisitor(StructVisitor): self.emit("}", depth) def simple_sum(self, sum, name, depth): - rustname = get_rust_type(name) - self.emit_visitor(name, rustname, depth) - self.emit_empty_generic_visitor(name, rustname, depth) + enumname = get_rust_type(name) + self.emit_visitor(name, enumname, depth) + self.emit_empty_generic_visitor(name, enumname, depth) - def visit_match_for_type(self, enumname, type_, depth): - self.emit(f"{enumname}::{type_.name} {{", depth) - for field in type_.fields: - self.emit(f"{rust_field(field.name)},", depth + 1) - self.emit(f"}} => self.visit_{type_.name}(", depth) + def visit_match_for_type(self, enumname, rustname, type_, depth): + self.emit(f"{rustname}::{type_.name}", depth) + if type_.fields: + self.emit(f"({enumname}{type_.name} {{", depth) + for field in type_.fields: + self.emit(f"{rust_field(field.name)},", depth + 1) + self.emit("})", depth) + self.emit(f"=> self.visit_{type_.name}(", depth) self.emit(f"{type_.name}Node {{", depth + 2) self.emit("location: node.location,", depth + 2) self.emit("end_location: node.end_location,", depth + 2) @@ -622,7 +625,7 @@ class VisitorTraitDefVisitor(StructVisitor): def visit_sumtype(self, type_, depth): rustname = get_rust_type(type_.name) + "Node" - self.emit_visitor(type_.name, rustname, depth) + self.emit_visitor(type_.name , rustname, depth) self.emit_generic_visitor_signature(type_.name, rustname, depth) for f in type_.fields: fieldname = rust_field(f.name) @@ -656,13 +659,13 @@ class VisitorTraitDefVisitor(StructVisitor): rustname = enumname = get_rust_type(name) if sum.attributes: - enumname += "Kind" - self.emit_visitor(name, rustname, depth) - self.emit_generic_visitor_signature(name, rustname, depth) + rustname = enumname + "Kind" + self.emit_visitor(name, enumname, depth) + self.emit_generic_visitor_signature(name, enumname, depth) depth += 1 self.emit("match node.node {", depth) for t in sum.types: - self.visit_match_for_type(enumname, t, depth + 1) + self.visit_match_for_type(enumname, rustname, t, depth + 1) self.emit("}", depth) depth -= 1 self.emit("}", depth)