diff --git a/ast/asdl_rs.py b/ast/asdl_rs.py index a08efd8b27..f0f5844edc 100755 --- a/ast/asdl_rs.py +++ b/ast/asdl_rs.py @@ -258,7 +258,6 @@ class StructVisitor(EmitVisitor): def __init__(self, *args, **kw): super().__init__(*args, **kw) - self.rust_type_defs = [] def visitModule(self, mod): for dfn in mod.dfns: @@ -359,17 +358,17 @@ class StructVisitor(EmitVisitor): typ = f"{typ}" # don't box if we're doing Vec, but do box if we're doing Vec>> if ( - field_type - and field_type.boxed - and (not (parent.product or field.seq) or field.opt) + field_type + and field_type.boxed + and (not (parent.product or field.seq) or field.opt) ): typ = f"Box<{typ}>" if field.opt or ( - # When a dictionary literal contains dictionary unpacking (e.g., `{**d}`), - # the expression to be unpacked goes in `values` with a `None` at the corresponding - # position in `keys`. To handle this, the type of `keys` needs to be `Option>`. - constructor == "Dict" - and field.name == "keys" + # When a dictionary literal contains dictionary unpacking (e.g., `{**d}`), + # the expression to be unpacked goes in `values` with a `None` at the corresponding + # position in `keys`. To handle this, the type of `keys` needs to be `Option>`. + constructor == "Dict" + and field.name == "keys" ): typ = f"Option<{typ}>" if field.seq: @@ -579,9 +578,10 @@ class VisitorTraitDefVisitor(StructVisitor): def emit_visitor(self, nodename, depth, has_node=True): type_info = self.type_info[nodename] node_type = type_info.rust_sum_name - generic, = self.apply_generics(nodename, "R") + (generic,) = self.apply_generics(nodename, "R") self.emit( - f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{", depth + f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{", + depth, ) if has_node: self.emit(f"self.generic_visit_{type_info.sum_name}(node)", depth + 1) @@ -594,7 +594,7 @@ class VisitorTraitDefVisitor(StructVisitor): node_type = type_info.rust_sum_name else: node_type = "()" - generic, = self.apply_generics(nodename, "R") + (generic,) = self.apply_generics(nodename, "R") self.emit( f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{", depth, @@ -677,7 +677,7 @@ class VisitorModuleVisitor(EmitVisitor): VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth) -class class_defVisitor(EmitVisitor): +class StdlibClassDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) @@ -686,9 +686,9 @@ class class_defVisitor(EmitVisitor): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): - struct_name = "NodeKind" + rust_type_name(name) + struct_name = "Node" + rust_type_name(name) self.emit( - f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]', + f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "NodeAst")]', depth, ) self.emit(f"struct {struct_name};", depth) @@ -703,8 +703,12 @@ class class_defVisitor(EmitVisitor): def visitProduct(self, product, name, depth): self.gen_class_def(name, product.fields, product.attributes, depth) - def gen_class_def(self, name, fields, attrs, depth, base="AstNode"): - struct_name = "Node" + rust_type_name(name) + def gen_class_def(self, name, fields, attrs, depth, base=None): + if base is None: + base = "NodeAst" + struct_name = "Node" + rust_type_name(name) + else: + struct_name = base + rust_type_name(name) self.emit( f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]', depth, @@ -735,7 +739,7 @@ class class_defVisitor(EmitVisitor): self.emit("}", depth) -class ExtendModuleVisitor(EmitVisitor): +class StdlibExtendModuleVisitor(EmitVisitor): def visitModule(self, mod): depth = 0 self.emit( @@ -753,24 +757,24 @@ class ExtendModuleVisitor(EmitVisitor): def visitSum(self, sum, name, depth): rust_name = rust_type_name(name) - self.emit( - f"{json.dumps(name)} => NodeKind{rust_name}::make_class(&vm.ctx),", depth - ) + self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth) for cons in sum.types: - self.visit(cons, depth) + self.visit(cons, depth, rust_name) - def visitConstructor(self, cons, depth): - self.gen_extension(cons.name, depth) + def visitConstructor(self, cons, depth, rust_name): + self.gen_extension(cons.name, depth, rust_name) def visitProduct(self, product, name, depth): self.gen_extension(name, depth) - def gen_extension(self, name, depth): + def gen_extension(self, name, depth, base=""): rust_name = rust_type_name(name) - self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth) + self.emit( + f"{json.dumps(name)} => Node{base}{rust_name}::make_class(&vm.ctx),", depth + ) -class TraitImplVisitor(EmitVisitor): +class StdlibTraitImplVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) @@ -779,45 +783,87 @@ class TraitImplVisitor(EmitVisitor): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): - rust_name = enum_name = rust_type_name(name) - if sum.attributes: - rust_name = enum_name + "Kind" + rust_name = rust_type_name(name) self.emit(f"impl NamedNode for ast::located::{rust_name} {{", depth) self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1) self.emit("}", depth) + self.emit("// sum", depth) self.emit(f"impl Node for ast::located::{rust_name} {{", depth) self.emit( - "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1 + "fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef {", depth + 1 ) - self.emit("match self {", depth + 2) - for variant in sum.types: - self.constructor_to_object(variant, enum_name, rust_name, depth + 3) - self.emit("}", depth + 2) + simple = is_simple(sum) + if simple: + self.emit("let node_type = match self {", depth + 2) + for cons in sum.types: + self.emit( + f"ast::located::{rust_name}::{cons.name} => Node{rust_name}{cons.name}::static_type(),", + depth, + ) + self.emit("};", depth + 3) + self.emit("NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2) + else: + self.emit("match self {", depth + 2) + for cons in sum.types: + self.emit( + f"ast::located::{rust_name}::{cons.name}(cons) => cons.ast_to_object(vm),", + depth + 3, + ) + self.emit("}", depth + 2) + self.emit("}", depth + 1) self.emit( "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1, ) - self.gen_sum_from_object(sum, name, enum_name, rust_name, depth + 2) + self.gen_sum_from_object(sum, name, rust_name, depth + 2) self.emit("}", depth + 1) self.emit("}", depth) - def constructor_to_object(self, cons, enum_name, rust_name, depth): - self.emit(f"ast::located::{rust_name}::{cons.name}", depth) - if cons.fields: - fields_pattern = self.make_pattern(cons.fields) - self.emit( - f"( ast::located::{enum_name}{cons.name} {{ {fields_pattern} }} )", - depth, - ) - self.emit(" => {", depth) - self.make_node(cons.name, cons.fields, depth + 1) + if not is_simple(sum): + for cons in sum.types: + self.visit(cons, sum, rust_name, depth) + + def visitConstructor(self, cons, sum, sum_rust_name, depth): + rust_name = rust_type_name(cons.name) + self.emit("// constructor", depth) + self.emit( + f"impl NamedNode for ast::located::{sum_rust_name}{rust_name} {{", depth + ) + self.emit(f"const NAME: &'static str = {json.dumps(cons.name)};", depth + 1) self.emit("}", depth) + self.emit(f"impl Node for ast::located::{sum_rust_name}{rust_name} {{", depth) + + fields_pattern = self.make_pattern(cons.fields) + + self.emit( + "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1 + ) + self.emit( + f"let ast::located::{sum_rust_name}{rust_name} {{ {fields_pattern} }} = self;", + depth, + ) + self.make_node(cons.name, sum, cons.fields, depth + 2, sum_rust_name) + + self.emit("}", depth + 1) + + self.emit( + "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", + depth + 1, + ) + + self.gen_product_from_object( + cons, cons.name, f"{sum_rust_name}{rust_name}", sum.attributes, depth + 2 + ) + self.emit("}", depth + 1) + + self.emit("}", depth + 1) def visitProduct(self, product, name, depth): struct_name = rust_type_name(name) + self.emit("// product", depth) self.emit(f"impl NamedNode for ast::located::{struct_name} {{", depth) self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1) self.emit("}", depth) @@ -827,53 +873,57 @@ class TraitImplVisitor(EmitVisitor): ) fields_pattern = self.make_pattern(product.fields) self.emit( - f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;", depth + 2 + f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;", + depth + 2, ) - self.make_node(name, product.fields, depth + 2) + self.make_node(name, product, product.fields, depth + 2) self.emit("}", depth + 1) self.emit( "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1, ) - self.gen_product_from_object(product, name, struct_name, depth + 2) + self.gen_product_from_object( + product, name, struct_name, product.attributes, depth + 2 + ) self.emit("}", depth + 1) self.emit("}", depth) - def make_node(self, variant, fields, depth): + def make_node(self, variant, owner, fields, depth, base=""): rust_variant = rust_type_name(variant) self.emit( - f"let _node = AstNode.into_ref_with_type(_vm, Node{rust_variant}::static_type().to_owned()).unwrap();", + f"let node = NodeAst.into_ref_with_type(_vm, Node{base}{rust_variant}::static_type().to_owned()).unwrap();", depth, ) - if fields: - self.emit("let _dict = _node.as_object().dict().unwrap();", depth) + if fields or owner.attributes: + self.emit("let dict = node.as_object().dict().unwrap();", depth) for f in fields: self.emit( - f"_dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();", + f"dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();", depth, ) - self.emit("_node.into()", depth) + if owner.attributes: + self.emit("node_add_location(&dict, _range, _vm);", depth) + self.emit("node.into()", depth) def make_pattern(self, fields): - return ",".join(rust_field(f.name) for f in fields) + return "".join(f"{rust_field(f.name)}," for f in fields) + "range: _range" - def gen_sum_from_object(self, sum, sum_name, enum_name, rust_name, depth): + def gen_sum_from_object(self, sum, sum_name, rust_name, depth): # if sum.attributes: # self.extract_location(sum_name, depth) self.emit("let _cls = _object.class();", depth) self.emit("Ok(", depth) for cons in sum.types: - self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth) - if cons.fields: + self.emit( + f"if _cls.is(Node{rust_name}{cons.name}::static_type()) {{", depth + ) + self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1) + if not is_simple(sum): self.emit( - f"ast::located::{rust_name}::{cons.name} (ast::located::{enum_name}{cons.name} {{", + f"(ast::located::{rust_name}{cons.name}::ast_from_object(_vm, _object)?)", depth + 1, ) - self.gen_construction_fields(cons, sum_name, depth + 1) - self.emit("})", depth + 1) - else: - self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1) self.emit("} else", depth) self.emit("{", depth) @@ -881,12 +931,13 @@ class TraitImplVisitor(EmitVisitor): self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1) self.emit("})", depth) - def gen_product_from_object(self, product, product_name, struct_name, depth): - # if product.attributes: - # self.extract_location(product_name, depth) - + def gen_product_from_object( + self, product, product_name, struct_name, has_attributes, depth + ): self.emit("Ok(", depth) - self.gen_construction(struct_name, product, product_name, depth + 1) + self.gen_construction( + struct_name, product, product_name, has_attributes, depth + 1 + ) self.emit(")", depth) def gen_construction_fields(self, cons, name, depth): @@ -896,9 +947,13 @@ class TraitImplVisitor(EmitVisitor): depth + 1, ) - def gen_construction(self, cons_path, cons, name, depth): + def gen_construction(self, cons_path, cons, name, attributes, depth): self.emit(f"ast::located::{cons_path} {{", depth) self.gen_construction_fields(cons, name, depth + 1) + if attributes: + self.emit(f'range: range_from_object(_vm, _object, "{name}")?,', depth + 1) + else: + self.emit("range: Default::default(),", depth + 1) self.emit("}", depth) def extract_location(self, typename, depth): @@ -940,21 +995,26 @@ class RangedDefVisitor(EmitVisitor): for ty in sum.types: variant_info = self.type_info[ty.name] - sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range()," + sum_match_arms += ( + f" Self::{variant_info.rust_name}(node) => node.range()," + ) self.emit_ranged_impl(variant_info) if not info.no_cfg(self.type_info): self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) - self.emit(f""" - impl Ranged for crate::{info.rust_sum_name} {{ - fn range(&self) -> TextRange {{ - match self {{ - {sum_match_arms} + self.emit( + f""" + impl Ranged for crate::{info.rust_sum_name} {{ + fn range(&self) -> TextRange {{ + match self {{ + {sum_match_arms} + }} }} }} - }} - """.lstrip(), 0) + """.lstrip(), + 0, + ) def visitProduct(self, product, name, depth): info = self.type_info[name] @@ -996,22 +1056,27 @@ class LocatedDefVisitor(EmitVisitor): for ty in sum.types: variant_info = self.type_info[ty.name] - sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range()," + sum_match_arms += ( + f" Self::{variant_info.rust_name}(node) => node.range()," + ) self.emit_type_alias(variant_info) self.emit_located_impl(variant_info) if not info.no_cfg(self.type_info): self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) - self.emit(f""" - impl Located for {info.rust_sum_name} {{ - fn range(&self) -> SourceRange {{ - match self {{ - {sum_match_arms} + self.emit( + f""" + impl Located for {info.rust_sum_name} {{ + fn range(&self) -> SourceRange {{ + match self {{ + {sum_match_arms} + }} }} }} - }} - """.lstrip(), 0) + """.lstrip(), + 0, + ) def visitProduct(self, product, name, depth): info = self.type_info[name] @@ -1022,7 +1087,10 @@ class LocatedDefVisitor(EmitVisitor): def emit_type_alias(self, info): generics = "" if info.is_simple else "::" - self.emit(f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", 0) + self.emit( + f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", + 0, + ) self.emit("", 0) def emit_located_impl(self, info): @@ -1036,8 +1104,9 @@ class LocatedDefVisitor(EmitVisitor): self.range }} }} - """ - , 0) + """, + 0, + ) class ChainOfVisitors: @@ -1084,18 +1153,18 @@ def write_ast_mod(mod, type_info, f): ) c = ChainOfVisitors( - class_defVisitor(f, type_info), - TraitImplVisitor(f, type_info), - ExtendModuleVisitor(f, type_info), + StdlibClassDefVisitor(f, type_info), + StdlibTraitImplVisitor(f, type_info), + StdlibExtendModuleVisitor(f, type_info), ) c.visit(mod) def main( - input_filename, - ast_dir, - module_filename, - dump_module=False, + input_filename, + ast_dir, + module_filename, + dump_module=False, ): auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) mod = asdl.parse(input_filename) diff --git a/ast/src/impls.rs b/ast/src/impls.rs index a8cb34010f..e0c47bc12e 100644 --- a/ast/src/impls.rs +++ b/ast/src/impls.rs @@ -1,4 +1,4 @@ -use crate::{Constant, Excepthandler, Expr, Pattern, Stmt}; +use crate::{Constant, Expr}; impl Expr { /// Returns a short name for the node suitable for use in error messages. @@ -55,10 +55,10 @@ impl Expr { } #[cfg(target_arch = "x86_64")] -static_assertions::assert_eq_size!(Expr, [u8; 72]); +static_assertions::assert_eq_size!(crate::Expr, [u8; 72]); #[cfg(target_arch = "x86_64")] -static_assertions::assert_eq_size!(Stmt, [u8; 136]); +static_assertions::assert_eq_size!(crate::Stmt, [u8; 136]); #[cfg(target_arch = "x86_64")] -static_assertions::assert_eq_size!(Pattern, [u8; 96]); +static_assertions::assert_eq_size!(crate::Pattern, [u8; 96]); #[cfg(target_arch = "x86_64")] -static_assertions::assert_eq_size!(Excepthandler, [u8; 64]); +static_assertions::assert_eq_size!(crate::Excepthandler, [u8; 64]);