Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx_importer.py] Fix dim_value None not correctly processed and missing Float8E4M3FNUZType. #4037

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
F16Type,
F32Type,
F64Type,
Float8E4M3FNUZType,
Float8E4M3FNType,
Float8E5M2FNUZType,
Float8E5M2Type,
Expand Down Expand Up @@ -642,9 +643,7 @@ def get_list_element_type(self, tp: onnx.TypeProto) -> IrType:
tt = tp.tensor_type
if tt.elem_type:
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
)
dims = tuple((d.dim_value if d.dim_value else None) for d in tt.shape.dim)
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
return f"vtensor<[{shape_asm}],{element_type}>"

Expand All @@ -655,9 +654,7 @@ def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType:
tt = tp.tensor_type
if tt.elem_type:
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
)
dims = tuple((d.dim_value if d.dim_value else None) for d in tt.shape.dim)
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
return f"vtensor<[{shape_asm}],{element_type}>"

Expand Down Expand Up @@ -707,13 +704,14 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:

tt = tp.tensor_type
if tt.elem_type:
if not tt.shape:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never None (protobuf default initializes), but rather an empty TensorShapeProto, which corresponds to a valid shape.

raise OnnxImportError(
f"Unsupported Tensor type without shape (run shape inference?): {tp}"
)
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
# NOTE: dynamic dimension can either be denoted by d.dim_param being set or
# by neither d.dim_value nor d.dim_param being set. Also note that
# d.dim_value being 0 corresponds to the protobuf default when the field
# is not set.
d.dim_value if d.dim_value else None
for d in tt.shape.dim
)
return self.get_vtensor_type(dims, element_type)

Expand Down Expand Up @@ -1097,7 +1095,7 @@ def get_operator_function(
onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()),
onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E4M3FNUZType.get(),
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",
Expand Down
5 changes: 4 additions & 1 deletion python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:

# Model is too big for in-memory inference: do file-based shape inference
# to a temp file.
# First need to save as model might have been changed (e.g. version conversion).
temp_raw_file = temp_dir / "raw.onnx"
temp_inferred_file = temp_dir / "inferred.onnx"
onnx.save(raw_model, temp_raw_file, save_as_external_data=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section precisely occurs when the provided model is large, so saving a temp file would be expensive. I'd prefer to only do this if it had actually been modified, so perhaps add a bool to track if the model got modified by previous arg specifications, and only do this if so.

I'm also concerned about not saving external data in this case, since this is exactly when we would be exceeding the 2gb protobuf limit.

onnx.shape_inference.infer_shapes_path(
args.input_file, temp_inferred_file, data_prop=args.data_prop
temp_raw_file, temp_inferred_file, data_prop=args.data_prop
)

# Sanity check the shape-inferred model to be sure we have a good model
Expand Down
Loading