Skip to content

Commit 41e9e72

Browse files
authored
Merge pull request #12 from youknowone/visitor
Visitor
2 parents 5cf85f0 + 17c8abc commit 41e9e72

File tree

7 files changed

+1891
-955
lines changed

7 files changed

+1891
-955
lines changed

Diff for: ast/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ constant-optimization = ["fold"]
1313
source-code = ["fold"]
1414
fold = []
1515
unparse = ["rustpython-literal"]
16+
visitor = []
1617

1718
[dependencies]
1819
rustpython-parser-core = { workspace = true }

Diff for: ast/asdl_rs.py

+160-36
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import asdl
1313

1414
TABSIZE = 4
15-
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n"
15+
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
1616

1717
builtin_type_mapping = {
1818
"identifier": "Ident",
@@ -68,6 +68,7 @@ class TypeInfo:
6868
enum_name: Optional[str]
6969
has_userdata: Optional[bool]
7070
has_attributes: bool
71+
empty_field: bool
7172
children: set
7273
boxed: bool
7374
product: bool
@@ -78,6 +79,7 @@ def __init__(self, name):
7879
self.enum_name = None
7980
self.has_userdata = None
8081
self.has_attributes = False
82+
self.empty_field = False
8183
self.children = set()
8284
self.boxed = False
8385
self.product = False
@@ -192,10 +194,9 @@ def visitSum(self, sum, name):
192194
info.has_userdata = False
193195
else:
194196
for t in sum.types:
195-
if not t.fields:
196-
continue
197197
t_info = TypeInfo(t.name)
198198
t_info.enum_name = name
199+
t_info.empty_field = not t.fields
199200
self.typeinfo[t.name] = t_info
200201
self.add_children(t.name, t.fields)
201202
if len(sum.types) > 1:
@@ -534,14 +535,140 @@ def gen_construction(self, header, fields, footer, depth):
534535
class FoldModuleVisitor(EmitVisitor):
535536
def visitModule(self, mod):
536537
depth = 0
537-
self.emit('#[cfg(feature = "fold")]', depth)
538-
self.emit("pub mod fold {", depth)
539-
self.emit("use super::*;", depth + 1)
540-
self.emit("use crate::fold_helpers::Foldable;", depth + 1)
541-
FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
542-
FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
538+
self.emit("use crate::fold_helpers::Foldable;", depth)
539+
FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth)
540+
FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth)
541+
542+
543+
class VisitorTraitDefVisitor(StructVisitor):
544+
def full_name(self, name):
545+
typeinfo = self.typeinfo[name]
546+
if typeinfo.enum_name:
547+
return f"{typeinfo.enum_name}_{name}"
548+
else:
549+
return name
550+
551+
def node_type_name(self, name):
552+
typeinfo = self.typeinfo[name]
553+
if typeinfo.enum_name:
554+
return f"{get_rust_type(typeinfo.enum_name)}{get_rust_type(name)}"
555+
else:
556+
return get_rust_type(name)
557+
558+
def visitModule(self, mod, depth):
559+
self.emit("pub trait Visitor<U=()> {", depth)
560+
561+
for dfn in mod.dfns:
562+
self.visit(dfn, depth + 1)
563+
self.emit("}", depth)
564+
565+
def visitType(self, type, depth=0):
566+
self.visit(type.value, type.name, depth)
567+
568+
def emit_visitor(self, nodename, depth, has_node=True):
569+
typeinfo = self.typeinfo[nodename]
570+
if has_node:
571+
node_type = typeinfo.rust_sum_name
572+
node_value = "node"
573+
else:
574+
node_type = "()"
575+
node_value = "()"
576+
self.emit(
577+
f"fn visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{", depth
578+
)
579+
self.emit(f"self.generic_visit_{typeinfo.sum_name}({node_value})", depth + 1)
580+
self.emit("}", depth)
581+
582+
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
583+
typeinfo = self.typeinfo[nodename]
584+
if has_node:
585+
node_type = typeinfo.rust_sum_name
586+
else:
587+
node_type = "()"
588+
self.emit(
589+
f"fn generic_visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{",
590+
depth,
591+
)
592+
593+
def emit_empty_generic_visitor(self, nodename, depth):
594+
self.emit_generic_visitor_signature(nodename, depth)
543595
self.emit("}", depth)
544596

