Skip to content

Conversation

@ynankani
Copy link
Contributor

@ynankani ynankani commented Oct 8, 2025

What does this PR do?

Type of change: new tests, accuracy benchmark metrics

Overview: Added perplexity and kl-divergence metrics as part of accuracy benchmark

Usage

python compute_kl_divergence.py --hf_model "F:\shared\Llama-3.1-8B-Instruct" --ep cuda --path "G:\llama3.1-8B-genai-cuda-fp16" --ep directml --path "G:\shared\llama3.1-8B-genai-dml_fp16"  --output "comparison_results.json"

python run_perplexity.py --models  G:\llama3.1-8B-genai-cuda-fp16 --i 1024,2048,4096,8192,12288 --chunk_size=1024

Testing

Tested and collected scores for below models

  1. Llama-3.1-8B-Instruct
  2. DeepSeek-R1-Distill-Qwen-7B
  3. Qwen2.5-1.5B-Instruct

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: NA
  • Did you write any new necessary tests?: NA
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: NA

Additional Information

Summary by CodeRabbit

  • New Features

    • Added KL-divergence toolkit with unified CLI for cross-model and same-execution-provider comparisons, automatic per-provider environment handling, and optional debug mode.
    • Added logits extraction tools for ONNX and Hugging Face models and a perplexity evaluator plus a batch runner that writes CSV results.
  • Documentation

    • Added README guides for the KL Divergence toolkit and Perplexity tool with usage, configuration, and result interpretation.
  • Chores

    • Added example-specific requirements to support benchmarking (datasets, transformers, onnx/genai, torch, pandas, etc.).

@ynankani ynankani requested a review from a team as a code owner October 8, 2025 14:04
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 8, 2025

Walkthrough

Adds a new Windows accuracy benchmarking toolkit: scripts to extract logits from Hugging Face and ONNX Runtime GenAI models, compute KL divergence (including same-execution-provider optimized path), run batch perplexity evaluations, plus READMEs and requirements files.

Changes

Cohort / File(s) Summary of Changes
KL Divergence Orchestrator & Same-EP Tool
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py, examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py
New orchestrator and same-execution-provider comparison script: subprocess-based logits extraction (HF/ONNX), per-EP environment management, delegation for two-ONNX same-EP comparisons, KL divergence computation utilities, CLI entry points, and JSON result serialization.
Logits Extraction (ONNX & HF)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py, examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py
New extraction scripts that load WikiText-2, tokenize, run deterministic forward passes in fixed-size sliding windows, collect per-chunk logits and chunk metadata, and save outputs (pickle). CLIs support provider/device selection and debug mode.
KL Divergence Docs & Requirements
examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md, examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt
New README describing toolkit purpose, usage modes, outputs, and examples; requirements file listing CUDA/DirectML/CPU-aware packages and related dependencies.
Perplexity Core & Runner
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py, examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py
New ONNX Runtime GenAI perplexity evaluator (sliding-window, optional KV-chunking overlay) and a batch runner to evaluate multiple models across input lengths producing CSV results and CLI-driven summaries.
Perplexity Docs & Requirements
examples/windows/accuracy_benchmark/perplexity_metrics/README.md, examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt
Adds README for the perplexity workflow and requirements file with dataset, transformers, onnxruntime-genai, and analysis dependencies.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant CLI as compute_kl_divergence.py
  participant HF as extract_logits_hf.py (subproc)
  participant ONNX as extract_logits.py (subproc)
  participant Env as EP Env Manager
  participant Comp as KL Computation
  User->>CLI: invoke with --hf_model? and (-ep, -path)*
  alt HF model provided
    CLI->>HF: spawn extractor -> HF logits.pkl
    HF-->>CLI: HF logits.pkl
  end
  loop each (ep,path)
    CLI->>Env: ensure onnxruntime package for EP
    Env-->>CLI: EP ready
    CLI->>ONNX: spawn extractor -> ONNX logits.pkl
    ONNX-->>CLI: ONNX logits.pkl
  end
  alt No HF and two ONNX with same EP
    CLI->>Env: setup chosen EP
    CLI->>CLI: delegate to KL_divergence_metrics_same_ep.py
    CLI-->>User: same-EP KL result
  else Unified comparison
    CLI->>Comp: compute_unified_comparison(logits[])
    Comp-->>CLI: results.json
    CLI-->>User: write results.json
  end
Loading
sequenceDiagram
  autonumber
  actor User
  participant Runner as run_perplexity.py
  participant Eval as perplexity_metrics.py
  participant ORT as onnxruntime-genai
  User->>Runner: --models dirs --i lengths [--chunk_size]
  loop per model
    Runner->>Eval: perplexity_eval(model_dir, input_len, chunk_size)
    Eval->>ORT: load model/tokenizer (+overlay)
    Eval->>Eval: tokenize Wikitext-2, chunk iterate
    Eval->>Eval: forward -> logits -> log-probs -> accumulate
    Eval-->>Runner: perplexity value or error
  end
  Runner-->>User: CSV results + console summary
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble tokens, counting every hop,
Through chunks and models my whiskers never stop.
KL and perplexity, numbers neat and small,
I stash them in pickles and JSONs on the wall.
A rabbit of runtimes, metrics for the hall 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly summarizes the primary change by indicating that both perplexity and KL-divergence benchmark metrics have been added, which aligns directly with the new scripts and documentation introduced in this PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ynankani/benchmark_metrics

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🧹 Nitpick comments (23)
examples/windows/accuracy_benchmark/perplexity_metrics/README.md (7)

49-49: Header consistency.

Use an action-oriented header like others.

-### Multiple models
+### Evaluate multiple models

83-84: Tighten wording and punctuation for expected scores.

Remove stray space before comma and avoid implying a hard bound.

-The expected score is between 2 to 1000 , lower score means better model performance
+Expected scores often fall between 2 and 1000; lower is better. See ranges below.

87-93: Consistent naming for “KV chunking”.

Use consistent casing and code formatting for keys; verify these defaults match the implementation.

  • Prefer “KV chunking” (feature name) and keep JSON keys in backticks (chunk_size, search, genai_config.json) as you already do.
  • Confirm the script actually sets max_input_seq_length and stride as documented.

117-120: Prefer a CLI flag or env var for debug over editing source.

Exposing --debug (or DEBUG=1) improves UX and avoids modifying code.

If feasible, add a --debug flag to run_perplexity.py and forward it to perplexity_metrics.py.


123-126: Minor formatting/punctuation nits in ranges.

Normalize spacing and casing.

-- Excellent : 2-20
-- Good: 20-40
-- Ok: 40-80  
-- Poor: 100+
+- Excellent: 2-20
+- Good: 20-40
+- OK: 40-80
+- Poor: 100+

Note: There’s a gap between 80–100; if intentional, consider adding a label for that band or adjusting thresholds for continuity.


