@@ -242,6 +242,29 @@ def _untie_weights_and_save_locally(model_id):
242242tokenizer = AutoTokenizer.from_pretrained(model_id)
243243"""
244244
245+ _int8_int4_hqq_quant_code = """
246+ from torchao.quantization.quant_api import (
247+ IntxWeightOnlyConfig,
248+ Int8DynamicActivationIntxWeightConfig,
249+ ModuleFqnToConfig,
250+ )
251+ from torchao.quantization.granularity import PerGroup, PerAxis
252+ embedding_config = IntxWeightOnlyConfig(
253+ weight_dtype=torch.int8,
254+ granularity=PerAxis(0),
255+ intx_choose_qparams_algorithm="hqq_scale_only",
256+ )
257+ linear_config = Int8DynamicActivationIntxWeightConfig(
258+ weight_dtype=torch.int4,
259+ weight_granularity=PerGroup(32),
260+ intx_choose_qparams_algorithm="hqq_scale_only",
261+ )
262+ quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
263+ quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[])
264+ quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
265+ tokenizer = AutoTokenizer.from_pretrained(model_id)
266+ """
267+
245268_awq_int4_quant_code = """
246269from torchao.quantization import Int4WeightOnlyConfig, quantize_
247270from torchao.prototype.awq import (
@@ -589,14 +612,8 @@ def quantize_and_upload(
589612 push_to_user_id : str ,
590613 populate_model_card_template : bool ,
591614):
592- _int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig (
593- weight_dtype = torch .int4 ,
594- weight_granularity = PerGroup (32 ),
595- )
596- _int8_int4_embedding_config = IntxWeightOnlyConfig (
597- weight_dtype = torch .int8 ,
598- granularity = PerAxis (0 ),
599- )
615+ is_mobile = quant in ["INT8-INT4" , "INT8-INT4-HQQ" ]
616+
600617 quant_to_config = {
601618 "FP8" : Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
602619 "INT4" : Int4WeightOnlyConfig (
@@ -606,8 +623,28 @@ def quantize_and_upload(
606623 ),
607624 "INT8-INT4" : ModuleFqnToConfig (
608625 {
609- "_default" : _int8_int4_linear_config ,
610- "model.embed_tokens" : _int8_int4_embedding_config ,
626+ "_default" : Int8DynamicActivationIntxWeightConfig (
627+ weight_dtype = torch .int4 ,
628+ weight_granularity = PerGroup (32 ),
629+ ),
630+ "model.embed_tokens" : IntxWeightOnlyConfig (
631+ weight_dtype = torch .int8 ,
632+ granularity = PerAxis (0 ),
633+ ),
634+ }
635+ ),
636+ "INT8-INT4-HQQ" : ModuleFqnToConfig (
637+ {
638+ "_default" : Int8DynamicActivationIntxWeightConfig (
639+ weight_dtype = torch .int4 ,
640+ weight_granularity = PerGroup (32 ),
641+ intx_choose_qparams_algorithm = "hqq_scale_only" ,
642+ ),
643+ "model.embed_tokens" : IntxWeightOnlyConfig (
644+ weight_dtype = torch .int8 ,
645+ granularity = PerAxis (0 ),
646+ intx_choose_qparams_algorithm = "hqq_scale_only" ,
647+ ),
611648 }
612649 ),
613650 }
@@ -616,12 +653,13 @@ def quantize_and_upload(
616653 "FP8" : _fp8_quant_code ,
617654 "INT4" : _int4_quant_code ,
618655 "INT8-INT4" : _int8_int4_quant_code ,
656+ "INT8-INT4-HQQ" : _int8_int4_hqq_quant_code ,
619657 "AWQ-INT4" : _awq_int4_quant_code ,
620658 }
621659
622660 # preparation
623661 model_to_quantize = model_id
624- if quant == "INT8-INT4" :
662+ if is_mobile :
625663 model_to_quantize = _untie_weights_and_save_locally (model_to_quantize )
626664
627665 # quantization
@@ -666,7 +704,7 @@ def quantize_and_upload(
666704 quant_config = quant_to_config [quant ]
667705
668706 torchao_config_kwargs = {}
669- if "INT8-INT4" in quant :
707+ if is_mobile :
670708 torchao_config_kwargs ["modules_to_not_convert" ] = []
671709 torchao_config_kwargs ["include_input_output_embeddings" ] = True
672710
@@ -688,7 +726,6 @@ def quantize_and_upload(
688726 save_to_user_id = username if push_to_user_id is None else push_to_user_id
689727 save_to = f"{ save_to_user_id } /{ MODEL_NAME } -{ quant } "
690728 untied_model_path = 'f"{{MODEL_NAME}}-untied-weights"'
691- is_mobile = quant == "INT8-INT4"
692729 quantized_model_id = save_to
693730 # model card
694731 content = MODEL_CARD .format (
@@ -775,7 +812,7 @@ def quantize_and_upload(
775812 parser .add_argument (
776813 "--quant" ,
777814 type = str ,
778- help = "Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4" ,
815+ help = "Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4" ,
779816 )
780817 parser .add_argument (
781818 "--tasks" ,
0 commit comments