Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
save_onnx,
)

DEFAULT_GATHER_BLOCK_SIZE = 32
DEFAULT_GATHER_QUANTIZE_AXIS = None


def is_const_input(tensor: Tensor) -> bool:
"""Returns whether the given tensor is an initializer or produced by const-foldable nodes."""
Expand Down Expand Up @@ -718,6 +721,8 @@ def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
precision_pattern_8bit: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
block_size: int = 128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use default 32 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default block_size is set to 128 in int4.py if block_size is passed as none.

quantize_axis: int = 0,
):
"""Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model.

Expand Down Expand Up @@ -746,7 +751,7 @@ def get_layer_precision_mapping(
matmul_nodes = [
node
for node in onnx_model.graph.node
if node.op_type == "MatMul" and "lm_head" not in node.name
if node.op_type in ["Gemm", "MatMul"] and "lm_head" not in node.name
]

# Only include nodes matching the specified patterns for all layers present in the model
Expand Down Expand Up @@ -808,47 +813,81 @@ def layer_idx(name):
layers_8bit_set.add(names_sorted[i])
layers_list_8bit = list(layers_8bit_set)

# NEW: Create precision info mapping
precision_info = {}
# NEW: Create layer info mapping with precision, block_size, and axis
layer_info = {}
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
weight_name = weight_tensor.name
if should_quantize_to_8bit(weight_name, layers_list_8bit):
precision_info[weight_name] = 8
layer_info[weight_name] = {
"precision": 8,
"block_size": -1, # Per-channel for 8-bit
"axis": 0,
}
else:
precision_info[weight_name] = 4
return precision_info
layer_info[weight_name] = {
"precision": 4,
"block_size": block_size, # Default block size for 4-bit
"axis": quantize_axis,
}

return layer_info


def get_precision_info(
def get_layer_info(
onnx_model: onnx.ModelProto,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
block_size: int = 128,
quantize_axis: int = 0,
**kwargs: Any,
):
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).
"""Generate a mapping of weight tensor names to their quantization configuration.

This function determines the quantization precision for each weight tensor in the ONNX model,
based on the provided configuration. If mixed quantization is enabled, it uses the layer
precision mapping; otherwise, it returns None.
This function determines the quantization configuration (precision, block_size, axis) for each
weight tensor in the ONNX model, based on the provided configuration. If mixed quantization
is enabled, it uses the layer precision mapping; otherwise, it returns None.

Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
**kwargs: Additional keyword arguments, such as:
- enable_mixed_quant (bool): Whether to enable mixed quantization.
- layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit.
- block_size (int): Default block size for quantization.
- quantize_axis (int): Default quantization axis.
- gather_block_size (int): Default block size for gather quantization.
- gather_quantize_axis (int): Default quantization axis for gather.

Returns:
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
or None if mixed quantization is not enabled.
dict[str, dict[str, Any]] | None: A mapping from weight tensor names to their quantization
configuration (with keys: precision, block_size, axis), or None if mixed quantization is not enabled.
"""
precision_info = None
layer_info = None
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
layers_8bit = kwargs.get("layers_8bit")
gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE)
gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS)
if enable_mixed_quant:
precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude)
layer_info = get_layer_precision_mapping(
onnx_model,
layers_8bit,
nodes_to_exclude,
block_size,
quantize_axis,
)
else:
precision_info = None
return precision_info
layer_info = None

if gather_quantize_axis is not None:
if layer_info is None:
layer_info = {}
for node in onnx_model.graph.node:
if node.op_type == "Gather":
layer_info[node.input[0]] = {
"precision": 4,
"block_size": gather_block_size,
"axis": gather_quantize_axis,
}
return layer_info


def expand_node_names_from_patterns(
Expand Down
Loading