Skip to content

Commit

Permalink
Bump Onnx Version to 1.16.1 (#3515)
Browse files Browse the repository at this point in the history
This commit adds the support for new data types: uint4, and int4 and
uint8 tensor protos. Also, it moves some tests from failing to crashing.

Fixes #3507

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Jul 1, 2024
1 parent 0e71a19 commit 2f231f3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
8 changes: 4 additions & 4 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2572,8 +2572,6 @@
"SplitDimStaticModule_basic",
"SqrtIntConstantModule_basic",
"SqrtIntModule_basic",
"StdCorrectionEmptyDimModule_basic",
"StdDimEmptyDimModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TanhBackward_basic",
Expand Down Expand Up @@ -2627,8 +2625,6 @@
"UpSampleNearest2dDynamicFactor_basic",
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2d_basic",
"VarCorrectionEmptyDimModule_basic",
"VarDimEmptyDimModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewCollapseModule_basic",
"ViewDynamicExpandCollapseModule_basic",
Expand Down Expand Up @@ -2797,6 +2793,10 @@
# Runtime crash: mismatched size for broadcast
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"StdDimEmptyDimModule_basic",
"StdCorrectionEmptyDimModule_basic",
"VarCorrectionEmptyDimModule_basic",
"VarDimEmptyDimModule_basic",
}

FX_IMPORTER_TOSA_XFAIL_SET = {
Expand Down
5 changes: 5 additions & 0 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,8 @@ def get_operator_function(
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",
onnx.TensorProto.DataType.UINT4: lambda: IntegerType.get_unsigned(4),
onnx.TensorProto.DataType.INT4: lambda: IntegerType.get_signed(4),
# Ommitted: STRING,
}

Expand Down Expand Up @@ -1134,6 +1136,9 @@ def get_operator_function(
),
signless=False,
),
onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False
),
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
raw_model = onnx.load(args.input_file)
else:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, args.data_dir)
onnx.load_external_data_for_model(raw_model, str(args.data_dir))

if args.opset_version:
raw_model = onnx.version_converter.convert_version(
Expand Down
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pillow
dill
multiprocess
onnx==1.15.0
onnx==1.16.1
mpmath==1.3.0

0 comments on commit 2f231f3

Please sign in to comment.