Skip to content

Commit 2f231f3

Browse files
Bump Onnx Version to 1.16.1 (#3515)
This commit adds the support for new data types: uint4, and int4 and uint8 tensor protos. Also, it moves some tests from failing to crashing. Fixes #3507 Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 0e71a19 commit 2f231f3

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,8 +2572,6 @@
25722572
"SplitDimStaticModule_basic",
25732573
"SqrtIntConstantModule_basic",
25742574
"SqrtIntModule_basic",
2575-
"StdCorrectionEmptyDimModule_basic",
2576-
"StdDimEmptyDimModule_basic",
25772575
"SubFloatModule_basic",
25782576
"SubIntModule_basic",
25792577
"TanhBackward_basic",
@@ -2627,8 +2625,6 @@
26272625
"UpSampleNearest2dDynamicFactor_basic",
26282626
"UpSampleNearest2dStaticFactor_basic",
26292627
"UpSampleNearest2d_basic",
2630-
"VarCorrectionEmptyDimModule_basic",
2631-
"VarDimEmptyDimModule_basic",
26322628
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
26332629
"ViewCollapseModule_basic",
26342630
"ViewDynamicExpandCollapseModule_basic",
@@ -2797,6 +2793,10 @@
27972793
# Runtime crash: mismatched size for broadcast
27982794
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
27992795
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
2796+
"StdDimEmptyDimModule_basic",
2797+
"StdCorrectionEmptyDimModule_basic",
2798+
"VarCorrectionEmptyDimModule_basic",
2799+
"VarDimEmptyDimModule_basic",
28002800
}
28012801

28022802
FX_IMPORTER_TOSA_XFAIL_SET = {

python/torch_mlir/extras/onnx_importer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,8 @@ def get_operator_function(
10981098
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
10991099
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
11001100
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",
1101+
onnx.TensorProto.DataType.UINT4: lambda: IntegerType.get_unsigned(4),
1102+
onnx.TensorProto.DataType.INT4: lambda: IntegerType.get_signed(4),
11011103
# Ommitted: STRING,
11021104
}
11031105

@@ -1134,6 +1136,9 @@ def get_operator_function(
11341136
),
11351137
signless=False,
11361138
),
1139+
onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get(
1140+
np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False
1141+
),
11371142
onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get(
11381143
np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False
11391144
),

python/torch_mlir/tools/import_onnx/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
8484
raw_model = onnx.load(args.input_file)
8585
else:
8686
raw_model = onnx.load(args.input_file, load_external_data=False)
87-
onnx.load_external_data_for_model(raw_model, args.data_dir)
87+
onnx.load_external_data_for_model(raw_model, str(args.data_dir))
8888

8989
if args.opset_version:
9090
raw_model = onnx.version_converter.convert_version(

test-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pillow
22
dill
33
multiprocess
4-
onnx==1.15.0
4+
onnx==1.16.1
55
mpmath==1.3.0

0 commit comments

Comments
 (0)