Skip to content

Commit f04aec7

Browse files
jcaipvkuzo
authored andcommitted
Add TTFT benchmarks + update sparsity benchmarks (#1140)
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available. Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
1 parent 65b885f commit f04aec7

File tree

5 files changed

+136
-16
lines changed

5 files changed

+136
-16
lines changed

scripts/prepare.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
22
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
33
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
44
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
5+
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
56
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
67
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
78
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
89
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
10+
# neuralmagic doesn't come with tokenizer, so we need to copy it over
11+
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model
12+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4

test/prototype/test_sparse_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def test_sparse(self):
5050
sparsify_(model, semi_sparse_weight())
5151
sparse_result = model(input)
5252

53+
if compile:
54+
model = torch.compile(model)
55+
5356
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
5457

5558

torchao/_models/llama/benchmarks.sh

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
5252
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
5353
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
5454
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt
55-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
55+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
5656
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
5757
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
5858

@@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
6262
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
6363
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
6464
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
65-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
65+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
6666
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
6767
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
6868

@@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
7979
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
8080
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
8181
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128
82+
83+
# TTFT benchmarks
84+
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
85+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000
86+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000
87+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000
88+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000
89+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000
90+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000
91+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
92+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured
93+
94+
# 2:4 sparse model
95+
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
96+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt
97+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
98+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt

torchao/_models/llama/generate.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@
1717
from torchao.quantization.quant_primitives import MappingType
1818
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1919

20+
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
21+
22+
class HostEvent:
23+
def __init__(self):
24+
self.event_time = None
25+
26+
def record(self):
27+
self.event_time = time.perf_counter()
28+
29+
def elapsed_time(self, other_event):
30+
if self.event_time is None:
31+
raise ValueError("Event not recorded!")
32+
# return ms to match cuda event
33+
return abs(other_event.event_time - self.event_time) * 1000
34+
35+
def device_timer(device):
36+
if "cuda" in device:
37+
return torch.cuda.Event(enable_timing=True)
38+
elif ("cpu" in device) or ("mps" in device):
39+
return HostEvent()
40+
else:
41+
print(f"device={device} is not yet suppported")
42+
2043
def device_sync(device):
2144
if "cuda" in device:
2245
torch.cuda.synchronize(device)
@@ -98,6 +121,10 @@ def generate(
98121
kv_cache_quantization: bool = False,
99122
cache_size: Optional[int] = None,
100123
linear_causal_mask: bool=False,
124+
prefill_start_event: Optional[torch.cuda.Event]=None,
125+
prefill_end_event: Optional[torch.cuda.Event]=None,
126+
decode_start_event: Optional[torch.cuda.Event]=None,
127+
decode_end_event: Optional[torch.cuda.Event]=None,
101128
**sampling_kwargs
102129
) -> torch.Tensor:
103130
"""
@@ -128,12 +155,21 @@ def generate(
128155
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)
129156

130157
# execute prefill
158+
if prefill_start_event is not None:
159+
prefill_start_event.record()
131160
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
132161
seq[:, T] = next_token.squeeze()
162+
if prefill_end_event is not None:
163+
prefill_end_event.record()
164+
133165
# execute token generation
166+
if decode_start_event is not None:
167+
decode_start_event.record()
134168
input_pos = torch.tensor([T], device=device, dtype=torch.int)
135169
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
136170
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
171+
if decode_end_event is not None:
172+
decode_end_event.record()
137173

138174
return seq
139175

@@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision):
157193
B_INST, E_INST = "[INST]", "[/INST]"
158194

159195
def main(
196+
prefill_size: Optional[int] = None,
160197
prompt: str = "Hello, my name is",
161198
interactive: bool = False,
162199
num_samples: int = 5,
@@ -166,6 +203,7 @@ def main(
166203
temperature: float = 0.8,
167204
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
168205
quantization: Optional[str] = None,
206+
sparsity: Optional[str] = None,
169207
kv_cache_quantization: bool = False,
170208
cache_size: Optional[int] = None,
171209
linear_causal_mask: bool=False,
@@ -181,6 +219,10 @@ def main(
181219
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
182220
"""
183221

222+
if prefill_size is not None and prefill_size > 0:
223+
# create prompt of prefill size
224+
prompt = "prompt " * (int(prefill_size)-3)
225+
184226
torchao.quantization.utils.recommended_inductor_config_setter()
185227

186228
assert checkpoint_path.is_file(), checkpoint_path
@@ -205,6 +247,14 @@ def main(
205247

206248
torch.manual_seed(1234)
207249

250+
def ffn_only(mod, fqn):
251+
return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn
252+
253+
def not_ffn_only(mod, fqn):
254+
return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn)
255+
256+
def ffn_or_attn_only(mod, fqn):
257+
return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn)
208258

209259
if quantization:
210260
from torchao.quantization import (
@@ -228,9 +278,14 @@ def main(
228278
apply_spinquant(model)
229279
if "int8wo" in quantization:
230280
quantize_(model, int8_weight_only())
231-
elif "int8dq" in quantization:
232-
quantize_(model, int8_dynamic_activation_int8_weight())
233-
elif "int4wo" in quantization:
281+
if "int8dq" in quantization:
282+
if sparsity and "semi" in sparsity:
283+
from torchao.dtypes import SemiSparseLayout
284+
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only)
285+
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only)
286+
else:
287+
quantize_(model, int8_dynamic_activation_int8_weight())
288+
if "int4wo" in quantization:
234289
if "hqq" in quantization:
235290
use_hqq=True
236291
else:
@@ -250,9 +305,9 @@ def main(
250305
layout=MarlinQQQLayout(),
251306
),
252307
)
253-
else:
308+
elif "semi" in sparsity:
254309
from torchao.dtypes import MarlinSparseLayout
255-
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
310+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only)
256311
if "fp6" in quantization:
257312
quantize_(model, fpx_weight_only(3, 2))
258313
elif "embed-int8wo" in quantization:
@@ -426,6 +481,13 @@ def main(
426481
if not TORCH_VERSION_AT_LEAST_2_5:
427482
unwrap_tensor_subclass(model)
428483

484+
# standalone sparsity
485+
elif sparsity:
486+
from torchao.sparsity import semi_sparse_weight, sparsify_
487+
if "semi" in sparsity:
488+
#TODO there is a bug here, need to fix
489+
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)
490+
429491
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
430492

431493
if save:
@@ -451,6 +513,9 @@ def main(
451513

452514
aggregate_metrics = {
453515
'tokens_per_sec': [],
516+
'time': [],
517+
'decode_tokens_per_sec': [],
518+
'prefill_time': [],
454519
}
455520
start = -1 if compile else 0
456521

@@ -485,6 +550,8 @@ def callback(x):
485550
else:
486551
callback = lambda x : x
487552
t0 = time.perf_counter()
553+
prefill_start_event, prefill_end_event = device_timer(device), device_timer(device)
554+
decode_start_event, decode_end_event = device_timer(device), device_timer(device)
488555
import contextlib
489556
if (i != num_samples - 1 or not profile):
490557
prof = contextlib.nullcontext()
@@ -504,6 +571,10 @@ def callback(x):
504571
kv_cache_quantization=kv_cache_quantization,
505572
cache_size=cache_size,
506573
linear_causal_mask=linear_causal_mask,
574+
prefill_start_event=prefill_start_event,
575+
prefill_end_event=prefill_end_event,
576+
decode_start_event=decode_start_event,
577+
decode_end_event=decode_end_event,
507578
)
508579
if i == -1:
509580
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
@@ -513,7 +584,7 @@ def callback(x):
513584
device_sync(device=device) # MKG
514585
t = time.perf_counter() - t0
515586

516-
if not interactive:
587+
if not interactive and prefill_size is None:
517588
tok_list = y[0].tolist()
518589
# truncate text after end of string token
519590
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
@@ -523,7 +594,14 @@ def callback(x):
523594
tokens_generated = (y.size(-1) - prompt_length)
524595
tokens_sec = tokens_generated / t
525596
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
526-
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
597+
aggregate_metrics['time'].append(t)
598+
decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000
599+
decode_tokens_sec = tokens_generated / decode_time
600+
aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec)
601+
prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000
602+
aggregate_metrics['prefill_time'].append(prefill_time)
603+
print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec",
604+
f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec")
527605
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")
528606

529607
if memory_profile and i==0:
@@ -544,8 +622,15 @@ def callback(x):
544622
break
545623
print("==========")
546624

625+
#ignore first sample for warmup
547626
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
627+
ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item()
628+
decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item()
548629
bandwidth = model_size * tokpersec
630+
mem = torch.cuda.max_memory_reserved() /1e9
631+
print(f"Average overall tokens/sec: {tokpersec:.2f}")
632+
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
633+
print(f"Average TTFT: {ttft:.04f} s")
549634
if device == "cuda":
550635
mem = torch.cuda.max_memory_reserved() /1e9
551636
elif device == "xpu":
@@ -557,15 +642,17 @@ def callback(x):
557642
print(f"Peak Memory Usage: {mem:.02f} GB")
558643
print(f"Model Size: {model_size:.02f} GB")
559644
if write_result:
560-
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
561-
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
645+
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
646+
result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
562647
result_txt += f"repro: python generate.py "
563648
result_txt += f"--quantization {quantization} " if quantization else ""
649+
result_txt += f"--sparsity {sparsity} " if sparsity else ""
564650
result_txt += f"--checkpoint_path {checkpoint_path} "
565651
result_txt += f"--device {device} "
566652
result_txt += f"--precision {precision} "
567653
result_txt += f"--compile " if compile else ""
568654
result_txt += f"--compile_prefill " if compile_prefill else ""
655+
result_txt += f"--prefill_size {prefill_size}" if prefill_size else ""
569656
result_txt += f"--profile {profile} " if profile else ""
570657
result_txt += f"--profile {memory_profile} " if memory_profile else ""
571658
result_txt += f"--interactive " if interactive else ""
@@ -587,7 +674,7 @@ def callback(x):
587674
if __name__ == '__main__':
588675
import argparse
589676
parser = argparse.ArgumentParser(description='Your CLI description.')
590-
677+
parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode')
591678
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
592679
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
593680
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
@@ -603,6 +690,11 @@ def callback(x):
603690
+'embed-int8wo, marlin_qqq'
604691
)
605692
)
693+
parser.add_argument('-s', '--sparsity', type=str,
694+
help=(
695+
'Which sparsity techniques to apply: semi-structured'
696+
)
697+
)
606698
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
607699
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
608700
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
@@ -617,6 +709,6 @@ def callback(x):
617709

618710
args = parser.parse_args()
619711
main(
620-
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
621-
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
712+
args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
713+
args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
622714
)

0 commit comments

Comments
 (0)