@@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
302
302
if sparsity .layout is torch .sparse_coo :
303
303
assert sparse_dim >= 2 and blocksize is None
304
304
trail_dim = batch_dim + sparse_dim - 1
305
- coords = "," .join (f"d{ d } :singleton(nonunique,soa)" for d in range (batch_dim + 1 , trail_dim ))
305
+ coords = "," .join (
306
+ f"d{ d } :singleton(nonunique,soa)" for d in range (batch_dim + 1 , trail_dim )
307
+ )
306
308
sep = "," if sparse_dim > 2 else ""
307
309
lvls = f"d{ batch_dim } :compressed(nonunique),{ coords } { sep } d{ trail_dim } :singleton(soa)"
308
310
elif sparsity .layout is torch .sparse_csr :
@@ -467,8 +469,12 @@ def module_op(self) -> Operation:
467
469
return self ._m .operation
468
470
469
471
def import_program (
470
- self , prog : torch .export .ExportedProgram , * , func_name : str = "main"
471
- ):
472
+ self ,
473
+ prog : torch .export .ExportedProgram ,
474
+ * ,
475
+ func_name : str = "main" ,
476
+ func_visibility : Optional [str ] = None ,
477
+ ) -> Operation :
472
478
"""Imports an ExportedProgram according to our chosen canonical representation.
473
479
474
480
This mechanism is the fully general solution for handling an ExportedProgram
@@ -628,7 +634,9 @@ def import_program(
628
634
629
635
# Create the function.
630
636
with loc :
631
- func_op = func_dialect .FuncOp (func_name , ftype , ip = self ._m_ip )
637
+ func_op = func_dialect .FuncOp (
638
+ func_name , ftype , ip = self ._m_ip , visibility = func_visibility
639
+ )
632
640
entry_block = Block .create_at_start (func_op .body , ftype .inputs )
633
641
634
642
node_importer = GraphNodeImporter (
@@ -668,10 +676,15 @@ def import_program(
668
676
)
669
677
node_importer .return_node_values (loc , user_outputs )
670
678
self .symbol_table .insert (func_op )
679
+ return func_op
671
680
672
681
def import_frozen_program (
673
- self , prog : torch .export .ExportedProgram , func_name : str = "main"
674
- ):
682
+ self ,
683
+ prog : torch .export .ExportedProgram ,
684
+ * ,
685
+ func_name : str = "main" ,
686
+ func_visibility : Optional [str ] = None ,
687
+ ) -> Operation :
675
688
"""Imports a consolidated torch.export.ExportedProgram instance.
676
689
677
690
If using the new torch.export path (vs a lower level precursor), then this is
@@ -750,17 +763,25 @@ def import_frozen_program(
750
763
node .replace_all_uses_with (replacement )
751
764
g .erase_node (node )
752
765
753
- self .import_stateless_graph (g , func_name )
766
+ return self .import_stateless_graph (
767
+ g , func_name = func_name , func_visibility = func_visibility
768
+ )
754
769
755
- def import_graph_module (self , gm : GraphModule ):
770
+ def import_graph_module (self , gm : GraphModule ) -> Operation :
756
771
"""Low-level import of a GraphModule assuming that it has been functionalized.
757
772
758
773
TODO: This mechanism is deprecated by the `import_program` entry-point and
759
774
it should be removed when no longer required for backwards compatibility.
760
775
"""
761
- self .import_stateless_graph (gm .graph )
776
+ return self .import_stateless_graph (gm .graph )
762
777
763
- def import_stateless_graph (self , g : Graph , func_name : str = "main" ):
778
+ def import_stateless_graph (
779
+ self ,
780
+ g : Graph ,
781
+ * ,
782
+ func_name : str = "main" ,
783
+ func_visibility : Optional [str ] = None ,
784
+ ) -> Operation :
764
785
"""Low-level import of a functionalized, assumed stateless Graph as a func.
765
786
766
787
TODO: This mechanism is deprecated by the `import_program` entry-point and
@@ -775,6 +796,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
775
796
func_name ,
776
797
ftype ,
777
798
ip = self ._m_ip ,
799
+ visibility = func_visibility ,
778
800
)
779
801
entry_block = Block .create_at_start (func .body , ftype .inputs )
780
802
node_importer = GraphNodeImporter (
@@ -785,6 +807,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
785
807
)
786
808
node_importer .import_nodes (g .nodes )
787
809
self .symbol_table .insert (func )
810
+ return func
788
811
789
812
def _graph_to_function_meta (self , g : Graph ) -> Tuple [FunctionType , Location ]:
790
813
"""Extracts function metadata from the Graph.
0 commit comments