Open
Description
Issue :
Trying to implement in a neural network a logic that routes dynamically a sample based on some condition. I built a dummy example of how the network should look like and I would like to export this model to MLIR. When I try to do so using torch-mlir, I get an error. I would like to know if the operator torch.cond is not supported or if my implementation is just wrong.
Steps to reproduce :
Just run this code :
import torch
import torch.nn as nn
import copy
from torch_mlir.fx import export_and_import
class CondNetwork(nn.Module):
def __init__(self):
super(CondNetwork, self).__init__()
self.confidence_threshold = 2
self.linear1 = nn.Linear(3072, 3)
self.linear2 = nn.Linear(3072, 3)
def forward(self, x):
condition = torch.mean(x) > self.confidence_threshold
def true_fn():
feature = x.clone().flatten()
return self.linear1(feature)
def false_fn():
feature = x.clone().flatten()
return self.linear2(feature)
return torch.cond(condition, true_fn, false_fn)
def torch_mlir_model_export(model):
cond_model = copy.deepcopy(model)
with torch.no_grad():
cond_model.eval()
module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
open("torchmlir_condmodel.mlir", "w").write(str(module))
###-- Main
def main():
model = CondNetwork()
#model_export(model, "cpu")
torch_mlir_model_export(model)
if __name__ == '__main__':
main()
You should get this error :
module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/fx.py", line 111, in export_and_import
fx_importer.import_frozen_program(
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 901, in import_frozen_program
return self.import_stateless_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 947, in import_stateless_graph
node_importer.import_nodes(
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1462, in import_nodes
self._import_hop(loc, node, target)
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1566, in _import_hop
raise NotImplementedError(
NotImplementedError: Higher-order operation 'cond' not implemented in the FxImporter (tried '_import_hop_cond')
Additional informations
torch version : 2.7.0.dev20250210+cpu
torchvision version : torchvision-0.22.0.dev20250210+cpu
torch_mlir version : 20250127.357
Metadata
Metadata
Assignees
Labels
No labels