diff --git a/generate.py b/generate.py index 8446d115..5accfbd6 100644 --- a/generate.py +++ b/generate.py @@ -224,7 +224,16 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyInt8QuantHandler(model) model = simple_quantizer.convert_for_runtime() - if "int4" in str(checkpoint_path): + if "int4-hqq" in str(checkpoint_path): + print("Using int4 weight-only HQQ quantization.") + from quantize import WeightOnlyInt4HqqQuantHandler + path_comps = checkpoint_path.name.split(".") + assert path_comps[-3].startswith("g") + assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" + groupsize = int(path_comps[-3][1:]) + quantizer = WeightOnlyInt4HqqQuantHandler(model, groupsize=groupsize) + model = quantizer._convert_for_runtime() + elif "int4" in str(checkpoint_path): print("Using int4 weight-only quantization!") path_comps = checkpoint_path.name.split(".") assert path_comps[-3].startswith("g") diff --git a/quantize.py b/quantize.py index 69fa9c9c..802231cb 100644 --- a/quantize.py +++ b/quantize.py @@ -519,6 +519,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) +# TODO a hacky placeholder class +class WeightOnlyInt4HqqQuantHandler: + def __init__(self, mod, groupsize): + self.mod = mod + self.groupsize = groupsize + + def _create_quantized_state_dict(self): + from hqq.core.quantize import Quantizer # TODO maybe torchao + + for m in self.mod.modules(): + for name, child in m.named_children(): + if isinstance(child, torch.nn.Linear): + child.weight = torch.nn.Parameter( + Quantizer.dequantize( + *Quantizer.quantize( + child.weight, + nbits=4, + group_size=self.groupsize, + axis=1, + ) + ) + ) + + return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict() + + def _convert_for_runtime(self): + return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True) def quantize( checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), @@ -592,6 +619,18 @@ def quantize( dir_name = checkpoint_path.parent base_name = checkpoint_path.name new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth") + + elif mode == 'int4-hqq': + print("Quantizing model weights for int4 using HQQ") + quant_handler = WeightOnlyInt4HqqQuantHandler(model, groupsize) + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + + quantized_state_dict = quant_handler._create_quantized_state_dict() + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4-hqq.g{groupsize}.{device}.pth") else: raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") @@ -606,7 +645,7 @@ def quantize( import argparse parser = argparse.ArgumentParser(description='Quantize a model.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') - parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') + parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'int4-hqq'], help='type of quantization to perform') parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..5879818f --- /dev/null +++ b/run.sh @@ -0,0 +1,27 @@ +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working +# echo "base" +# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5 + + +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-hqq +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile +python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --tasks wikitext + +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --compile +python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +# broken + +# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5