9-14: License attribution check (upstream MIT).

Since this tool is “based on” Microsoft’s script, ensure the actual Python files derived from it include the required MIT license/copyright notice.

I can draft the appropriate license header and attribution text if needed.


29-33: Reword CLI login note for WikiText-2
Apply the following update to soften the requirement:

- [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) login is required to access the WikiText-2 dataset:
+ [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) login may be required (e.g., for gated assets or higher rate limits). WikiText-2 is public and typically does not require login:
examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt (1)

1-19: Pin torch/torchvision/torchaudio to a compatible CUDA cu129 triplet and verify protobuf/onnx compatibility

  • Unpinned torch/torchvision/torchaudio frequently leads to ABI mismatches on Windows; please pin a compatible triplet from the cu129 index (e.g., torch==X.Y.Z+cu129, torchvision==A.B.C+cu129, torchaudio==D.E.F+cu129).
  • Consider pinning transformers/datasets for reproducibility.
  • Verify protobuf>=5.28.2 works with your onnx version; onnx often constrains protobuf ranges.

Example (adjust versions to your tested set):

- torch>=2.0.0
- torchaudio
- torchvision
+ torch==2.4.0+cu129
+ torchvision==0.19.0+cu129
+ torchaudio==2.4.0+cu129
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (3)

46-55: Avoid conflicting device placement with device_map='auto'

When device='cuda', you set device_map='auto' and then call model.to(device); this can fight HF’s placement. Only move the model when loading on CPU.

-    model = AutoModelForCausalLM.from_pretrained(
+    model = AutoModelForCausalLM.from_pretrained(
         model_path,
         torch_dtype=torch.float16 if device == "cuda" else torch.float32,
         device_map="auto" if device == "cuda" else None,
     )
-
-    model = model.to(device)
+    if device == "cpu":
+        model = model.to("cpu")

58-61: Also set model.config.pad_token_id to avoid warnings/mismatched padding

     # Add padding token if it doesn't exist
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
+    if getattr(getattr(model, "config", None), "pad_token_id", None) is None:
+        model.config.pad_token_id = tokenizer.pad_token_id

112-120: Serialize plain tuple shape (numpy) instead of torch.Size

-                        "shape": logits.shape,
+                        "shape": logits_numpy.shape,
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)

166-174: Be robust to provider outputs: prefer .as_numpy() when available

Some providers return non-ndarray outputs. Safely standardize before tensor conversion.

-            logits = generator.get_output("logits")
+            logits = generator.get_output("logits")
+            if hasattr(logits, "as_numpy"):
+                logits = logits.as_numpy()
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)

82-84: Nit: duplicate/ambiguous log message labels

Two identical “[INFO] Full input length:” lines (ref and target). Consider clarifying:

-        print(f"[INFO] Full input length: {seq_len_ref}")
-        print(f"[INFO] Full input length: {seq_len_tar}")
+        print(f"[INFO] Ref input length: {seq_len_ref}")
+        print(f"[INFO] Tar input length: {seq_len_tar}")
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)

76-79: Fix debug print; currently prints an og.Config object, not the chunk size

-    if DEBUG:
-        print(f"config['search']['chunk_size']: {og.Config(model_dir)}")
+    if DEBUG:
+        print(f"[CONFIG] Effective chunk_size: {effective_chunk_size if kv_chunking_enabled else 'disabled'}")
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (4)

1-6: Make import robust to working directory

Bare import will fail if the script isn’t run from its directory. Add script-dir to sys.path before importing.

 import argparse
-import os
+import os
+import sys
@@
-import pandas as pd
-from perplexity_metrics import perplexity_eval
+import pandas as pd
+# Ensure this directory is on sys.path for local imports
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+if SCRIPT_DIR not in sys.path:
+    sys.path.insert(0, SCRIPT_DIR)
+from perplexity_metrics import perplexity_eval

61-77: Don’t abort all lengths on first failure; handle per-length errors

An exception on one input length marks the entire model as Failed and skips remaining lengths. Catch per-length to continue others and record the failing length.

-            for input_len in i_list:
-                print(f"  Evaluating with input length: {input_len}")
-                if chunk_size is not None:
-                    print(f"  Using chunk_size: {chunk_size}")
-                    perplexity = perplexity_eval(model_dir, str(input_len), chunk_size)
-                else:
-                    perplexity = perplexity_eval(model_dir, str(input_len))
-                results.append(
-                    {
-                        "Model Path": model_dir,
-                        "Input Length": input_len,
-                        "Perplexity": f"{perplexity:.4f}",
-                        "Status": "Success",
-                        "Error": "None",
-                    }
-                )
+            for input_len in i_list:
+                try:
+                    print(f"  Evaluating with input length: {input_len}")
+                    if chunk_size is not None:
+                        print(f"  Using chunk_size: {chunk_size}")
+                        perplexity = perplexity_eval(model_dir, str(input_len), chunk_size)
+                    else:
+                        perplexity = perplexity_eval(model_dir, str(input_len))
+                    results.append(
+                        {
+                            "Model Path": model_dir,
+                            "Input Length": int(input_len),
+                            "Perplexity": float(perplexity),
+                            "Status": "Success",
+                            "Error": "None",
+                        }
+                    )
+                except Exception as e:
+                    print(f"  Error for input length {input_len}: {e!s}")
+                    results.append(
+                        {
+                            "Model Path": model_dir,
+                            "Input Length": int(input_len),
+                            "Perplexity": "N/A",
+                            "Status": "Failed",
+                            "Error": str(e),
+                        }
+                    )

70-77: Store numeric values for CSV; format only for display

Keep Perplexity as float in results; format on print. This eases downstream analysis.

-                        "Perplexity": f"{perplexity:.4f}",
+                        "Perplexity": float(perplexity),

103-107: Include input length in summary output

Disambiguates multiple rows per model when different input lengths are evaluated.

-        for _, row in successful.iterrows():
-            print(f"  {os.path.basename(row['Model Path'])}: {row['Perplexity']}")
+        for _, row in successful.iterrows():
+            print(
+                f"  {os.path.basename(row['Model Path'])} [i={row.get('Input Length','?')}]: "
+                f"{row['Perplexity']:.4f}" if isinstance(row['Perplexity'], (int, float)) else row['Perplexity']
+            )
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)

326-329: Use min vocab size across models per chunk

Protects against vocab size mismatches and is clearer than assuming first model’s size.

-        vocab_size = getattr(chunk_logits[0][1], "shape", [None, None, 0])[2]
+        vocab_size = min(getattr(logits, "shape", [None, None, 0])[2] for _, logits in chunk_logits)

270-276: Operational caution: pip uninstall/install in active env

The script uninstalls/installs onnxruntime packages in the running environment. Prefer isolated envs per EP (venv/conda) or a flag to skip package mgmt and assume preconfigured env.

examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md (3)

34-41: Fix torch install command formatting

Remove angle brackets around the index URL so the command can be copy-pasted.

-   Note: Install torch with cuda for faster inference "pip install torch torchvision torchaudio --index-url <https://download.pytorch.org/whl/cu129>"
+   Note: Install torch with CUDA for faster inference:
+   "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129"

270-275: Document environment side effects (auto package switching)

Warn users that compute_kl_divergence.py will uninstall/install onnxruntime packages; recommend using a virtual environment.

 ### Automatic Package Management
@@
 - The script automatically installs/uninstalls the correct ONNX Runtime packages based on execution provider
 - Minimizes package switching by reusing environments when possible
 - Handles CUDA, DirectML, and CPU providers seamlessly
+
+Warning: This mutates your Python environment (pip uninstall/install). Run inside an isolated virtualenv/conda env to avoid impacting other projects.

281-283: Fix minor typo

Double space in sentence.

-- HF model is  optional - you can compare ONNX models directly
+- HF model is optional - you can compare ONNX models directly
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 512e96f and 4bdfb92.

📒 Files selected for processing (10)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/README.md (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
  • debug_print (38-41)
  • main (419-709)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (3)
  • debug_print (19-22)
  • get_wikitext2 (25-35)
  • main (140-184)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (4)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
  • get_wikitext2 (28-37)
  • main (198-245)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (2)
  • get_wikitext2 (25-35)
  • main (140-184)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
  • get_wikitext2 (13-27)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1)
  • main (419-709)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
  • main (198-245)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
  • main (140-184)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
  • main (168-195)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
  • debug_print (38-41)
  • main (419-709)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (3)
  • debug_print (22-25)
  • get_wikitext2 (28-37)
  • main (198-245)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
  • perplexity_eval (30-248)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
  • get_wikitext2 (23-37)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
  • get_wikitext2 (28-37)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
  • get_wikitext2 (25-35)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1)

380-393: Remove unnecessary guard: max_context_length and model_path are guaranteed by the extract functions.

Likely an incorrect or invalid review comment.

examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt (1)

1-16: Lock versions, verify protobuf/ONNX/ORT‑GenAI compatibility, and trim non‑runtime deps

  • Reproducibility: Please pin or constrain critical libs (at least transformers, torch, datasets, onnx, onnxruntime-genai, protobuf) or reference a constraints file. Unpinned installs on Windows frequently drift/break.
  • Compatibility: protobuf>=5.28.2 may conflict with specific onnx/onnxruntime-genai builds. Align versions to what you validated and document them.
  • Runtime vs dev: pytest looks like a dev/test dependency; consider moving it out of this runtime example file. flatbuffers and sympy are typically pulled transitively by onnx; consider omitting unless directly imported to avoid over‑constraining.
  • Platforms/backends: If you support CPU/CUDA/DirectML, consider split files (requirements.cpu.txt/requirements.cuda.txt/requirements.dml.txt) or add clear instructions. For torch on Windows, note GPU builds require custom index; default pip gives CPU.
  • huggingface_hub[cli]: Keep only if the scripts use the CLI; otherwise use huggingface_hub without extras.
  • Document Python version bounds tested (e.g., markers like ; python_version < "3.13") to avoid resolver surprises.

Option: introduce a constraints file and reference it here:

  • Add examples/windows/accuracy_benchmark/constraints-windows.txt with the exact versions you validated.
  • Top of this file, add: -c ../../constraints-windows.txt

Run locally to confirm resolver health and surface conflicts quickly (Windows shell):

Consider these adjustments once versions are chosen and validated:

  • Move pytest to a dev/ci requirements file.
  • Drop flatbuffers/sympy unless directly imported by your scripts.
  • Replace huggingface_hub[cli] with huggingface_hub if CLI isn’t used.
  • Add a constraints file and reference it from here for reproducible Windows installs.

Comment on lines +229 to +305
kl_divergence = 0.0
for i in range(log_probs_ref.shape[0]):
log_probs_ref_i = np.array(log_probs_ref[i])
log_probs_tar_i = np.array(log_probs_tar[i])
prob_ref_i = np.exp(log_probs_ref_i)
kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i))
kl_divergence = kl_divergence / log_probs_ref.shape[0]
debug_print(f"KL divergence computed: {kl_divergence}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix KL divergence formula (remove absolute value)

KL(P||Q) = Σ P * (log P − log Q). The abs distorts the metric.

-    kl_divergence = 0.0
-    for i in range(log_probs_ref.shape[0]):
+    kl_divergence = 0.0
+    for i in range(log_probs_ref.shape[0]):
         log_probs_ref_i = np.array(log_probs_ref[i])
         log_probs_tar_i = np.array(log_probs_tar[i])
         prob_ref_i = np.exp(log_probs_ref_i)
-        kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i))
+        kl_divergence += np.sum(prob_ref_i * (log_probs_ref_i - log_probs_tar_i))
     kl_divergence = kl_divergence / log_probs_ref.shape[0]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
kl_divergence = 0.0
for i in range(log_probs_ref.shape[0]):
log_probs_ref_i = np.array(log_probs_ref[i])
log_probs_tar_i = np.array(log_probs_tar[i])
prob_ref_i = np.exp(log_probs_ref_i)
kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i))
kl_divergence = kl_divergence / log_probs_ref.shape[0]
debug_print(f"KL divergence computed: {kl_divergence}")
kl_divergence = 0.0
for i in range(log_probs_ref.shape[0]):
log_probs_ref_i = np.array(log_probs_ref[i])
log_probs_tar_i = np.array(log_probs_tar[i])
prob_ref_i = np.exp(log_probs_ref_i)
kl_divergence += np.sum(prob_ref_i * (log_probs_ref_i - log_probs_tar_i))
kl_divergence = kl_divergence / log_probs_ref.shape[0]
debug_print(f"KL divergence computed: {kl_divergence}")
🤖 Prompt for AI Agents
In
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py
around lines 229 to 236, the KL divergence computation incorrectly uses an
absolute value; change the accumulation to use the signed difference
(log_probs_ref_i - log_probs_tar_i) so KL = sum(P * (log P - log Q)) per
example, then average; also ensure probabilities are computed via
exp(log_probs_ref_i) and consider clipping/log-safety if needed to avoid NaNs.


- Any ONNX Runtime GenAI model exported with a compatible `genai_config.json` and tokenizer.
- Supported architectures include: Gemma, Llama, Mistral, Phi (language + vision), Qwen.
- Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix EP name (“TensorRT”), keep common ORT EP names.

Replace non-standard "NvTensorRtRtx" with "TensorRT" and keep ordering consistent.

-- Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx.
+- Supported execution providers: CPU, CUDA, TensorRT, DirectML.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx.
- Supported execution providers: CPU, CUDA, TensorRT, DirectML.
🤖 Prompt for AI Agents
In examples/windows/accuracy_benchmark/perplexity_metrics/README.md around line
39, the execution provider name "NvTensorRtRtx" is non-standard and should be
replaced with "TensorRT" and keep the ordering consistent with other ORT docs;
update the line to list supported execution providers as "CPU, DirectML, CUDA,
TensorRT" (or reorder to match project convention) so the EP name is correct and
consistent.

Comment on lines 57 to 62
You can specify the input sequence length(s) to evaluate using the `--i` argument.
For example, to evaluate with input lengths:
Note: higher isl is only supported when model has Kv chunking enabled in genai config

```bash
python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify ISL note, define format, and casing.

Define “input sequence length(s)” clearly, note comma-separated list, and use consistent “KV chunking” and genai_config.json naming.

-You can specify the input sequence length(s) to evaluate using the `--i` argument.  
-For example, to evaluate with input lengths:
-Note: higher isl is only supported when model has Kv chunking enabled in genai config
+You can specify input sequence length(s) with the `--i` argument as a comma‑separated list.
+Note: Higher input sequence lengths require KV chunking to be enabled in `genai_config.json`.
+Example:
 
 ```bash
 python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288
Optional: consider a more descriptive flag alias (e.g., `--input_lengths`) while keeping `--i` for backward compatibility.

<details>
<summary>🤖 Prompt for AI Agents</summary>

In examples/windows/accuracy_benchmark/perplexity_metrics/README.md around lines
57-63, clarify that "input sequence length(s)" means one or more integer token
lengths provided as a comma-separated list (no spaces) and show the flag usage
consistently; update the KV-related note to use consistent casing "KV chunking"
and refer to the configuration file as genai_config.json, and optionally mention
that a more descriptive alias like --input_lengths could be introduced while
keeping --i for backward compatibility.


</details>

<!-- fingerprinting:phantom:triton:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines 67 to 71
You can specify the prefill chunk size to evaluate using the `--chunk_size` argument.
For example:
Note: higher isl is only supported when model has Kv chunking enabled in genai config

```bash
python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 --chunk_size=1024
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix misplaced note; tie chunk_size to KV chunking.

The “higher ISL” note is unrelated here. Clarify when --chunk_size applies.

-You can specify the prefill chunk size to evaluate using the `--chunk_size` argument.  
-For example:
-Note: higher isl is only supported when model has Kv chunking enabled in genai config
+You can specify the prefill chunk size with the `--chunk_size` argument.
+Note: `--chunk_size` is applicable only when KV chunking is enabled in `genai_config.json`.
+Example:
 
 ```bash
 python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 --chunk_size=1024
Also consider using a consistent CLI style (`--chunk_size 1024`) across examples.

<details>
<summary>🤖 Prompt for AI Agents</summary>

In examples/windows/accuracy_benchmark/perplexity_metrics/README.md around lines
67-73, the "higher ISL" note is misplaced and the --chunk_size usage needs
clarification and consistent CLI style; move the KV-chunking note so it
specifically follows the explanation of --chunk_size, state that
--chunk_size applies only when the model's genai config enables KV chunking,
and update the example to use consistent CLI spacing (e.g., --chunk_size 1024)
to match other examples.


</details>

<!-- fingerprinting:phantom:triton:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

@codecov
Copy link

codecov bot commented Oct 8, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.41%. Comparing base (8745a3c) to head (9e177a4).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #411   +/-   ##
=======================================
  Coverage   73.41%   73.41%           
=======================================
  Files         180      180           
  Lines       18077    18077           
=======================================
  Hits        13271    13271           
  Misses       4806     4806           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ynankani ynankani force-pushed the ynankani/benchmark_metrics branch from 4bdfb92 to 6f0ecd3 Compare October 9, 2025 08:23
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

♻️ Duplicate comments (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)

24-48: Correct the KL divergence formula (drop the absolute value and avoid in-place mutation).

KL(P‖Q) must be Σ P * (log P − log Q); the current abs(logP-logQ) breaks the metric and still mutates the input arrays. Please switch to the standard definition and add minimal numeric guarding while keeping the inputs untouched.

Apply this diff:

