-
Notifications
You must be signed in to change notification settings - Fork 183
[4975376][5541172]perplexity and kl-divergence benchmark metrics #411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a new Windows accuracy benchmarking toolkit: scripts to extract logits from Hugging Face and ONNX Runtime GenAI models, compute KL divergence (including same-execution-provider optimized path), run batch perplexity evaluations, plus READMEs and requirements files. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as compute_kl_divergence.py
participant HF as extract_logits_hf.py (subproc)
participant ONNX as extract_logits.py (subproc)
participant Env as EP Env Manager
participant Comp as KL Computation
User->>CLI: invoke with --hf_model? and (-ep, -path)*
alt HF model provided
CLI->>HF: spawn extractor -> HF logits.pkl
HF-->>CLI: HF logits.pkl
end
loop each (ep,path)
CLI->>Env: ensure onnxruntime package for EP
Env-->>CLI: EP ready
CLI->>ONNX: spawn extractor -> ONNX logits.pkl
ONNX-->>CLI: ONNX logits.pkl
end
alt No HF and two ONNX with same EP
CLI->>Env: setup chosen EP
CLI->>CLI: delegate to KL_divergence_metrics_same_ep.py
CLI-->>User: same-EP KL result
else Unified comparison
CLI->>Comp: compute_unified_comparison(logits[])
Comp-->>CLI: results.json
CLI-->>User: write results.json
end
sequenceDiagram
autonumber
actor User
participant Runner as run_perplexity.py
participant Eval as perplexity_metrics.py
participant ORT as onnxruntime-genai
User->>Runner: --models dirs --i lengths [--chunk_size]
loop per model
Runner->>Eval: perplexity_eval(model_dir, input_len, chunk_size)
Eval->>ORT: load model/tokenizer (+overlay)
Eval->>Eval: tokenize Wikitext-2, chunk iterate
Eval->>Eval: forward -> logits -> log-probs -> accumulate
Eval-->>Runner: perplexity value or error
end
Runner-->>User: CSV results + console summary
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
🧹 Nitpick comments (23)
examples/windows/accuracy_benchmark/perplexity_metrics/README.md (7)
49-49: Header consistency.Use an action-oriented header like others.
-### Multiple models +### Evaluate multiple models
83-84: Tighten wording and punctuation for expected scores.Remove stray space before comma and avoid implying a hard bound.
-The expected score is between 2 to 1000 , lower score means better model performance +Expected scores often fall between 2 and 1000; lower is better. See ranges below.
87-93: Consistent naming for “KV chunking”.Use consistent casing and code formatting for keys; verify these defaults match the implementation.
- Prefer “KV chunking” (feature name) and keep JSON keys in backticks (
chunk_size,search,genai_config.json) as you already do.- Confirm the script actually sets
max_input_seq_lengthandstrideas documented.
117-120: Prefer a CLI flag or env var for debug over editing source.Exposing
--debug(orDEBUG=1) improves UX and avoids modifying code.If feasible, add a
--debugflag torun_perplexity.pyand forward it toperplexity_metrics.py.
123-126: Minor formatting/punctuation nits in ranges.Normalize spacing and casing.
-- Excellent : 2-20 -- Good: 20-40 -- Ok: 40-80 -- Poor: 100+ +- Excellent: 2-20 +- Good: 20-40 +- OK: 40-80 +- Poor: 100+Note: There’s a gap between 80–100; if intentional, consider adding a label for that band or adjusting thresholds for continuity.
9-14: License attribution check (upstream MIT).Since this tool is “based on” Microsoft’s script, ensure the actual Python files derived from it include the required MIT license/copyright notice.
I can draft the appropriate license header and attribution text if needed.
29-33: Reword CLI login note for WikiText-2
Apply the following update to soften the requirement:- [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) login is required to access the WikiText-2 dataset: + [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) login may be required (e.g., for gated assets or higher rate limits). WikiText-2 is public and typically does not require login:examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt (1)
1-19: Pin torch/torchvision/torchaudio to a compatible CUDA cu129 triplet and verify protobuf/onnx compatibility
- Unpinned torch/torchvision/torchaudio frequently leads to ABI mismatches on Windows; please pin a compatible triplet from the cu129 index (e.g., torch==X.Y.Z+cu129, torchvision==A.B.C+cu129, torchaudio==D.E.F+cu129).
- Consider pinning transformers/datasets for reproducibility.
- Verify protobuf>=5.28.2 works with your onnx version; onnx often constrains protobuf ranges.
Example (adjust versions to your tested set):
- torch>=2.0.0 - torchaudio - torchvision + torch==2.4.0+cu129 + torchvision==0.19.0+cu129 + torchaudio==2.4.0+cu129examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (3)
46-55: Avoid conflicting device placement with device_map='auto'When device='cuda', you set device_map='auto' and then call model.to(device); this can fight HF’s placement. Only move the model when loading on CPU.
- model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, ) - - model = model.to(device) + if device == "cpu": + model = model.to("cpu")
58-61: Also set model.config.pad_token_id to avoid warnings/mismatched padding# Add padding token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + if getattr(getattr(model, "config", None), "pad_token_id", None) is None: + model.config.pad_token_id = tokenizer.pad_token_id
112-120: Serialize plain tuple shape (numpy) instead of torch.Size- "shape": logits.shape, + "shape": logits_numpy.shape,examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
166-174: Be robust to provider outputs: prefer.as_numpy()when availableSome providers return non-ndarray outputs. Safely standardize before tensor conversion.
- logits = generator.get_output("logits") + logits = generator.get_output("logits") + if hasattr(logits, "as_numpy"): + logits = logits.as_numpy()examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
82-84: Nit: duplicate/ambiguous log message labelsTwo identical “[INFO] Full input length:” lines (ref and target). Consider clarifying:
- print(f"[INFO] Full input length: {seq_len_ref}") - print(f"[INFO] Full input length: {seq_len_tar}") + print(f"[INFO] Ref input length: {seq_len_ref}") + print(f"[INFO] Tar input length: {seq_len_tar}")examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
76-79: Fix debug print; currently prints an og.Config object, not the chunk size- if DEBUG: - print(f"config['search']['chunk_size']: {og.Config(model_dir)}") + if DEBUG: + print(f"[CONFIG] Effective chunk_size: {effective_chunk_size if kv_chunking_enabled else 'disabled'}")examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (4)
1-6: Make import robust to working directoryBare import will fail if the script isn’t run from its directory. Add script-dir to sys.path before importing.
import argparse -import os +import os +import sys @@ -import pandas as pd -from perplexity_metrics import perplexity_eval +import pandas as pd +# Ensure this directory is on sys.path for local imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) +from perplexity_metrics import perplexity_eval
61-77: Don’t abort all lengths on first failure; handle per-length errorsAn exception on one input length marks the entire model as Failed and skips remaining lengths. Catch per-length to continue others and record the failing length.
- for input_len in i_list: - print(f" Evaluating with input length: {input_len}") - if chunk_size is not None: - print(f" Using chunk_size: {chunk_size}") - perplexity = perplexity_eval(model_dir, str(input_len), chunk_size) - else: - perplexity = perplexity_eval(model_dir, str(input_len)) - results.append( - { - "Model Path": model_dir, - "Input Length": input_len, - "Perplexity": f"{perplexity:.4f}", - "Status": "Success", - "Error": "None", - } - ) + for input_len in i_list: + try: + print(f" Evaluating with input length: {input_len}") + if chunk_size is not None: + print(f" Using chunk_size: {chunk_size}") + perplexity = perplexity_eval(model_dir, str(input_len), chunk_size) + else: + perplexity = perplexity_eval(model_dir, str(input_len)) + results.append( + { + "Model Path": model_dir, + "Input Length": int(input_len), + "Perplexity": float(perplexity), + "Status": "Success", + "Error": "None", + } + ) + except Exception as e: + print(f" Error for input length {input_len}: {e!s}") + results.append( + { + "Model Path": model_dir, + "Input Length": int(input_len), + "Perplexity": "N/A", + "Status": "Failed", + "Error": str(e), + } + )
70-77: Store numeric values for CSV; format only for displayKeep Perplexity as float in results; format on print. This eases downstream analysis.
- "Perplexity": f"{perplexity:.4f}", + "Perplexity": float(perplexity),
103-107: Include input length in summary outputDisambiguates multiple rows per model when different input lengths are evaluated.
- for _, row in successful.iterrows(): - print(f" {os.path.basename(row['Model Path'])}: {row['Perplexity']}") + for _, row in successful.iterrows(): + print( + f" {os.path.basename(row['Model Path'])} [i={row.get('Input Length','?')}]: " + f"{row['Perplexity']:.4f}" if isinstance(row['Perplexity'], (int, float)) else row['Perplexity'] + )examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
326-329: Use min vocab size across models per chunkProtects against vocab size mismatches and is clearer than assuming first model’s size.
- vocab_size = getattr(chunk_logits[0][1], "shape", [None, None, 0])[2] + vocab_size = min(getattr(logits, "shape", [None, None, 0])[2] for _, logits in chunk_logits)
270-276: Operational caution: pip uninstall/install in active envThe script uninstalls/installs onnxruntime packages in the running environment. Prefer isolated envs per EP (venv/conda) or a flag to skip package mgmt and assume preconfigured env.
examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md (3)
34-41: Fix torch install command formattingRemove angle brackets around the index URL so the command can be copy-pasted.
- Note: Install torch with cuda for faster inference "pip install torch torchvision torchaudio --index-url <https://download.pytorch.org/whl/cu129>" + Note: Install torch with CUDA for faster inference: + "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129"
270-275: Document environment side effects (auto package switching)Warn users that compute_kl_divergence.py will uninstall/install onnxruntime packages; recommend using a virtual environment.
### Automatic Package Management @@ - The script automatically installs/uninstalls the correct ONNX Runtime packages based on execution provider - Minimizes package switching by reusing environments when possible - Handles CUDA, DirectML, and CPU providers seamlessly + +Warning: This mutates your Python environment (pip uninstall/install). Run inside an isolated virtualenv/conda env to avoid impacting other projects.
281-283: Fix minor typoDouble space in sentence.
-- HF model is optional - you can compare ONNX models directly +- HF model is optional - you can compare ONNX models directly
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/README.md(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
debug_print(38-41)main(419-709)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (3)
debug_print(19-22)get_wikitext2(25-35)main(140-184)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (4)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
get_wikitext2(28-37)main(198-245)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (2)
get_wikitext2(25-35)main(140-184)examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
get_wikitext2(13-27)examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1)
main(419-709)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
main(198-245)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
main(140-184)examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
main(168-195)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
debug_print(38-41)main(419-709)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (3)
debug_print(22-25)get_wikitext2(28-37)main(198-245)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
perplexity_eval(30-248)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
get_wikitext2(23-37)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
get_wikitext2(28-37)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
get_wikitext2(25-35)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1)
380-393: Remove unnecessary guard:max_context_lengthandmodel_pathare guaranteed by the extract functions.Likely an incorrect or invalid review comment.
examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt (1)
1-16: Lock versions, verify protobuf/ONNX/ORT‑GenAI compatibility, and trim non‑runtime deps
- Reproducibility: Please pin or constrain critical libs (at least transformers, torch, datasets, onnx, onnxruntime-genai, protobuf) or reference a constraints file. Unpinned installs on Windows frequently drift/break.
- Compatibility: protobuf>=5.28.2 may conflict with specific onnx/onnxruntime-genai builds. Align versions to what you validated and document them.
- Runtime vs dev: pytest looks like a dev/test dependency; consider moving it out of this runtime example file. flatbuffers and sympy are typically pulled transitively by onnx; consider omitting unless directly imported to avoid over‑constraining.
- Platforms/backends: If you support CPU/CUDA/DirectML, consider split files (requirements.cpu.txt/requirements.cuda.txt/requirements.dml.txt) or add clear instructions. For torch on Windows, note GPU builds require custom index; default pip gives CPU.
- huggingface_hub[cli]: Keep only if the scripts use the CLI; otherwise use huggingface_hub without extras.
- Document Python version bounds tested (e.g., markers like ; python_version < "3.13") to avoid resolver surprises.
Option: introduce a constraints file and reference it here:
- Add examples/windows/accuracy_benchmark/constraints-windows.txt with the exact versions you validated.
- Top of this file, add: -c ../../constraints-windows.txt
Run locally to confirm resolver health and surface conflicts quickly (Windows shell):
Consider these adjustments once versions are chosen and validated:
- Move pytest to a dev/ci requirements file.
- Drop flatbuffers/sympy unless directly imported by your scripts.
- Replace huggingface_hub[cli] with huggingface_hub if CLI isn’t used.
- Add a constraints file and reference it from here for reproducible Windows installs.
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py
Show resolved
Hide resolved
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py
Show resolved
Hide resolved
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py
Outdated
Show resolved
Hide resolved
| kl_divergence = 0.0 | ||
| for i in range(log_probs_ref.shape[0]): | ||
| log_probs_ref_i = np.array(log_probs_ref[i]) | ||
| log_probs_tar_i = np.array(log_probs_tar[i]) | ||
| prob_ref_i = np.exp(log_probs_ref_i) | ||
| kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i)) | ||
| kl_divergence = kl_divergence / log_probs_ref.shape[0] | ||
| debug_print(f"KL divergence computed: {kl_divergence}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix KL divergence formula (remove absolute value)
KL(P||Q) = Σ P * (log P − log Q). The abs distorts the metric.
- kl_divergence = 0.0
- for i in range(log_probs_ref.shape[0]):
+ kl_divergence = 0.0
+ for i in range(log_probs_ref.shape[0]):
log_probs_ref_i = np.array(log_probs_ref[i])
log_probs_tar_i = np.array(log_probs_tar[i])
prob_ref_i = np.exp(log_probs_ref_i)
- kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i))
+ kl_divergence += np.sum(prob_ref_i * (log_probs_ref_i - log_probs_tar_i))
kl_divergence = kl_divergence / log_probs_ref.shape[0]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| kl_divergence = 0.0 | |
| for i in range(log_probs_ref.shape[0]): | |
| log_probs_ref_i = np.array(log_probs_ref[i]) | |
| log_probs_tar_i = np.array(log_probs_tar[i]) | |
| prob_ref_i = np.exp(log_probs_ref_i) | |
| kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i)) | |
| kl_divergence = kl_divergence / log_probs_ref.shape[0] | |
| debug_print(f"KL divergence computed: {kl_divergence}") | |
| kl_divergence = 0.0 | |
| for i in range(log_probs_ref.shape[0]): | |
| log_probs_ref_i = np.array(log_probs_ref[i]) | |
| log_probs_tar_i = np.array(log_probs_tar[i]) | |
| prob_ref_i = np.exp(log_probs_ref_i) | |
| kl_divergence += np.sum(prob_ref_i * (log_probs_ref_i - log_probs_tar_i)) | |
| kl_divergence = kl_divergence / log_probs_ref.shape[0] | |
| debug_print(f"KL divergence computed: {kl_divergence}") |
🤖 Prompt for AI Agents
In
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py
around lines 229 to 236, the KL divergence computation incorrectly uses an
absolute value; change the accumulation to use the signed difference
(log_probs_ref_i - log_probs_tar_i) so KL = sum(P * (log P - log Q)) per
example, then average; also ensure probabilities are computed via
exp(log_probs_ref_i) and consider clipping/log-safety if needed to avoid NaNs.
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py
Show resolved
Hide resolved
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py
Show resolved
Hide resolved
|
|
||
| - Any ONNX Runtime GenAI model exported with a compatible `genai_config.json` and tokenizer. | ||
| - Supported architectures include: Gemma, Llama, Mistral, Phi (language + vision), Qwen. | ||
| - Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix EP name (“TensorRT”), keep common ORT EP names.
Replace non-standard "NvTensorRtRtx" with "TensorRT" and keep ordering consistent.
-- Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx.
+- Supported execution providers: CPU, CUDA, TensorRT, DirectML.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx. | |
| - Supported execution providers: CPU, CUDA, TensorRT, DirectML. |
🤖 Prompt for AI Agents
In examples/windows/accuracy_benchmark/perplexity_metrics/README.md around line
39, the execution provider name "NvTensorRtRtx" is non-standard and should be
replaced with "TensorRT" and keep the ordering consistent with other ORT docs;
update the line to list supported execution providers as "CPU, DirectML, CUDA,
TensorRT" (or reorder to match project convention) so the EP name is correct and
consistent.
| You can specify the input sequence length(s) to evaluate using the `--i` argument. | ||
| For example, to evaluate with input lengths: | ||
| Note: higher isl is only supported when model has Kv chunking enabled in genai config | ||
|
|
||
| ```bash | ||
| python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify ISL note, define format, and casing.
Define “input sequence length(s)” clearly, note comma-separated list, and use consistent “KV chunking” and genai_config.json naming.
-You can specify the input sequence length(s) to evaluate using the `--i` argument.
-For example, to evaluate with input lengths:
-Note: higher isl is only supported when model has Kv chunking enabled in genai config
+You can specify input sequence length(s) with the `--i` argument as a comma‑separated list.
+Note: Higher input sequence lengths require KV chunking to be enabled in `genai_config.json`.
+Example:
```bash
python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288Optional: consider a more descriptive flag alias (e.g., `--input_lengths`) while keeping `--i` for backward compatibility.
<details>
<summary>🤖 Prompt for AI Agents</summary>
In examples/windows/accuracy_benchmark/perplexity_metrics/README.md around lines
57-63, clarify that "input sequence length(s)" means one or more integer token
lengths provided as a comma-separated list (no spaces) and show the flag usage
consistently; update the KV-related note to use consistent casing "KV chunking"
and refer to the configuration file as genai_config.json, and optionally mention
that a more descriptive alias like --input_lengths could be introduced while
keeping --i for backward compatibility.
</details>
<!-- fingerprinting:phantom:triton:chinchilla -->
<!-- This is an auto-generated comment by CodeRabbit -->
| You can specify the prefill chunk size to evaluate using the `--chunk_size` argument. | ||
| For example: | ||
| Note: higher isl is only supported when model has Kv chunking enabled in genai config | ||
|
|
||
| ```bash | ||
| python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 --chunk_size=1024 | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix misplaced note; tie chunk_size to KV chunking.
The “higher ISL” note is unrelated here. Clarify when --chunk_size applies.
-You can specify the prefill chunk size to evaluate using the `--chunk_size` argument.
-For example:
-Note: higher isl is only supported when model has Kv chunking enabled in genai config
+You can specify the prefill chunk size with the `--chunk_size` argument.
+Note: `--chunk_size` is applicable only when KV chunking is enabled in `genai_config.json`.
+Example:
```bash
python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 --chunk_size=1024Also consider using a consistent CLI style (`--chunk_size 1024`) across examples.
<details>
<summary>🤖 Prompt for AI Agents</summary>
In examples/windows/accuracy_benchmark/perplexity_metrics/README.md around lines
67-73, the "higher ISL" note is misplaced and the --chunk_size usage needs
clarification and consistent CLI style; move the KV-chunking note so it
specifically follows the explanation of --chunk_size, state that
--chunk_size applies only when the model's genai config enables KV chunking,
and update the example to use consistent CLI spacing (e.g., --chunk_size 1024)
to match other examples.
</details>
<!-- fingerprinting:phantom:triton:chinchilla -->
<!-- This is an auto-generated comment by CodeRabbit -->
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py
Outdated
Show resolved
Hide resolved
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #411 +/- ##
=======================================
Coverage 73.41% 73.41%
=======================================
Files 180 180
Lines 18077 18077
=======================================
Hits 13271 13271
Misses 4806 4806 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
4bdfb92 to
6f0ecd3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
♻️ Duplicate comments (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
24-48: Correct the KL divergence formula (drop the absolute value and avoid in-place mutation).
KL(P‖Q)must beΣ P * (log P − log Q); the currentabs(logP-logQ)breaks the metric and still mutates the input arrays. Please switch to the standard definition and add minimal numeric guarding while keeping the inputs untouched.Apply this diff:
@@ - Note: - Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length + Note: + Formula: KL(P||Q) = sum(P(x) * (log(P(x)) - log(Q(x)))) averaged over sequence length. """ - kl_divergence = 0.0 - for i in range(log_probs_ref.shape[0]): - log_probs_ref[i] = np.array(log_probs_ref[i]) - log_probs_tar[i] = np.array(log_probs_tar[i]) - prob_ref = np.exp(log_probs_ref[i]) - kl_divergence += np.sum(prob_ref * abs(log_probs_ref[i] - log_probs_tar[i])) - kl_divergence = kl_divergence / log_probs_ref.shape[0] - return kl_divergence + log_probs_ref = np.asarray(log_probs_ref, dtype=np.float64) + log_probs_tar = np.asarray(log_probs_tar, dtype=np.float64) + eps = 1e-12 + log_probs_tar = np.clip(log_probs_tar, np.log(eps), None) + probs_ref = np.exp(log_probs_ref) + per_position = np.sum(probs_ref * (log_probs_ref - log_probs_tar), axis=1) + return float(np.mean(per_position))examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1)
292-307: Fix the KL divergence computation (remove absolute value).
KL(P‖Q)must beΣ P * (log P − log Q); the absolute difference destroys the metric’s meaning.Suggested change:
- kl_divergence = 0.0 - for i in range(log_probs_ref.shape[0]): - log_probs_ref_i = np.array(log_probs_ref[i]) - log_probs_tar_i = np.array(log_probs_tar[i]) - prob_ref_i = np.exp(log_probs_ref_i) - kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i)) - kl_divergence = kl_divergence / log_probs_ref.shape[0] + log_probs_ref = np.asarray(log_probs_ref, dtype=np.float64) + log_probs_tar = np.asarray(log_probs_tar, dtype=np.float64) + eps = 1e-12 + log_probs_tar = np.clip(log_probs_tar, np.log(eps), None) + probs_ref = np.exp(log_probs_ref) + per_position = np.sum(probs_ref * (log_probs_ref - log_probs_tar), axis=1) + kl_divergence = np.mean(per_position) debug_print(f"KL divergence computed: {kl_divergence}") return kl_divergenceexamples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1)
93-97: Move warning outside the loop to reduce verbosity.The warning is printed once per input_len, which creates repetitive output when evaluating multiple lengths. Print it once per model instead.
Apply this diff to print the warning once per model:
# For each input length, run perplexity_eval and record results + if chunk_size is None: + print( + " Note: input length is ignored unless chunk_size is set or " + "config.search.chunk_size is present." + ) for input_len in i_list: try: print(f" Evaluating with input length: {input_len}") - if chunk_size is None: - print( - " Note: input length is ignored unless chunk_size is set or " - "config.search.chunk_size is present." - ) if chunk_size is not None:
🧹 Nitpick comments (6)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (2)
45-69: Consider consolidating duplicate get_wikitext2 implementations.This function is duplicated in at least three other files: KL_divergence_metrics_same_ep.py, extract_logits.py, and extract_logits_hf.py. Consolidating into a shared utility module would improve maintainability.
296-298: Consider logging when chunk_seq_len <= 1.This edge case produces empty log probabilities, which should be rare in practice but could indicate an issue. Adding a warning would aid debugging.
Apply this diff to add a warning:
else: + if DEBUG: + print("[WARNING] chunk_seq_len <= 1, skipping this chunk") valid_log_probs = np.array([]) mask_flat = np.array([], dtype=bool)examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (4)
93-97: Clarify the default behavior in the warning message.The warning correctly alerts users that
input_lenis ignored without chunking, but should mention that it defaults to 1024 for clarity.Apply this diff to improve the warning message:
if chunk_size is None: print( - " Note: input length is ignored unless chunk_size is set or " - "config.search.chunk_size is present." + " Note: input length is ignored (defaults to 1024) unless chunk_size is set or " + "config.search.chunk_size is present in the model config." )
124-124: Remove leading space from message.The success message has an unnecessary leading space.
Apply this diff:
- print(" Perplexity evaluation completed successfully") + print(" Perplexity evaluation completed successfully")
150-156: Simplify the perplexity formatting logic.The ternary expression for formatting perplexity values is unnecessarily complex and harder to read.
Apply this diff to simplify:
if len(successful) > 0: print("\nPerplexity Results:") for _, row in successful.iterrows(): + perplexity_str = ( + f"{row['Perplexity']:.4f}" + if isinstance(row["Perplexity"], (int, float)) + else row["Perplexity"] + ) print( f" {os.path.basename(row['Model Path'])} [i={row.get('Input Length', '?')}]: " - f"{row['Perplexity']:.4f}" - if isinstance(row["Perplexity"], (int, float)) - else row["Perplexity"] + f"{perplexity_str}" )
204-214: Useos.path.isdir()for more precise validation.Using
os.path.isdir()is more explicit about expecting directory paths and will correctly reject files.Apply this diff:
# Validate that all model directories exist valid_models = [] for model_dir in args.models: - if os.path.exists(model_dir): + if os.path.isdir(model_dir): valid_models.append(model_dir) else: - print(f"Warning: Model directory does not exist: {model_dir}") + print(f"Warning: Not a valid directory: {model_dir}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py(1 hunks)examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/README.md(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt(1 hunks)examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/windows/accuracy_benchmark/perplexity_metrics/README.md
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt
- examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt
🧰 Additional context used
🧬 Code graph analysis (6)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (3)
debug_print(22-30)get_wikitext2(33-50)main(222-276)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
debug_print(38-46)main(518-808)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (3)
debug_print(19-27)get_wikitext2(30-48)main(163-214)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
get_wikitext2(33-50)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
perplexity_eval(72-327)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
debug_print(22-30)main(222-276)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (2)
debug_print(19-27)main(163-214)examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
main(218-261)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
get_wikitext2(51-75)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
get_wikitext2(33-50)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
get_wikitext2(30-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (14)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (7)
1-11: LGTM!The sys.path setup and imports are correctly structured for a standalone script that needs to import from the same directory.
161-220: LGTM!The CLI interface is well-structured with appropriate argument parsing and validation before invoking the evaluation function.
126-130: Remove redundant outer exception handler.This outer exception handler is unreachable because all model-level exceptions (directory/config validation) are already caught explicitly in lines 57-81, and per-length exceptions are caught in lines 112-122.
Apply this diff to remove the redundant handler:
print(" Perplexity evaluation completed successfully") - except Exception as e: - print(f"Error during perplexity evaluation: {e!s}") - results.append( - {"Model Path": model_dir, "Perplexity": "N/A", "Status": "Failed", "Error": str(e)} - ) - # Create results DataFrame and save to CSVLikely an incorrect or invalid review comment.
100-102: Fix incorrect type conversion for input_len.Converting input_len to string before passing to perplexity_eval is incorrect. The function signature expects an integer, not a string.
Apply this diff to pass the integer directly:
if chunk_size is not None: print(f" Using chunk_size: {chunk_size}") - perplexity = perplexity_eval(model_dir, str(input_len), chunk_size) + perplexity = perplexity_eval(model_dir, input_len, chunk_size) else: - perplexity = perplexity_eval(model_dir, str(input_len)) + perplexity = perplexity_eval(model_dir, input_len)Likely an incorrect or invalid review comment.
1-11: LGTM!The import setup and path manipulation are appropriate for ensuring local module imports work regardless of where the script is executed from. The
noqa: E402annotation correctly suppresses the late import warning.
112-112: Exception handling in loop is necessary for fault tolerance.The
noqa: PERF203comment acknowledges the linter warning about exception handling in loops. In this context, per-iteration exception handling is essential for robust batch processing, allowing the script to continue evaluating remaining models even if one fails.
222-223: LGTM!Standard entry point pattern is correctly implemented.
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (7)
115-125: LGTM!The genai_config.json loading now has proper error handling with fallback to defaults, addressing the previous review concern about handling missing or malformed config files.
164-174: LGTM!The context_length retrieval now safely accesses nested configuration with appropriate fallback, addressing the previous review concern about handling missing config keys.
300-302: LGTM!The token counting now correctly uses the size of gathered log probabilities instead of the mask sum, addressing the previous review concern about off-by-one errors under KV chunking.
1-42: LGTM!The license headers properly attribute both Microsoft (original) and NVIDIA (modifications), and the imports are appropriate for the perplexity evaluation task. The global DEBUG flag provides a simple mechanism for verbose logging.
45-69: LGTM!The WikiText-2 dataset loading implementation is correct and consistent with similar implementations in related KL-divergence scripts within the same PR. The DEBUG logging is appropriately placed for troubleshooting.
115-125: Robust error handling for missing config file.The try-except block correctly handles missing or malformed
genai_config.jsonfiles with appropriate fallback to defaults. This addresses the previous review concern about potential crashes from missing configuration.
300-302: Token counting correctly uses gathered logits size.The token counting now correctly uses
valid_log_probs.sizeas the authoritative count, which ensures the token tally matches the summed log-probability entries. This addresses the previous review concern about token counting mismatches under KV chunking.
| # Calculate average KL divergences | ||
| num_chunks = len(chunk_results) | ||
| pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()} | ||
|
|
||
| debug_print("\nFinal KL divergence totals:") | ||
| for pair, total in pairwise_totals.items(): | ||
| debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}") | ||
|
|
||
| # Prepare results | ||
| results = { | ||
| "models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list}, | ||
| "total_chunks": int(num_chunks), | ||
| "sequence_length": int(seq_len), | ||
| "max_context_length": int(reference_data["max_context_length"]), | ||
| "kl_divergences": { | ||
| pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])} | ||
| for pair in pairwise_totals | ||
| }, | ||
| "chunk_results": chunk_results, | ||
| "computation_timestamp": datetime.now().isoformat(), | ||
| "summary": { | ||
| "interpretation": "Lower KL divergence indicates more similar model outputs", | ||
| "baseline_reference": model_logits_list[0][0], | ||
| "pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()}, | ||
| "chunks_processed": int(num_chunks), | ||
| }, | ||
| } | ||
|
|
||
| return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle zero processed chunks before averaging.
When extraction yields zero chunks (e.g., upstream generator error), pairwise_averages = total/num_chunks divides by zero and the whole run crashes. Bail out (or raise) when num_chunks == 0.
Patch example:
- num_chunks = len(chunk_results)
- pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}
+ num_chunks = len(chunk_results)
+ if num_chunks == 0:
+ raise RuntimeError("No chunks produced by any model; KL comparison cannot proceed.")
+ pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Calculate average KL divergences | |
| num_chunks = len(chunk_results) | |
| pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()} | |
| debug_print("\nFinal KL divergence totals:") | |
| for pair, total in pairwise_totals.items(): | |
| debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}") | |
| # Prepare results | |
| results = { | |
| "models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list}, | |
| "total_chunks": int(num_chunks), | |
| "sequence_length": int(seq_len), | |
| "max_context_length": int(reference_data["max_context_length"]), | |
| "kl_divergences": { | |
| pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])} | |
| for pair in pairwise_totals | |
| }, | |
| "chunk_results": chunk_results, | |
| "computation_timestamp": datetime.now().isoformat(), | |
| "summary": { | |
| "interpretation": "Lower KL divergence indicates more similar model outputs", | |
| "baseline_reference": model_logits_list[0][0], | |
| "pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()}, | |
| "chunks_processed": int(num_chunks), | |
| }, | |
| } | |
| return results | |
| # Calculate average KL divergences | |
| num_chunks = len(chunk_results) | |
| if num_chunks == 0: | |
| raise RuntimeError("No chunks produced by any model; KL comparison cannot proceed.") | |
| pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()} | |
| debug_print("\nFinal KL divergence totals:") | |
| for pair, total in pairwise_totals.items(): | |
| debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}") | |
| # Prepare results | |
| results = { | |
| "models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list}, | |
| "total_chunks": int(num_chunks), | |
| "sequence_length": int(seq_len), | |
| "max_context_length": int(reference_data["max_context_length"]), | |
| "kl_divergences": { | |
| pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])} | |
| for pair in pairwise_totals | |
| }, | |
| "chunk_results": chunk_results, | |
| "computation_timestamp": datetime.now().isoformat(), | |
| "summary": { | |
| "interpretation": "Lower KL divergence indicates more similar model outputs", | |
| "baseline_reference": model_logits_list[0][0], | |
| "pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()}, | |
| "chunks_processed": int(num_chunks), | |
| }, | |
| } | |
| return results |
| total_kl_divergence = 0.0 | ||
| total_batch = 0 | ||
| for begin_loc in range(0, seq_len_ref, max_context_length): | ||
| end_loc = min(begin_loc + max_context_length, seq_len_ref) | ||
| # Extract the current chunk of input tokens | ||
| input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone() | ||
| input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone() | ||
| if DEBUG: | ||
| print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}") | ||
| print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}") | ||
| # Set up generator parameters for deterministic generation (no sampling) | ||
| params_ref = og.GeneratorParams(ref_model) | ||
| params_tar = og.GeneratorParams(tar_model) | ||
| params_ref.set_search_options( | ||
| max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False | ||
| ) | ||
| params_tar.set_search_options( | ||
| max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False | ||
| ) | ||
| # Create generator and append input tokens | ||
| generator_ref = og.Generator(ref_model, params_ref) | ||
| generator_ref.append_tokens(input_ids_chunk_ref.numpy()) | ||
| generator_tar = og.Generator(tar_model, params_tar) | ||
| generator_tar.append_tokens(input_ids_chunk_tar.numpy()) | ||
|
|
||
| # Run the model forward pass without gradient calculation | ||
| with torch.no_grad(): | ||
| if DEBUG: | ||
| print("[INFER] Running model forward pass ...") | ||
| try: | ||
| generator_ref.generate_next_token() | ||
| generator_tar.generate_next_token() | ||
| except Exception as e: | ||
| print(f"[INFER] .generate_next_token() failed: {e}") | ||
| break # Fatal error | ||
| # Get logits output from the model | ||
| logits_ref = generator_ref.get_output("logits") | ||
| logits_tar = generator_tar.get_output("logits") | ||
| if DEBUG: | ||
| print(f"logits_ref.shape: {logits_ref.shape}") | ||
| print(f"logits_tar.shape: {logits_tar.shape}") | ||
| # Convert numpy arrays to torch tensors | ||
| logits_ref = torch.tensor(logits_ref, dtype=torch.float32) | ||
| logits_tar = torch.tensor(logits_tar, dtype=torch.float32) | ||
| # Compute log probabilities over vocabulary for each position | ||
| log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy() | ||
| log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy() | ||
| if DEBUG: | ||
| print(f"log_probs_ref.shape: {log_probs_ref.shape}") | ||
| print(f"log_probs_tar.shape: {log_probs_tar.shape}") | ||
| # Compute KL divergence | ||
| kl_divergence = 0.0 | ||
| # Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256) | ||
| log_probs_ref = log_probs_ref.squeeze(0) | ||
| log_probs_tar = log_probs_tar.squeeze(0) | ||
|
|
||
| # log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32) | ||
| # log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32) | ||
| # kl_divergence = torch.nn.functional.kl_div( | ||
| # log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True | ||
| # ) | ||
| kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar) | ||
| total_kl_divergence += kl_divergence | ||
| total_batch += 1 | ||
| if DEBUG: | ||
| print(f"KL divergence: {kl_divergence}") | ||
| avg_kl_divergence = total_kl_divergence / total_batch | ||
| if DEBUG: | ||
| print(f"Average KL divergence: {avg_kl_divergence}") | ||
| print(f"Total KL divergence: {total_kl_divergence}") | ||
| print(f"Total batch: {total_batch}") | ||
| print(f"Average KL divergence: {avg_kl_divergence}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid dividing by zero when no chunk finishes.
If generate_next_token() fails on the first chunk (or the dataset is empty), total_batch stays 0 and we raise ZeroDivisionError at Line 210. Short-circuit when no batches succeed instead of crashing.
Consider:
@@
- avg_kl_divergence = total_kl_divergence / total_batch
+ if total_batch == 0:
+ print("[ERROR] No chunks were processed successfully; aborting KL computation.")
+ return
+ avg_kl_divergence = total_kl_divergence / total_batch📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| total_kl_divergence = 0.0 | |
| total_batch = 0 | |
| for begin_loc in range(0, seq_len_ref, max_context_length): | |
| end_loc = min(begin_loc + max_context_length, seq_len_ref) | |
| # Extract the current chunk of input tokens | |
| input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone() | |
| input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone() | |
| if DEBUG: | |
| print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}") | |
| print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}") | |
| # Set up generator parameters for deterministic generation (no sampling) | |
| params_ref = og.GeneratorParams(ref_model) | |
| params_tar = og.GeneratorParams(tar_model) | |
| params_ref.set_search_options( | |
| max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False | |
| ) | |
| params_tar.set_search_options( | |
| max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False | |
| ) | |
| # Create generator and append input tokens | |
| generator_ref = og.Generator(ref_model, params_ref) | |
| generator_ref.append_tokens(input_ids_chunk_ref.numpy()) | |
| generator_tar = og.Generator(tar_model, params_tar) | |
| generator_tar.append_tokens(input_ids_chunk_tar.numpy()) | |
| # Run the model forward pass without gradient calculation | |
| with torch.no_grad(): | |
| if DEBUG: | |
| print("[INFER] Running model forward pass ...") | |
| try: | |
| generator_ref.generate_next_token() | |
| generator_tar.generate_next_token() | |
| except Exception as e: | |
| print(f"[INFER] .generate_next_token() failed: {e}") | |
| break # Fatal error | |
| # Get logits output from the model | |
| logits_ref = generator_ref.get_output("logits") | |
| logits_tar = generator_tar.get_output("logits") | |
| if DEBUG: | |
| print(f"logits_ref.shape: {logits_ref.shape}") | |
| print(f"logits_tar.shape: {logits_tar.shape}") | |
| # Convert numpy arrays to torch tensors | |
| logits_ref = torch.tensor(logits_ref, dtype=torch.float32) | |
| logits_tar = torch.tensor(logits_tar, dtype=torch.float32) | |
| # Compute log probabilities over vocabulary for each position | |
| log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy() | |
| log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy() | |
| if DEBUG: | |
| print(f"log_probs_ref.shape: {log_probs_ref.shape}") | |
| print(f"log_probs_tar.shape: {log_probs_tar.shape}") | |
| # Compute KL divergence | |
| kl_divergence = 0.0 | |
| # Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256) | |
| log_probs_ref = log_probs_ref.squeeze(0) | |
| log_probs_tar = log_probs_tar.squeeze(0) | |
| # log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32) | |
| # log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32) | |
| # kl_divergence = torch.nn.functional.kl_div( | |
| # log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True | |
| # ) | |
| kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar) | |
| total_kl_divergence += kl_divergence | |
| total_batch += 1 | |
| if DEBUG: | |
| print(f"KL divergence: {kl_divergence}") | |
| avg_kl_divergence = total_kl_divergence / total_batch | |
| if DEBUG: | |
| print(f"Average KL divergence: {avg_kl_divergence}") | |
| print(f"Total KL divergence: {total_kl_divergence}") | |
| print(f"Total batch: {total_batch}") | |
| print(f"Average KL divergence: {avg_kl_divergence}") | |
| # After processing all chunks, compute average KL divergence | |
| if total_batch == 0: | |
| print("[ERROR] No chunks were processed successfully; aborting KL computation.") | |
| return | |
| avg_kl_divergence = total_kl_divergence / total_batch | |
| if DEBUG: | |
| print(f"Average KL divergence: {avg_kl_divergence}") | |
| print(f"Total KL divergence: {total_kl_divergence}") | |
| print(f"Total batch: {total_batch}") | |
| print(f"Average KL divergence: {avg_kl_divergence}") |
🤖 Prompt for AI Agents
In
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py
around lines 144-215, the code divides by total_batch without guarding against
total_batch == 0 which can raise ZeroDivisionError if no chunk completes; change
the logic to check if total_batch is zero before computing avg_kl_divergence and
short-circuit: if total_batch == 0, log/print a clear message (or set
avg_kl_divergence to 0.0) and avoid the division, then proceed to print totals
or return/exit as appropriate so the script does not crash.
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py
Show resolved
Hide resolved
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py
Outdated
Show resolved
Hide resolved
| mask_flat = target_ids_shifted != -100 | ||
| if kv_chunking_enabled: | ||
| trg_len = min(trg_len, stride) | ||
| mask_flat = np.ones(trg_len, dtype=bool) | ||
| valid_indices = np.arange(0, trg_len - 1) | ||
| valid_targets = target_ids_shifted[-trg_len + 1 :] | ||
| else: | ||
| valid_indices = np.arange(len(target_ids_shifted))[mask_flat] | ||
| valid_targets = target_ids_shifted[mask_flat] | ||
| if DEBUG: | ||
| print(f"[VALID_INDICES] Valid indices shape: {valid_indices.shape}") | ||
| print(f"[VALID_TARGETS] Valid targets shape: {valid_targets.shape}") | ||
| # Gather the log probabilities for the correct target tokens | ||
| valid_log_probs = pred_log_probs[valid_indices, valid_targets] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix indexing mismatch in KV chunking branch.
Line 284 creates indices from 0 to trg_len-2 (first trg_len-1 predictions), but line 285 extracts the last trg_len-1 targets. When indexing pred_log_probs on line 293, this mismatch causes incorrect log probabilities to be gathered.
Apply this diff to align indices with targets:
# Only include tokens with label != -100 (matching HF masking)
mask_flat = target_ids_shifted != -100
if kv_chunking_enabled:
trg_len = min(trg_len, stride)
mask_flat = np.ones(trg_len, dtype=bool)
- valid_indices = np.arange(0, trg_len - 1)
+ valid_indices = np.arange(chunk_seq_len - trg_len, chunk_seq_len - 1)
valid_targets = target_ids_shifted[-trg_len + 1 :]Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as part of prefill chunking enabled logits returned are for the last chunk_seq_len so trg_len is same as chunk_seq_len.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ynankani, thank you for the clarification! I understand now—when prefill chunking is enabled, the returned logits are already scoped to the last chunk_seq_len positions, so trg_len matches chunk_seq_len and the indexing from 0 correctly aligns with those logits. I appreciate you explaining this implementation detail.
✏️ Learnings added
Learnt from: ynankani
PR: NVIDIA/TensorRT-Model-Optimizer#411
File: examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py:281-294
Timestamp: 2025-10-12T07:09:32.831Z
Learning: In `examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py`, when KV chunking (prefill chunking) is enabled in the perplexity_eval function, the logits returned from the ONNX Runtime GenAI model correspond to the last chunk_seq_len positions. Therefore, trg_len equals chunk_seq_len, and the indexing pattern `valid_indices = np.arange(0, trg_len - 1)` correctly aligns with `valid_targets = target_ids_shifted[-trg_len + 1:]` without needing offset adjustment.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
| if isinstance(i, str): | ||
| i_list = [int(x.strip()) for x in i.split(",") if x.strip()] | ||
| elif isinstance(i, (list, tuple)): | ||
| i_list = [int(x) for x in i] | ||
| else: | ||
| i_list = [int(i)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Validate parsed input lengths.
The input parsing doesn't verify that parsed values are positive integers, which could lead to runtime errors when calling perplexity_eval.
Apply this diff to add validation:
# Parse i as a comma-separated list of integers, pass each to perplexity_eval
if isinstance(i, str):
- i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
+ try:
+ i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
+ if any(val < 1 for val in i_list):
+ raise ValueError("Input lengths must be positive integers")
+ except ValueError as e:
+ print(f"Error: Invalid input length format: {e}")
+ return pd.DataFrame()
elif isinstance(i, (list, tuple)):
- i_list = [int(x) for x in i]
+ try:
+ i_list = [int(x) for x in i]
+ if any(val < 1 for val in i_list):
+ raise ValueError("Input lengths must be positive integers")
+ except (ValueError, TypeError) as e:
+ print(f"Error: Invalid input length format: {e}")
+ return pd.DataFrame()
else:
- i_list = [int(i)]
+ try:
+ i_list = [int(i)]
+ if i_list[0] < 1:
+ raise ValueError("Input length must be a positive integer")
+ except (ValueError, TypeError) as e:
+ print(f"Error: Invalid input length format: {e}")
+ return pd.DataFrame()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if isinstance(i, str): | |
| i_list = [int(x.strip()) for x in i.split(",") if x.strip()] | |
| elif isinstance(i, (list, tuple)): | |
| i_list = [int(x) for x in i] | |
| else: | |
| i_list = [int(i)] | |
| # Parse i as a comma-separated list of integers, pass each to perplexity_eval | |
| if isinstance(i, str): | |
| try: | |
| i_list = [int(x.strip()) for x in i.split(",") if x.strip()] | |
| if any(val < 1 for val in i_list): | |
| raise ValueError("Input lengths must be positive integers") | |
| except ValueError as e: | |
| print(f"Error: Invalid input length format: {e}") | |
| return pd.DataFrame() | |
| elif isinstance(i, (list, tuple)): | |
| try: | |
| i_list = [int(x) for x in i] | |
| if any(val < 1 for val in i_list): | |
| raise ValueError("Input lengths must be positive integers") | |
| except (ValueError, TypeError) as e: | |
| print(f"Error: Invalid input length format: {e}") | |
| return pd.DataFrame() | |
| else: | |
| try: | |
| i_list = [int(i)] | |
| if i_list[0] < 1: | |
| raise ValueError("Input length must be a positive integer") | |
| except (ValueError, TypeError) as e: | |
| print(f"Error: Invalid input length format: {e}") | |
| return pd.DataFrame() |
🤖 Prompt for AI Agents
In examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py
around lines 83 to 88, the parsed input values are converted to ints but not
validated for positivity; update the parsing so after converting values in each
branch you filter/validate that every int is a positive integer (>0), raising a
ValueError (with a clear message including the original input) if any value is
non-positive or conversion fails; ensure the validation covers single-string,
comma-separated string, list/tuple and single-value cases and preserves the
resulting i_list for downstream perplexity_eval.
| except Exception as e: # noqa: PERF203 | ||
| print(f" Error for input length {input_len}: {e!s}") | ||
| results.append( | ||
| { | ||
| "Model Path": model_dir, | ||
| "Input Length": int(input_len), | ||
| "Perplexity": "N/A", | ||
| "Status": "Failed", | ||
| "Error": str(e), | ||
| } | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Narrow exception handling scope.
Catching broad Exception with a loop-suppression comment (PERF203) makes debugging difficult. Catch specific exceptions and let unexpected errors propagate for proper diagnosis.
Apply this diff to catch specific exceptions:
- except Exception as e: # noqa: PERF203
+ except (RuntimeError, ValueError, OSError) as e:
print(f" Error for input length {input_len}: {e!s}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| except Exception as e: # noqa: PERF203 | |
| print(f" Error for input length {input_len}: {e!s}") | |
| results.append( | |
| { | |
| "Model Path": model_dir, | |
| "Input Length": int(input_len), | |
| "Perplexity": "N/A", | |
| "Status": "Failed", | |
| "Error": str(e), | |
| } | |
| ) | |
| except (RuntimeError, ValueError, OSError) as e: | |
| print(f" Error for input length {input_len}: {e!s}") | |
| results.append( | |
| { | |
| "Model Path": model_dir, | |
| "Input Length": int(input_len), | |
| "Perplexity": "N/A", | |
| "Status": "Failed", | |
| "Error": str(e), | |
| } | |
| ) |
🤖 Prompt for AI Agents
In examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py
around lines 112 to 122, the code currently catches a broad Exception which
hides unexpected errors; replace the broad except Exception with targeted except
blocks for the expected failure modes (for example except ValueError, OSError,
RuntimeError as e:) that append the failure result and error string, and
explicitly re-raise or omit catching critical/unexpected exceptions (e.g.,
KeyboardInterrupt, SystemExit) so they propagate; remove the PERF203 noqa and
ensure any additional specific exceptions you expect from the code path are
handled individually rather than using a blanket Exception catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
281-294: Fix indexing mismatch in KV chunking branch.Lines 285-286 create an index/target misalignment:
valid_indicesselects the firsttrg_len-1positions ofpred_log_probs(indices 0 to trg_len-2), whilevalid_targetsextracts the lasttrg_len-1elements oftarget_ids_shifted. When gathering log probabilities at line 294 withpred_log_probs[valid_indices, valid_targets], this compares predictions from the start of the chunk against targets from the end, producing incorrect perplexity scores.Apply this diff to align indices with targets:
# Only include tokens with label != -100 (matching HF masking) mask_flat = target_ids_shifted != -100 if kv_chunking_enabled: trg_len = min(trg_len, stride) mask_flat = np.ones(trg_len, dtype=bool) - valid_indices = np.arange(0, trg_len - 1) + valid_indices = np.arange(chunk_seq_len - trg_len, chunk_seq_len - 1) valid_targets = target_ids_shifted[-trg_len + 1 :]
🧹 Nitpick comments (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
139-145: Remove duplicate assignment.Line 143 duplicates the assignment from line 142.
Apply this diff:
elif model_cfg_json and "search" in model_cfg_json and "chunk_size" in model_cfg_json["search"]: # Use chunk_size from existing config file effective_chunk_size = model_cfg_json["search"]["chunk_size"] kv_chunking_enabled = True - kv_chunking_enabled = True if DEBUG: print(f"[CONFIG] Using config file chunk_size: {effective_chunk_size}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
get_wikitext2(51-75)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
get_wikitext2(33-50)examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
get_wikitext2(30-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (6)
1-69: LGTM! License, imports, and dataset loader are well-structured.The MIT license header properly attributes both Microsoft (original) and NVIDIA (modifications). The
get_wikitext2()function is consistent with identical implementations across the benchmark suite (extract_logits.py,extract_logits_hf.py,KL_divergence_metrics_same_ep.py), which maintains good code consistency.
72-114: LGTM! Function signature and model loading are well-designed.The function is thoroughly documented with clear parameter descriptions and return value guidance. The chunk_size overlay mechanism is properly implemented with appropriate error handling, supporting the PR's goal of configurable chunk sizes.
115-126: LGTM! Config file loading now properly handles errors.The try-except wrapper around
genai_config.jsonloading addresses the past review concern about missing or malformed config files. The fallback to defaults prevents crashes.
166-176: LGTM! Safe fallback for context_length.The code safely extracts
context_lengthfrom the config JSON with appropriate None checks and defaults, preventing crashes when the config structure varies.
178-262: LGTM! Dataset loading, tokenization, and inference are robust.The code properly handles different tokenizer output formats (dict with "input_ids" key vs direct array), ensures correct tensor dimensions, and uses appropriate context managers. The try-except around
generate_next_token()prevents cascading failures during inference.
301-328: LGTM! Token counting and final computation are correct.The token counting now correctly uses
int(valid_log_probs.size)instead of summing the mask, ensuring the count matches the number of gathered log probabilities (addresses a past review concern). The final perplexity calculation follows standard methodology.
Signed-off-by: unknown <[email protected]>
Signed-off-by: unknown <[email protected]>
Signed-off-by: unknown <[email protected]>
Signed-off-by: unknown <[email protected]>
f0c8093 to
9e177a4
Compare
What does this PR do?
Type of change: new tests, accuracy benchmark metrics
Overview: Added perplexity and kl-divergence metrics as part of accuracy benchmark
Usage
Testing
Tested and collected scores for below models
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Chores