Skip to content

Commit 6f0ecd3

Browse files
committed
[4975376][5541172]Handle review comments
Signed-off-by: unknown <[email protected]>
1 parent 40a78a6 commit 6f0ecd3

File tree

9 files changed

+449
-77
lines changed

9 files changed

+449
-77
lines changed

examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
"""
2+
Optimized KL divergence comparison for ONNX Runtime GenAI models with the same execution provider.
3+
4+
This script efficiently compares two ONNX Runtime GenAI models by computing KL divergence
5+
between their output distributions without package switching overhead.
6+
7+
Usage:
8+
python KL_divergence_metrics_same_ep.py \\
9+
--reference_model "path/to/reference/model" \\
10+
--target_model "path/to/target/model"
11+
"""
12+
113
import argparse
214
import os
315

@@ -10,6 +22,22 @@
1022

1123

1224
def get_kl_divergence(log_probs_ref, log_probs_tar):
25+
"""
26+
Compute Kullback-Leibler divergence between two log probability distributions.
27+
28+
KL divergence measures how one probability distribution diverges from a reference
29+
distribution. Lower values indicate more similar distributions.
30+
31+
Args:
32+
log_probs_ref (np.ndarray): Reference log probabilities with shape (seq_len, vocab_size).
33+
log_probs_tar (np.ndarray): Target log probabilities with shape (seq_len, vocab_size).
34+
35+
Returns:
36+
float: Average KL divergence across all positions.
37+
38+
Note:
39+
Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length
40+
"""
1341
kl_divergence = 0.0
1442
for i in range(log_probs_ref.shape[0]):
1543
log_probs_ref[i] = np.array(log_probs_ref[i])
@@ -21,6 +49,16 @@ def get_kl_divergence(log_probs_ref, log_probs_tar):
2149

2250

2351
def get_wikitext2():
52+
"""
53+
Load and concatenate the WikiText-2 test dataset.
54+
55+
Returns:
56+
str: Concatenated text from all samples in the WikiText-2 test split,
57+
with samples separated by double newlines.
58+
59+
Note:
60+
Requires HuggingFace CLI authentication to access the dataset.
61+
"""
2462
# Load the Wikitext-2 test split using HuggingFace datasets
2563
print("\n[INFO] Loading Wikitext-2 'test' split ...")
2664
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
@@ -38,6 +76,18 @@ def get_wikitext2():
3876

3977

4078
def run_kl_divergence_on_models(reference_model, target_model):
79+
"""
80+
Compute KL divergence between two ONNX Runtime GenAI models on WikiText-2 dataset.
81+
82+
This function loads both models, processes the WikiText-2 dataset in chunks, and
83+
computes the KL divergence between their output distributions for each chunk.
84+
The results are averaged across all chunks.
85+
86+
Args:
87+
reference_model (str): Path to the reference ONNX Runtime GenAI model directory.
88+
target_model (str): Path to the target ONNX Runtime GenAI model directory.
89+
90+
"""
4191
ref_model = og.Model(reference_model)
4292
tar_model = og.Model(target_model)
4393
tokenizer_ref = og.Tokenizer(ref_model)
@@ -79,8 +129,8 @@ def run_kl_divergence_on_models(reference_model, target_model):
79129
seq_len_ref = int(input_ids_ref.shape[1])
80130
seq_len_tar = int(input_ids_tar.shape[1])
81131
if DEBUG:
82-
print(f"[INFO] Full input length: {seq_len_ref}")
83-
print(f"[INFO] Full input length: {seq_len_tar}")
132+
print(f"[INFO] Ref input length: {seq_len_ref}")
133+
print(f"[INFO] Tar input length: {seq_len_tar}")
84134