@@
-    Note:
-        Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length
+    Note:
+        Formula: KL(P||Q) = sum(P(x) * (log(P(x)) - log(Q(x)))) averaged over sequence length.
     """
-    kl_divergence = 0.0
-    for i in range(log_probs_ref.shape[0]):
-        log_probs_ref[i] = np.array(log_probs_ref[i])
-        log_probs_tar[i] = np.array(log_probs_tar[i])
-        prob_ref = np.exp(log_probs_ref[i])
-        kl_divergence += np.sum(prob_ref * abs(log_probs_ref[i] - log_probs_tar[i]))
-    kl_divergence = kl_divergence / log_probs_ref.shape[0]
-    return kl_divergence
+    log_probs_ref = np.asarray(log_probs_ref, dtype=np.float64)
+    log_probs_tar = np.asarray(log_probs_tar, dtype=np.float64)
+    eps = 1e-12
+    log_probs_tar = np.clip(log_probs_tar, np.log(eps), None)
+    probs_ref = np.exp(log_probs_ref)
+    per_position = np.sum(probs_ref * (log_probs_ref - log_probs_tar), axis=1)
+    return float(np.mean(per_position))
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1)

292-307: Fix the KL divergence computation (remove absolute value).

KL(P‖Q) must be Σ P * (log P − log Q); the absolute difference destroys the metric’s meaning.

Suggested change:

-    kl_divergence = 0.0
-    for i in range(log_probs_ref.shape[0]):
-        log_probs_ref_i = np.array(log_probs_ref[i])
-        log_probs_tar_i = np.array(log_probs_tar[i])
-        prob_ref_i = np.exp(log_probs_ref_i)
-        kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i))
-    kl_divergence = kl_divergence / log_probs_ref.shape[0]
+    log_probs_ref = np.asarray(log_probs_ref, dtype=np.float64)
+    log_probs_tar = np.asarray(log_probs_tar, dtype=np.float64)
+    eps = 1e-12
+    log_probs_tar = np.clip(log_probs_tar, np.log(eps), None)
+    probs_ref = np.exp(log_probs_ref)
+    per_position = np.sum(probs_ref * (log_probs_ref - log_probs_tar), axis=1)
+    kl_divergence = np.mean(per_position)
     debug_print(f"KL divergence computed: {kl_divergence}")
     return kl_divergence
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1)

93-97: Move warning outside the loop to reduce verbosity.

The warning is printed once per input_len, which creates repetitive output when evaluating multiple lengths. Print it once per model instead.

Apply this diff to print the warning once per model:

         # For each input length, run perplexity_eval and record results
+        if chunk_size is None:
+            print(
+                "  Note: input length is ignored unless chunk_size is set or "
+                "config.search.chunk_size is present."
+            )
         for input_len in i_list:
             try:
                 print(f"  Evaluating with input length: {input_len}")
-                if chunk_size is None:
-                    print(
-                        "  Note: input length is ignored unless chunk_size is set or "
-                        "config.search.chunk_size is present."
-                    )
                 if chunk_size is not None:
🧹 Nitpick comments (6)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (2)

45-69: Consider consolidating duplicate get_wikitext2 implementations.

This function is duplicated in at least three other files: KL_divergence_metrics_same_ep.py, extract_logits.py, and extract_logits_hf.py. Consolidating into a shared utility module would improve maintainability.


296-298: Consider logging when chunk_seq_len <= 1.

This edge case produces empty log probabilities, which should be rare in practice but could indicate an issue. Adding a warning would aid debugging.

Apply this diff to add a warning:

         else:
+            if DEBUG:
+                print("[WARNING] chunk_seq_len <= 1, skipping this chunk")
             valid_log_probs = np.array([])
             mask_flat = np.array([], dtype=bool)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (4)

93-97: Clarify the default behavior in the warning message.

The warning correctly alerts users that input_len is ignored without chunking, but should mention that it defaults to 1024 for clarity.

Apply this diff to improve the warning message:

                 if chunk_size is None:
                     print(
-                        "  Note: input length is ignored unless chunk_size is set or "
-                        "config.search.chunk_size is present."
+                        "  Note: input length is ignored (defaults to 1024) unless chunk_size is set or "
+                        "config.search.chunk_size is present in the model config."
                     )

124-124: Remove leading space from message.

The success message has an unnecessary leading space.

Apply this diff:

-            print(" Perplexity evaluation completed successfully")
+            print("  Perplexity evaluation completed successfully")

150-156: Simplify the perplexity formatting logic.

The ternary expression for formatting perplexity values is unnecessarily complex and harder to read.

Apply this diff to simplify:

     if len(successful) > 0:
         print("\nPerplexity Results:")
         for _, row in successful.iterrows():
+            perplexity_str = (
+                f"{row['Perplexity']:.4f}"
+                if isinstance(row["Perplexity"], (int, float))
+                else row["Perplexity"]
+            )
             print(
                 f"  {os.path.basename(row['Model Path'])} [i={row.get('Input Length', '?')}]: "
-                f"{row['Perplexity']:.4f}"
-                if isinstance(row["Perplexity"], (int, float))
-                else row["Perplexity"]
+                f"{perplexity_str}"
             )

204-214: Use os.path.isdir() for more precise validation.

Using os.path.isdir() is more explicit about expecting directory paths and will correctly reject files.

Apply this diff:

     # Validate that all model directories exist
     valid_models = []
     for model_dir in args.models:
-        if os.path.exists(model_dir):
+        if os.path.isdir(model_dir):
             valid_models.append(model_dir)
         else:
-            print(f"Warning: Model directory does not exist: {model_dir}")
+            print(f"Warning: Not a valid directory: {model_dir}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4bdfb92 and 6f0ecd3.

📒 Files selected for processing (10)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1 hunks)
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/README.md (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt (1 hunks)
  • examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/windows/accuracy_benchmark/perplexity_metrics/README.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt
  • examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt
🧰 Additional context used
🧬 Code graph analysis (6)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (3)
  • debug_print (22-30)
  • get_wikitext2 (33-50)
  • main (222-276)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (2)
  • debug_print (38-46)
  • main (518-808)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (3)
  • debug_print (19-27)
  • get_wikitext2 (30-48)
  • main (163-214)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
  • get_wikitext2 (33-50)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)
  • perplexity_eval (72-327)
examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (2)
  • debug_print (22-30)
  • main (222-276)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (2)
  • debug_print (19-27)
  • main (163-214)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
  • main (218-261)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
  • get_wikitext2 (51-75)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
  • get_wikitext2 (33-50)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
  • get_wikitext2 (30-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (14)
examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py (7)

1-11: LGTM!

The sys.path setup and imports are correctly structured for a standalone script that needs to import from the same directory.


161-220: LGTM!

The CLI interface is well-structured with appropriate argument parsing and validation before invoking the evaluation function.


126-130: Remove redundant outer exception handler.

This outer exception handler is unreachable because all model-level exceptions (directory/config validation) are already caught explicitly in lines 57-81, and per-length exceptions are caught in lines 112-122.

Apply this diff to remove the redundant handler:

         print(" Perplexity evaluation completed successfully")
 
-    except Exception as e:
-        print(f"Error during perplexity evaluation: {e!s}")
-        results.append(
-            {"Model Path": model_dir, "Perplexity": "N/A", "Status": "Failed", "Error": str(e)}
-        )
-
     # Create results DataFrame and save to CSV

Likely an incorrect or invalid review comment.


100-102: Fix incorrect type conversion for input_len.

Converting input_len to string before passing to perplexity_eval is incorrect. The function signature expects an integer, not a string.

Apply this diff to pass the integer directly:

                 if chunk_size is not None:
                     print(f"  Using chunk_size: {chunk_size}")
-                    perplexity = perplexity_eval(model_dir, str(input_len), chunk_size)
+                    perplexity = perplexity_eval(model_dir, input_len, chunk_size)
                 else:
-                    perplexity = perplexity_eval(model_dir, str(input_len))
+                    perplexity = perplexity_eval(model_dir, input_len)

Likely an incorrect or invalid review comment.


1-11: LGTM!

The import setup and path manipulation are appropriate for ensuring local module imports work regardless of where the script is executed from. The noqa: E402 annotation correctly suppresses the late import warning.


112-112: Exception handling in loop is necessary for fault tolerance.

The noqa: PERF203 comment acknowledges the linter warning about exception handling in loops. In this context, per-iteration exception handling is essential for robust batch processing, allowing the script to continue evaluating remaining models even if one fails.


222-223: LGTM!

Standard entry point pattern is correctly implemented.

examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (7)

115-125: LGTM!

The genai_config.json loading now has proper error handling with fallback to defaults, addressing the previous review concern about handling missing or malformed config files.


164-174: LGTM!

The context_length retrieval now safely accesses nested configuration with appropriate fallback, addressing the previous review concern about handling missing config keys.


300-302: LGTM!

The token counting now correctly uses the size of gathered log probabilities instead of the mask sum, addressing the previous review concern about off-by-one errors under KV chunking.


1-42: LGTM!

The license headers properly attribute both Microsoft (original) and NVIDIA (modifications), and the imports are appropriate for the perplexity evaluation task. The global DEBUG flag provides a simple mechanism for verbose logging.


45-69: LGTM!

The WikiText-2 dataset loading implementation is correct and consistent with similar implementations in related KL-divergence scripts within the same PR. The DEBUG logging is appropriately placed for troubleshooting.


115-125: Robust error handling for missing config file.

The try-except block correctly handles missing or malformed genai_config.json files with appropriate fallback to defaults. This addresses the previous review concern about potential crashes from missing configuration.


300-302: Token counting correctly uses gathered logits size.

The token counting now correctly uses valid_log_probs.size as the authoritative count, which ensures the token tally matches the summed log-probability entries. This addresses the previous review concern about token counting mismatches under KV chunking.

Comment on lines +457 to +485
# Calculate average KL divergences
num_chunks = len(chunk_results)
pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}

debug_print("\nFinal KL divergence totals:")
for pair, total in pairwise_totals.items():
debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}")

# Prepare results
results = {
"models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list},
"total_chunks": int(num_chunks),
"sequence_length": int(seq_len),
"max_context_length": int(reference_data["max_context_length"]),
"kl_divergences": {
pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])}
for pair in pairwise_totals
},
"chunk_results": chunk_results,
"computation_timestamp": datetime.now().isoformat(),
"summary": {
"interpretation": "Lower KL divergence indicates more similar model outputs",
"baseline_reference": model_logits_list[0][0],
"pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()},
"chunks_processed": int(num_chunks),
},
}

return results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle zero processed chunks before averaging.

When extraction yields zero chunks (e.g., upstream generator error), pairwise_averages = total/num_chunks divides by zero and the whole run crashes. Bail out (or raise) when num_chunks == 0.

Patch example:

-    num_chunks = len(chunk_results)
-    pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}
+    num_chunks = len(chunk_results)
+    if num_chunks == 0:
+        raise RuntimeError("No chunks produced by any model; KL comparison cannot proceed.")
+    pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Calculate average KL divergences
num_chunks = len(chunk_results)
pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}
debug_print("\nFinal KL divergence totals:")
for pair, total in pairwise_totals.items():
debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}")
# Prepare results
results = {
"models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list},
"total_chunks": int(num_chunks),
"sequence_length": int(seq_len),
"max_context_length": int(reference_data["max_context_length"]),
"kl_divergences": {
pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])}
for pair in pairwise_totals
},
"chunk_results": chunk_results,
"computation_timestamp": datetime.now().isoformat(),
"summary": {
"interpretation": "Lower KL divergence indicates more similar model outputs",
"baseline_reference": model_logits_list[0][0],
"pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()},
"chunks_processed": int(num_chunks),
},
}
return results
# Calculate average KL divergences
num_chunks = len(chunk_results)
if num_chunks == 0:
raise RuntimeError("No chunks produced by any model; KL comparison cannot proceed.")
pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()}
debug_print("\nFinal KL divergence totals:")
for pair, total in pairwise_totals.items():
debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}")
# Prepare results
results = {
"models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list},
"total_chunks": int(num_chunks),
"sequence_length": int(seq_len),
"max_context_length": int(reference_data["max_context_length"]),
"kl_divergences": {
pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])}
for pair in pairwise_totals
},
"chunk_results": chunk_results,
"computation_timestamp": datetime.now().isoformat(),
"summary": {
"interpretation": "Lower KL divergence indicates more similar model outputs",
"baseline_reference": model_logits_list[0][0],
"pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()},
"chunks_processed": int(num_chunks),
},
}
return results

Comment on lines +144 to +215
total_kl_divergence = 0.0
total_batch = 0
for begin_loc in range(0, seq_len_ref, max_context_length):
end_loc = min(begin_loc + max_context_length, seq_len_ref)
# Extract the current chunk of input tokens
input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone()
input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone()
if DEBUG:
print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}")
print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}")
# Set up generator parameters for deterministic generation (no sampling)
params_ref = og.GeneratorParams(ref_model)
params_tar = og.GeneratorParams(tar_model)
params_ref.set_search_options(
max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False
)
params_tar.set_search_options(
max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False
)
# Create generator and append input tokens
generator_ref = og.Generator(ref_model, params_ref)
generator_ref.append_tokens(input_ids_chunk_ref.numpy())
generator_tar = og.Generator(tar_model, params_tar)
generator_tar.append_tokens(input_ids_chunk_tar.numpy())

# Run the model forward pass without gradient calculation
with torch.no_grad():
if DEBUG:
print("[INFER] Running model forward pass ...")
try:
generator_ref.generate_next_token()
generator_tar.generate_next_token()
except Exception as e:
print(f"[INFER] .generate_next_token() failed: {e}")
break # Fatal error
# Get logits output from the model
logits_ref = generator_ref.get_output("logits")
logits_tar = generator_tar.get_output("logits")
if DEBUG:
print(f"logits_ref.shape: {logits_ref.shape}")
print(f"logits_tar.shape: {logits_tar.shape}")
# Convert numpy arrays to torch tensors
logits_ref = torch.tensor(logits_ref, dtype=torch.float32)
logits_tar = torch.tensor(logits_tar, dtype=torch.float32)
# Compute log probabilities over vocabulary for each position
log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy()
log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy()
if DEBUG:
print(f"log_probs_ref.shape: {log_probs_ref.shape}")
print(f"log_probs_tar.shape: {log_probs_tar.shape}")
# Compute KL divergence
kl_divergence = 0.0
# Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256)
log_probs_ref = log_probs_ref.squeeze(0)
log_probs_tar = log_probs_tar.squeeze(0)

# log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32)
# log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32)
# kl_divergence = torch.nn.functional.kl_div(
# log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True
# )
kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar)
total_kl_divergence += kl_divergence
total_batch += 1
if DEBUG:
print(f"KL divergence: {kl_divergence}")
avg_kl_divergence = total_kl_divergence / total_batch
if DEBUG:
print(f"Average KL divergence: {avg_kl_divergence}")
print(f"Total KL divergence: {total_kl_divergence}")
print(f"Total batch: {total_batch}")
print(f"Average KL divergence: {avg_kl_divergence}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid dividing by zero when no chunk finishes.

If generate_next_token() fails on the first chunk (or the dataset is empty), total_batch stays 0 and we raise ZeroDivisionError at Line 210. Short-circuit when no batches succeed instead of crashing.

Consider:

@@
-    avg_kl_divergence = total_kl_divergence / total_batch
+    if total_batch == 0:
+        print("[ERROR] No chunks were processed successfully; aborting KL computation.")
+        return
+    avg_kl_divergence = total_kl_divergence / total_batch
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
total_kl_divergence = 0.0
total_batch = 0
for begin_loc in range(0, seq_len_ref, max_context_length):
end_loc = min(begin_loc + max_context_length, seq_len_ref)
# Extract the current chunk of input tokens
input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone()
input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone()
if DEBUG:
print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}")
print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}")
# Set up generator parameters for deterministic generation (no sampling)
params_ref = og.GeneratorParams(ref_model)
params_tar = og.GeneratorParams(tar_model)
params_ref.set_search_options(
max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False
)
params_tar.set_search_options(
max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False
)
# Create generator and append input tokens
generator_ref = og.Generator(ref_model, params_ref)
generator_ref.append_tokens(input_ids_chunk_ref.numpy())
generator_tar = og.Generator(tar_model, params_tar)
generator_tar.append_tokens(input_ids_chunk_tar.numpy())
# Run the model forward pass without gradient calculation
with torch.no_grad():
if DEBUG:
print("[INFER] Running model forward pass ...")
try:
generator_ref.generate_next_token()
generator_tar.generate_next_token()
except Exception as e:
print(f"[INFER] .generate_next_token() failed: {e}")
break # Fatal error
# Get logits output from the model
logits_ref = generator_ref.get_output("logits")
logits_tar = generator_tar.get_output("logits")
if DEBUG:
print(f"logits_ref.shape: {logits_ref.shape}")
print(f"logits_tar.shape: {logits_tar.shape}")
# Convert numpy arrays to torch tensors
logits_ref = torch.tensor(logits_ref, dtype=torch.float32)
logits_tar = torch.tensor(logits_tar, dtype=torch.float32)
# Compute log probabilities over vocabulary for each position
log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy()
log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy()
if DEBUG:
print(f"log_probs_ref.shape: {log_probs_ref.shape}")
print(f"log_probs_tar.shape: {log_probs_tar.shape}")
# Compute KL divergence
kl_divergence = 0.0
# Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256)
log_probs_ref = log_probs_ref.squeeze(0)
log_probs_tar = log_probs_tar.squeeze(0)
# log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32)
# log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32)
# kl_divergence = torch.nn.functional.kl_div(
# log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True
# )
kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar)
total_kl_divergence += kl_divergence
total_batch += 1
if DEBUG:
print(f"KL divergence: {kl_divergence}")
avg_kl_divergence = total_kl_divergence / total_batch
if DEBUG:
print(f"Average KL divergence: {avg_kl_divergence}")
print(f"Total KL divergence: {total_kl_divergence}")
print(f"Total batch: {total_batch}")
print(f"Average KL divergence: {avg_kl_divergence}")
# After processing all chunks, compute average KL divergence
if total_batch == 0:
print("[ERROR] No chunks were processed successfully; aborting KL computation.")
return
avg_kl_divergence = total_kl_divergence / total_batch
if DEBUG:
print(f"Average KL divergence: {avg_kl_divergence}")
print(f"Total KL divergence: {total_kl_divergence}")
print(f"Total batch: {total_batch}")
print(f"Average KL divergence: {avg_kl_divergence}")
🤖 Prompt for AI Agents
In
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py
around lines 144-215, the code divides by total_batch without guarding against
total_batch == 0 which can raise ZeroDivisionError if no chunk completes; change
the logic to check if total_batch is zero before computing avg_kl_divergence and
short-circuit: if total_batch == 0, log/print a clear message (or set
avg_kl_divergence to 0.0) and avoid the division, then proceed to print totals
or return/exit as appropriate so the script does not crash.

Comment on lines +280 to +497
mask_flat = target_ids_shifted != -100
if kv_chunking_enabled:
trg_len = min(trg_len, stride)
mask_flat = np.ones(trg_len, dtype=bool)
valid_indices = np.arange(0, trg_len - 1)
valid_targets = target_ids_shifted[-trg_len + 1 :]
else:
valid_indices = np.arange(len(target_ids_shifted))[mask_flat]
valid_targets = target_ids_shifted[mask_flat]
if DEBUG:
print(f"[VALID_INDICES] Valid indices shape: {valid_indices.shape}")
print(f"[VALID_TARGETS] Valid targets shape: {valid_targets.shape}")
# Gather the log probabilities for the correct target tokens
valid_log_probs = pred_log_probs[valid_indices, valid_targets]
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix indexing mismatch in KV chunking branch.

Line 284 creates indices from 0 to trg_len-2 (first trg_len-1 predictions), but line 285 extracts the last trg_len-1 targets. When indexing pred_log_probs on line 293, this mismatch causes incorrect log probabilities to be gathered.

Apply this diff to align indices with targets:

             # Only include tokens with label != -100 (matching HF masking)
             mask_flat = target_ids_shifted != -100
             if kv_chunking_enabled:
                 trg_len = min(trg_len, stride)
                 mask_flat = np.ones(trg_len, dtype=bool)
-                valid_indices = np.arange(0, trg_len - 1)
+                valid_indices = np.arange(chunk_seq_len - trg_len, chunk_seq_len - 1)
                 valid_targets = target_ids_shifted[-trg_len + 1 :]

Committable suggestion skipped: line range outside the PR's diff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as part of prefill chunking enabled logits returned are for the last chunk_seq_len so trg_len is same as chunk_seq_len.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ynankani, thank you for the clarification! I understand now—when prefill chunking is enabled, the returned logits are already scoped to the last chunk_seq_len positions, so trg_len matches chunk_seq_len and the indexing from 0 correctly aligns with those logits. I appreciate you explaining this implementation detail.


✏️ Learnings added
Learnt from: ynankani
PR: NVIDIA/TensorRT-Model-Optimizer#411
File: examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py:281-294
Timestamp: 2025-10-12T07:09:32.831Z
Learning: In `examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py`, when KV chunking (prefill chunking) is enabled in the perplexity_eval function, the logits returned from the ONNX Runtime GenAI model correspond to the last chunk_seq_len positions. Therefore, trg_len equals chunk_seq_len, and the indexing pattern `valid_indices = np.arange(0, trg_len - 1)` correctly aligns with `valid_targets = target_ids_shifted[-trg_len + 1:]` without needing offset adjustment.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Comment on lines 83 to 88
if isinstance(i, str):
i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
elif isinstance(i, (list, tuple)):
i_list = [int(x) for x in i]
else:
i_list = [int(i)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Validate parsed input lengths.

The input parsing doesn't verify that parsed values are positive integers, which could lead to runtime errors when calling perplexity_eval.

Apply this diff to add validation:

     # Parse i as a comma-separated list of integers, pass each to perplexity_eval
     if isinstance(i, str):
-        i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
+        try:
+            i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
+            if any(val < 1 for val in i_list):
+                raise ValueError("Input lengths must be positive integers")
+        except ValueError as e:
+            print(f"Error: Invalid input length format: {e}")
+            return pd.DataFrame()
     elif isinstance(i, (list, tuple)):
-        i_list = [int(x) for x in i]
+        try:
+            i_list = [int(x) for x in i]
+            if any(val < 1 for val in i_list):
+                raise ValueError("Input lengths must be positive integers")
+        except (ValueError, TypeError) as e:
+            print(f"Error: Invalid input length format: {e}")
+            return pd.DataFrame()
     else:
-        i_list = [int(i)]
+        try:
+            i_list = [int(i)]
+            if i_list[0] < 1:
+                raise ValueError("Input length must be a positive integer")
+        except (ValueError, TypeError) as e:
+            print(f"Error: Invalid input length format: {e}")
+            return pd.DataFrame()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(i, str):
i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
elif isinstance(i, (list, tuple)):
i_list = [int(x) for x in i]
else:
i_list = [int(i)]
# Parse i as a comma-separated list of integers, pass each to perplexity_eval
if isinstance(i, str):
try:
i_list = [int(x.strip()) for x in i.split(",") if x.strip()]
if any(val < 1 for val in i_list):
raise ValueError("Input lengths must be positive integers")
except ValueError as e:
print(f"Error: Invalid input length format: {e}")
return pd.DataFrame()
elif isinstance(i, (list, tuple)):
try:
i_list = [int(x) for x in i]
if any(val < 1 for val in i_list):
raise ValueError("Input lengths must be positive integers")
except (ValueError, TypeError) as e:
print(f"Error: Invalid input length format: {e}")
return pd.DataFrame()
else:
try:
i_list = [int(i)]
if i_list[0] < 1:
raise ValueError("Input length must be a positive integer")
except (ValueError, TypeError) as e:
print(f"Error: Invalid input length format: {e}")
return pd.DataFrame()
🤖 Prompt for AI Agents
In examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py
around lines 83 to 88, the parsed input values are converted to ints but not
validated for positivity; update the parsing so after converting values in each
branch you filter/validate that every int is a positive integer (>0), raising a
ValueError (with a clear message including the original input) if any value is
non-positive or conversion fails; ensure the validation covers single-string,
comma-separated string, list/tuple and single-value cases and preserves the
resulting i_list for downstream perplexity_eval.

Comment on lines 112 to 215
except Exception as e: # noqa: PERF203
print(f" Error for input length {input_len}: {e!s}")
results.append(
{
"Model Path": model_dir,
"Input Length": int(input_len),
"Perplexity": "N/A",
"Status": "Failed",
"Error": str(e),
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Narrow exception handling scope.

Catching broad Exception with a loop-suppression comment (PERF203) makes debugging difficult. Catch specific exceptions and let unexpected errors propagate for proper diagnosis.

Apply this diff to catch specific exceptions:

-            except Exception as e:  # noqa: PERF203
+            except (RuntimeError, ValueError, OSError) as e:
                 print(f"  Error for input length {input_len}: {e!s}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except Exception as e: # noqa: PERF203
print(f" Error for input length {input_len}: {e!s}")
results.append(
{
"Model Path": model_dir,
"Input Length": int(input_len),
"Perplexity": "N/A",
"Status": "Failed",
"Error": str(e),
}
)
except (RuntimeError, ValueError, OSError) as e:
print(f" Error for input length {input_len}: {e!s}")
results.append(
{
"Model Path": model_dir,
"Input Length": int(input_len),
"Perplexity": "N/A",
"Status": "Failed",
"Error": str(e),
}
)
🤖 Prompt for AI Agents
In examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py
around lines 112 to 122, the code currently catches a broad Exception which
hides unexpected errors; replace the broad except Exception with targeted except
blocks for the expected failure modes (for example except ValueError, OSError,
RuntimeError as e:) that append the failure result and error string, and
explicitly re-raise or omit catching critical/unexpected exceptions (e.g.,
KeyboardInterrupt, SystemExit) so they propagate; remove the PERF203 noqa and
ensure any additional specific exceptions you expect from the code path are
handled individually rather than using a blanket Exception catch.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)

281-294: Fix indexing mismatch in KV chunking branch.

Lines 285-286 create an index/target misalignment: valid_indices selects the first trg_len-1 positions of pred_log_probs (indices 0 to trg_len-2), while valid_targets extracts the last trg_len-1 elements of target_ids_shifted. When gathering log probabilities at line 294 with pred_log_probs[valid_indices, valid_targets], this compares predictions from the start of the chunk against targets from the end, producing incorrect perplexity scores.

Apply this diff to align indices with targets:

             # Only include tokens with label != -100 (matching HF masking)
             mask_flat = target_ids_shifted != -100
             if kv_chunking_enabled:
                 trg_len = min(trg_len, stride)
                 mask_flat = np.ones(trg_len, dtype=bool)
-                valid_indices = np.arange(0, trg_len - 1)
+                valid_indices = np.arange(chunk_seq_len - trg_len, chunk_seq_len - 1)
                 valid_targets = target_ids_shifted[-trg_len + 1 :]
🧹 Nitpick comments (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1)

139-145: Remove duplicate assignment.

Line 143 duplicates the assignment from line 142.

Apply this diff:

     elif model_cfg_json and "search" in model_cfg_json and "chunk_size" in model_cfg_json["search"]:
         # Use chunk_size from existing config file
         effective_chunk_size = model_cfg_json["search"]["chunk_size"]
         kv_chunking_enabled = True
-        kv_chunking_enabled = True
         if DEBUG:
             print(f"[CONFIG] Using config file chunk_size: {effective_chunk_size}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6f0ecd3 and f0c8093.

📒 Files selected for processing (1)
  • examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (3)
examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py (1)
  • get_wikitext2 (51-75)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py (1)
  • get_wikitext2 (33-50)
examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py (1)
  • get_wikitext2 (30-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py (6)

1-69: LGTM! License, imports, and dataset loader are well-structured.

The MIT license header properly attributes both Microsoft (original) and NVIDIA (modifications). The get_wikitext2() function is consistent with identical implementations across the benchmark suite (extract_logits.py, extract_logits_hf.py, KL_divergence_metrics_same_ep.py), which maintains good code consistency.


72-114: LGTM! Function signature and model loading are well-designed.

The function is thoroughly documented with clear parameter descriptions and return value guidance. The chunk_size overlay mechanism is properly implemented with appropriate error handling, supporting the PR's goal of configurable chunk sizes.


115-126: LGTM! Config file loading now properly handles errors.

The try-except wrapper around genai_config.json loading addresses the past review concern about missing or malformed config files. The fallback to defaults prevents crashes.


166-176: LGTM! Safe fallback for context_length.

The code safely extracts context_length from the config JSON with appropriate None checks and defaults, preventing crashes when the config structure varies.


178-262: LGTM! Dataset loading, tokenization, and inference are robust.

The code properly handles different tokenizer output formats (dict with "input_ids" key vs direct array), ensures correct tensor dimensions, and uses appropriate context managers. The try-except around generate_next_token() prevents cascading failures during inference.


301-328: LGTM! Token counting and final computation are correct.

The token counting now correctly uses int(valid_log_probs.size) instead of summing the mask, ensuring the count matches the number of gathered log probabilities (addresses a past review concern). The final perplexity calculation follows standard methodology.

@ynankani ynankani force-pushed the ynankani/benchmark_metrics branch from f0c8093 to 9e177a4 Compare October 25, 2025 11:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant