12
12
import asdl
13
13
14
14
TABSIZE = 4
15
- AUTOGEN_MESSAGE = "// File automatically generated by {}.\n "
15
+ AUTOGEN_MESSAGE = "// File automatically generated by {}.\n \n "
16
16
17
17
builtin_type_mapping = {
18
18
"identifier" : "Ident" ,
@@ -68,6 +68,7 @@ class TypeInfo:
68
68
enum_name : Optional [str ]
69
69
has_userdata : Optional [bool ]
70
70
has_attributes : bool
71
+ empty_field : bool
71
72
children : set
72
73
boxed : bool
73
74
product : bool
@@ -78,6 +79,7 @@ def __init__(self, name):
78
79
self .enum_name = None
79
80
self .has_userdata = None
80
81
self .has_attributes = False
82
+ self .empty_field = False
81
83
self .children = set ()
82
84
self .boxed = False
83
85
self .product = False
@@ -192,10 +194,9 @@ def visitSum(self, sum, name):
192
194
info .has_userdata = False
193
195
else :
194
196
for t in sum .types :
195
- if not t .fields :
196
- continue
197
197
t_info = TypeInfo (t .name )
198
198
t_info .enum_name = name
199
+ t_info .empty_field = not t .fields
199
200
self .typeinfo [t .name ] = t_info
200
201
self .add_children (t .name , t .fields )
201
202
if len (sum .types ) > 1 :
@@ -534,14 +535,140 @@ def gen_construction(self, header, fields, footer, depth):
534
535
class FoldModuleVisitor (EmitVisitor ):
535
536
def visitModule (self , mod ):
536
537
depth = 0
537
- self .emit ('#[cfg(feature = "fold")]' , depth )
538
- self .emit ("pub mod fold {" , depth )
539
- self .emit ("use super::*;" , depth + 1 )
540
- self .emit ("use crate::fold_helpers::Foldable;" , depth + 1 )
541
- FoldTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
542
- FoldImplVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
538
+ self .emit ("use crate::fold_helpers::Foldable;" , depth )
539
+ FoldTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth )
540
+ FoldImplVisitor (self .file , self .typeinfo ).visit (mod , depth )
541
+
542
+
543
+ class VisitorTraitDefVisitor (StructVisitor ):
544
+ def full_name (self , name ):
545
+ typeinfo = self .typeinfo [name ]
546
+ if typeinfo .enum_name :
547
+ return f"{ typeinfo .enum_name } _{ name } "
548
+ else :
549
+ return name
550
+
551
+ def node_type_name (self , name ):
552
+ typeinfo = self .typeinfo [name ]
553
+ if typeinfo .enum_name :
554
+ return f"{ get_rust_type (typeinfo .enum_name )} { get_rust_type (name )} "
555
+ else :
556
+ return get_rust_type (name )
557
+
558
+ def visitModule (self , mod , depth ):
559
+ self .emit ("pub trait Visitor<U=()> {" , depth )
560
+
561
+ for dfn in mod .dfns :
562
+ self .visit (dfn , depth + 1 )
563
+ self .emit ("}" , depth )
564
+
565
+ def visitType (self , type , depth = 0 ):
566
+ self .visit (type .value , type .name , depth )
567
+
568
+ def emit_visitor (self , nodename , depth , has_node = True ):
569
+ typeinfo = self .typeinfo [nodename ]
570
+ if has_node :
571
+ node_type = typeinfo .rust_sum_name
572
+ node_value = "node"
573
+ else :
574
+ node_type = "()"
575
+ node_value = "()"
576
+ self .emit (
577
+ f"fn visit_{ typeinfo .sum_name } (&mut self, node: { node_type } ) {{" , depth
578
+ )
579
+ self .emit (f"self.generic_visit_{ typeinfo .sum_name } ({ node_value } )" , depth + 1 )
580
+ self .emit ("}" , depth )
581
+
582
+ def emit_generic_visitor_signature (self , nodename , depth , has_node = True ):
583
+ typeinfo = self .typeinfo [nodename ]
584
+ if has_node :
585
+ node_type = typeinfo .rust_sum_name
586
+ else :
587
+ node_type = "()"
588
+ self .emit (
589
+ f"fn generic_visit_{ typeinfo .sum_name } (&mut self, node: { node_type } ) {{" ,
590
+ depth ,
591
+ )
592
+
593
+ def emit_empty_generic_visitor (self , nodename , depth ):
594
+ self .emit_generic_visitor_signature (nodename , depth )
543
595
self .emit ("}" , depth )
544
596
597
+ def simple_sum (self , sum , name , depth ):
598
+ self .emit_visitor (name , depth )
599
+ self .emit_empty_generic_visitor (name , depth )
600
+
601
+ def visit_match_for_type (self , nodename , rustname , type_ , depth ):
602
+ self .emit (f"{ rustname } ::{ type_ .name } " , depth )
603
+ if type_ .fields :
604
+ self .emit ("(data)" , depth )
605
+ data = "data"
606
+ else :
607
+ data = "()"
608
+ self .emit (f"=> self.visit_{ nodename } _{ type_ .name } ({ data } )," , depth )
609
+
610
+ def visit_sumtype (self , name , type_ , depth ):
611
+ self .emit_visitor (type_ .name , depth , has_node = type_ .fields )
612
+ self .emit_generic_visitor_signature (type_ .name , depth , has_node = type_ .fields )
613
+ for f in type_ .fields :
614
+ fieldname = rust_field (f .name )
615
+ fieldtype = self .typeinfo .get (f .type )
616
+ if not (fieldtype and fieldtype .has_userdata ):
617
+ continue
618
+
619
+ if f .opt :
620
+ self .emit (f"if let Some(value) = node.{ fieldname } {{" , depth + 1 )
621
+ elif f .seq :
622
+ iterable = f"node.{ fieldname } "
623
+ if type_ .name == "Dict" and f .name == "keys" :
624
+ iterable = f"{ iterable } .into_iter().flatten()"
625
+ self .emit (f"for value in { iterable } {{" , depth + 1 )
626
+ else :
627
+ self .emit ("{" , depth + 1 )
628
+ self .emit (f"let value = node.{ fieldname } ;" , depth + 2 )
629
+
630
+ variable = "value"
631
+ if fieldtype .boxed and (not f .seq or f .opt ):
632
+ variable = "*" + variable
633
+ typeinfo = self .typeinfo [fieldtype .name ]
634
+ self .emit (f"self.visit_{ typeinfo .sum_name } ({ variable } );" , depth + 2 )
635
+
636
+ self .emit ("}" , depth + 1 )
637
+
638
+ self .emit ("}" , depth )
639
+
640
+ def sum_with_constructors (self , sum , name , depth ):
641
+ if not sum .attributes :
642
+ return
643
+
644
+ rustname = enumname = get_rust_type (name )
645
+ if sum .attributes :
646
+ rustname = enumname + "Kind"
647
+ self .emit_visitor (name , depth )
648
+ self .emit_generic_visitor_signature (name , depth )
649
+ depth += 1
650
+ self .emit ("match node.node {" , depth )
651
+ for t in sum .types :
652
+ self .visit_match_for_type (name , rustname , t , depth + 1 )
653
+ self .emit ("}" , depth )
654
+ depth -= 1
655
+ self .emit ("}" , depth )
656
+
657
+ # Now for the visitors for the types
658
+ for t in sum .types :
659
+ self .visit_sumtype (name , t , depth )
660
+
661
+ def visitProduct (self , product , name , depth ):
662
+ self .emit_visitor (name , depth )
663
+ self .emit_empty_generic_visitor (name , depth )
664
+
665
+
666
+ class VisitorModuleVisitor (EmitVisitor ):
667
+ def visitModule (self , mod ):
668
+ depth = 0
669
+ self .emit ("#[allow(unused_variables, non_snake_case)]" , depth )
670
+ VisitorTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth )
671
+
545
672
546
673
class ClassDefVisitor (EmitVisitor ):
547
674
def visitModule (self , mod ):
@@ -799,23 +926,19 @@ def visit(self, object):
799
926
v .emit ("" , 0 )
800
927
801
928
802
- def write_generic_def (mod , typeinfo , f ):
803
- f .write (
804
- textwrap .dedent (
805
- """
806
- pub use crate::{Attributed, constant::*};
929
+ def write_ast_def (mod , typeinfo , f ):
930
+ StructVisitor (f , typeinfo ).visit (mod )
807
931
808
- type Ident = String;
809
- \n
810
- """
811
- )
812
- )
813
932
814
- c = ChainOfVisitors (StructVisitor (f , typeinfo ), FoldModuleVisitor (f , typeinfo ))
815
- c .visit (mod )
933
+ def write_fold_def (mod , typeinfo , f ):
934
+ FoldModuleVisitor (f , typeinfo ).visit (mod )
935
+
936
+
937
+ def write_visitor_def (mod , typeinfo , f ):
938
+ VisitorModuleVisitor (f , typeinfo ).visit (mod )
816
939
817
940
818
- def write_located_def (typeinfo , f ):
941
+ def write_located_def (mod , typeinfo , f ):
819
942
f .write (
820
943
textwrap .dedent (
821
944
"""
@@ -826,6 +949,8 @@ def write_located_def(typeinfo, f):
826
949
)
827
950
)
828
951
for info in typeinfo .values ():
952
+ if info .empty_field :
953
+ continue
829
954
if info .has_userdata :
830
955
generics = "::<SourceRange>"
831
956
else :
@@ -863,8 +988,7 @@ def write_ast_mod(mod, typeinfo, f):
863
988
864
989
def main (
865
990
input_filename ,
866
- generic_filename ,
867
- located_filename ,
991
+ ast_dir ,
868
992
module_filename ,
869
993
dump_module = False ,
870
994
):
@@ -879,34 +1003,34 @@ def main(
879
1003
typeinfo = {}
880
1004
FindUserdataTypesVisitor (typeinfo ).visit (mod )
881
1005
882
- with generic_filename .open ("w" ) as generic_file , located_filename .open (
883
- "w"
884
- ) as located_file :
885
- generic_file .write (auto_gen_msg )
886
- write_generic_def (mod , typeinfo , generic_file )
887
- located_file .write (auto_gen_msg )
888
- write_located_def (typeinfo , located_file )
1006
+ for filename , write in [
1007
+ ("generic" , write_ast_def ),
1008
+ ("fold" , write_fold_def ),
1009
+ ("located" , write_located_def ),
1010
+ ("visitor" , write_visitor_def ),
1011
+ ]:
1012
+ with (ast_dir / f"{ filename } .rs" ).open ("w" ) as f :
1013
+ f .write (auto_gen_msg )
1014
+ write (mod , typeinfo , f )
889
1015
890
1016
with module_filename .open ("w" ) as module_file :
891
1017
module_file .write (auto_gen_msg )
892
1018
write_ast_mod (mod , typeinfo , module_file )
893
1019
894
- print (f"{ generic_filename } , { located_filename } , { module_filename } regenerated." )
1020
+ print (f"{ ast_dir } , { module_filename } regenerated." )
895
1021
896
1022
897
1023
if __name__ == "__main__" :
898
1024
parser = ArgumentParser ()
899
1025
parser .add_argument ("input_file" , type = Path )
900
- parser .add_argument ("-G" , "--generic-file" , type = Path , required = True )
901
- parser .add_argument ("-L" , "--located-file" , type = Path , required = True )
1026
+ parser .add_argument ("-A" , "--ast-dir" , type = Path , required = True )
902
1027
parser .add_argument ("-M" , "--module-file" , type = Path , required = True )
903
1028
parser .add_argument ("-d" , "--dump-module" , action = "store_true" )
904
1029
905
1030
args = parser .parse_args ()
906
1031
main (
907
1032
args .input_file ,
908
- args .generic_file ,
909
- args .located_file ,
1033
+ args .ast_dir ,
910
1034
args .module_file ,
911
1035
args .dump_module ,
912
1036
)
0 commit comments