85135
if seq_len_ref != seq_len_tar:
86136
print(
@@ -166,6 +216,22 @@ def run_kl_divergence_on_models(reference_model, target_model):
166216

167217

168218
def main():
219+
"""
220+
Command-line entry point for optimized KL divergence comparison of same-EP models.
221+
222+
This script is optimized for comparing two ONNX Runtime GenAI models that use
223+
the same execution provider, avoiding package switching overhead. It computes
224+
KL divergence between model outputs on the WikiText-2 dataset.
225+
226+
Command-line Arguments:
227+
--reference_model: Path to reference model directory (required)
228+
--target_model: Path to target model directory (required)
229+
230+
Example:
231+
$ python KL_divergence_metrics_same_ep.py \\
232+
--reference_model "G:\\models\\cuda_fp16" \\
233+
--target_model "G:\\models\\cuda_int4"
234+
"""
169235
parser = argparse.ArgumentParser(
170236
description="Run KL divergence evaluation on ONNX Runtime GenAI models"
171237
)

examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ The toolkit includes several Python scripts for:
3535
pip install -r requirements.txt
3636
```
3737

38-
Note: Install torch with cuda for faster inference "pip install torch torchvision torchaudio --index-url <https://download.pytorch.org/whl/cu129>"
38+
Note: Install torch with CUDA for faster inference:
39+
"pip install torch torchvision torchaudio --index-url <https://download.pytorch.org/whl/cu129>"
3940

4041
2. **Install execution provider-specific packages** (as needed):
4142

@@ -273,11 +274,13 @@ JSON files containing:
273274
- Minimizes package switching by reusing environments when possible
274275
- Handles CUDA, DirectML, and CPU providers seamlessly
275276

277+
Warning: This mutates your Python environment (pip uninstall/install). Run inside an isolated virtualenv/conda env to avoid impacting other projects.
278+
276279
## Notes
277280

278281
- The comparison uses the Wikitext-2 dataset for evaluation
279282
- Processing is done in chunks (1024 tokens) to handle memory constraints
280283
- The script automatically handles package installation/uninstallation for different providers
281284
- Results are deterministic (no sampling) for consistent comparisons
282285
- All pairwise comparisons are computed for multi-model scenarios
283-
- HF model is optional - you can compare ONNX models directly
286+
- HF model is optional - you can compare ONNX models directly

examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py

Lines changed: 120 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
against multiple ONNX Runtime GenAI models with different execution providers.
55
66
Usage:
7-
python compare_models_generic.py --hf_model "F:\shared\Llama-3.1-8B-Instruct"
7+
python compute_kl_divergence.py --hf_model "F:\shared\Llama-3.1-8B-Instruct"
88
--ep cuda --path "G:\models\cuda_model"
99
--ep directml --path "G:\models\directml_model"
1010
--output "comparison_results.json"
@@ -36,24 +36,40 @@
3636

3737

3838
def debug_print(message):
39-
"""Print debug message only if DEBUG is True"""
39+
"""
40+
Print debug message only if DEBUG flag is enabled.
41+
42+
Args:
43+
message (str): Debug message to print.
44+
"""
4045
if DEBUG:
4146
print(f"[DEBUG] {message}")
4247

4348

4449
def run_command(cmd, description="", capture_output=True):
45-
"""Run a command and handle errors"""
50+
"""
51+
Execute a subprocess command with error handling.
52+
53+
Args:
54+
cmd (list[str]): Command and arguments to execute.
55+
description (str, optional): Description of the command for logging. Defaults to "".
56+
capture_output (bool, optional): Whether to capture stdout/stderr or show in real-time.
57+
Defaults to True.
58+
59+
Returns:
60+
bool: True if command succeeded, False otherwise.
61+
"""
4662
debug_print(f"[INFO] {description}")
4763
debug_print(f"Running: {' '.join(cmd)}")
4864

4965
try:
5066
if capture_output:
51-
result = subprocess.run(cmd, check=True, capture_output=True, text=True, shell=True)
67+
result = subprocess.run(cmd, check=True, capture_output=True, text=True, shell=False)
5268
if result.stdout and DEBUG:
5369
print(f"[OUT] {result.stdout}")
5470
else:
5571
# Real-time output - shows prints as they happen
56-
result = subprocess.run(cmd, check=True, shell=True)
72+
result = subprocess.run(cmd, check=True, shell=False)
5773
return True
5874
except subprocess.CalledProcessError as e:
5975
print(f"[ERROR] Command failed: {e}")
@@ -70,7 +86,12 @@ def get_python_executable():
7086

7187

7288
def uninstall_onnxruntime_packages():
73-
"""Uninstall all ONNX Runtime packages"""
89+
"""
90+
Uninstall all ONNX Runtime and ONNX Runtime GenAI packages.
91+
92+
This ensures a clean environment before installing provider-specific packages
93+
to avoid version conflicts.
94+
"""
7495
packages_to_remove = [
7596
"onnxruntime",
7697
"onnxruntime-genai",
@@ -88,7 +109,15 @@ def uninstall_onnxruntime_packages():
88109

89110

90111
def install_package(package_name):
91-
"""Install a specific package"""
112+
"""
113+
Install a specific Python package using pip.
114+
115+
Args:
116+
package_name (str): Name of the package to install.
117+
118+
Returns:
119+
bool: True if installation succeeded, False otherwise.
120+
"""
92121
debug_print(f"Installing package: {package_name}")
93122
python_exe = get_python_executable()
94123
debug_print(f"Python executable: {python_exe}")
@@ -101,19 +130,33 @@ def install_package(package_name):
101130

102131

103132
def extract_hf_logits_subprocess(model_path, device="cuda"):
104-
"""Extract logits from Hugging Face model using subprocess"""
133+
"""
134+
Extract logits from a Hugging Face transformer model using a subprocess.
135+
136+
Runs extract_logits_hf.py in a separate process to avoid package conflicts.
137+
Uses temporary file for data transfer between processes.
138+
139+
Args:
140+
model_path (str): Path to the Hugging Face model directory.
141+
device (str, optional): Device for inference ('cuda' or 'cpu'). Defaults to "cuda".
142+
143+
"""
105144
print("[INFO] Extracting logits from Hugging Face baseline model...")
106145
debug_print(f"Model path: {model_path}, Device: {device}")
107146

108147
# Create temporary output file
109-
output_file = f"temp_logits_hf_{int(time.time())}.pkl"
148+
import tempfile
149+
150+
script_dir = os.path.dirname(os.path.abspath(__file__))
151+
with tempfile.NamedTemporaryFile(prefix="temp_logits_hf_", suffix=".pkl", delete=False) as tmp:
152+
output_file = tmp.name
110153
debug_print(f"Temporary output file: {output_file}")
111154

112155
try:
113156
python_exe = get_python_executable()
114157
cmd = [
115158
python_exe,
116-
"extract_logits_hf.py",
159+
os.path.join(script_dir, "extract_logits_hf.py"),
117160
"--model_path",
118161
model_path,
119162
"--output_file",
@@ -158,19 +201,33 @@ def extract_hf_logits_subprocess(model_path, device="cuda"):
158201

159202

160203
def extract_onnx_logits_subprocess(model_path, provider):
161-
"""Extract logits from ONNX Runtime GenAI model using subprocess"""
204+
"""
205+
Extract logits from an ONNX Runtime GenAI model using a subprocess.
206+
207+
Runs extract_logits.py in a separate process with the appropriate ONNX Runtime
208+
package for the specified execution provider. Uses temporary file for data transfer.
209+
210+
Args:
211+
model_path (str): Path to the ONNX Runtime GenAI model directory.
212+
provider (str): Execution provider ('cuda', 'directml', or 'cpu').
213+
214+
"""
162215
print(f"[INFO] Extracting logits from {provider.upper()} model...")
163216
debug_print(f"Model path: {model_path}, Provider: {provider}")
164217

165218
# Create temporary output file
166-
output_file = f"temp_logits_{provider}_{int(time.time())}.pkl"
219+
import tempfile
220+
221+
script_dir = os.path.dirname(os.path.abspath(__file__))
222+
with tempfile.NamedTemporaryFile(prefix="temp_logits_", suffix=".pkl", delete=False) as tmp:
223+
output_file = tmp.name
167224
debug_print(f"Temporary output file: {output_file}")
168225

169226
try:
170227
python_exe = get_python_executable()
171228
cmd = [
172229
python_exe,
173-
"extract_logits.py",
230+
os.path.join(script_dir, "extract_logits.py"),
174231
"--model_path",
175232
model_path,
176233
"--output_file",
@@ -220,8 +277,20 @@ def extract_onnx_logits_subprocess(model_path, provider):
220277

221278
def compute_kl_divergence_from_logits(log_probs_ref, log_probs_tar):
222279
"""
223-
Compute KL divergence between two log probability distributions.
224-
Same logic as in compute_kl_divergence.py
280+
Compute Kullback-Leibler divergence between two log probability distributions.
281+
282+
KL divergence measures how one probability distribution diverges from a reference
283+
distribution. Lower values indicate more similar distributions.
284+
285+
Args:
286+
log_probs_ref (np.ndarray): Reference log probabilities with shape (seq_len, vocab_size).
287+
log_probs_tar (np.ndarray): Target log probabilities with shape (seq_len, vocab_size).
288+
289+
Returns:
290+
float: Average KL divergence across all positions.
291+
292+
Note:
293+
Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length
225294
"""
226295
debug_print(
227296
f"Computing KL divergence - log_probs shapes: ref={log_probs_ref.shape}, tar={log_probs_tar.shape}"
@@ -239,7 +308,13 @@ def compute_kl_divergence_from_logits(log_probs_ref, log_probs_tar):
239308

240309
def to_serializable(obj):
241310
"""
242-
Recursively convert numpy types and torch types to native Python types for JSON serialization.
311+
Recursively convert numpy and torch types to native Python types for JSON serialization.
312+
313+
Args:
314+
obj: Object to convert (dict, list, tuple, np.ndarray, torch.Tensor, etc.).
315+
316+
Returns:
317+
Converted object with native Python types (int, float, list, dict, tuple).
243318
"""
244319
if isinstance(obj, dict):
245320
return {k: to_serializable(v) for k, v in obj.items()}
@@ -261,8 +336,23 @@ def to_serializable(obj):
261336

262337
def compute_unified_comparison(model_logits_list, output_file):
263338
"""
264-
Compute KL divergence comparison between all models in a unified way
265-
model_logits_list: List of tuples (model_name, model_data)
339+
Compute pairwise KL divergence between all models and save results to JSON.
340+
341+
This function performs an all-vs-all comparison of the provided models by computing
342+
KL divergence for each chunk and averaging across all chunks. Results are saved
343+
in a structured JSON format.
344+
345+
Args:
346+
model_logits_list (list): List of tuples (model_name, model_data) where:
347+
- model_name (str): Identifier for the model (e.g., "hf_baseline", "cuda_1")
348+
- model_data (dict): Dictionary containing:
349+
- 'logits': List of numpy arrays (one per chunk)
350+
- 'total_chunks': Number of chunks
351+
- 'seq_len': Sequence length
352+
- 'model_path': Path to model
353+
- 'chunk_info': Chunk position info
354+
output_file (str): Path to save the JSON results file.
355+
266356
"""
267357
print("\n[INFO] Computing unified KL divergence comparison...")
268358
debug_print(f"Number of models to compare: {len(model_logits_list)}")
@@ -325,7 +415,7 @@ def compute_unified_comparison(model_logits_list, output_file):
325415
# Find minimum sequence length for this chunk
326416
min_seq_len = min(getattr(logits, "shape", [None, 0])[1] for _, logits in chunk_logits)
327417
# Assume all have same vocab size
328-
vocab_size = getattr(chunk_logits[0][1], "shape", [None, None, 0])[2]
418+
vocab_size = min(getattr(logits, "shape", [None, None, 0])[2] for _, logits in chunk_logits)
329419
debug_print(f" Min seq len: {min_seq_len}, Vocab size: {vocab_size}")
330420

331421
# Trim all logits to matching dimensions
@@ -396,7 +486,16 @@ def compute_unified_comparison(model_logits_list, output_file):
396486

397487

398488
def validate_inputs(hf_model, ep_path_pairs):
399-
"""Validate that all input paths exist and EPs are supported"""
489+
"""
490+
Validate that all model paths exist and execution providers are supported.
491+
492+
Args:
493+
hf_model (str or None): Path to Hugging Face model (optional).
494+
ep_path_pairs (list): List of (execution_provider, model_path) tuples.
495+
496+
Returns:
497+
bool: True if all inputs are valid, False otherwise.
498+
"""
400499
# Check HF model path (only if provided)
401500
if hf_model and not os.path.exists(hf_model):
402501
print(f"[ERROR] Hugging Face model path does not exist: {hf_model}")
@@ -556,7 +655,7 @@ def main():
556655
]
557656

558657
print("\n[INFO] Running KL_divergence_metrics_same_ep.py...")
559-
result = subprocess.run(cmd, shell=True)
658+
result = subprocess.run(cmd, shell=False)
560659

561660
if result.returncode == 0:
562661
print("\n[SUCCESS] KL divergence computation completed successfully")

0 commit comments

Comments
 (0)