Skip to content

Commit e288735

Browse files
committed
update
1 parent 50a1f01 commit e288735

File tree

4 files changed

+9
-10
lines changed

4 files changed

+9
-10
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
9191
observer = getattr(module, f"{base_name}_observer")
9292
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
9393

94-
if hasattr(module, "input_global_scale"):
94+
if base_name == "input" and hasattr(module, "input_global_scale"):
9595
update_parameter_data(module, updated_scale, "input_global_scale")
9696
else:
9797
# update scale and zero point

src/llmcompressor/observers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_qparams(
9393
observed, calculate_global_scale=True
9494
)
9595

96-
if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
96+
elif self.quantization_args.strategy == QuantizationStrategy.TENSOR:
9797
# re-calculate scale and zero point, update the stored value
9898
self._scale, self._zero_point = self.calculate_qparams(observed)
9999

src/llmcompressor/observers/min_max.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def calculate_qparams(
9494
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
9595
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
9696
global_scale = FP8_E4M3_DATA.max * FP4_E2M1_DATA.max / max_val_pos
97-
print("global_scale")
98-
return global_scale, None
97+
return global_scale.to(torch.float32), None
9998

10099
return calculate_qparams(
101100
min_vals=updated_min_val,

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def infer_quantization_format(
6161
)
6262
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
6363

64-
if is_weight_only: # w4a16 and w8a16
65-
if (
66-
weight_args[0].num_bits == 4
67-
and weight_args[0].type == QuantizationType.FLOAT.value
68-
):
69-
return CompressionFormat.nvfp4_pack_quantized
64+
if (
65+
weight_args[0].num_bits == 4
66+
and weight_args[0].type == QuantizationType.FLOAT.value
67+
):
68+
return CompressionFormat.nvfp4_pack_quantized
7069

70+
if is_weight_only: # w4a16 and w8a16
7171
is_valid_pack = all(
7272
weight_arg.num_bits in [4, 8]
7373
and weight_arg.type == QuantizationType.INT.value

0 commit comments

Comments
 (0)