Skip to content

Converting to onnx #154

Open
Open
@cjenkins5614

Description

@cjenkins5614

Hello,

Thanks for the great work. I'm trying to convert this model into onnx, but have met a few issues.

The mv and dot operator used by PyTorch's spectral_norm was one of them. Following onnx/onnx#3006 (comment) I coverted them to matmul in my own implementation of spectral_norm and the issue went away.

Now it's complaining:

Traceback (most recent call last):
    out = torch.onnx.export(model, input_dict["image"], "model.onnx", verbose=False, opset_version=11,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 271, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 694, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 463, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 206, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 309, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 997, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 148, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1285, in batch_norm
    if weight is None or sym_help._is_none(weight):
RuntimeError: Unsupported: ONNX export of batch_norm for unknown channel size.

The code to convert this is:

    opt = EasyDict(aspect_ratio=1.0,
                checkpoints_dir='Face_Enhancement/checkpoints',
                contain_dontcare_label=False,
                crop_size=256,
                gpu_ids=[0],
                init_type='xavier',
                init_variance=0.02,
                injection_layer='all',
                isTrain=False,
                label_nc=18,
                load_size=256,
                model='pix2pix',
                name='Setting_9_epoch_100',
                nef=16,
                netG='spade',
                ngf=64,
                no_flip=True,
                no_instance=True,
                no_parsing_map=True,
                norm_D='spectralinstance',
                norm_E='spectralinstance',
                # norm_G='spectralspadebatch3x3',
                norm_G='spectralspadesyncbatch3x3',
                num_upsampling_layers='normal',
                output_nc=3,
                preprocess_mode='resize',
                semantic_nc=18,
                use_vae=False,
                which_epoch='latest',
                z_dim=256)

    model = Pix2PixModel(opt)
    model.eval()

    input_dict = {
        "label": torch.zeros((1, 18, 256, 256)),
        "image": torch.randn(1, 3, 256, 256),
        "path": None,
    }

    # from torchsummary import summary
    # summary(model, (3, 256, 256))
    out = torch.onnx.export(model, input_dict, "model.onnx", verbose=False, opset_version=11,
                      input_names = ['input'],
                      output_names = ['output'])

I printed out the graph g from https://github.com/pytorch/pytorch/blob/e56d3b023818f54553f2dc5d30b6b7aaf6b6a325/torch/onnx/symbolic_opset9.py#L1337

...
  %450 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPULongType{2} ]]()
  %451 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %452 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 0  0 [ CPULongType{2} ]]()
  %453 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %454 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %455 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %456 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %457 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %458 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%436, %447, %netG.head_0.conv_1.bias) # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:395:0
  %459 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %460 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Add(%266, %458) # /workdir/Face_Enhancement/models/networks/architecture.py:56:0
  %461 : None = prim::Constant()
  %462 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %463 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPUFloatType{2} ]]()
  %464 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %465 : Float(4, strides=[1], device=cpu) = onnx::Concat[axis=0](%463, %464)
  %466 : Float(0, strides=[1], device=cpu) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
  %468 : None = prim::Constant()
  %469 : None = prim::Constant()
  %470 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %471 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={0.1}]()
  %472 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={1e-05}]()
  %473 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  return ()

ipdb> input
467 defined in (%467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
)
ipdb> weight
468 defined in (%468 : None = prim::Constant()
)
ipdb> bias
469 defined in (%469 : None = prim::Constant()
)

Float(*, *, *, * stood out to me but I'm not sure how to interpret this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions