mirror of https://github.com/astral-sh/ruff
838 lines
25 KiB
Python
Executable File
838 lines
25 KiB
Python
Executable File
# spell-checker:words dfn dfns
|
|
|
|
# ! /usr/bin/env python
|
|
"""Generate Rust code from an ASDL description."""
|
|
|
|
import re
|
|
import sys
|
|
import textwrap
|
|
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
import asdl
|
|
|
|
TABSIZE = 4
|
|
AUTO_GEN_MESSAGE = "// File automatically generated by {}.\n\n"
|
|
|
|
BUILTIN_TYPE_NAMES = {
|
|
"identifier": "Identifier",
|
|
"string": "String",
|
|
"int": "Int",
|
|
"constant": "Constant",
|
|
}
|
|
assert BUILTIN_TYPE_NAMES.keys() == asdl.builtin_types
|
|
|
|
BUILTIN_INT_NAMES = {
|
|
"simple": "bool",
|
|
"is_async": "bool",
|
|
"conversion": "ConversionFlag",
|
|
}
|
|
|
|
RENAME_MAP = {
|
|
"cmpop": "cmp_op",
|
|
"unaryop": "unary_op",
|
|
"boolop": "bool_op",
|
|
"excepthandler": "except_handler",
|
|
"withitem": "with_item",
|
|
}
|
|
|
|
RUST_KEYWORDS = {
|
|
"if",
|
|
"while",
|
|
"for",
|
|
"return",
|
|
"match",
|
|
"try",
|
|
"await",
|
|
"yield",
|
|
"in",
|
|
"mod",
|
|
"type",
|
|
}
|
|
|
|
attributes = [
|
|
asdl.Field("int", "lineno"),
|
|
asdl.Field("int", "col_offset"),
|
|
asdl.Field("int", "end_lineno"),
|
|
asdl.Field("int", "end_col_offset"),
|
|
]
|
|
|
|
ORIGINAL_NODE_WARNING = "NOTE: This type is different from original Python AST."
|
|
|
|
arg_with_default = asdl.Type(
|
|
"arg_with_default",
|
|
asdl.Product(
|
|
[
|
|
asdl.Field("arg", "def"),
|
|
asdl.Field(
|
|
"expr", "default", opt=True
|
|
), # order is important for cost-free borrow!
|
|
],
|
|
),
|
|
)
|
|
arg_with_default.doc = f"""
|
|
An alternative type of AST `arg`. This is used for each function argument that might have a default value.
|
|
Used by `Arguments` original type.
|
|
|
|
{ORIGINAL_NODE_WARNING}
|
|
""".strip()
|
|
|
|
alt_arguments = asdl.Type(
|
|
"alt:arguments",
|
|
asdl.Product(
|
|
[
|
|
asdl.Field("arg_with_default", "posonlyargs", seq=True),
|
|
asdl.Field("arg_with_default", "args", seq=True),
|
|
asdl.Field("arg", "vararg", opt=True),
|
|
asdl.Field("arg_with_default", "kwonlyargs", seq=True),
|
|
asdl.Field("arg", "kwarg", opt=True),
|
|
]
|
|
),
|
|
)
|
|
alt_arguments.doc = f"""
|
|
An alternative type of AST `arguments`. This is parser-friendly and human-friendly definition of function arguments.
|
|
This form also has advantage to implement pre-order traverse.
|
|
`defaults` and `kw_defaults` fields are removed and the default values are placed under each `arg_with_default` typed argument.
|
|
`vararg` and `kwarg` are still typed as `arg` because they never can have a default value.
|
|
|
|
The matching Python style AST type is [PythonArguments]. While [PythonArguments] has ordered `kwonlyargs` fields by
|
|
default existence, [Arguments] has location-ordered kwonlyargs fields.
|
|
|
|
{ORIGINAL_NODE_WARNING}
|
|
""".strip()
|
|
|
|
# Must be used only for rust types, not python types
|
|
CUSTOM_TYPES = [
|
|
alt_arguments,
|
|
arg_with_default,
|
|
]
|
|
|
|
CUSTOM_REPLACEMENTS = {
|
|
"arguments": alt_arguments,
|
|
}
|
|
CUSTOM_ATTACHMENTS = [
|
|
arg_with_default,
|
|
]
|
|
|
|
|
|
def maybe_custom(type):
|
|
return CUSTOM_REPLACEMENTS.get(type.name, type)
|
|
|
|
|
|
def rust_field_name(name):
|
|
name = rust_type_name(name)
|
|
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
|
|
|
|
|
|
def rust_type_name(name):
|
|
"""Return a string for the C name of the type.
|
|
|
|
This function special cases the default types provided by asdl.
|
|
"""
|
|
name = RENAME_MAP.get(name, name)
|
|
if name in asdl.builtin_types:
|
|
builtin = BUILTIN_TYPE_NAMES[name]
|
|
return builtin
|
|
elif name.islower():
|
|
return "".join(part.capitalize() for part in name.split("_"))
|
|
else:
|
|
return name
|
|
|
|
|
|
def is_simple(sum):
|
|
"""Return True if a sum is a simple.
|
|
|
|
A sum is simple if its types have no fields, e.g.
|
|
unaryop = Invert | Not | UAdd | USub
|
|
"""
|
|
for t in sum.types:
|
|
if t.fields:
|
|
return False
|
|
return True
|
|
|
|
|
|
def asdl_of(name, obj):
|
|
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
|
|
fields = ", ".join(map(str, obj.fields))
|
|
if fields:
|
|
fields = "({})".format(fields)
|
|
return "{}{}".format(name, fields)
|
|
else:
|
|
if is_simple(obj):
|
|
types = " | ".join(type.name for type in obj.types)
|
|
else:
|
|
sep = "\n{}| ".format(" " * (len(name) + 1))
|
|
types = sep.join(asdl_of(type.name, type) for type in obj.types)
|
|
return "{} = {}".format(name, types)
|
|
|
|
|
|
class TypeInfo:
|
|
type: asdl.Type
|
|
enum_name: Optional[str]
|
|
has_user_data: Optional[bool]
|
|
has_attributes: bool
|
|
is_simple: bool
|
|
children: set
|
|
fields: Optional[Any]
|
|
boxed: bool
|
|
|
|
def __init__(self, type):
|
|
self.type = type
|
|
self.enum_name = None
|
|
self.has_user_data = None
|
|
self.has_attributes = False
|
|
self.is_simple = False
|
|
self.children = set()
|
|
self.fields = None
|
|
self.boxed = False
|
|
|
|
def __repr__(self):
|
|
return f"<TypeInfo: {self.name}>"
|
|
|
|
@property
|
|
def name(self):
|
|
return self.type.name
|
|
|
|
@property
|
|
def is_type(self):
|
|
return isinstance(self.type, asdl.Type)
|
|
|
|
@property
|
|
def is_product(self):
|
|
return self.is_type and isinstance(self.type.value, asdl.Product)
|
|
|
|
@property
|
|
def is_sum(self):
|
|
return self.is_type and isinstance(self.type.value, asdl.Sum)
|
|
|
|
@property
|
|
def has_expr(self):
|
|
return self.is_product and any(
|
|
f.type != "identifier" for f in self.type.value.fields
|
|
)
|
|
|
|
@property
|
|
def is_custom(self):
|
|
return self.type.name in [t.name for t in CUSTOM_TYPES]
|
|
|
|
@property
|
|
def is_custom_replaced(self):
|
|
return self.type.name in CUSTOM_REPLACEMENTS
|
|
|
|
@property
|
|
def custom(self):
|
|
if self.type.name in CUSTOM_REPLACEMENTS:
|
|
return CUSTOM_REPLACEMENTS[self.type.name]
|
|
return self.type
|
|
|
|
def no_cfg(self, typeinfo):
|
|
if self.is_product:
|
|
return self.has_attributes
|
|
elif self.enum_name:
|
|
return typeinfo[self.enum_name].has_attributes
|
|
else:
|
|
return self.has_attributes
|
|
|
|
@property
|
|
def rust_name(self):
|
|
return rust_type_name(self.name)
|
|
|
|
@property
|
|
def full_field_name(self):
|
|
name = self.name
|
|
if name.startswith("alt:"):
|
|
name = name[4:]
|
|
if self.enum_name is None:
|
|
return name
|
|
else:
|
|
return f"{self.enum_name}_{rust_field_name(name)}"
|
|
|
|
@property
|
|
def full_type_name(self):
|
|
name = self.name
|
|
if name.startswith("alt:"):
|
|
name = name[4:]
|
|
rust_name = rust_type_name(name)
|
|
if self.enum_name is not None:
|
|
rust_name = rust_type_name(self.enum_name) + rust_name
|
|
if self.is_custom_replaced:
|
|
rust_name = "Python" + rust_name
|
|
return rust_name
|
|
|
|
def determine_user_data(self, type_info, stack):
|
|
if self.name in stack:
|
|
return None
|
|
stack.add(self.name)
|
|
for child, child_seq in self.children:
|
|
if child in asdl.builtin_types:
|
|
continue
|
|
child_info = type_info[child]
|
|
child_has_user_data = child_info.determine_user_data(type_info, stack)
|
|
if self.has_user_data is None and child_has_user_data is True:
|
|
self.has_user_data = True
|
|
|
|
stack.remove(self.name)
|
|
return self.has_user_data
|
|
|
|
|
|
class TypeInfoMixin:
|
|
type_info: Dict[str, TypeInfo]
|
|
|
|
def customized_type_info(self, type_name):
|
|
info = self.type_info[type_name]
|
|
return self.type_info[info.custom.name]
|
|
|
|
def has_user_data(self, typ):
|
|
return self.type_info[typ].has_user_data
|
|
|
|
def apply_generics(self, typ, *generics):
|
|
needs_generics = not self.type_info[typ].is_simple
|
|
if needs_generics:
|
|
return [f"<{g}>" for g in generics]
|
|
else:
|
|
return ["" for g in generics]
|
|
|
|
|
|
class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
|
|
"""Visit that emits lines"""
|
|
|
|
def __init__(self, file, type_info):
|
|
self.file = file
|
|
self.type_info = type_info
|
|
self.identifiers = set()
|
|
super(EmitVisitor, self).__init__()
|
|
|
|
def emit_identifier(self, name):
|
|
name = str(name)
|
|
if name in self.identifiers:
|
|
return
|
|
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
|
|
self.identifiers.add(name)
|
|
|
|
def emit(self, line, depth):
|
|
if line:
|
|
line = (" " * TABSIZE * depth) + textwrap.dedent(line)
|
|
self.file.write(line + "\n")
|
|
|
|
|
|
class FindUserDataTypesVisitor(asdl.VisitorBase):
|
|
def __init__(self, type_info):
|
|
self.type_info = type_info
|
|
super().__init__()
|
|
|
|
def visitModule(self, mod):
|
|
for dfn in mod.dfns + CUSTOM_TYPES:
|
|
self.visit(dfn)
|
|
stack = set()
|
|
for info in self.type_info.values():
|
|
info.determine_user_data(self.type_info, stack)
|
|
|
|
def visitType(self, type):
|
|
key = type.name
|
|
info = self.type_info[key] = TypeInfo(type)
|
|
self.visit(type.value, info)
|
|
|
|
def visitSum(self, sum, info):
|
|
type = info.type
|
|
info.is_simple = is_simple(sum)
|
|
for cons in sum.types:
|
|
self.visit(cons, type, info.is_simple)
|
|
|
|
if info.is_simple:
|
|
info.has_user_data = False
|
|
return
|
|
|
|
for t in sum.types:
|
|
self.add_children(t.name, t.fields)
|
|
|
|
if len(sum.types) > 1:
|
|
info.boxed = True
|
|
if sum.attributes:
|
|
# attributes means located, which has the `range: R` field
|
|
info.has_user_data = True
|
|
info.has_attributes = True
|
|
|
|
for variant in sum.types:
|
|
self.add_children(type.name, variant.fields)
|
|
|
|
def visitConstructor(self, cons, type, simple):
|
|
info = self.type_info[cons.name] = TypeInfo(cons)
|
|
info.enum_name = type.name
|
|
info.is_simple = simple
|
|
|
|
def visitProduct(self, product, info):
|
|
type = info.type
|
|
if product.attributes:
|
|
# attributes means located, which has the `range: R` field
|
|
info.has_user_data = True
|
|
info.has_attributes = True
|
|
if len(product.fields) > 2:
|
|
info.boxed = True
|
|
self.add_children(type.name, product.fields)
|
|
|
|
def add_children(self, name, fields):
|
|
self.type_info[name].children.update(
|
|
(field.type, field.seq) for field in fields
|
|
)
|
|
|
|
|
|
def rust_field(field_name):
|
|
if field_name in RUST_KEYWORDS:
|
|
field_name += "_"
|
|
return field_name
|
|
|
|
|
|
class StructVisitor(EmitVisitor):
|
|
"""Visitor to generate type-defs for AST."""
|
|
|
|
def __init__(self, *args, **kw):
|
|
super().__init__(*args, **kw)
|
|
|
|
def emit_attrs(self, depth):
|
|
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
|
|
|
|
def emit_range(self, has_attributes, depth):
|
|
if has_attributes:
|
|
self.emit("pub range: R,", depth + 1)
|
|
else:
|
|
self.emit("pub range: OptionalRange<R>,", depth + 1)
|
|
|
|
def visitModule(self, mod):
|
|
self.emit_attrs(0)
|
|
self.emit(
|
|
"""
|
|
#[derive(is_macro::Is)]
|
|
pub enum Ast<R=TextRange> {
|
|
""",
|
|
0,
|
|
)
|
|
for dfn in mod.dfns:
|
|
info = self.customized_type_info(dfn.name)
|
|
dfn = info.custom
|
|
rust_name = info.full_type_name
|
|
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
|
|
if dfn.name == "mod":
|
|
# This is exceptional rule to other enums.
|
|
# Unlike other enums, this is justified because `Mod` is only used as
|
|
# the top node of parsing result and never a child node of other nodes.
|
|
# Because it will be very rarely used in very particular applications,
|
|
# "ast_" prefix to everywhere seems less useful.
|
|
self.emit('#[is(name = "module")]', 1)
|
|
self.emit(f"{rust_name}({rust_name}{generics}),", 1)
|
|
self.emit(
|
|
"""
|
|
}
|
|
impl<R> Node for Ast<R> {
|
|
const NAME: &'static str = "AST";
|
|
const FIELD_NAMES: &'static [&'static str] = &[];
|
|
}
|
|
""",
|
|
0,
|
|
)
|
|
for dfn in mod.dfns:
|
|
info = self.customized_type_info(dfn.name)
|
|
rust_name = info.full_type_name
|
|
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
|
|
self.emit(
|
|
f"""
|
|
impl<R> From<{rust_name}{generics}> for Ast<R> {{
|
|
fn from(node: {rust_name}{generics}) -> Self {{
|
|
Ast::{rust_name}(node)
|
|
}}
|
|
}}
|
|
""",
|
|
0,
|
|
)
|
|
|
|
for dfn in mod.dfns + CUSTOM_TYPES:
|
|
self.visit(dfn)
|
|
|
|
def visitType(self, type, depth=0):
|
|
if hasattr(type, "doc"):
|
|
doc = "/// " + type.doc.replace("\n", "\n/// ") + "\n"
|
|
else:
|
|
doc = f"/// See also [{type.name}](https://docs.python.org/3/library/ast.html#ast.{type.name})"
|
|
self.emit(doc, depth)
|
|
self.visit(type.value, type, depth)
|
|
|
|
def visitSum(self, sum, type, depth):
|
|
if is_simple(sum):
|
|
self.simple_sum(sum, type, depth)
|
|
else:
|
|
self.sum_with_constructors(sum, type, depth)
|
|
|
|
(generics_applied,) = self.apply_generics(type.name, "R")
|
|
self.emit(
|
|
f"""
|
|
impl{generics_applied} Node for {rust_type_name(type.name)}{generics_applied} {{
|
|
const NAME: &'static str = "{type.name}";
|
|
const FIELD_NAMES: &'static [&'static str] = &[];
|
|
}}
|
|
""",
|
|
depth,
|
|
)
|
|
|
|
def simple_sum(self, sum, type, depth):
|
|
rust_name = rust_type_name(type.name)
|
|
self.emit_attrs(depth)
|
|
self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth)
|
|
self.emit(f"pub enum {rust_name} {{", depth)
|
|
for cons in sum.types:
|
|
self.emit(f"{cons.name},", depth + 1)
|
|
self.emit("}", depth)
|
|
self.emit(f"impl {rust_name} {{", depth)
|
|
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
|
|
if needs_escape:
|
|
prefix = rust_field_name(type.name) + "_"
|
|
else:
|
|
prefix = ""
|
|
for cons in sum.types:
|
|
self.emit(
|
|
f"""
|
|
#[inline]
|
|
pub const fn {prefix}{rust_field_name(cons.name)}(&self) -> Option<{rust_name}{cons.name}> {{
|
|
match self {{
|
|
{rust_name}::{cons.name} => Some({rust_name}{cons.name}),
|
|
_ => None,
|
|
}}
|
|
}}
|
|
""",
|
|
depth,
|
|
)
|
|
self.emit("}", depth)
|
|
self.emit("", depth)
|
|
|
|
for cons in sum.types:
|
|
self.emit(
|
|
f"""
|
|
pub struct {rust_name}{cons.name};
|
|
impl From<{rust_name}{cons.name}> for {rust_name} {{
|
|
fn from(_: {rust_name}{cons.name}) -> Self {{
|
|
{rust_name}::{cons.name}
|
|
}}
|
|
}}
|
|
impl<R> From<{rust_name}{cons.name}> for Ast<R> {{
|
|
fn from(_: {rust_name}{cons.name}) -> Self {{
|
|
{rust_name}::{cons.name}.into()
|
|
}}
|
|
}}
|
|
impl Node for {rust_name}{cons.name} {{
|
|
const NAME: &'static str = "{cons.name}";
|
|
const FIELD_NAMES: &'static [&'static str] = &[];
|
|
}}
|
|
impl std::cmp::PartialEq<{rust_name}> for {rust_name}{cons.name} {{
|
|
#[inline]
|
|
fn eq(&self, other: &{rust_name}) -> bool {{
|
|
matches!(other, {rust_name}::{cons.name})
|
|
}}
|
|
}}
|
|
""",
|
|
0,
|
|
)
|
|
|
|
def sum_with_constructors(self, sum, type, depth):
|
|
type_info = self.type_info[type.name]
|
|
rust_name = rust_type_name(type.name)
|
|
|
|
self.emit_attrs(depth)
|
|
self.emit("#[derive(is_macro::Is)]", depth)
|
|
self.emit(f"pub enum {rust_name}<R = TextRange> {{", depth)
|
|
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
|
|
for t in sum.types:
|
|
if needs_escape:
|
|
self.emit(
|
|
f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]',
|
|
depth + 1,
|
|
)
|
|
self.emit(f"{t.name}({rust_name}{t.name}<R>),", depth + 1)
|
|
self.emit("}", depth)
|
|
self.emit("", depth)
|
|
|
|
for t in sum.types:
|
|
self.sum_subtype_struct(type_info, t, rust_name, depth)
|
|
|
|
def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
|
|
self.emit(
|
|
f"""/// See also [{t.name}](https://docs.python.org/3/library/ast.html#ast.{t.name})""",
|
|
depth,
|
|
)
|
|
self.emit_attrs(depth)
|
|
payload_name = f"{rust_name}{t.name}"
|
|
self.emit(f"pub struct {payload_name}<R = TextRange> {{", depth)
|
|
self.emit_range(sum_type_info.has_attributes, depth)
|
|
for f in t.fields:
|
|
self.visit(f, sum_type_info, "pub ", depth + 1, t.name)
|
|
|
|
assert sum_type_info.has_attributes == self.type_info[t.name].no_cfg(
|
|
self.type_info
|
|
)
|
|
|
|
self.emit("}", depth)
|
|
field_names = [f'"{f.name}"' for f in t.fields]
|
|
self.emit(
|
|
f"""
|
|
impl<R> Node for {payload_name}<R> {{
|
|
const NAME: &'static str = "{t.name}";
|
|
const FIELD_NAMES: &'static [&'static str] = &[{', '.join(field_names)}];
|
|
}}
|
|
impl<R> From<{payload_name}<R>> for {rust_name}<R> {{
|
|
fn from(payload: {payload_name}<R>) -> Self {{
|
|
{rust_name}::{t.name}(payload)
|
|
}}
|
|
}}
|
|
impl<R> From<{payload_name}<R>> for Ast<R> {{
|
|
fn from(payload: {payload_name}<R>) -> Self {{
|
|
{rust_name}::from(payload).into()
|
|
}}
|
|
}}
|
|
""",
|
|
depth,
|
|
)
|
|
|
|
self.emit("", depth)
|
|
|
|
def visitConstructor(self, cons, parent, depth):
|
|
if cons.fields:
|
|
self.emit(f"{cons.name} {{", depth)
|
|
for f in cons.fields:
|
|
self.visit(f, parent, "", depth + 1, cons.name)
|
|
self.emit("},", depth)
|
|
else:
|
|
self.emit(f"{cons.name},", depth)
|
|
|
|
def visitField(self, field, parent, vis, depth, constructor=None):
|
|
try:
|
|
field_type = self.customized_type_info(field.type)
|
|
typ = field_type.full_type_name
|
|
except KeyError:
|
|
field_type = None
|
|
typ = rust_type_name(field.type)
|
|
if field_type and not field_type.is_simple:
|
|
typ = f"{typ}<R>"
|
|
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
|
|
if (
|
|
field_type
|
|
and field_type.boxed
|
|
and (not (parent.is_product or field.seq) or field.opt)
|
|
):
|
|
typ = f"Box<{typ}>"
|
|
if field.opt or (
|
|
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
|
|
# the expression to be unpacked goes in `values` with a `None` at the corresponding
|
|
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
|
|
constructor == "Dict"
|
|
and field.name == "keys"
|
|
):
|
|
typ = f"Option<{typ}>"
|
|
if field.seq:
|
|
typ = f"Vec<{typ}>"
|
|
if typ == "Int":
|
|
typ = BUILTIN_INT_NAMES.get(field.name, typ)
|
|
name = rust_field(field.name)
|
|
|
|
# Use a String, rather than an Identifier, for the `id` field of `Expr::Name`.
|
|
# Names already include a range, so there's no need to duplicate the span.
|
|
if name == "id":
|
|
typ = "String"
|
|
|
|
self.emit(f"{vis}{name}: {typ},", depth)
|
|
|
|
def visitProduct(self, product, type, depth):
|
|
type_info = self.type_info[type.name]
|
|
product_name = type_info.full_type_name
|
|
self.emit_attrs(depth)
|
|
self.emit(f"pub struct {product_name}<R = TextRange> {{", depth)
|
|
self.emit_range(product.attributes, depth + 1)
|
|
for f in product.fields:
|
|
self.visit(f, type_info, "pub ", depth + 1)
|
|
assert bool(product.attributes) == type_info.no_cfg(self.type_info)
|
|
self.emit("}", depth)
|
|
|
|
field_names = [f'"{f.name}"' for f in product.fields]
|
|
self.emit(
|
|
f"""
|
|
impl<R> Node for {product_name}<R> {{
|
|
const NAME: &'static str = "{type.name}";
|
|
const FIELD_NAMES: &'static [&'static str] = &[
|
|
{', '.join(field_names)}
|
|
];
|
|
}}
|
|
""",
|
|
depth,
|
|
)
|
|
|
|
|
|
class RangedDefVisitor(EmitVisitor):
|
|
def visitModule(self, mod):
|
|
for dfn in mod.dfns + CUSTOM_TYPES:
|
|
self.visit(dfn)
|
|
|
|
def visitType(self, type, depth=0):
|
|
self.visit(type.value, type.name, depth)
|
|
|
|
def visitSum(self, sum, name, depth):
|
|
info = self.type_info[name]
|
|
|
|
self.emit_type_alias(info)
|
|
|
|
if info.is_simple:
|
|
for ty in sum.types:
|
|
variant_info = self.type_info[ty.name]
|
|
self.emit_type_alias(variant_info)
|
|
return
|
|
|
|
sum_match_arms = ""
|
|
|
|
for ty in sum.types:
|
|
variant_info = self.type_info[ty.name]
|
|
sum_match_arms += (
|
|
f" Self::{variant_info.rust_name}(node) => node.range(),"
|
|
)
|
|
self.emit_type_alias(variant_info)
|
|
self.emit_ranged_impl(variant_info)
|
|
|
|
if not info.no_cfg(self.type_info):
|
|
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
|
|
|
|
self.emit(
|
|
f"""
|
|
impl Ranged for crate::{info.full_type_name} {{
|
|
fn range(&self) -> TextRange {{
|
|
match self {{
|
|
{sum_match_arms}
|
|
}}
|
|
}}
|
|
}}
|
|
""".lstrip(),
|
|
0,
|
|
)
|
|
|
|
def visitProduct(self, product, name, depth):
|
|
info = self.type_info[name]
|
|
|
|
self.emit_type_alias(info)
|
|
self.emit_ranged_impl(info)
|
|
|
|
def emit_type_alias(self, info):
|
|
return # disable
|
|
generics = "" if info.is_simple else "::<TextRange>"
|
|
|
|
self.emit(
|
|
f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};",
|
|
0,
|
|
)
|
|
self.emit("", 0)
|
|
|
|
def emit_ranged_impl(self, info):
|
|
if not info.no_cfg(self.type_info):
|
|
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
|
|
|
|
self.file.write(
|
|
f"""
|
|
impl Ranged for crate::generic::{info.full_type_name}::<TextRange> {{
|
|
fn range(&self) -> TextRange {{
|
|
self.range
|
|
}}
|
|
}}
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def write_ast_def(mod, type_info, f):
|
|
f.write("use crate::text_size::TextRange;")
|
|
StructVisitor(f, type_info).visit(mod)
|
|
|
|
|
|
def write_ranged_def(mod, type_info, f):
|
|
RangedDefVisitor(f, type_info).visit(mod)
|
|
|
|
|
|
def write_parse_def(mod, type_info, f):
|
|
for info in type_info.values():
|
|
if info.enum_name not in ["expr", "stmt"]:
|
|
continue
|
|
|
|
type_name = rust_type_name(info.enum_name)
|
|
cons_name = rust_type_name(info.name)
|
|
|
|
f.write(
|
|
f"""
|
|
impl Parse for ast::{info.full_type_name} {{
|
|
fn lex_starts_at(
|
|
source: &str,
|
|
offset: TextSize,
|
|
) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {{
|
|
ast::{type_name}::lex_starts_at(source, offset)
|
|
}}
|
|
fn parse_tokens(
|
|
lxr: impl IntoIterator<Item = LexResult>,
|
|
source_path: &str,
|
|
) -> Result<Self, ParseError> {{
|
|
let node = ast::{type_name}::parse_tokens(lxr, source_path)?;
|
|
match node {{
|
|
ast::{type_name}::{cons_name}(node) => Ok(node),
|
|
node => Err(ParseError {{
|
|
error: ParseErrorType::InvalidToken,
|
|
offset: node.range().start(),
|
|
source_path: source_path.to_owned(),
|
|
}}),
|
|
}}
|
|
}}
|
|
}}
|
|
"""
|
|
)
|
|
|
|
|
|
def main(
|
|
input_filename,
|
|
ast_dir,
|
|
parser_dir,
|
|
dump_module=False,
|
|
):
|
|
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
|
|
mod = asdl.parse(input_filename)
|
|
if dump_module:
|
|
print("Parsed Module:")
|
|
print(mod)
|
|
if not asdl.check(mod):
|
|
sys.exit(1)
|
|
|
|
type_info = {}
|
|
FindUserDataTypesVisitor(type_info).visit(mod)
|
|
|
|
from functools import partial as p
|
|
|
|
for filename, write in [
|
|
("generic", p(write_ast_def, mod, type_info)),
|
|
("ranged", p(write_ranged_def, mod, type_info)),
|
|
]:
|
|
with (ast_dir / f"{filename}.rs").open("w") as f:
|
|
f.write(auto_gen_msg)
|
|
write(f)
|
|
|
|
for filename, write in [
|
|
("parse", p(write_parse_def, mod, type_info)),
|
|
]:
|
|
with (parser_dir / f"{filename}.rs").open("w") as f:
|
|
f.write(auto_gen_msg)
|
|
write(f)
|
|
|
|
print(f"{ast_dir} regenerated.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser()
|
|
parser.add_argument("input_file", type=Path)
|
|
parser.add_argument("-A", "--ast-dir", type=Path, required=True)
|
|
parser.add_argument("-P", "--parser-dir", type=Path, required=True)
|
|
parser.add_argument("-d", "--dump-module", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
main(
|
|
args.input_file,
|
|
args.ast_dir,
|
|
args.parser_dir,
|
|
args.dump_module,
|
|
)
|