Skip to content

[fx] Use module Operations instead of Module. #3046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
if sparsity.layout is torch.sparse_coo:
assert sparse_dim >= 2 and blocksize is None
trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim))
coords = ",".join(
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
)
sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
Expand Down Expand Up @@ -415,7 +417,7 @@ class FxImporter:
__slots__ = [
"_c",
"_cc",
"_m",
"_m_op",
"_m_ip",
"_py_attr_tracker",
"_hooks",
Expand All @@ -425,28 +427,31 @@ class FxImporter:
def __init__(
self,
*,
module: Optional[Module] = None,
module_op: Optional[Operation] = None,
context: Optional[Context] = None,
config_check: bool = True,
py_attr_tracker: Optional["RefTracker"] = None,
hooks: Optional[FxImporterHooks] = None,
):
if module is not None:
assert context is None, "If configuring with a Module, context must be None"
self._m = module
self._c = self.module.context
if module_op is not None:
assert (
context is None
), "If configuring with a module op, context must be None"
self._m_op = module_op
self._c = self._m_op.context
else:
self._c = context if context else Context()
self._m = Module.create(Location.unknown(self._c))
self._m_op = Module.create(Location.unknown(self._c)).operation
body = self._m_op.regions[0].blocks[0]
if config_check:
# Production code can disable this for a bit of a boost.
self._config_check()
self._py_attr_tracker = py_attr_tracker or RefTracker()
self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker)
self._m_ip = InsertionPoint(self._m.body)
self._m_ip = InsertionPoint(body)
self._hooks = hooks or FxImporterHooks()
self.symbol_table = SymbolTable(self._m.operation)
self._hooks.prepare_module(self._m.operation)
self.symbol_table = SymbolTable(self._m_op)
self._hooks.prepare_module(self._m_op)

def _config_check(self):
for dname in REQUIRED_DIALCTS:
Expand All @@ -458,17 +463,17 @@ def _config_check(self):
f"The MLIR context {self._c} is missing required dialect '{dname}'"
)

@property
def module(self) -> Module:
return self._m

@property
def module_op(self) -> Operation:
return self._m.operation
return self._m_op

def import_program(
self, prog: torch.export.ExportedProgram, *, func_name: str = "main"
):
self,
prog: torch.export.ExportedProgram,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Imports an ExportedProgram according to our chosen canonical representation.

This mechanism is the fully general solution for handling an ExportedProgram
Expand All @@ -490,6 +495,8 @@ def import_program(
It is recommended that integrators subclass and override the `resolve_literal`
method to control access to mutable buffers and parameters. Without that, the
default policy is to capture them as frozen values.

Returns the created entry function as a generic Operation.
"""
# Create lookaside table of placeholders/outputs.
placeholder_nodes: Dict[str, Node] = {}
Expand Down Expand Up @@ -628,7 +635,9 @@ def import_program(

# Create the function.
with loc:
func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip)
func_op = func_dialect.FuncOp(
func_name, ftype, ip=self._m_ip, visibility=func_visibility
)
entry_block = Block.create_at_start(func_op.body, ftype.inputs)

node_importer = GraphNodeImporter(
Expand Down Expand Up @@ -668,9 +677,13 @@ def import_program(
)
node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op)
return func_op.operation

def import_frozen_program(
self, prog: torch.export.ExportedProgram, func_name: str = "main"
self,
prog: torch.export.ExportedProgram,
func_name: str = "main",
func_visibility: Optional[str] = None,
):
"""Imports a consolidated torch.export.ExportedProgram instance.

Expand Down Expand Up @@ -750,7 +763,7 @@ def import_frozen_program(
node.replace_all_uses_with(replacement)
g.erase_node(node)

self.import_stateless_graph(g, func_name)
self.import_stateless_graph(g, func_name, func_visibility=func_visibility)

def import_graph_module(self, gm: GraphModule):
"""Low-level import of a GraphModule assuming that it has been functionalized.
Expand All @@ -760,7 +773,9 @@ def import_graph_module(self, gm: GraphModule):
"""
self.import_stateless_graph(gm.graph)

def import_stateless_graph(self, g: Graph, func_name: str = "main"):
def import_stateless_graph(
self, g: Graph, func_name: str = "main", func_visibility: Optional[str] = None
):
"""Low-level import of a functionalized, assumed stateless Graph as a func.

TODO: This mechanism is deprecated by the `import_program` entry-point and
Expand All @@ -775,6 +790,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
func_name,
ftype,
ip=self._m_ip,
visibility=func_visibility,
)
entry_block = Block.create_at_start(func.body, ftype.inputs)
node_importer = GraphNodeImporter(
Expand Down
4 changes: 2 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def export_and_import(
else:
fx_importer.import_frozen_program(prog, func_name=func_name)

return fx_importer.module
return fx_importer.module_op


def stateless_fx_import(
Expand All @@ -55,4 +55,4 @@ def stateless_fx_import(
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return fx_importer.module
return fx_importer.module_op
2 changes: 1 addition & 1 deletion test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def export_and_import(f, *args, **kwargs):
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
fx_importer.import_frozen_program(prog)
return fx_importer.module
return fx_importer.module_op


def sparse_jit(f, *args, **kwargs):
Expand Down