Skip to content

torch.cond operator not supported on a simple example #4028

@JibAxelera

Description

@JibAxelera

Issue :

Trying to implement in a neural network a logic that routes dynamically a sample based on some condition. I built a dummy example of how the network should look like and I would like to export this model to MLIR. When I try to do so using torch-mlir, I get an error. I would like to know if the operator torch.cond is not supported or if my implementation is just wrong.

Steps to reproduce :

Just run this code :

import torch
import torch.nn as nn
import copy
from torch_mlir.fx import export_and_import

class CondNetwork(nn.Module):

    def __init__(self):
        super(CondNetwork, self).__init__()

        self.confidence_threshold = 2
        self.linear1 = nn.Linear(3072, 3)
        self.linear2 = nn.Linear(3072, 3)

    def forward(self, x):

        condition = torch.mean(x) > self.confidence_threshold

        def true_fn():

            feature = x.clone().flatten()
            return self.linear1(feature)

        def false_fn():

            feature = x.clone().flatten()
            return self.linear2(feature)

        return torch.cond(condition, true_fn, false_fn)


def torch_mlir_model_export(model):

    cond_model = copy.deepcopy(model)

    with torch.no_grad():
        cond_model.eval()
        module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
        open("torchmlir_condmodel.mlir", "w").write(str(module))

###-- Main
def main():

    model = CondNetwork()

    #model_export(model, "cpu")
    torch_mlir_model_export(model)

if __name__ == '__main__':
    main()

You should get this error :

 module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/fx.py", line 111, in export_and_import
    fx_importer.import_frozen_program(
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 901, in import_frozen_program
    return self.import_stateless_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 947, in import_stateless_graph
    node_importer.import_nodes(
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1462, in import_nodes
    self._import_hop(loc, node, target)
  File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1566, in _import_hop
    raise NotImplementedError(
NotImplementedError: Higher-order operation 'cond' not implemented in the FxImporter (tried '_import_hop_cond')

Additional informations

torch version : 2.7.0.dev20250210+cpu

torchvision version : torchvision-0.22.0.dev20250210+cpu

torch_mlir version : 20250127.357

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions