Skip to content

Commit 40a78a6

Browse files
committed
[4975376][5541172]perplexity and kl-divergence benchmark metrics
Signed-off-by: unknown <[email protected]>
1 parent bc54694 commit 40a78a6

File tree

10 files changed

+2209
-0
lines changed

10 files changed

+2209
-0
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import onnxruntime_genai as og
6+
import torch
7+
from datasets import load_dataset
8+
9+
DEBUG = False
10+
11+
12+
def get_kl_divergence(log_probs_ref, log_probs_tar):
13+
kl_divergence = 0.0
14+
for i in range(log_probs_ref.shape[0]):
15+
log_probs_ref[i] = np.array(log_probs_ref[i])
16+
log_probs_tar[i] = np.array(log_probs_tar[i])
17+
prob_ref = np.exp(log_probs_ref[i])
18+
kl_divergence += np.sum(prob_ref * abs(log_probs_ref[i] - log_probs_tar[i]))
19+
kl_divergence = kl_divergence / log_probs_ref.shape[0]
20+
return kl_divergence
21+
22+
23+
def get_wikitext2():
24+
# Load the Wikitext-2 test split using HuggingFace datasets
25+
print("\n[INFO] Loading Wikitext-2 'test' split ...")
26+
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
27+
if DEBUG:
28+
print(f"[DATASET] Number of raw samples: {len(test)}")
29+
for i in range(3):
30+
print(f"[DATASET] Sample[{i}]: {repr(test[i]['text'])[:200]} ...")
31+
# Concatenate all text samples into a single string, separated by double newlines
32+
result = "\n\n".join(text for text in test["text"])
33+
if DEBUG:
34+
print(
35+
f"[DATASET] Concatenated text preview: {result[:512]!r} ... [total chars: {len(result)}]"
36+
)
37+
return result
38+
39+
40+
def run_kl_divergence_on_models(reference_model, target_model):
41+
ref_model = og.Model(reference_model)
42+
tar_model = og.Model(target_model)
43+
tokenizer_ref = og.Tokenizer(ref_model)
44+
tokenizer_tar = og.Tokenizer(tar_model)
45+
max_context_length = 1024
46+
dataset = get_wikitext2()
47+
48+
input_ids_ref = tokenizer_ref.encode_batch([dataset])
49+
input_ids_tar = tokenizer_tar.encode_batch([dataset])
50+
# Handle possible dict output from tokenizer
51+
if isinstance(input_ids_ref, dict) and "input_ids" in input_ids_ref:
52+
input_ids_ref = input_ids_ref["input_ids"]
53+
# Convert to numpy if needed
54+
if hasattr(input_ids_ref, "as_numpy"):
55+
input_ids_ref = input_ids_ref.as_numpy()
56+
if DEBUG:
57+
print("[TOKENIZER] Used as_numpy()")
58+
if isinstance(input_ids_tar, dict) and "input_ids" in input_ids_tar:
59+
input_ids_tar = input_ids_tar["input_ids"]
60+
if hasattr(input_ids_tar, "as_numpy"):
61+
input_ids_tar = input_ids_tar.as_numpy()
62+
if DEBUG:
63+
print("[TOKENIZER] Used as_numpy()")
64+
input_ids_ref = np.array(input_ids_ref)
65+
input_ids_tar = np.array(input_ids_tar)
66+
67+
# Ensure input_ids is 2D (batch, seq_len)
68+
if input_ids_ref.ndim == 1:
69+
input_ids_ref = np.expand_dims(input_ids_ref, 0)
70+
if DEBUG:
71+
print(f"[SHAPE] Expanded dims, now: {input_ids_ref.shape}")
72+
if input_ids_tar.ndim == 1:
73+
input_ids_tar = np.expand_dims(input_ids_tar, 0)
74+
if DEBUG:
75+
print(f"[SHAPE] Expanded dims, now: {input_ids_tar.shape}")
76+
# Convert input_ids to torch tensor
77+
input_ids_ref = torch.tensor(input_ids_ref, dtype=torch.long)
78+
input_ids_tar = torch.tensor(input_ids_tar, dtype=torch.long)
79+
seq_len_ref = int(input_ids_ref.shape[1])
80+
seq_len_tar = int(input_ids_tar.shape[1])
81+
if DEBUG:
82+
print(f"[INFO] Full input length: {seq_len_ref}")
83+
print(f"[INFO] Full input length: {seq_len_tar}")
84+
85+
if seq_len_ref != seq_len_tar:
86+
print(
87+
f"Error: Input tokenizer lengths for reference and target models do not match: "
88+
f"{seq_len_ref} != {seq_len_tar}"
89+
)
90+
return
91+
if DEBUG:
92+
print(f"[INFO] Input lengths match: {seq_len_ref}")
93+
# Slide a window over the input to compute perplexity in chunks
94+
total_kl_divergence = 0.0
95+
total_batch = 0
96+
for begin_loc in range(0, seq_len_ref, max_context_length):
97+
end_loc = min(begin_loc + max_context_length, seq_len_ref)
98+
# Extract the current chunk of input tokens
99+
input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone()
100+
input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone()
101+
if DEBUG:
102+
print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}")
103+
print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}")
104+
# Set up generator parameters for deterministic generation (no sampling)
105+
params_ref = og.GeneratorParams(ref_model)
106+
params_tar = og.GeneratorParams(tar_model)
107+
params_ref.set_search_options(
108+
max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False
109+
)
110+
params_tar.set_search_options(
111+
max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False
112+
)
113+
# Create generator and append input tokens
114+
generator_ref = og.Generator(ref_model, params_ref)
115+
generator_ref.append_tokens(input_ids_chunk_ref.numpy())
116+
generator_tar = og.Generator(tar_model, params_tar)
117+
generator_tar.append_tokens(input_ids_chunk_tar.numpy())
118+
119+
# Run the model forward pass without gradient calculation
120+
with torch.no_grad():
121+
if DEBUG:
122+
print("[INFER] Running model forward pass ...")
123+
try:
124+
generator_ref.generate_next_token()
125+
generator_tar.generate_next_token()
126+
except Exception as e:
127+
print(f"[INFER] .generate_next_token() failed: {e}")
128+
break # Fatal error
129+
# Get logits output from the model
130+
logits_ref = generator_ref.get_output("logits")
131+
logits_tar = generator_tar.get_output("logits")
132+
if DEBUG:
133+
print(f"logits_ref.shape: {logits_ref.shape}")
134+
print(f"logits_tar.shape: {logits_tar.shape}")
135+
# Convert numpy arrays to torch tensors
136+
logits_ref = torch.tensor(logits_ref, dtype=torch.float32)
137+
logits_tar = torch.tensor(logits_tar, dtype=torch.float32)
138+
# Compute log probabilities over vocabulary for each position
139+
log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy()
140+
log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy()
141+
if DEBUG:
142+
print(f"log_probs_ref.shape: {log_probs_ref.shape}")
143+
print(f"log_probs_tar.shape: {log_probs_tar.shape}")
144+
# Compute KL divergence
145+
kl_divergence = 0.0
146+
# Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256)
147+
log_probs_ref = log_probs_ref.squeeze(0)
148+
log_probs_tar = log_probs_tar.squeeze(0)
149+
150+
# log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32)
151+
# log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32)
152+
# kl_divergence = torch.nn.functional.kl_div(
153+
# log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True
154+
# )
155+
kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar)
156+
total_kl_divergence += kl_divergence
157+
total_batch += 1
158+
if DEBUG:
159+
print(f"KL divergence: {kl_divergence}")
160+
avg_kl_divergence = total_kl_divergence / total_batch
161+
if DEBUG:
162+
print(f"Average KL divergence: {avg_kl_divergence}")
163+
print(f"Total KL divergence: {total_kl_divergence}")
164+
print(f"Total batch: {total_batch}")
165+
print(f"Average KL divergence: {avg_kl_divergence}")
166+
167+
168+
def main():
169+
parser = argparse.ArgumentParser(
170+
description="Run KL divergence evaluation on ONNX Runtime GenAI models"
171+
)
172+
parser.add_argument(
173+
"--reference_model", required=True, help="Path to reference model directory"
174+
)
175+
parser.add_argument("--target_model", required=True, help="Path to target model directory")
176+
args = parser.parse_args()
177+
178+
# Validate that all model directories exist
179+
valid_models = []
180+
if os.path.exists(args.reference_model):
181+
valid_models.append(args.reference_model)
182+
else:
183+
print(f"Warning: Reference Model directory does not exist: {args.reference_model}")
184+
if os.path.exists(args.target_model):
185+
valid_models.append(args.target_model)
186+
else:
187+
print(f"Warning: Target Model directory does not exist: {args.target_model}")
188+
if len(valid_models) != 2:
189+
print("Error: No valid model directories provided")
190+
return
191+
192+
print(
193+
f"Running KL divergence evaluation on reference model={valid_models[0]} and target model={valid_models[1]}"
194+
)
195+
run_kl_divergence_on_models(valid_models[0], valid_models[1])
196+
197+
198+
if __name__ == "__main__":
199+
main()

0 commit comments

Comments
 (0)