597+
def simple_sum(self, sum, name, depth):
598+
self.emit_visitor(name, depth)
599+
self.emit_empty_generic_visitor(name, depth)
600+
601+
def visit_match_for_type(self, nodename, rustname, type_, depth):
602+
self.emit(f"{rustname}::{type_.name}", depth)
603+
if type_.fields:
604+
self.emit("(data)", depth)
605+
data = "data"
606+
else:
607+
data = "()"
608+
self.emit(f"=> self.visit_{nodename}_{type_.name}({data}),", depth)
609+
610+
def visit_sumtype(self, name, type_, depth):
611+
self.emit_visitor(type_.name, depth, has_node=type_.fields)
612+
self.emit_generic_visitor_signature(type_.name, depth, has_node=type_.fields)
613+
for f in type_.fields:
614+
fieldname = rust_field(f.name)
615+
fieldtype = self.typeinfo.get(f.type)
616+
if not (fieldtype and fieldtype.has_userdata):
617+
continue
618+
619+
if f.opt:
620+
self.emit(f"if let Some(value) = node.{fieldname} {{", depth + 1)
621+
elif f.seq:
622+
iterable = f"node.{fieldname}"
623+
if type_.name == "Dict" and f.name == "keys":
624+
iterable = f"{iterable}.into_iter().flatten()"
625+
self.emit(f"for value in {iterable} {{", depth + 1)
626+
else:
627+
self.emit("{", depth + 1)
628+
self.emit(f"let value = node.{fieldname};", depth + 2)
629+
630+
variable = "value"
631+
if fieldtype.boxed and (not f.seq or f.opt):
632+
variable = "*" + variable
633+
typeinfo = self.typeinfo[fieldtype.name]
634+
self.emit(f"self.visit_{typeinfo.sum_name}({variable});", depth + 2)
635+
636+
self.emit("}", depth + 1)
637+
638+
self.emit("}", depth)
639+
640+
def sum_with_constructors(self, sum, name, depth):
641+
if not sum.attributes:
642+
return
643+
644+
rustname = enumname = get_rust_type(name)
645+
if sum.attributes:
646+
rustname = enumname + "Kind"
647+
self.emit_visitor(name, depth)
648+
self.emit_generic_visitor_signature(name, depth)
649+
depth += 1
650+
self.emit("match node.node {", depth)
651+
for t in sum.types:
652+
self.visit_match_for_type(name, rustname, t, depth + 1)
653+
self.emit("}", depth)
654+
depth -= 1
655+
self.emit("}", depth)
656+
657+
# Now for the visitors for the types
658+
for t in sum.types:
659+
self.visit_sumtype(name, t, depth)
660+
661+
def visitProduct(self, product, name, depth):
662+
self.emit_visitor(name, depth)
663+
self.emit_empty_generic_visitor(name, depth)
664+
665+
666+
class VisitorModuleVisitor(EmitVisitor):
667+
def visitModule(self, mod):
668+
depth = 0
669+
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
670+
VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth)
671+
545672

546673
class ClassDefVisitor(EmitVisitor):
547674
def visitModule(self, mod):
@@ -799,23 +926,19 @@ def visit(self, object):
799926
v.emit("", 0)
800927

801928

802-
def write_generic_def(mod, typeinfo, f):
803-
f.write(
804-
textwrap.dedent(
805-
"""
806-
pub use crate::{Attributed, constant::*};
929+
def write_ast_def(mod, typeinfo, f):
930+
StructVisitor(f, typeinfo).visit(mod)
807931

808-
type Ident = String;
809-
\n
810-
"""
811-
)
812-
)
813932

814-
c = ChainOfVisitors(StructVisitor(f, typeinfo), FoldModuleVisitor(f, typeinfo))
815-
c.visit(mod)
933+
def write_fold_def(mod, typeinfo, f):
934+
FoldModuleVisitor(f, typeinfo).visit(mod)
935+
936+
937+
def write_visitor_def(mod, typeinfo, f):
938+
VisitorModuleVisitor(f, typeinfo).visit(mod)
816939

817940

818-
def write_located_def(typeinfo, f):
941+
def write_located_def(mod, typeinfo, f):
819942
f.write(
820943
textwrap.dedent(
821944
"""
@@ -826,6 +949,8 @@ def write_located_def(typeinfo, f):
826949
)
827950
)
828951
for info in typeinfo.values():
952+
if info.empty_field:
953+
continue
829954
if info.has_userdata:
830955
generics = "::<SourceRange>"
831956
else:
@@ -863,8 +988,7 @@ def write_ast_mod(mod, typeinfo, f):
863988

864989
def main(
865990
input_filename,
866-
generic_filename,
867-
located_filename,
991+
ast_dir,
868992
module_filename,
869993
dump_module=False,
870994
):
@@ -879,34 +1003,34 @@ def main(
8791003
typeinfo = {}
8801004
FindUserdataTypesVisitor(typeinfo).visit(mod)
8811005

882-
with generic_filename.open("w") as generic_file, located_filename.open(
883-
"w"
884-
) as located_file:
885-
generic_file.write(auto_gen_msg)
886-
write_generic_def(mod, typeinfo, generic_file)
887-
located_file.write(auto_gen_msg)
888-
write_located_def(typeinfo, located_file)
1006+
for filename, write in [
1007+
("generic", write_ast_def),
1008+
("fold", write_fold_def),
1009+
("located", write_located_def),
1010+
("visitor", write_visitor_def),
1011+
]:
1012+
with (ast_dir / f"{filename}.rs").open("w") as f:
1013+
f.write(auto_gen_msg)
1014+
write(mod, typeinfo, f)
8891015

8901016
with module_filename.open("w") as module_file:
8911017
module_file.write(auto_gen_msg)
8921018
write_ast_mod(mod, typeinfo, module_file)
8931019

894-
print(f"{generic_filename}, {located_filename}, {module_filename} regenerated.")
1020+
print(f"{ast_dir}, {module_filename} regenerated.")
8951021

8961022

8971023
if __name__ == "__main__":
8981024
parser = ArgumentParser()
8991025
parser.add_argument("input_file", type=Path)
900-
parser.add_argument("-G", "--generic-file", type=Path, required=True)
901-
parser.add_argument("-L", "--located-file", type=Path, required=True)
1026+
parser.add_argument("-A", "--ast-dir", type=Path, required=True)
9021027
parser.add_argument("-M", "--module-file", type=Path, required=True)
9031028
parser.add_argument("-d", "--dump-module", action="store_true")
9041029

9051030
args = parser.parse_args()
9061031
main(
9071032
args.input_file,
908-
args.generic_file,
909-
args.located_file,
1033+
args.ast_dir,
9101034
args.module_file,
9111035
args.dump_module,
9121036
)

0 commit comments

Comments
 (0)