From a72a2c516c6e96bd3367491ed19765ac7f2bb03a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 9 Apr 2024 11:13:02 -0700 Subject: [PATCH 1/2] testing HQQ [not for land] Summary: for eval=5 wikitext: {'word_perplexity,none': 11.49343838017535, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6110947678444059, 'byte_perplexity_stderr,none': for eval all ... Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- generate.py | 11 ++++++++++- quantize.py | 41 ++++++++++++++++++++++++++++++++++++++++- run.sh | 27 +++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 run.sh diff --git a/generate.py b/generate.py index 8446d115..5accfbd6 100644 --- a/generate.py +++ b/generate.py @@ -224,7 +224,16 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyInt8QuantHandler(model) model = simple_quantizer.convert_for_runtime() - if "int4" in str(checkpoint_path): + if "int4-hqq" in str(checkpoint_path): + print("Using int4 weight-only HQQ quantization.") + from quantize import WeightOnlyInt4HqqQuantHandler + path_comps = checkpoint_path.name.split(".") + assert path_comps[-3].startswith("g") + assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" + groupsize = int(path_comps[-3][1:]) + quantizer = WeightOnlyInt4HqqQuantHandler(model, groupsize=groupsize) + model = quantizer._convert_for_runtime() + elif "int4" in str(checkpoint_path): print("Using int4 weight-only quantization!") path_comps = checkpoint_path.name.split(".") assert path_comps[-3].startswith("g") diff --git a/quantize.py b/quantize.py index 69fa9c9c..802231cb 100644 --- a/quantize.py +++ b/quantize.py @@ -519,6 +519,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) +# TODO a hacky placeholder class +class WeightOnlyInt4HqqQuantHandler: + def __init__(self, mod, groupsize): + self.mod = mod + self.groupsize = groupsize + + def _create_quantized_state_dict(self): + from hqq.core.quantize import Quantizer # TODO maybe torchao + + for m in self.mod.modules(): + for name, child in m.named_children(): + if isinstance(child, torch.nn.Linear): + child.weight = torch.nn.Parameter( + Quantizer.dequantize( + *Quantizer.quantize( + child.weight, + nbits=4, + group_size=self.groupsize, + axis=1, + ) + ) + ) + + return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict() + + def _convert_for_runtime(self): + return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True) def quantize( checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), @@ -592,6 +619,18 @@ def quantize( dir_name = checkpoint_path.parent base_name = checkpoint_path.name new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth") + + elif mode == 'int4-hqq': + print("Quantizing model weights for int4 using HQQ") + quant_handler = WeightOnlyInt4HqqQuantHandler(model, groupsize) + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + + quantized_state_dict = quant_handler._create_quantized_state_dict() + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4-hqq.g{groupsize}.{device}.pth") else: raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") @@ -606,7 +645,7 @@ def quantize( import argparse parser = argparse.ArgumentParser(description='Quantize a model.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') - parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') + parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'int4-hqq'], help='type of quantization to perform') parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..d0e6a202 --- /dev/null +++ b/run.sh @@ -0,0 +1,27 @@ +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working +# echo "base" +# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5 + + +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-hqq +python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile +python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --tasks wikitext + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --compile +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +# broken + +# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 From 8efb00dcd31edb94f2d60f7dca92f5147b74eefe Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 9 Apr 2024 11:21:25 -0700 Subject: [PATCH 2/2] Update on "testing HQQ [not for land]" Summary: for eval=5 wikitext: {'word_perplexity,none': 11.49343838017535, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6110947678444059, 'byte_perplexity_stderr,none': for eval all ... Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- run.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/run.sh b/run.sh index d0e6a202..5879818f 100644 --- a/run.sh +++ b/run.sh @@ -8,12 +8,12 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-hqq -python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --tasks wikitext -# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 # python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --compile -# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext # python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 # python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5