Skip to content

Commit 47c04d6

Browse files
authored
[5676209][ONNX][Autocast] Add check for input bs vs calibration data bs (#652)
## What does this PR do? **Type of change:** Bug fix **Overview:** Autocast crashes if the input batch size in the ONNX model is different to the calibration data input batch size. For example: calibration data has shape `[10, 6, 3, 480, 800]` and ONNX model has shape `[1, 6, 3, 480, 800]`. The quantization workflow interprets this as 10 calibration samples, so ideally, Autocast would also interpret them similarly. This PR just allows Autocast to exit gracefully with a custom message. ## Usage ```python $ python -m modelopt.onnx.autocast --onnx_path=$MODEL_NAME.onnx --calibration_data=calib_data_10.npz ``` ## Testing See bug 5676209. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No ## Additional Information Original error: ```sh polygraphy.exception.exception.PolygraphyException: Input tensor: image | Received incompatible shape: (10, 6, 3, 480, 800). Note: Expected a shape compatible with: BoundedShape([1, 6, 3, 480, 800], min=None, max=None) ``` Autocast error: ```sh ValueError: Input shape from 'image' does not match provided input shape: [1, 6, 3, 480, 800] vs [10, 6, 3, 480, 800]. Please make sure that your calibration data matches the ONNX input shapes. ``` --------- Signed-off-by: gcunhase <[email protected]>
1 parent f265f8d commit 47c04d6

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

docs/source/guides/8_autocast.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Best Practices
110110
#. **Validate with Real Data**:
111111

112112
- Provide representative input data using the ``calibration_data`` option for more accurate node classification.
113+
- The input names and shapes in ``calibration_data`` should match the ones in the given ONNX model.
113114

114115
#. **Control Reduction Depth**:
115116
- Use ``max_depth_of_reduction`` to limit the depth of reduction operations that can be converted to low precision.

modelopt/onnx/autocast/referencerunner.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def __init__(
4444
"""Initialize with ONNX model path."""
4545
self.model = model
4646
self.input_names = [input.name for input in self.model.graph.input]
47+
self.input_shapes = {
48+
input.name: [s.dim_value for s in input.type.tensor_type.shape.dim]
49+
for input in self.model.graph.input
50+
}
4751
self.providers = self._prepare_ep_list_with_trt_plugin_path(providers, trt_plugins)
4852

4953
def _prepare_ep_list_with_trt_plugin_path(self, providers, trt_plugins):
@@ -69,12 +73,19 @@ def _load_inputs_from_npz(self, input_data_path):
6973
return [np.load(input_data_path)]
7074

7175
def _validate_inputs(self, data_loader):
72-
"""Validate that input names match the model."""
76+
"""Validate that input names and shapes match the model."""
7377
if isinstance(data_loader, list) and (
7478
isinstance(data_loader[0], (dict, np.lib.npyio.NpzFile))
7579
):
7680
if sorted(self.input_names) != sorted(data_loader[0].keys()):
7781
raise ValueError("Input names from ONNX model do not match provided input names.")
82+
for inp_name, inp_shape in data_loader[0].items():
83+
if self.input_shapes[inp_name] != list(inp_shape.shape):
84+
raise ValueError(
85+
f"Input shape from '{inp_name}' does not match provided input shape: "
86+
f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "
87+
f"Please make sure that your calibration data matches the ONNX input shapes."
88+
)
7889
else:
7990
raise ValueError("Invalid input file.")
8091

0 commit comments

Comments
 (0)