diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index c4dbdcc4a..238f8632a 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -537,7 +537,7 @@ def _get_scale_and_zp( node: onnx.NodeProto, initializers: dict[str, onnx.TensorProto], tensor_producers: dict[str, onnx.NodeProto], -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[onnx.TensorProto, onnx.TensorProto]: """Get scale and zero point tensors for a node. Args: @@ -546,7 +546,7 @@ def _get_scale_and_zp( tensor_producers: Dictionary of tensor producers Returns: - Tuple of (scale_array, zero_point_array) + Tuple of (scale_tensor, zero_point_tensor) Raises: ValueError: If scale or zero point cannot be found @@ -560,7 +560,6 @@ def _get_scale_and_zp( if not producer or not producer.attribute: raise ValueError(f"Invalid scale producer for {scale_name}") scale = producer.attribute[0].t - scale_array = onnx.numpy_helper.to_array(scale) # Get zero point tensor zp_name = node.input[2] @@ -571,9 +570,8 @@ def _get_scale_and_zp( if not producer or not producer.attribute: raise ValueError(f"Invalid zero point producer for {zp_name}") zp = producer.attribute[0].t - zp_array = onnx.numpy_helper.to_array(zp) - return scale_array, zp_array + return scale, zp def _get_successive_consumers( @@ -611,16 +609,16 @@ def _get_successive_consumers( def _convert_weight( weight_array: np.ndarray, - scale_array: np.ndarray, - zp_array: np.ndarray, + scale: onnx.TensorProto, + zp: onnx.TensorProto, quantized_node: onnx.NodeProto, ) -> np.ndarray: """Convert a weight tensor to INT8/FP8 format based on scale and zero point. Args: weight_array: The weight tensor to convert - scale_array: The scale tensor for quantization - zp_array: The zero point tensor for quantization + scale: The scale tensor for quantization + zp: The zero point tensor for quantization quantized_node: The operation node that will use the converted weight Returns: @@ -637,6 +635,10 @@ def _convert_weight( weight_shape = weight_array.shape op_type = quantized_node.op_type + # Convert onnx tensors to numpy array + scale_array = onnx.numpy_helper.to_array(scale) + zp_array = onnx.numpy_helper.to_array(zp) + # Dynamically determine transB for Gemm trans_b = 0 if op_type == "Gemm": @@ -668,7 +670,7 @@ def _convert_weight( zp_array = zp_array.reshape(*reshape_dims) # Convert to INT8/FP8 - if zp_array.dtype == onnx_dtype_map["Float8"]: + if zp.data_type == onnx_dtype_map["Float8"]: scaled = np.asarray(weight_array / scale_array) + zp_array else: scaled = np.asarray((weight_array / scale_array).round()) @@ -709,7 +711,9 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray: def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto: """Create a FLOAT8E4M3FN tensor directly from numpy array.""" fp8_data = _cast_fp8(scaled) - return onnx.numpy_helper.from_array(fp8_data, weight_name) + tensor = onnx.numpy_helper.from_array(fp8_data, weight_name) + tensor.data_type = onnx_dtype_map["Float8"] + return tensor def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto: @@ -761,16 +765,16 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto: weight_array = onnx.numpy_helper.to_array(weight) # Get scale and zero point - scale_array, zp_array = _get_scale_and_zp(node, initializers, tensor_producers) + scale, zp = _get_scale_and_zp(node, initializers, tensor_producers) # Validate Q->DQ->Op pattern and get consumers dq_node, quantized_node = _get_successive_consumers(node, tensor_consumers) # Convert weight - scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node) + scaled = _convert_weight(weight_array, scale, zp, quantized_node) # Create and update new weight tensor - if zp_array.dtype == onnx_dtype_map["Float8"]: + if zp.data_type == onnx_dtype_map["Float8"]: new_weight = _create_fp8_tensor(scaled, weight_name) logger.debug(f"Converted {weight_name} to FP8") else: