Skip to content

Commit 17eeac8

Browse files
[fx] Accept func_visibility= and return created func op. (#3054)
This is a partial landing of #3046 while waiting for an upstream change for the rest of it.
1 parent 9ae33e4 commit 17eeac8

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
302302
if sparsity.layout is torch.sparse_coo:
303303
assert sparse_dim >= 2 and blocksize is None
304304
trail_dim = batch_dim + sparse_dim - 1
305-
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim))
305+
coords = ",".join(
306+
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
307+
)
306308
sep = "," if sparse_dim > 2 else ""
307309
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
308310
elif sparsity.layout is torch.sparse_csr:
@@ -467,8 +469,12 @@ def module_op(self) -> Operation:
467469
return self._m.operation
468470

469471
def import_program(
470-
self, prog: torch.export.ExportedProgram, *, func_name: str = "main"
471-
):
472+
self,
473+
prog: torch.export.ExportedProgram,
474+
*,
475+
func_name: str = "main",
476+
func_visibility: Optional[str] = None,
477+
) -> Operation:
472478
"""Imports an ExportedProgram according to our chosen canonical representation.
473479
474480
This mechanism is the fully general solution for handling an ExportedProgram
@@ -628,7 +634,9 @@ def import_program(
628634

629635
# Create the function.
630636
with loc:
631-
func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip)
637+
func_op = func_dialect.FuncOp(
638+
func_name, ftype, ip=self._m_ip, visibility=func_visibility
639+
)
632640
entry_block = Block.create_at_start(func_op.body, ftype.inputs)
633641

634642
node_importer = GraphNodeImporter(
@@ -668,10 +676,15 @@ def import_program(
668676
)
669677
node_importer.return_node_values(loc, user_outputs)
670678
self.symbol_table.insert(func_op)
679+
return func_op
671680

672681
def import_frozen_program(
673-
self, prog: torch.export.ExportedProgram, func_name: str = "main"
674-
):
682+
self,
683+
prog: torch.export.ExportedProgram,
684+
*,
685+
func_name: str = "main",
686+
func_visibility: Optional[str] = None,
687+
) -> Operation:
675688
"""Imports a consolidated torch.export.ExportedProgram instance.
676689
677690
If using the new torch.export path (vs a lower level precursor), then this is
@@ -750,17 +763,25 @@ def import_frozen_program(
750763
node.replace_all_uses_with(replacement)
751764
g.erase_node(node)
752765

753-
self.import_stateless_graph(g, func_name)
766+
return self.import_stateless_graph(
767+
g, func_name=func_name, func_visibility=func_visibility
768+
)
754769

755-
def import_graph_module(self, gm: GraphModule):
770+
def import_graph_module(self, gm: GraphModule) -> Operation:
756771
"""Low-level import of a GraphModule assuming that it has been functionalized.
757772
758773
TODO: This mechanism is deprecated by the `import_program` entry-point and
759774
it should be removed when no longer required for backwards compatibility.
760775
"""
761-
self.import_stateless_graph(gm.graph)
776+
return self.import_stateless_graph(gm.graph)
762777

763-
def import_stateless_graph(self, g: Graph, func_name: str = "main"):
778+
def import_stateless_graph(
779+
self,
780+
g: Graph,
781+
*,
782+
func_name: str = "main",
783+
func_visibility: Optional[str] = None,
784+
) -> Operation:
764785
"""Low-level import of a functionalized, assumed stateless Graph as a func.
765786
766787
TODO: This mechanism is deprecated by the `import_program` entry-point and
@@ -775,6 +796,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
775796
func_name,
776797
ftype,
777798
ip=self._m_ip,
799+
visibility=func_visibility,
778800
)
779801
entry_block = Block.create_at_start(func.body, ftype.inputs)
780802
node_importer = GraphNodeImporter(
@@ -785,6 +807,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
785807
)
786808
node_importer.import_nodes(g.nodes)
787809
self.symbol_table.insert(func)
810+
return func
788811

789812
def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
790813
"""Extracts function metadata from the Graph.

0 commit comments

Comments
 (0)