Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _get_scale_and_zp(
node: onnx.NodeProto,
initializers: dict[str, onnx.TensorProto],
tensor_producers: dict[str, onnx.NodeProto],
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[onnx.TensorProto, onnx.TensorProto]:
"""Get scale and zero point tensors for a node.

Args:
Expand All @@ -546,7 +546,7 @@ def _get_scale_and_zp(
tensor_producers: Dictionary of tensor producers

Returns:
Tuple of (scale_array, zero_point_array)
Tuple of (scale_tensor, zero_point_tensor)

Raises:
ValueError: If scale or zero point cannot be found
Expand All @@ -560,7 +560,6 @@ def _get_scale_and_zp(
if not producer or not producer.attribute:
raise ValueError(f"Invalid scale producer for {scale_name}")
scale = producer.attribute[0].t
scale_array = onnx.numpy_helper.to_array(scale)

# Get zero point tensor
zp_name = node.input[2]
Expand All @@ -571,9 +570,8 @@ def _get_scale_and_zp(
if not producer or not producer.attribute:
raise ValueError(f"Invalid zero point producer for {zp_name}")
zp = producer.attribute[0].t
zp_array = onnx.numpy_helper.to_array(zp)

return scale_array, zp_array
return scale, zp


def _get_successive_consumers(
Expand Down Expand Up @@ -611,16 +609,16 @@ def _get_successive_consumers(

def _convert_weight(
weight_array: np.ndarray,
scale_array: np.ndarray,
zp_array: np.ndarray,
scale: onnx.TensorProto,
zp: onnx.TensorProto,
quantized_node: onnx.NodeProto,
) -> np.ndarray:
"""Convert a weight tensor to INT8/FP8 format based on scale and zero point.

Args:
weight_array: The weight tensor to convert
scale_array: The scale tensor for quantization
zp_array: The zero point tensor for quantization
scale: The scale tensor for quantization
zp: The zero point tensor for quantization
quantized_node: The operation node that will use the converted weight

Returns:
Expand All @@ -637,6 +635,10 @@ def _convert_weight(
weight_shape = weight_array.shape
op_type = quantized_node.op_type

# Convert onnx tensors to numpy array
scale_array = onnx.numpy_helper.to_array(scale)
zp_array = onnx.numpy_helper.to_array(zp)

# Dynamically determine transB for Gemm
trans_b = 0
if op_type == "Gemm":
Expand Down Expand Up @@ -668,7 +670,7 @@ def _convert_weight(
zp_array = zp_array.reshape(*reshape_dims)

# Convert to INT8/FP8
if zp_array.dtype == onnx_dtype_map["Float8"]:
if zp.data_type == onnx_dtype_map["Float8"]:
scaled = np.asarray(weight_array / scale_array) + zp_array
else:
scaled = np.asarray((weight_array / scale_array).round())
Expand Down Expand Up @@ -709,7 +711,9 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray:
def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
"""Create a FLOAT8E4M3FN tensor directly from numpy array."""
fp8_data = _cast_fp8(scaled)
return onnx.numpy_helper.from_array(fp8_data, weight_name)
tensor = onnx.numpy_helper.from_array(fp8_data, weight_name)
tensor.data_type = onnx_dtype_map["Float8"]
return tensor


def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
Expand Down Expand Up @@ -761,16 +765,16 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
weight_array = onnx.numpy_helper.to_array(weight)

# Get scale and zero point
scale_array, zp_array = _get_scale_and_zp(node, initializers, tensor_producers)
scale, zp = _get_scale_and_zp(node, initializers, tensor_producers)

# Validate Q->DQ->Op pattern and get consumers
dq_node, quantized_node = _get_successive_consumers(node, tensor_consumers)

# Convert weight
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
scaled = _convert_weight(weight_array, scale, zp, quantized_node)

# Create and update new weight tensor
if zp_array.dtype == onnx_dtype_map["Float8"]:
if zp.data_type == onnx_dtype_map["Float8"]:
new_weight = _create_fp8_tensor(scaled, weight_name)
logger.debug(f"Converted {weight_name} to FP8")
else:
Expand Down