Skip to content

torch.cond operator not supported on a simple example #4028

Open
@JibAxelera

Description

@JibAxelera

Issue :

Trying to implement in a neural network a logic that routes dynamically a sample based on some condition. I built a dummy example of how the network should look like and I would like to export this model to MLIR. When I try to do so using torch-mlir, I get an error. I would like to know if the operator torch.cond is not supported or if my implementation is just wrong.

Steps to reproduce :

Just run this code :

import torch
import torch.nn as nn
import copy
from torch_mlir.fx import export_and_import

class CondNetwork(nn.Module):

    def __init__(self):
        super(CondNetwork, self).__init__()

        self.confidence_threshold = 2
        self.linear1 = nn.Linear(3072, 3)
        self.linear2 = nn.Linear(3072, 3)

    def forward(self, x):

        condition = torch.mean(x) > self.confidence_threshold

        def true_fn():

            feature = x.clone().flatten()
            return self.linear1(feature)

        def false_fn():

            feature = x.clone().flatten()
            return self.linear2(feature)

        return torch.cond(condition, true_fn, false_fn)


def torch_mlir_model_export(model):

    cond_model = copy.deepcopy(model)

    with torch.no_grad():
        cond_model.eval()
        module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
        open("torchmlir_condmodel.mlir", "w").write(str(module))

###-- Main
def main():

    model = CondNetwork()

    #model_export(model, "cpu")
    torch_mlir_model_export(model)

if __name__ == '__main__':
    main()

You should get this error :

 module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/fx.py", line 111, in export_and_import
    fx_importer.import_frozen_program(
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 901, in import_frozen_program
    return self.import_stateless_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 947, in import_stateless_graph
    node_importer.import_nodes(
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1462, in import_nodes
    self._import_hop(loc, node, target)
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1566, in _import_hop
    raise NotImplementedError(
NotImplementedError: Higher-order operation 'cond' not implemented in the FxImporter (tried '_import_hop_cond')

Additional informations

torch version : 2.7.0.dev20250210+cpu

torchvision version : torchvision-0.22.0.dev20250210+cpu

torch_mlir version : 20250127.357

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions