Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 24 additions & 93 deletions benchmarks/benchmark_aq.py
Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about inlining _int8wo_api, _int8da_int8w_api, _int4wo_api ? They are used only once across codebase.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think that's fine if they're only used in benchmarks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc @jainapurva, can you take a look at the benchmark changes?

Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,6 @@
_replace_with_custom_fn_if_matches_filter,
quantize_,
)
from torchao.quantization.subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)


def _int8wo_api(mod, **kwargs):
quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False)


def _int8da_int8w_api(mod, **kwargs):
quantize_(
mod,
Int8DynamicActivationInt8WeightConfig(**kwargs),
set_inductor_config=False,
)


def _int4wo_api(mod, **kwargs):
kwargs_copy = kwargs.copy()
if "groupsize" in kwargs_copy:
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
del kwargs_copy["groupsize"]
quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False)


class ToyLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -68,34 +44,6 @@ def forward(self, x):
return x


def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import (
_get_subclass_inserter,
_is_linear,
)
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

def _in_features_greater_than_16(mod, *args):
return hasattr(mod, "in_features") and mod.in_features > 16

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
),
filter_fn,
)


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
Expand All @@ -117,38 +65,18 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
return _ref_change_linear_weights_to_woqtensors


_ref_change_linear_weights_to_int8_woqtensors = (
_get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
)
_ref_change_linear_weights_to_int4_woqtensors = (
_get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
)


torch._dynamo.config.cache_size_limit = 50000


@torch.no_grad
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
if kwargs is None:
kwargs = {}

def _bench_quantized_tensor_subclass_perf(api, config, M, N, K):
m = ToyLinearModel(
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
).eval()
m_bf16 = copy.deepcopy(m)
m_ref = copy.deepcopy(m)
example_inputs = m.example_inputs()

api(m, **kwargs)

# reference
ref_api(m_ref, **kwargs)

res = m(*example_inputs)
ref = m_ref(*example_inputs)

assert torch.equal(res, ref)
api(m, config) # Pass both model and config

# perf comparison
from torchao.utils import benchmark_model
Expand All @@ -158,22 +86,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
RUNS = 100

torch._dynamo.reset()
m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True)
benchmark_model(m_ref, WARMUP, example_inputs)
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

torch._dynamo.reset()
m = torch.compile(m, mode="max-autotune", fullgraph=True)
benchmark_model(m, WARMUP, example_inputs)
elapsed_time = benchmark_model(m, RUNS, example_inputs)

torch._dynamo.reset()
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

print(
f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
)


Expand All @@ -182,24 +105,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
(20, 2048, 2048),
]

print("_int8da_int8w_api")

print("Int8DynamicActivationInt8WeightConfig")
for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily updated to use new APIs 2 times to fix CI, but maybe we can update _bench_quantized_tensor_subclass_perf to compare only original vs. new quantization flows?

_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K
quantize_,
Int8DynamicActivationInt8WeightConfig(),
M,
N,
K,
)

print("_int8wo_api")

print("Int8WeightOnlyConfig")
for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(
_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K
quantize_,
Int8WeightOnlyConfig(),
M,
N,
K,
)

print("_int4wo_api")
kwargs = {"groupsize": 32, "version": 1}

print("Int4WeightOnlyConfig")
for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(
_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs
quantize_,
Int4WeightOnlyConfig(group_size=32),
M,
N,
K,
)
118 changes: 0 additions & 118 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@
smooth_fq_linear_to_inference,
swap_linear_with_smooth_fq_linear,
)
from torchao.quantization.subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import (
LoggingTensorMode,
_apply_logging_hook,
Expand Down Expand Up @@ -681,62 +676,6 @@ def _test_dequantize_impl(
f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}",
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_dequantize_int8_dynamic_quant_subclass(self, device, dtype):
self._test_dequantize_impl(
Int8DynamicallyQuantizedLinearWeight.from_float,
device,
35,
test_dtype=dtype,
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
self._test_dequantize_impl(
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@skip_if_rocm("ROCm enablement in progress")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
for test_shape in [(16, 1024, 16)] + (
[(1, 1024, 8)] if device == "cuda" else []
):
self._test_dequantize_impl(
Int4WeightOnlyQuantizedLinearWeight.from_float,
device,
15,
test_shape=test_shape,
test_dtype=dtype,
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@skip_if_rocm("ROCm enablement in progress")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
m_shapes = [16, 256] + ([1] if device == "cuda" else [])
n_shapes = [16] + ([8, 13] if device == "cuda" else [])
for groupsize in [256, 128]:
for inner_k_tiles in [8, 4, 2]:
for m in m_shapes:
for n in n_shapes:
self._test_dequantize_impl(
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(
w, groupsize, inner_k_tiles
),
device,
15,
test_shape=[m, 256, n],
test_dtype=dtype,
)

@run_supported_device_dtype
def _test_lin_weight_subclass_impl(
self,
Expand Down Expand Up @@ -771,22 +710,6 @@ def _test_lin_weight_subclass_impl(
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}",
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
Int8DynamicallyQuantizedLinearWeight.from_float,
device,
35,
test_dtype=dtype,
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_weight_only_quant_subclass(self, device, dtype):
undo_recommended_configs()
self._test_lin_weight_subclass_impl(
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
Expand Down Expand Up @@ -891,46 +814,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
test_dtype=dtype,
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
for test_shape in [(16, 1024, 16)] + (
[(1, 1024, 8)] if device == "cuda" else []
):
self._test_lin_weight_subclass_impl(
Int4WeightOnlyQuantizedLinearWeight.from_float,
device,
10,
test_shape=test_shape,
test_dtype=dtype,
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@skip_if_rocm("ROCm enablement in progress")
@unittest.skip("Skip to fix CI until we deprecate these APIs long term")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
m_shapes = [16, 256] + ([1] if device == "cuda" else [])
n_shapes = [16] + ([8, 13] if device == "cuda" else [])
for groupsize in [128, 64]:
for inner_k_tiles in [8, 4, 2]:
for m in m_shapes:
for n in n_shapes:
self._test_lin_weight_subclass_impl(
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(
w, groupsize, inner_k_tiles
),
device,
10,
test_shape=[m, 256, n],
test_dtype=dtype,
)

@torch.no_grad()
@run_supported_device_dtype
def _test_lin_weight_subclass_api_impl(
Expand Down Expand Up @@ -1120,7 +1003,6 @@ def test_dynamic_quant(self):

sqnr = compute_error(y_ref, y_test)
self.assertGreater(sqnr, 40.0)
# self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear))


class TestWeightOnlyInt8Quant(unittest.TestCase):
Expand Down
Loading