1717from torchao .quantization .quant_primitives import MappingType
1818from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
1919
20+ torch .sparse .SparseSemiStructuredTensor ._FORCE_CUTLASS = False
21+
22+ class HostEvent :
23+ def __init__ (self ):
24+ self .event_time = None
25+
26+ def record (self ):
27+ self .event_time = time .perf_counter ()
28+
29+ def elapsed_time (self , other_event ):
30+ if self .event_time is None :
31+ raise ValueError ("Event not recorded!" )
32+ # return ms to match cuda event
33+ return abs (other_event .event_time - self .event_time ) * 1000
34+
35+ def device_timer (device ):
36+ if "cuda" in device :
37+ return torch .cuda .Event (enable_timing = True )
38+ elif ("cpu" in device ) or ("mps" in device ):
39+ return HostEvent ()
40+ else :
41+ print (f"device={ device } is not yet suppported" )
42+
2043def device_sync (device ):
2144 if "cuda" in device :
2245 torch .cuda .synchronize (device )
@@ -98,6 +121,10 @@ def generate(
98121 kv_cache_quantization : bool = False ,
99122 cache_size : Optional [int ] = None ,
100123 linear_causal_mask : bool = False ,
124+ prefill_start_event : Optional [torch .cuda .Event ]= None ,
125+ prefill_end_event : Optional [torch .cuda .Event ]= None ,
126+ decode_start_event : Optional [torch .cuda .Event ]= None ,
127+ decode_end_event : Optional [torch .cuda .Event ]= None ,
101128 ** sampling_kwargs
102129) -> torch .Tensor :
103130 """
@@ -128,12 +155,21 @@ def generate(
128155 model .setup_caches (max_batch_size = batch_size , max_seq_length = cache_size , kv_cache_quantization = kv_cache_quantization , linear_causal_mask = linear_causal_mask , prompt_length = T )
129156
130157 # execute prefill
158+ if prefill_start_event is not None :
159+ prefill_start_event .record ()
131160 next_token = prefill (model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs ).clone ()
132161 seq [:, T ] = next_token .squeeze ()
162+ if prefill_end_event is not None :
163+ prefill_end_event .record ()
164+
133165 # execute token generation
166+ if decode_start_event is not None :
167+ decode_start_event .record ()
134168 input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
135169 generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , new_tokens - 1 , callback = callback , ** sampling_kwargs )
136170 seq = torch .cat ((seq [:, :T + 1 ], * generated_tokens ), dim = - 1 )
171+ if decode_end_event is not None :
172+ decode_end_event .record ()
137173
138174 return seq
139175
@@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision):
157193B_INST , E_INST = "[INST]" , "[/INST]"
158194
159195def main (
196+ prefill_size : Optional [int ] = None ,
160197 prompt : str = "Hello, my name is" ,
161198 interactive : bool = False ,
162199 num_samples : int = 5 ,
@@ -166,6 +203,7 @@ def main(
166203 temperature : float = 0.8 ,
167204 checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
168205 quantization : Optional [str ] = None ,
206+ sparsity : Optional [str ] = None ,
169207 kv_cache_quantization : bool = False ,
170208 cache_size : Optional [int ] = None ,
171209 linear_causal_mask : bool = False ,
@@ -181,6 +219,10 @@ def main(
181219 """Generates text samples based on a pre-trained Transformer model and tokenizer.
182220 """
183221
222+ if prefill_size is not None and prefill_size > 0 :
223+ # create prompt of prefill size
224+ prompt = "prompt " * (int (prefill_size )- 3 )
225+
184226 torchao .quantization .utils .recommended_inductor_config_setter ()
185227
186228 assert checkpoint_path .is_file (), checkpoint_path
@@ -205,6 +247,14 @@ def main(
205247
206248 torch .manual_seed (1234 )
207249
250+ def ffn_only (mod , fqn ):
251+ return isinstance (mod , torch .nn .Linear ) and "feed_forward" in fqn
252+
253+ def not_ffn_only (mod , fqn ):
254+ return isinstance (mod , torch .nn .Linear ) and not ffn_only (mod , fqn )
255+
256+ def ffn_or_attn_only (mod , fqn ):
257+ return isinstance (mod , torch .nn .Linear ) and ("feed_forward" in fqn or "attention" in fqn )
208258
209259 if quantization :
210260 from torchao .quantization import (
@@ -228,9 +278,14 @@ def main(
228278 apply_spinquant (model )
229279 if "int8wo" in quantization :
230280 quantize_ (model , int8_weight_only ())
231- elif "int8dq" in quantization :
232- quantize_ (model , int8_dynamic_activation_int8_weight ())
233- elif "int4wo" in quantization :
281+ if "int8dq" in quantization :
282+ if sparsity and "semi" in sparsity :
283+ from torchao .dtypes import SemiSparseLayout
284+ quantize_ (model , int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ()), filter_fn = ffn_only )
285+ quantize_ (model , int8_dynamic_activation_int8_weight (), filter_fn = not_ffn_only )
286+ else :
287+ quantize_ (model , int8_dynamic_activation_int8_weight ())
288+ if "int4wo" in quantization :
234289 if "hqq" in quantization :
235290 use_hqq = True
236291 else :
@@ -250,9 +305,9 @@ def main(
250305 layout = MarlinQQQLayout (),
251306 ),
252307 )
253- else :
308+ elif "semi" in sparsity :
254309 from torchao .dtypes import MarlinSparseLayout
255- quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()))
310+ quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()), filter_fn = ffn_or_attn_only )
256311 if "fp6" in quantization :
257312 quantize_ (model , fpx_weight_only (3 , 2 ))
258313 elif "embed-int8wo" in quantization :
@@ -426,6 +481,13 @@ def main(
426481 if not TORCH_VERSION_AT_LEAST_2_5 :
427482 unwrap_tensor_subclass (model )
428483
484+ # standalone sparsity
485+ elif sparsity :
486+ from torchao .sparsity import semi_sparse_weight , sparsify_
487+ if "semi" in sparsity :
488+ #TODO there is a bug here, need to fix
489+ sparsify_ (model .to (device ), semi_sparse_weight (), filter_fn = ffn_only )
490+
429491 model_size = get_model_size_in_bytes (model , ignore_embeddings = True ) / 1e9
430492
431493 if save :
@@ -451,6 +513,9 @@ def main(
451513
452514 aggregate_metrics = {
453515 'tokens_per_sec' : [],
516+ 'time' : [],
517+ 'decode_tokens_per_sec' : [],
518+ 'prefill_time' : [],
454519 }
455520 start = - 1 if compile else 0
456521
@@ -485,6 +550,8 @@ def callback(x):
485550 else :
486551 callback = lambda x : x
487552 t0 = time .perf_counter ()
553+ prefill_start_event , prefill_end_event = device_timer (device ), device_timer (device )
554+ decode_start_event , decode_end_event = device_timer (device ), device_timer (device )
488555 import contextlib
489556 if (i != num_samples - 1 or not profile ):
490557 prof = contextlib .nullcontext ()
@@ -504,6 +571,10 @@ def callback(x):
504571 kv_cache_quantization = kv_cache_quantization ,
505572 cache_size = cache_size ,
506573 linear_causal_mask = linear_causal_mask ,
574+ prefill_start_event = prefill_start_event ,
575+ prefill_end_event = prefill_end_event ,
576+ decode_start_event = decode_start_event ,
577+ decode_end_event = decode_end_event ,
507578 )
508579 if i == - 1 :
509580 print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
@@ -513,7 +584,7 @@ def callback(x):
513584 device_sync (device = device ) # MKG
514585 t = time .perf_counter () - t0
515586
516- if not interactive :
587+ if not interactive and prefill_size is None :
517588 tok_list = y [0 ].tolist ()
518589 # truncate text after end of string token
519590 tokens = tok_list if not tokenizer .eos_id () in tok_list else tok_list [:tok_list .index (tokenizer .eos_id ())]
@@ -523,7 +594,14 @@ def callback(x):
523594 tokens_generated = (y .size (- 1 ) - prompt_length )
524595 tokens_sec = tokens_generated / t
525596 aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
526- print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
597+ aggregate_metrics ['time' ].append (t )
598+ decode_time = decode_start_event .elapsed_time (decode_end_event ) / 1000
599+ decode_tokens_sec = tokens_generated / decode_time
600+ aggregate_metrics ['decode_tokens_per_sec' ].append (decode_tokens_sec )
601+ prefill_time = prefill_start_event .elapsed_time (prefill_end_event ) / 1000
602+ aggregate_metrics ['prefill_time' ].append (prefill_time )
603+ print (f"Sample { i + 1 } | overall time { t :.04f} s { tokens_sec :.02f} tokens/sec" ,
604+ f"| prefill time { prefill_time :.04f} s decode { decode_tokens_sec :.02f} tokens/sec" )
527605 print (f"Bandwidth achieved: { model_size * tokens_sec :.02f} GB/s" )
528606
529607 if memory_profile and i == 0 :
@@ -544,8 +622,15 @@ def callback(x):
544622 break
545623 print ("==========" )
546624
625+ #ignore first sample for warmup
547626 tokpersec = torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ()
627+ ttft = torch .mean (torch .tensor (aggregate_metrics ['prefill_time' ])).item ()
628+ decode_tokpersec = torch .mean (torch .tensor (aggregate_metrics ['decode_tokens_per_sec' ])).item ()
548629 bandwidth = model_size * tokpersec
630+ mem = torch .cuda .max_memory_reserved () / 1e9
631+ print (f"Average overall tokens/sec: { tokpersec :.2f} " )
632+ print (f"Average decode tokens/sec: { decode_tokens_sec :.04f} s" )
633+ print (f"Average TTFT: { ttft :.04f} s" )
549634 if device == "cuda" :
550635 mem = torch .cuda .max_memory_reserved () / 1e9
551636 elif device == "xpu" :
@@ -557,15 +642,17 @@ def callback(x):
557642 print (f"Peak Memory Usage: { mem :.02f} GB" )
558643 print (f"Model Size: { model_size :.02f} GB" )
559644 if write_result :
560- result_txt = f"\n { datetime .today ().strftime ('%Y%m%d%H%M%S' )} , tok/s={ tokpersec :6.2f} , mem/s={ bandwidth :7.2f} GB/s, peak_mem={ mem :5.2f} GB, model_size={ model_size :5.2f} GB "
561- result_txt += f"quant: { quantization } , mod: { checkpoint_path .parent .name } , kv_quant: { kv_cache_quantization } , compile: { compile } , compile_prefill: { compile_prefill } , dtype: { precision } , device: { device } "
645+ result_txt = f"\n { datetime .today ().strftime ('%Y%m%d%H%M%S' )} , tok/s={ tokpersec :6.2f} , tok/s_decode= { decode_tokpersec :6.2f } , ttft= { ttft :5.4f } , mem/s={ bandwidth :7.2f} GB/s, peak_mem={ mem :5.2f} GB, model_size={ model_size :5.2f} GB "
646+ result_txt += f"quant: { quantization } , sparse: { sparsity } , mod: { checkpoint_path .parent .name } , kv_quant: { kv_cache_quantization } , compile: { compile } , compile_prefill: { compile_prefill } , dtype: { precision } , device: { device } "
562647 result_txt += f"repro: python generate.py "
563648 result_txt += f"--quantization { quantization } " if quantization else ""
649+ result_txt += f"--sparsity { sparsity } " if sparsity else ""
564650 result_txt += f"--checkpoint_path { checkpoint_path } "
565651 result_txt += f"--device { device } "
566652 result_txt += f"--precision { precision } "
567653 result_txt += f"--compile " if compile else ""
568654 result_txt += f"--compile_prefill " if compile_prefill else ""
655+ result_txt += f"--prefill_size { prefill_size } " if prefill_size else ""
569656 result_txt += f"--profile { profile } " if profile else ""
570657 result_txt += f"--profile { memory_profile } " if memory_profile else ""
571658 result_txt += f"--interactive " if interactive else ""
@@ -587,7 +674,7 @@ def callback(x):
587674if __name__ == '__main__' :
588675 import argparse
589676 parser = argparse .ArgumentParser (description = 'Your CLI description.' )
590-
677+ parser . add_argument ( '--prefill_size' , type = int , default = 0 , help = 'Whether to run in ttft mode' )
591678 parser .add_argument ('--prompt' , type = str , default = "Hello, my name is" , help = 'Input prompt.' )
592679 parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
593680 parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
@@ -603,6 +690,11 @@ def callback(x):
603690 + 'embed-int8wo, marlin_qqq'
604691 )
605692 )
693+ parser .add_argument ('-s' , '--sparsity' , type = str ,
694+ help = (
695+ 'Which sparsity techniques to apply: semi-structured'
696+ )
697+ )
606698 parser .add_argument ('--kv_cache_quantization' , action = 'store_true' , help = 'Whether to quantize the KV cache' )
607699 parser .add_argument ('--cache_size' , type = int , default = None , help = 'Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size' )
608700 parser .add_argument ('--linear_causal_mask' , action = 'store_true' , help = 'Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)' )
@@ -617,6 +709,6 @@ def callback(x):
617709
618710 args = parser .parse_args ()
619711 main (
620- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args .top_k ,
621- args .temperature , args .checkpoint_path , args .quantization , args .kv_cache_quantization , args .cache_size , args .linear_causal_mask , args .save , args .compile , args .compile_prefill , args .profile , args .memory_profile , args .device , args .precision , args .write_result
712+ args .prefill_size , args . prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args .top_k ,
713+ args .temperature , args .checkpoint_path , args .quantization , args .sparsity , args . kv_cache_quantization , args .cache_size , args .linear_causal_mask , args .save , args .compile , args .compile_prefill , args .profile , args .memory_profile , args .device , args .precision , args .write_result
622714 )
0 commit comments