Skip to content

Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig #2474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Inference APIs for quantize\_
:nosignatures:

Int4WeightOnlyConfig
Float8ActivationInt4WeightConfig
Float8DynamicActivationFloat8WeightConfig
Float8WeightOnlyConfig
Float8StaticActivationFloat8WeightConfig
Expand Down
62 changes: 62 additions & 0 deletions test/integration/test_serialization_bc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from transformers import AutoModelForCausalLM, AutoTokenizer

from torchao.utils import _is_fbgemm_genai_gpu_available, is_sm_at_least_90

_MODEL_NAMES = [
"torchao-testing/opt-125m-float8dq-row-0.13-dev",
"torchao-testing/opt-125m-int4wo-preshuffled-0.13-dev",
]


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestSerializationBC(TestCase):
"""Test we can still load and run serialized model in previous AO versions
we commit to have BC for 3 pytorch releases
"""

@common_utils.parametrize("model_name", _MODEL_NAMES)
def test_load_model_and_run(self, model_name):
if "fbgemm" in model_name and not _is_fbgemm_genai_gpu_available():
# TODO: this is not enabled in CI, enable this after new fbgemm releases
print("can't run fbgemm model without fbgemm_genai_gpu installed")
return
# Load and quantize model
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="bfloat16",
device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = ("Hello, my name is",)

inputs = tokenizer(
prompt,
return_tensors="pt",
).to("cuda")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
# make sure it runs
_ = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)


common_utils.instantiate_parametrized_tests(TestSerializationBC)

if __name__ == "__main__":
run_tests()
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
run_tests,
)

from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
Float8ActivationInt4WeightConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -27,44 +27,15 @@
is_sm_at_least_90,
)

if TORCH_VERSION_AT_LEAST_2_8:
BF16_ACT_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

BF16_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

FP8_ACT_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

FP8_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

else:
BF16_ACT_CONFIG = None
BF16_ACT_BMM_CONFIG = None
FP8_ACT_CONFIG = None
FP8_ACT_BMM_CONFIG = None
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="preshuffled",
)

FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
group_size=128,
packing_format="preshuffled",
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
Expand All @@ -90,7 +61,7 @@ def test_linear(self, config):

# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
def test_bmm(self, bmm_config):
class M(torch.nn.Module):
def __init__(self, weight):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from torchao.quantization import (
FbgemmConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -26,19 +26,11 @@
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestFbgemmInt4Tensor(TestCase):
class TestInt4Tensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
self.config = Int4WeightOnlyConfig(
group_size=128,
packing_format="plain",
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

Expand Down Expand Up @@ -68,13 +60,9 @@ def test_slice(self):
quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)
self.assertEqual(
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
)
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
self.assertEqual(
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
)
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))

# check for sliced weight, before and after float8 quantization
Expand All @@ -100,12 +88,10 @@ def test_slice_and_copy_(self):
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert (
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
)
assert param.data._data.data_ptr() == param_data._data.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
orig_value = param.data.packed_weight[0][0].item()
orig_value = param.data._data[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
Expand All @@ -116,7 +102,7 @@ def test_slice_and_copy_(self):
param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value
assert param.data._data[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
Expand All @@ -135,7 +121,7 @@ def forward(self, x):
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

Expand Down
3 changes: 0 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to_affine_quantized_intx_static,
)
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -64,8 +63,6 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_int4",
"FbgemmInt4Tensor",
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
"Int8DynamicActInt4WeightCPULayout",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .quant_api import (
CutlassInt4PackedLayout,
FbgemmConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan for FbgemmConfig? Looks like it was added only ~1.5 months ago but it's technically public API. Do we know if anyone's using it already? I don't think it's released yet so wonder if it's OK to just remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll remove it, it is used in some internal script but we'll update these as well

Float8ActivationInt4WeightConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
Expand Down Expand Up @@ -90,6 +91,7 @@
from .quantize_.workflows import (
Float8Tensor,
Int4PreshuffledTensor,
Int4Tensor,
)
from .smoothquant import (
SmoothFakeDynamicallyQuantizedLinear,
Expand Down Expand Up @@ -141,6 +143,7 @@
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int4WeightOnlyConfig",
"Float8ActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
Expand All @@ -154,6 +157,7 @@
"ModuleFqnToConfig",
"FbgemmConfig",
# tensor subclasses
"Int4Tensor",
"Int4PreshuffledTensor",
"Float8Tensor",
# smooth quant - subject to change
Expand Down
Loading
Loading