-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconvert_to_onnx.py
25 lines (22 loc) · 932 Bytes
/
convert_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import onnxruntime
import onnxruntime.tools.convert_onnx_models_to_ort
import monai
import torch
import os
densenet = getattr(monai.networks.nets, 'densenet121')
model = densenet(spatial_dims=2,
in_channels=3,
out_channels=3,
dropout_prob=float(0.1),
pretrained=True)
device = 'cpu'
dummy_input = torch.randn(1, 3, 96, 96, device=device)
model.eval()
dynamic_axes = {0: 'batch', 1: 'num_channels', 2: 'height', 3: 'width'}
torch.onnx.export(model, dummy_input, "dummy_model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': dynamic_axes, 'output': dynamic_axes})
# nnapi
# os.system("python -m onnxruntime.tools.convert_onnx_models_to_ort . --use_nnapi --optimization_level basic")
os.system("python -m onnxruntime.tools.convert_onnx_models_to_ort . --optimization_level basic")