diff --git a/examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py b/examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py new file mode 100644 index 000000000..43e5ec39c --- /dev/null +++ b/examples/windows/accuracy_benchmark/kl_divergence_metrics/KL_divergence_metrics_same_ep.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Optimized KL divergence comparison for ONNX Runtime GenAI models with the same execution provider. + +This script efficiently compares two ONNX Runtime GenAI models by computing KL divergence +between their output distributions without package switching overhead. + +Usage: + python KL_divergence_metrics_same_ep.py \\ + --reference_model "path/to/reference/model" \\ + --target_model "path/to/target/model" +""" + +import argparse +import os + +import numpy as np +import onnxruntime_genai as og +import torch +from datasets import load_dataset + +DEBUG = False + + +def get_kl_divergence(log_probs_ref, log_probs_tar): + """ + Compute Kullback-Leibler divergence between two log probability distributions. + + KL divergence measures how one probability distribution diverges from a reference + distribution. Lower values indicate more similar distributions. + + Args: + log_probs_ref (np.ndarray): Reference log probabilities with shape (seq_len, vocab_size). + log_probs_tar (np.ndarray): Target log probabilities with shape (seq_len, vocab_size). + + Returns: + float: Average KL divergence across all positions. + + Note: + Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length + """ + kl_divergence = 0.0 + for i in range(log_probs_ref.shape[0]): + log_probs_ref[i] = np.array(log_probs_ref[i]) + log_probs_tar[i] = np.array(log_probs_tar[i]) + prob_ref = np.exp(log_probs_ref[i]) + kl_divergence += np.sum(prob_ref * abs(log_probs_ref[i] - log_probs_tar[i])) + kl_divergence = kl_divergence / log_probs_ref.shape[0] + return kl_divergence + + +def get_wikitext2(): + """ + Load and concatenate the WikiText-2 test dataset. + + Returns: + str: Concatenated text from all samples in the WikiText-2 test split, + with samples separated by double newlines. + + Note: + Requires HuggingFace CLI authentication to access the dataset. + """ + # Load the Wikitext-2 test split using HuggingFace datasets + print("\n[INFO] Loading Wikitext-2 'test' split ...") + test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + if DEBUG: + print(f"[DATASET] Number of raw samples: {len(test)}") + for i in range(3): + print(f"[DATASET] Sample[{i}]: {repr(test[i]['text'])[:200]} ...") + # Concatenate all text samples into a single string, separated by double newlines + result = "\n\n".join(text for text in test["text"]) + if DEBUG: + print( + f"[DATASET] Concatenated text preview: {result[:512]!r} ... [total chars: {len(result)}]" + ) + return result + + +def run_kl_divergence_on_models(reference_model, target_model): + """ + Compute KL divergence between two ONNX Runtime GenAI models on WikiText-2 dataset. + + This function loads both models, processes the WikiText-2 dataset in chunks, and + computes the KL divergence between their output distributions for each chunk. + The results are averaged across all chunks. + + Args: + reference_model (str): Path to the reference ONNX Runtime GenAI model directory. + target_model (str): Path to the target ONNX Runtime GenAI model directory. + + """ + ref_model = og.Model(reference_model) + tar_model = og.Model(target_model) + tokenizer_ref = og.Tokenizer(ref_model) + tokenizer_tar = og.Tokenizer(tar_model) + max_context_length = 1024 + dataset = get_wikitext2() + + input_ids_ref = tokenizer_ref.encode_batch([dataset]) + input_ids_tar = tokenizer_tar.encode_batch([dataset]) + # Handle possible dict output from tokenizer + if isinstance(input_ids_ref, dict) and "input_ids" in input_ids_ref: + input_ids_ref = input_ids_ref["input_ids"] + # Convert to numpy if needed + if hasattr(input_ids_ref, "as_numpy"): + input_ids_ref = input_ids_ref.as_numpy() + if DEBUG: + print("[TOKENIZER] Used as_numpy()") + if isinstance(input_ids_tar, dict) and "input_ids" in input_ids_tar: + input_ids_tar = input_ids_tar["input_ids"] + if hasattr(input_ids_tar, "as_numpy"): + input_ids_tar = input_ids_tar.as_numpy() + if DEBUG: + print("[TOKENIZER] Used as_numpy()") + input_ids_ref = np.array(input_ids_ref) + input_ids_tar = np.array(input_ids_tar) + + # Ensure input_ids is 2D (batch, seq_len) + if input_ids_ref.ndim == 1: + input_ids_ref = np.expand_dims(input_ids_ref, 0) + if DEBUG: + print(f"[SHAPE] Expanded dims, now: {input_ids_ref.shape}") + if input_ids_tar.ndim == 1: + input_ids_tar = np.expand_dims(input_ids_tar, 0) + if DEBUG: + print(f"[SHAPE] Expanded dims, now: {input_ids_tar.shape}") + # Convert input_ids to torch tensor + input_ids_ref = torch.tensor(input_ids_ref, dtype=torch.long) + input_ids_tar = torch.tensor(input_ids_tar, dtype=torch.long) + seq_len_ref = int(input_ids_ref.shape[1]) + seq_len_tar = int(input_ids_tar.shape[1]) + if DEBUG: + print(f"[INFO] Ref input length: {seq_len_ref}") + print(f"[INFO] Tar input length: {seq_len_tar}") + + if seq_len_ref != seq_len_tar: + print( + f"Error: Input tokenizer lengths for reference and target models do not match: " + f"{seq_len_ref} != {seq_len_tar}" + ) + return + if DEBUG: + print(f"[INFO] Input lengths match: {seq_len_ref}") + # Slide a window over the input to compute perplexity in chunks + total_kl_divergence = 0.0 + total_batch = 0 + for begin_loc in range(0, seq_len_ref, max_context_length): + end_loc = min(begin_loc + max_context_length, seq_len_ref) + # Extract the current chunk of input tokens + input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone() + input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone() + if DEBUG: + print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}") + print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}") + # Set up generator parameters for deterministic generation (no sampling) + params_ref = og.GeneratorParams(ref_model) + params_tar = og.GeneratorParams(tar_model) + params_ref.set_search_options( + max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False + ) + params_tar.set_search_options( + max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False + ) + # Create generator and append input tokens + generator_ref = og.Generator(ref_model, params_ref) + generator_ref.append_tokens(input_ids_chunk_ref.numpy()) + generator_tar = og.Generator(tar_model, params_tar) + generator_tar.append_tokens(input_ids_chunk_tar.numpy()) + + # Run the model forward pass without gradient calculation + with torch.no_grad(): + if DEBUG: + print("[INFER] Running model forward pass ...") + try: + generator_ref.generate_next_token() + generator_tar.generate_next_token() + except Exception as e: + print(f"[INFER] .generate_next_token() failed: {e}") + break # Fatal error + # Get logits output from the model + logits_ref = generator_ref.get_output("logits") + logits_tar = generator_tar.get_output("logits") + if DEBUG: + print(f"logits_ref.shape: {logits_ref.shape}") + print(f"logits_tar.shape: {logits_tar.shape}") + # Convert numpy arrays to torch tensors + logits_ref = torch.tensor(logits_ref, dtype=torch.float32) + logits_tar = torch.tensor(logits_tar, dtype=torch.float32) + # Compute log probabilities over vocabulary for each position + log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy() + log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy() + if DEBUG: + print(f"log_probs_ref.shape: {log_probs_ref.shape}") + print(f"log_probs_tar.shape: {log_probs_tar.shape}") + # Compute KL divergence + kl_divergence = 0.0 + # Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256) + log_probs_ref = log_probs_ref.squeeze(0) + log_probs_tar = log_probs_tar.squeeze(0) + + # log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32) + # log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32) + # kl_divergence = torch.nn.functional.kl_div( + # log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True + # ) + kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar) + total_kl_divergence += kl_divergence + total_batch += 1 + if DEBUG: + print(f"KL divergence: {kl_divergence}") + avg_kl_divergence = total_kl_divergence / total_batch + if DEBUG: + print(f"Average KL divergence: {avg_kl_divergence}") + print(f"Total KL divergence: {total_kl_divergence}") + print(f"Total batch: {total_batch}") + print(f"Average KL divergence: {avg_kl_divergence}") + + +def main(): + """ + Command-line entry point for optimized KL divergence comparison of same-EP models. + + This script is optimized for comparing two ONNX Runtime GenAI models that use + the same execution provider, avoiding package switching overhead. It computes + KL divergence between model outputs on the WikiText-2 dataset. + + Command-line Arguments: + --reference_model: Path to reference model directory (required) + --target_model: Path to target model directory (required) + + Example: + $ python KL_divergence_metrics_same_ep.py \\ + --reference_model "G:\\models\\cuda_fp16" \\ + --target_model "G:\\models\\cuda_int4" + """ + parser = argparse.ArgumentParser( + description="Run KL divergence evaluation on ONNX Runtime GenAI models" + ) + parser.add_argument( + "--reference_model", required=True, help="Path to reference model directory" + ) + parser.add_argument("--target_model", required=True, help="Path to target model directory") + args = parser.parse_args() + + # Validate that all model directories exist + valid_models = [] + if os.path.exists(args.reference_model): + valid_models.append(args.reference_model) + else: + print(f"Warning: Reference Model directory does not exist: {args.reference_model}") + if os.path.exists(args.target_model): + valid_models.append(args.target_model) + else: + print(f"Warning: Target Model directory does not exist: {args.target_model}") + if len(valid_models) != 2: + print("Error: No valid model directories provided") + return + + print( + f"Running KL divergence evaluation on reference model={valid_models[0]} and target model={valid_models[1]}" + ) + run_kl_divergence_on_models(valid_models[0], valid_models[1]) + + +if __name__ == "__main__": + main() diff --git a/examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md b/examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md new file mode 100644 index 000000000..ba2b8763b --- /dev/null +++ b/examples/windows/accuracy_benchmark/kl_divergence_metrics/README.md @@ -0,0 +1,286 @@ +# KL Divergence Model Validation Toolkit + +This toolkit provides comprehensive model validation capabilities using KL divergence metrics to compare different model implementations and execution providers. It's designed to evaluate the similarity between model outputs across different optimization techniques and hardware backends. + +## Overview + +The toolkit includes several Python scripts for: + +1. **Extracting logits** from both Hugging Face and ONNX Runtime GenAI models +2. **Computing KL divergence** between model pairs or multiple models +3. **Comparing execution providers** (CUDA, DirectML, CPU) against baseline models +4. **Validating model optimization** quality by measuring output similarity + +## Key Components + +### Core Scripts + +| Script | Purpose | Usage | +|--------|---------|--------| +| `extract_logits_hf.py` | Extract logits from Hugging Face models using transformers | Baseline model logit extraction | +| `extract_logits.py` | Extract logits from ONNX Runtime GenAI models | Optimized model logit extraction | +| `KL_divergence_metrics_same_ep.py` | Compare two ONNX Runtime GenAI models directly | Same execution provider comparison | +| `compute_kl_divergence.py` | Unified comparison framework | All-in-one comparison tool | + +### Datasets Used + +- **Wikitext-2** test split for consistent evaluation across all models +- Automatic dataset loading and preprocessing via HuggingFace datasets + +## Installation + +1. **Install base requirements:** + + ```bash + pip install -r requirements.txt + ``` + + Note: Install torch with CUDA for faster inference: + "pip install torch torchvision torchaudio --index-url " + +2. **Install execution provider-specific packages** (as needed): + + ```bash + # For CUDA support + pip install onnxruntime-genai-cuda + + # For DirectML support + pip install onnxruntime-genai-directml + + # For CPU support (default) + pip install onnxruntime-genai + ``` + +## Usage Examples + +### 1. Unified Comparison Tool + +The `compute_kl_divergence.py` script is the unified tool for all comparison scenarios: + +#### Compare HF baseline vs ONNX models + +```bash +python compute_kl_divergence.py \ + --hf_model "path/to/hf/model" \ + --ep cuda --path "path/to/cuda/model" \ + --output "hf_vs_cuda_results.json" +``` + +#### Compare HF vs multiple execution providers + +```bash +python compute_kl_divergence.py \ + --hf_model "path/to/hf/model" \ + --ep cuda --path "path/to/cuda/model" \ + --ep directml --path "path/to/directml/model" \ + --output "multi_provider_comparison.json" +``` + +#### Compare ONNX models WITHOUT HF baseline (NEW!) + +```bash +# Two models with same EP (automatically uses optimized same_ep script) +python compute_kl_divergence.py \ + --ep cuda --path "path/to/cuda_fp16/model" \ + --ep cuda --path "path/to/cuda_int4/model" \ + --output "cuda_fp16_vs_int4.json" + +# Two models with different EPs +python compute_kl_divergence.py \ + --ep cuda --path "path/to/cuda/model" \ + --ep directml --path "path/to/directml/model" \ + --output "cuda_vs_directml.json" + +# Multiple ONNX models with mixed EPs +python compute_kl_divergence.py \ + --ep cuda --path "path/to/cuda_fp16/model" \ + --ep cuda --path "path/to/cuda_int4/model" \ + --ep directml --path "path/to/directml/model" \ + --output "multi_onnx_comparison.json" +``` + +#### Enable debug output for detailed logging + +```bash +python compute_kl_divergence.py \ + --ep cuda --path "path/to/model1" \ + --ep directml --path "path/to/model2" \ + --output "results.json" \ + --debug +``` + +### 2. Direct Same-EP Comparison (Alternative) + +For comparing two ONNX models with the same execution provider, you can also use: + +```bash +python KL_divergence_metrics_same_ep.py \ + --reference_model "path/to/reference/model" \ + --target_model "path/to/target/model" +``` + +### 3. Extract Logits Separately (Advanced) + +If you need to extract logits separately for reuse: + +#### From Hugging Face Model + +```bash +python extract_logits_hf.py \ + --model_path "path/to/huggingface/model" \ + --output_file "hf_logits.pkl" \ + --device cuda \ + --debug +``` + +#### From ONNX Runtime GenAI Model + +```bash +python extract_logits.py \ + --model_path "path/to/onnx/model" \ + --output_file "onnx_logits.pkl" \ + --provider cuda \ + --debug +``` + +## Configuration Parameters + +### compute_kl_divergence.py Parameters + +- `--hf_model`: Path to Hugging Face baseline model (optional - can compare ONNX models directly) +- `--ep`: Execution provider (can be specified multiple times for multiple models) + - Supported: `cuda`, `directml`, `cpu` +- `--path`: Model path (must match order of --ep arguments) +- `--output`: Output JSON file for results (required) +- `--device`: Device for HF model inference (default: `cuda`, choices: `cuda`, `cpu`) +- `--keep_logits`: Keep extracted logits files after comparison +- `--debug`: Enable verbose debug output with detailed logging + +### Other Script Parameters + +- `--model_path`: Path to model (for extract_logits scripts) +- `--output_file`: Output file for extracted logits (`.pkl` format) +- `--provider`: Execution provider for ONNX models (`cuda`, `directml`, `cpu`) +- `--reference_model`: Reference model path (for same_ep script) +- `--target_model`: Target model path (for same_ep script) + +### Model Processing Parameters + +- **Max context length**: 1024 tokens (configurable in code) +- **Chunk processing**: Automatic chunking for memory management +- **Deterministic generation**: No sampling for consistent results + +## Output Files and Interpretation + +### Logits Files (`.pkl`) + +Pickled files containing: + +- **logits**: List of numpy arrays with model logits per chunk +- **chunk_info**: Metadata about each processed chunk +- **model_path**: Path to the source model +- **provider**: Execution provider used +- **total_chunks**: Number of chunks processed + +### Results Files (`.json`) + +JSON files containing: + +- **models**: Paths to all compared models +- **kl_divergences**: Pairwise KL divergence values + - `total`: Sum of KL divergences across all chunks + - `average`: Mean KL divergence per chunk +- **chunk_results**: Detailed per-chunk analysis +- **summary**: Interpretation and key metrics + +### Example Results Structure + +```json +{ + "models": { + "huggingface": "path/to/hf/model", + "cuda": "path/to/cuda/model", + "directml": "path/to/directml/model" + }, + "kl_divergences": { + "huggingface_vs_cuda": { + "total": , + "average": + }, + "huggingface_vs_directml": { + "total": , + "average": + }, + "cuda_vs_directml": { + "total": , + "average": + } + }, + "summary": { + "interpretation": "Lower KL divergence indicates more similar model outputs", + "baseline_reference": "huggingface", + "pairwise_averages": { + "huggingface_vs_cuda": , + "huggingface_vs_directml": , + "cuda_vs_directml": + } + } +} +``` + +## Interpreting KL Divergence Values + +| KL Divergence Range | Interpretation | +|-------------------|----------------| +| **0 - 1** | Nearly identical outputs | +| **1 - 10** | Very similar outputs | +| **10 - 50** | Moderately similar outputs | +| **50+** | Significantly different outputs | + +### Key Insights from Results + +- **Lower values** indicate better optimization quality (closer to baseline) +- **Baseline comparison** shows how much optimization affects output quality +- **Provider comparison** reveals differences between execution backends +- **Consistency check** ensures model optimization maintains output quality + +## Key Features + +### Flexible Comparison Modes + +1. **HF vs ONNX models**: Compare Hugging Face baseline against one or more ONNX models +2. **ONNX-only comparison**: Compare ONNX models directly without HF baseline +3. **Mixed execution providers**: Compare models across different hardware backends +4. **Multiple same-EP models**: Compare multiple variants of the same execution provider + +### Output Verbosity Control + +- **Without `--debug`**: Clean, minimal progress output showing key steps + - Model extraction progress + - Environment switching notifications + - Computation progress + - Final results summary +- **With `--debug`**: Comprehensive logging including: + - Detailed model paths and configurations + - Package installation details + - Chunk-by-chunk processing + - Validation warnings + - Temporary file management + - Full error tracebacks + +### Automatic Package Management + +- The script automatically installs/uninstalls the correct ONNX Runtime packages based on execution provider +- Minimizes package switching by reusing environments when possible +- Handles CUDA, DirectML, and CPU providers seamlessly + +Warning: This mutates your Python environment (pip uninstall/install). Run inside an isolated virtualenv/conda env to avoid impacting other projects. + +## Notes + +- The comparison uses the Wikitext-2 dataset for evaluation +- Processing is done in chunks (1024 tokens) to handle memory constraints +- The script automatically handles package installation/uninstallation for different providers +- Results are deterministic (no sampling) for consistent comparisons +- All pairwise comparisons are computed for multi-model scenarios +- HF model is optional - you can compare ONNX models directly diff --git a/examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py b/examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py new file mode 100644 index 000000000..99aa3c342 --- /dev/null +++ b/examples/windows/accuracy_benchmark/kl_divergence_metrics/compute_kl_divergence.py @@ -0,0 +1,836 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Generic model comparison script that compares a Hugging Face baseline model +against multiple ONNX Runtime GenAI models with different execution providers. + +Usage: +python compute_kl_divergence.py --hf_model "F:\shared\Llama-3.1-8B-Instruct" + --ep cuda --path "G:\models\cuda_model" + --ep directml --path "G:\models\directml_model" + --output "comparison_results.json" +""" + +import argparse +import json +import os +import pickle +import subprocess +import sys +import time +from datetime import datetime + +import numpy as np +import torch + +# We'll use subprocess calls to run extraction scripts with fresh Python processes +# This ensures no import cache issues when switching packages + +# Mapping of execution providers to their corresponding ONNX Runtime packages +EP_PACKAGE_MAP = { + "cuda": "onnxruntime-genai-cuda", + "directml": "onnxruntime-genai-directml", + "cpu": "onnxruntime-genai", +} + +DEBUG = False # Global debug flag + + +def debug_print(message): + """ + Print debug message only if DEBUG flag is enabled. + + Args: + message (str): Debug message to print. + """ + if DEBUG: + print(f"[DEBUG] {message}") + + +def run_command(cmd, description="", capture_output=True): + """ + Execute a subprocess command with error handling. + + Args: + cmd (list[str]): Command and arguments to execute. + description (str, optional): Description of the command for logging. Defaults to "". + capture_output (bool, optional): Whether to capture stdout/stderr or show in real-time. + Defaults to True. + + Returns: + bool: True if command succeeded, False otherwise. + """ + debug_print(f"[INFO] {description}") + debug_print(f"Running: {' '.join(cmd)}") + + try: + if capture_output: + result = subprocess.run(cmd, check=True, capture_output=True, text=True, shell=False) + if result.stdout and DEBUG: + print(f"[OUT] {result.stdout}") + else: + # Real-time output - shows prints as they happen + result = subprocess.run(cmd, check=True, shell=False) + return True + except subprocess.CalledProcessError as e: + print(f"[ERROR] Command failed: {e}") + if capture_output and hasattr(e, "stdout") and e.stdout: + print(f"[STDOUT] {e.stdout}") + if capture_output and hasattr(e, "stderr") and e.stderr: + print(f"[STDERR] {e.stderr}") + return False + + +def get_python_executable(): + """Get the current Python executable being used""" + return sys.executable + + +def uninstall_onnxruntime_packages(): + """ + Uninstall all ONNX Runtime and ONNX Runtime GenAI packages. + + This ensures a clean environment before installing provider-specific packages + to avoid version conflicts. + """ + packages_to_remove = [ + "onnxruntime", + "onnxruntime-genai", + "onnxruntime-genai-cuda", + "onnxruntime-gpu", + "onnxruntime-directml", + "onnxruntime-genai-directml", + ] + + debug_print(f"Packages to remove: {packages_to_remove}") + python_exe = get_python_executable() + debug_print(f"Python executable: {python_exe}") + cmd = [python_exe, "-m", "pip", "uninstall", "-y", *packages_to_remove] + run_command(cmd, "Uninstalling existing ONNX Runtime packages") + + +def install_package(package_name): + """ + Install a specific Python package using pip. + + Args: + package_name (str): Name of the package to install. + + Returns: + bool: True if installation succeeded, False otherwise. + """ + debug_print(f"Installing package: {package_name}") + python_exe = get_python_executable() + debug_print(f"Python executable: {python_exe}") + cmd = [python_exe, "-m", "pip", "install", package_name, "--force-reinstall"] + debug_print(f"Install command: {' '.join(cmd)}") + return run_command(cmd, f"Installing {package_name}") + + +# Module cache clearing is no longer needed since we use subprocess calls + + +def extract_hf_logits_subprocess(model_path, device="cuda"): + """ + Extract logits from a Hugging Face transformer model using a subprocess. + + Runs extract_logits_hf.py in a separate process to avoid package conflicts. + Uses temporary file for data transfer between processes. + + Args: + model_path (str): Path to the Hugging Face model directory. + device (str, optional): Device for inference ('cuda' or 'cpu'). Defaults to "cuda". + + """ + print("[INFO] Extracting logits from Hugging Face baseline model...") + debug_print(f"Model path: {model_path}, Device: {device}") + + # Create temporary output file + import tempfile + + script_dir = os.path.dirname(os.path.abspath(__file__)) + with tempfile.NamedTemporaryFile(prefix="temp_logits_hf_", suffix=".pkl", delete=False) as tmp: + output_file = tmp.name + debug_print(f"Temporary output file: {output_file}") + + try: + python_exe = get_python_executable() + cmd = [ + python_exe, + os.path.join(script_dir, "extract_logits_hf.py"), + "--model_path", + model_path, + "--output_file", + output_file, + "--device", + device, + ] + if DEBUG: + cmd.append("--debug") + + if not run_command(cmd, "Running HF logits extraction", capture_output=False): + raise RuntimeError("HF logits extraction failed") + + # Load the extracted logits + debug_print(f"Loading logits from: {output_file}") + with open(output_file, "rb") as f: + logits_data = pickle.load(f) + + debug_print(f"Loaded logits data keys: {logits_data.keys()}") + debug_print( + f"Total chunks: {logits_data['total_chunks']}, Seq len: {logits_data['seq_len']}" + ) + + # Clean up temporary file + try: + os.remove(output_file) + debug_print(f"Cleaned up temporary file: {output_file}") + except Exception: + pass + + print(f"[INFO] HF logits extraction completed ({logits_data['total_chunks']} chunks)") + return logits_data + + except Exception as e: + # Clean up temporary file on error + import contextlib + + with contextlib.suppress(BaseException): + os.remove(output_file) + print(f"[ERROR] Failed to extract HF logits: {e}") + raise + + +def extract_onnx_logits_subprocess(model_path, provider): + """ + Extract logits from an ONNX Runtime GenAI model using a subprocess. + + Runs extract_logits.py in a separate process with the appropriate ONNX Runtime + package for the specified execution provider. Uses temporary file for data transfer. + + Args: + model_path (str): Path to the ONNX Runtime GenAI model directory. + provider (str): Execution provider ('cuda', 'directml', or 'cpu'). + + """ + print(f"[INFO] Extracting logits from {provider.upper()} model...") + debug_print(f"Model path: {model_path}, Provider: {provider}") + + # Create temporary output file + import tempfile + + script_dir = os.path.dirname(os.path.abspath(__file__)) + with tempfile.NamedTemporaryFile(prefix="temp_logits_", suffix=".pkl", delete=False) as tmp: + output_file = tmp.name + debug_print(f"Temporary output file: {output_file}") + + try: + python_exe = get_python_executable() + cmd = [ + python_exe, + os.path.join(script_dir, "extract_logits.py"), + "--model_path", + model_path, + "--output_file", + output_file, + "--provider", + provider, + ] + if DEBUG: + cmd.append("--debug") + + if not run_command( + cmd, f"Running {provider.upper()} logits extraction", capture_output=False + ): + raise RuntimeError(f"{provider.upper()} logits extraction failed") + + # Load the extracted logits + debug_print(f"Loading logits from: {output_file}") + with open(output_file, "rb") as f: + logits_data = pickle.load(f) + + debug_print(f"Loaded logits data keys: {logits_data.keys()}") + debug_print( + f"Total chunks: {logits_data['total_chunks']}, Seq len: {logits_data['seq_len']}" + ) + + # Clean up temporary file + import contextlib + + with contextlib.suppress(BaseException): + os.remove(output_file) + debug_print(f"Cleaned up temporary file: {output_file}") + + print( + f"[INFO] {provider.upper()} logits extraction completed ({logits_data['total_chunks']} chunks)" + ) + return logits_data + + except Exception as e: + # Clean up temporary file on error + import contextlib + + with contextlib.suppress(BaseException): + os.remove(output_file) + print(f"[ERROR] Failed to extract {provider.upper()} logits: {e}") + raise + + +def compute_kl_divergence_from_logits(log_probs_ref, log_probs_tar): + """ + Compute Kullback-Leibler divergence between two log probability distributions. + + KL divergence measures how one probability distribution diverges from a reference + distribution. Lower values indicate more similar distributions. + + Args: + log_probs_ref (np.ndarray): Reference log probabilities with shape (seq_len, vocab_size). + log_probs_tar (np.ndarray): Target log probabilities with shape (seq_len, vocab_size). + + Returns: + float: Average KL divergence across all positions. + + Note: + Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length + """ + debug_print( + f"Computing KL divergence - log_probs shapes: ref={log_probs_ref.shape}, tar={log_probs_tar.shape}" + ) + kl_divergence = 0.0 + for i in range(log_probs_ref.shape[0]): + log_probs_ref_i = np.array(log_probs_ref[i]) + log_probs_tar_i = np.array(log_probs_tar[i]) + prob_ref_i = np.exp(log_probs_ref_i) + kl_divergence += np.sum(prob_ref_i * np.abs(log_probs_ref_i - log_probs_tar_i)) + kl_divergence = kl_divergence / log_probs_ref.shape[0] + debug_print(f"KL divergence computed: {kl_divergence}") + return kl_divergence + + +def to_serializable(obj): + """ + Recursively convert numpy and torch types to native Python types for JSON serialization. + + Args: + obj: Object to convert (dict, list, tuple, np.ndarray, torch.Tensor, etc.). + + Returns: + Converted object with native Python types (int, float, list, dict, tuple). + """ + if isinstance(obj, dict): + return {k: to_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [to_serializable(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(to_serializable(v) for v in obj) + elif isinstance(obj, (np.integer,)): + return int(obj) + elif isinstance(obj, (np.floating,)): + return float(obj) + elif isinstance(obj, (np.ndarray,)): + return obj.tolist() + elif isinstance(obj, (torch.Tensor,)): + return obj.detach().cpu().tolist() + else: + return obj + + +def compute_unified_comparison(model_logits_list, output_file): + """ + Compute pairwise KL divergence between all models and save results to JSON. + + This function performs an all-vs-all comparison of the provided models by computing + KL divergence for each chunk and averaging across all chunks. Results are saved + in a structured JSON format. + + Args: + model_logits_list (list): List of tuples (model_name, model_data) where: + - model_name (str): Identifier for the model (e.g., "hf_baseline", "cuda_1") + - model_data (dict): Dictionary containing: + - 'logits': List of numpy arrays (one per chunk) + - 'total_chunks': Number of chunks + - 'seq_len': Sequence length + - 'model_path': Path to model + - 'chunk_info': Chunk position info + output_file (str): Path to save the JSON results file. + + """ + print("\n[INFO] Computing unified KL divergence comparison...") + debug_print(f"Number of models to compare: {len(model_logits_list)}") + debug_print(f"Model names: {[name for name, _ in model_logits_list]}") + + # Validate compatibility + reference_data = model_logits_list[0][1] # Use first model as reference for validation + total_chunks = reference_data["total_chunks"] + seq_len = reference_data["seq_len"] + + debug_print(f"Reference model: {model_logits_list[0][0]}") + debug_print(f"Reference total_chunks: {total_chunks}, seq_len: {seq_len}") + + for model_name, data in model_logits_list: + debug_print( + f"Validating {model_name}: chunks={data['total_chunks']}, seq_len={data['seq_len']}" + ) + if data["total_chunks"] != total_chunks: + debug_print( + f"[WARNING] Chunk count mismatch in {model_name}: {data['total_chunks']} vs {total_chunks}" + ) + if data["seq_len"] != seq_len: + debug_print( + f"[WARNING] Sequence length mismatch in {model_name}: {data['seq_len']} vs {seq_len}" + ) + + print(f"[INFO] Computing KL divergences for {total_chunks} chunks...") + + # Process each chunk and compute all pairwise KL divergences + chunk_results = [] + pairwise_totals = {} + + # Initialize totals for all pairs + debug_print("Initializing pairwise comparisons:") + for i in range(len(model_logits_list)): + for j in range(i + 1, len(model_logits_list)): + model1_name = model_logits_list[i][0] + model2_name = model_logits_list[j][0] + pair_key = f"{model1_name}_vs_{model2_name}" + pairwise_totals[pair_key] = 0.0 + debug_print(f" Pair: {pair_key}") + + for chunk_idx in range(total_chunks): + debug_print(f"[PROGRESS] Processing chunk {chunk_idx + 1}/{total_chunks}...") + + chunk_result = { + "chunk_id": int(chunk_idx + 1), + "begin_loc": int(reference_data["chunk_info"][chunk_idx]["begin_loc"]), + "end_loc": int(reference_data["chunk_info"][chunk_idx]["end_loc"]), + "kl_divergences": {}, + } + + # Get logits for this chunk from all models + chunk_logits = [] + for model_name, model_data in model_logits_list: + logits = model_data["logits"][chunk_idx] + debug_print(f" {model_name} logits shape: {getattr(logits, 'shape', type(logits))}") + chunk_logits.append((model_name, logits)) + + # Find minimum sequence length for this chunk + min_seq_len = min(getattr(logits, "shape", [None, 0])[1] for _, logits in chunk_logits) + # Assume all have same vocab size + vocab_size = min(getattr(logits, "shape", [None, None, 0])[2] for _, logits in chunk_logits) + debug_print(f" Min seq len: {min_seq_len}, Vocab size: {vocab_size}") + + # Trim all logits to matching dimensions + trimmed_logits = [] + for model_name, logits in chunk_logits: + # Ensure logits is a numpy array + arr = np.array(logits) + trimmed = arr[:, :min_seq_len, :vocab_size] + debug_print(f" Trimmed {model_name} from {arr.shape} to {trimmed.shape}") + trimmed_logits.append((model_name, trimmed)) + + # Convert logits to log probabilities for all models + log_probs_list = [] + for model_name, logits in trimmed_logits: + logits_tensor = torch.tensor(logits, dtype=torch.float32) + log_probs = torch.nn.functional.log_softmax(logits_tensor, dim=2).cpu().numpy() + log_probs = np.squeeze(log_probs, axis=0) + log_probs_list.append((model_name, log_probs)) + + # Compute all pairwise KL divergences for this chunk + for i in range(len(log_probs_list)): + for j in range(i + 1, len(log_probs_list)): + model1_name, log_probs1 = log_probs_list[i] + model2_name, log_probs2 = log_probs_list[j] + + chunk_kl = compute_kl_divergence_from_logits(log_probs1, log_probs2) + pair_key = f"{model1_name}_vs_{model2_name}" + + # Instead of assigning to an object (which is not type-checked as dict), use dict update + kl_divergences: dict = chunk_result.get("kl_divergences", {}) + kl_divergences[pair_key] = float(chunk_kl) + chunk_result["kl_divergences"] = kl_divergences + pairwise_totals[pair_key] += chunk_kl + + debug_print(f" {pair_key}: {chunk_kl:.6f}") + + chunk_results.append(chunk_result) + + # Calculate average KL divergences + num_chunks = len(chunk_results) + pairwise_averages = {pair: total / num_chunks for pair, total in pairwise_totals.items()} + + debug_print("\nFinal KL divergence totals:") + for pair, total in pairwise_totals.items(): + debug_print(f" {pair}: total={total:.6f}, avg={pairwise_averages[pair]:.6f}") + + # Prepare results + results = { + "models": {model_name: str(data["model_path"]) for model_name, data in model_logits_list}, + "total_chunks": int(num_chunks), + "sequence_length": int(seq_len), + "max_context_length": int(reference_data["max_context_length"]), + "kl_divergences": { + pair: {"total": float(pairwise_totals[pair]), "average": float(pairwise_averages[pair])} + for pair in pairwise_totals + }, + "chunk_results": chunk_results, + "computation_timestamp": datetime.now().isoformat(), + "summary": { + "interpretation": "Lower KL divergence indicates more similar model outputs", + "baseline_reference": model_logits_list[0][0], + "pairwise_averages": {pair: float(avg) for pair, avg in pairwise_averages.items()}, + "chunks_processed": int(num_chunks), + }, + } + + return results + + +def validate_inputs(hf_model, ep_path_pairs): + """ + Validate that all model paths exist and execution providers are supported. + + Args: + hf_model (str or None): Path to Hugging Face model (optional). + ep_path_pairs (list): List of (execution_provider, model_path) tuples. + + Returns: + bool: True if all inputs are valid, False otherwise. + """ + # Check HF model path (only if provided and it looks like a local path) + # If it doesn't exist locally, assume it's a HF model name to be downloaded + if hf_model and os.path.exists(hf_model): + # Verify it's a valid directory + if not os.path.isdir(hf_model): + print(f"[ERROR] Hugging Face model path is not a directory: {hf_model}") + return False + + # Check execution providers and paths + for ep, path in ep_path_pairs: + if ep not in EP_PACKAGE_MAP: + print(f"[ERROR] Unsupported execution provider: {ep}") + print(f"[ERROR] Supported providers: {list(EP_PACKAGE_MAP.keys())}") + return False + + if not os.path.exists(path): + print(f"[ERROR] Model path for {ep} does not exist: {path}") + return False + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Generic model comparison: HF baseline vs ONNX Runtime GenAI models (or ONNX-only comparison)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Compare HF vs CUDA model + python compute_kl_divergence.py --hf_model "F:\\shared\\Llama-3.1-8B-Instruct" + --ep cuda --path "G:\\models\\cuda_model" --output "hf_vs_cuda.json" + + # Compare HF vs CUDA model (download from Hugging Face) + python compute_kl_divergence.py --hf_model "meta-llama/Llama-3.1-8B-Instruct" + --ep cuda --path "G:\\models\\cuda_model" --output "hf_vs_cuda.json" + + # Compare HF vs CUDA vs DirectML models + python compute_kl_divergence.py --hf_model "F:\\shared\\Llama-3.1-8B-Instruct" + --ep cuda --path "G:\\models\\cuda_model" + --ep directml --path "G:\\models\\directml_model" + --output "hf_vs_cuda_vs_directml.json" + + # Compare multiple models with same EP (e.g., different CUDA model variants) + python compute_kl_divergence.py --hf_model "F:\\shared\\Llama-3.1-8B-Instruct" + --ep cuda --path "G:\\models\\cuda_fp16" + --ep cuda --path "G:\\models\\cuda_int4" + --ep directml --path "G:\\models\\directml_model" + --output "multi_model_comparison.json" + + # Compare two ONNX models with same EP (no HF model needed, uses optimized same_ep script) + python compute_kl_divergence.py --ep cuda --path "G:\\models\\cuda_fp16" + --ep cuda --path "G:\\models\\cuda_int4" + --output "cuda_fp16_vs_int4.json" + + # Compare ONNX models with different EPs (no HF model needed) + python compute_kl_divergence.py --ep cuda --path "G:\\models\\cuda_model" + --ep directml --path "G:\\models\\directml_model" + --output "cuda_vs_directml.json" + + # Compare multiple ONNX models with mixed EPs (no HF model needed) + python compute_kl_divergence.py --ep cuda --path "G:\\models\\cuda_fp16" + --ep cuda --path "G:\\models\\cuda_int4" + --ep directml --path "G:\\models\\directml_model" + --output "multi_onnx_comparison.json" + +Supported execution providers: cuda, directml, cpu +Note: Multiple models with same EP are supported and will be named ep_1, ep_2, etc. +Note: HF model is now optional - you can compare ONNX models directly +Note: If no HF model is provided and exactly 2 models with same EP are given, + the script will automatically use KL_divergence_metrics_same_ep.py for optimal performance + """, + ) + + parser.add_argument( + "--hf_model", + required=False, + default=None, + help="Path to Hugging Face baseline model (optional - can compare ONNX models directly)", + ) + parser.add_argument( + "--ep", + action="append", + required=True, + help="Execution provider (can be specified multiple times)", + ) + parser.add_argument( + "--path", + action="append", + required=True, + help="Model path (must match order of --ep arguments)", + ) + parser.add_argument("--output", required=True, help="Output JSON file for results") + parser.add_argument( + "--device", + default="cuda", + choices=["cuda", "cpu"], + help="Device for HF model inference (default: cuda)", + ) + parser.add_argument( + "--keep_logits", action="store_true", help="Keep extracted logits files after comparison" + ) + parser.add_argument("--debug", action="store_true", help="Enable verbose debug output") + + args = parser.parse_args() + + # Set global debug flag + global DEBUG + DEBUG = args.debug + + # Validate arguments + if len(args.ep) != len(args.path): + print("[ERROR] Number of --ep arguments must match number of --path arguments") + return 1 + + debug_print(f"Execution providers: {args.ep}") + debug_print(f"Model paths: {args.path}") + + # No limit on number of models - support any number of EP/path combinations + + ep_path_pairs = list(zip(args.ep, args.path)) + debug_print(f"EP-Path pairs: {ep_path_pairs}") + + # Check if we should delegate to KL_divergence_metrics_same_ep.py + # Condition: No HF model provided, exactly 2 models, and same execution provider + same_ep = ep_path_pairs[0][0] == ep_path_pairs[1][0] if len(ep_path_pairs) == 2 else "N/A" + debug_print( + f"Checking delegation conditions: hf_model={args.hf_model}, " + f"num_models={len(ep_path_pairs)}, same_ep={same_ep}" + ) + if ( + args.hf_model is None + and len(ep_path_pairs) == 2 + and ep_path_pairs[0][0] == ep_path_pairs[1][0] + ): + print("=" * 80) + print("DETECTED: Two ONNX models with same EP, no HF model") + print("Delegating to KL_divergence_metrics_same_ep.py") + print("=" * 80) + debug_print("Delegation conditions met - calling KL_divergence_metrics_same_ep.py") + print(f"Reference Model: {ep_path_pairs[0][1]}") + print(f"Target Model: {ep_path_pairs[1][1]}") + print(f"Execution Provider: {ep_path_pairs[0][0].upper()}") + print("=" * 80) + + # Install the correct ONNX Runtime package for this EP + ep = ep_path_pairs[0][0] + print(f"\n[INFO] Ensuring {ep.upper()} environment is set up...") + debug_print(f"Installing {EP_PACKAGE_MAP[ep]} for same_ep script") + uninstall_onnxruntime_packages() + if not install_package(EP_PACKAGE_MAP[ep]): + print(f"[ERROR] Failed to install {EP_PACKAGE_MAP[ep]}") + return 1 + debug_print(f"Successfully set up {ep.upper()} environment") + + # Call KL_divergence_metrics_same_ep.py + python_exe = get_python_executable() + cmd = [ + python_exe, + "KL_divergence_metrics_same_ep.py", + "--reference_model", + ep_path_pairs[0][1], + "--target_model", + ep_path_pairs[1][1], + ] + + print("\n[INFO] Running KL_divergence_metrics_same_ep.py...") + result = subprocess.run(cmd, shell=False) + + if result.returncode == 0: + print("\n[SUCCESS] KL divergence computation completed successfully") + else: + print("\n[ERROR] KL divergence computation failed") + + return result.returncode + + # Validate inputs + if not validate_inputs(args.hf_model, ep_path_pairs): + return 1 + + print("=" * 80) + if args.hf_model: + print("GENERIC MODEL COMPARISON (with HF baseline)") + else: + print("ONNX MODEL COMPARISON (no HF baseline)") + print("=" * 80) + if args.hf_model: + print(f"Hugging Face Model: {args.hf_model}") + for ep, path in ep_path_pairs: + print(f"{ep.upper()} Model: {path}") + print(f"Output: {args.output}") + if args.hf_model: + print(f"Device for HF: {args.device}") + print("=" * 80) + + start_time = time.time() + + try: + # Store all model logits data + model_logits_list = [] + + # Step 1: Extract logits from HF model (if provided) + if args.hf_model: + debug_print("\n Hugging Face Baseline Extraction") + hf_logits_data = extract_hf_logits_subprocess(args.hf_model, args.device) + model_logits_list.append(("huggingface", hf_logits_data)) + + # Step 2: Extract logits from each ONNX model + current_ep = None # Track current installed EP to avoid unnecessary reinstalls + + for i, (ep, path) in enumerate(ep_path_pairs): + # Create unique model name for same EP models + model_name = ( + f"{ep}_{i + 1}" + if ep_path_pairs.count((ep, path)) > 1 + or sum(1 for x in ep_path_pairs if x[0] == ep) > 1 + else ep + ) + + debug_print( + f"Processing model {i + 1}/{len(ep_path_pairs)}: ep={ep}, path={path}, model_name={model_name}" + ) + print(f"\n[INFO] Processing {ep.upper()} model ({i + 1}/{len(ep_path_pairs)})") + debug_print(f"Path: {path}") + + # Package management - only reinstall if EP changed + if current_ep != ep: + print(f"[INFO] Switching to {ep.upper()} environment...") + debug_print(f"EP changed from {current_ep} to {ep}") + debug_print("Uninstalling existing ONNX Runtime packages") + uninstall_onnxruntime_packages() + + debug_print(f"Installing {EP_PACKAGE_MAP[ep]}") + if not install_package(EP_PACKAGE_MAP[ep]): + print(f"[ERROR] Failed to install {EP_PACKAGE_MAP[ep]}") + return 1 + current_ep = ep + debug_print(f"Successfully switched to {ep} environment") + else: + debug_print(f"Reusing {ep.upper()} environment (already installed)") + debug_print(f"EP unchanged: {ep}") + + # Extract logits + debug_print(f"Extracting logits for {model_name}") + onnx_logits_data = extract_onnx_logits_subprocess(path, ep) + model_logits_list.append((model_name, onnx_logits_data)) + debug_print( + f"Added {model_name} to model_logits_list (total models: {len(model_logits_list)})" + ) + + # Step 3: Compute unified comparison + print("\n[INFO] Computing KL Divergences...") + debug_print(f"Total models for comparison: {len(model_logits_list)}") + debug_print(f"Model list: {[name for name, _ in model_logits_list]}") + results = compute_unified_comparison(model_logits_list, args.output) + + end_time = time.time() + + # Add timing information + results["computation_time_seconds"] = float(end_time - start_time) + + # Step 4: Save results + print(f"\n[INFO] Saving results to: {args.output}") + debug_print(f"Results keys: {results.keys()}") + debug_print("Serializing results to JSON") + with open(args.output, "w") as f: + json.dump(to_serializable(results), f, indent=2) + debug_print(f"Results saved successfully to {args.output}") + + # Step 5: Generate summary + print("\n" + "=" * 80) + print("COMPARISON COMPLETED SUCCESSFULLY!") + print("=" * 80) + print(f"Total execution time: {end_time - start_time:.2f} seconds") + print(f"Results saved to: {args.output}") + print() + print("MODELS COMPARED:") + for model_name, model_path in results["models"].items(): + print(f" {model_name.upper()}: {model_path}") + print() + + # Display KL divergence results + if "kl_divergences" in results: + print("KL DIVERGENCE RESULTS:") + kl_divs = results["kl_divergences"] + for comparison, values in kl_divs.items(): + comp_name = comparison.replace("_", " ").upper() + print(f" {comp_name}: {values['average']:.6f}") + + print() + debug_print("INTERPRETATION:") + debug_print("- Lower KL divergence = more similar model outputs") + debug_print("- All comparisons are pairwise between models") + debug_print("- Values show how much models differ from each other") + print("=" * 80) + + # Optional: Save logits to files if requested + if args.keep_logits: + debug_print("\n[SAVE] Saving logits to individual files...") + for model_name, model_data in model_logits_list: + logits_file = f"logits_{model_name}.pkl" + with open(logits_file, "wb") as f: + pickle.dump(model_data, f) + debug_print(f" Saved: {logits_file}") + debug_print("Note: Logits files saved for future reuse.") + + return 0 + + except KeyboardInterrupt: + print("\n[INFO] Comparison interrupted by user") + return 1 + except Exception as e: + debug_print(f"[ERROR] Unexpected error: {e}") + import traceback + + if DEBUG: + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py b/examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py new file mode 100644 index 000000000..54db48e2a --- /dev/null +++ b/examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Extract logits from ONNX Runtime GenAI models and save them to file. +Follows the same logic as KL_divergence_metrics.py but saves logits instead of computing KL divergence. +""" + +import argparse +import os +import pickle +import time + +import numpy as np +import torch +from datasets import load_dataset + +# Lazy import for onnxruntime_genai - will be imported when needed +og = None + +DEBUG = False + + +def debug_print(message): + """ + Print debug message only if DEBUG flag is enabled. + + Args: + message (str): Debug message to print. + """ + if DEBUG: + print(f"[DEBUG] {message}") + + +def get_wikitext2(): + """ + Load and concatenate the WikiText-2 test dataset. + + Returns: + str: Concatenated text from all samples, separated by double newlines. + + Note: + Requires HuggingFace CLI authentication. + """ + print("\n[INFO] Loading Wikitext-2 'test' split ...") + test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + print(f"[DATASET] Number of raw samples: {len(test)}") + + # Concatenate all text samples into a single string, separated by double newlines + result = "\n\n".join(text for text in test["text"]) + print(f"[DATASET] Total text length: {len(result)} characters") + return result + + +def extract_logits_from_model(model_path, provider="cuda"): + """ + Extract logits from an ONNX Runtime GenAI model on WikiText-2 dataset. + + Uses a sliding window approach to process the dataset in chunks and extract + model logits for each chunk. + + Args: + model_path (str): Path to the ONNX Runtime GenAI model directory. + provider (str, optional): Execution provider hint ('cuda', 'directml', 'cpu'). + The actual provider is determined by the installed package. + Defaults to "cuda". + + """ + print(f"\n[INFO] Loading model from: {model_path}") + print(f"[INFO] Using provider: {provider}") + + try: + # Import onnxruntime_genai when needed + global og + try: + import onnxruntime_genai as og_module + + og = og_module + print("[INFO] Successfully imported onnxruntime_genai") + except ImportError as e: + raise ImportError( + f"Failed to import onnxruntime_genai: {e}. " + f"Make sure the correct package is installed for provider '{provider}'" + ) + + # Load model and tokenizer with explicit provider configuration + print(f"[INFO] Creating model with provider: {provider}") + + # For ONNX Runtime GenAI, the execution provider is determined by the installed package + # We don't need to explicitly set it in the model creation + model = og.Model(model_path) + tokenizer = og.Tokenizer(model) + print("[INFO] Successfully loaded model and tokenizer") + + # Print available providers for debugging + try: + import onnxruntime as ort + + available_providers = ort.get_available_providers() + print(f"[INFO] Available ONNX Runtime providers: {available_providers}") + except Exception: + print("[INFO] Could not check available ONNX Runtime providers") + + except Exception as e: + print(f"[ERROR] Failed to load model: {e}") + print( + f"[ERROR] Make sure the correct onnxruntime-genai package is installed for provider '{provider}'" + ) + + # Add more detailed error information + try: + import onnxruntime as ort + + available_providers = ort.get_available_providers() + print(f"[ERROR] Currently available providers: {available_providers}") + + if provider == "cuda" and "CUDAExecutionProvider" not in available_providers: + print("[ERROR] CUDA provider not available. Install onnxruntime-genai-cuda") + elif provider == "directml" and "DmlExecutionProvider" not in available_providers: + print("[ERROR] DirectML provider not available. Install onnxruntime-genai-directml") + except Exception: + pass + + raise + + # Parameters (same as KL_divergence_metrics.py) + max_context_length = 1024 + + print(f"[INFO] Max context length: {max_context_length}") + + # Load dataset + dataset = get_wikitext2() + + # Tokenize + print("[INFO] Tokenizing dataset...") + input_ids = tokenizer.encode_batch([dataset]) + if isinstance(input_ids, dict) and "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + if hasattr(input_ids, "as_numpy"): + input_ids = input_ids.as_numpy() + input_ids = np.array(input_ids) + + # Ensure input_ids is 2D (batch, seq_len) + if input_ids.ndim == 1: + input_ids = np.expand_dims(input_ids, 0) + + # Convert to torch tensor + input_ids = torch.tensor(input_ids, dtype=torch.long) + seq_len = int(input_ids.shape[1]) + + print(f"[INFO] Input sequence length: {seq_len}") + + # Store all logits + all_logits = [] + chunk_info = [] + + # Process chunks following the same logic as KL_divergence_metrics.py + for chunk_count, begin_loc in enumerate( + range(0, min(50 * max_context_length, seq_len), max_context_length), 1 + ): + if DEBUG: + print(f"[PROGRESS] Processing chunk {chunk_count}...") + + end_loc = min(begin_loc + max_context_length, seq_len) + + # Extract the current chunk of input tokens + input_ids_chunk = input_ids[:, begin_loc:end_loc].clone() + if DEBUG: + print(f" Chunk range: {begin_loc} to {end_loc}") + print(f" Chunk shape: {input_ids_chunk.shape}") + + # Set up generator parameters for deterministic generation (no sampling) + params = og.GeneratorParams(model) + params.set_search_options( + max_length=int(input_ids_chunk.shape[1]), do_sample=False, early_stopping=False + ) + + # Create generator and append input tokens + generator = og.Generator(model, params) + generator.append_tokens(input_ids_chunk.numpy()) + + # Run the model forward pass + with torch.no_grad(): + try: + generator.generate_next_token() + except Exception as e: + print(f"[ERROR] generate_next_token() failed: {e}") + break + + # Get logits output from the model + logits = generator.get_output("logits") + if hasattr(logits, "as_numpy"): + logits = logits.as_numpy() + if DEBUG: + print(f" Logits shape: {logits.shape}") + + # Convert to torch tensor and store + logits_tensor = torch.tensor(logits, dtype=torch.float32) + all_logits.append(logits_tensor.cpu().numpy()) + + # Store chunk information + chunk_info.append( + { + "chunk_id": chunk_count, + "begin_loc": begin_loc, + "end_loc": end_loc, + "shape": logits_tensor.shape, + } + ) + + print(f"[INFO] Extracted logits from {len(all_logits)} chunks") + + return { + "logits": all_logits, + "chunk_info": chunk_info, + "model_path": model_path, + "provider": provider, + "seq_len": seq_len, + "max_context_length": max_context_length, + "total_chunks": len(all_logits), + } + + +def main(): + """ + Command-line entry point for extracting logits from ONNX Runtime GenAI models. + + Extracts model logits on WikiText-2 dataset and saves them to a pickle file + for later KL divergence comparison. + + """ + parser = argparse.ArgumentParser(description="Extract logits from ONNX Runtime GenAI model") + parser.add_argument("--model_path", required=True, help="Path to model directory") + parser.add_argument("--output_file", required=True, help="Output pickle file path") + parser.add_argument( + "--provider", + default="cuda", + choices=["cuda", "directml", "cpu"], + help="Execution provider (cuda, directml, or cpu)", + ) + parser.add_argument("--debug", action="store_true", help="Enable verbose debug output") + + args = parser.parse_args() + + # Set global debug flag + global DEBUG + DEBUG = args.debug + + # Validate model directory exists + if not os.path.exists(args.model_path): + print(f"Error: Model directory does not exist: {args.model_path}") + return 1 + + try: + # Extract logits + start_time = time.time() + logits_data = extract_logits_from_model(args.model_path, args.provider) + end_time = time.time() + + # Save to file + print(f"\n[INFO] Saving logits to: {args.output_file}") + with open(args.output_file, "wb") as f: + pickle.dump(logits_data, f) + + print(f"[INFO] Extraction completed in {end_time - start_time:.2f} seconds") + print(f"[INFO] Total chunks processed: {logits_data['total_chunks']}") + print(f"[INFO] Model: {args.model_path}") + print(f"[INFO] Provider: {args.provider}") + print(f"[INFO] Output: {args.output_file}") + + return 0 + + except Exception as e: + print(f"[ERROR] Failed to extract logits: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py b/examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py new file mode 100644 index 000000000..a209dd153 --- /dev/null +++ b/examples/windows/accuracy_benchmark/kl_divergence_metrics/extract_logits_hf.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Extract logits from Hugging Face model using transformers library. +Follows the same logic as extract_logits.py but uses transformers instead of ONNX Runtime GenAI. +""" + +import argparse +import pickle +import time + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +DEBUG = False + + +def debug_print(message): + """ + Print debug message only if DEBUG flag is enabled. + + Args: + message (str): Debug message to print. + """ + if DEBUG: + print(f"[DEBUG] {message}") + + +def get_wikitext2(): + """ + Load and concatenate the WikiText-2 test dataset. + + Returns: + str: Concatenated text from all samples, separated by double newlines. + + Note: + Requires HuggingFace CLI authentication. + """ + print("Loading Wikitext-2 dataset...") + test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + debug_print(f"Number of raw samples: {len(test)}") + + # Concatenate all text samples into a single string, separated by double newlines + result = "\n\n".join(text for text in test["text"]) + debug_print(f"Total text length: {len(result)} characters") + print(f"Dataset loaded ({len(result):,} characters)") + return result + + +def extract_logits_from_hf_model(model_path, device="cuda"): + """ + Extract logits from a Hugging Face transformer model on WikiText-2 dataset. + + Uses a sliding window approach to process the dataset in chunks and extract + model logits for each chunk using the transformers library. + + Args: + model_path (str): Path to the Hugging Face model directory or model name. + device (str, optional): Device for inference ('cuda' or 'cpu'). Defaults to "cuda". + + """ + print("Loading Hugging Face model...") + debug_print(f"Model path: {model_path}") + debug_print(f"Device: {device}") + + # Load tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.float16 if device == "cuda" else torch.float32, + device_map="auto" if device == "cuda" else None, + ) + + if device == "cpu": + model = model.to("cpu") + + model.eval() + + # Add padding token if it doesn't exist + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if getattr(getattr(model, "config", None), "pad_token_id", None) is None: + model.config.pad_token_id = tokenizer.pad_token_id + # Parameters (same as extract_logits.py) + max_context_length = 1024 + + print(f"[INFO] Max context length: {max_context_length}") + + # Load dataset + dataset = get_wikitext2() + + # Tokenize + print("[INFO] Tokenizing dataset...") + inputs = tokenizer(dataset, return_tensors="pt", truncation=False) + input_ids = inputs["input_ids"] + + seq_len = int(input_ids.shape[1]) + + print(f"[INFO] Input sequence length: {seq_len}") + + # Store all logits + all_logits = [] + chunk_info = [] + + # Process chunks following the same logic as extract_logits.py + for chunk_count, begin_loc in enumerate( + range(0, min(50 * max_context_length, seq_len), max_context_length), 1 + ): + if DEBUG: + print(f"[PROGRESS] Processing chunk {chunk_count}...") + + end_loc = min(begin_loc + max_context_length, seq_len) + + # Extract the current chunk of input tokens + input_ids_chunk = input_ids[:, begin_loc:end_loc] + if DEBUG: + print(f" Chunk range: {begin_loc} to {end_loc}") + print(f" Chunk shape: {input_ids_chunk.shape}") + + # Move to device + input_ids_chunk = input_ids_chunk.to(model.device) + # Run the model forward pass + with torch.no_grad(): + try: + outputs = model(input_ids_chunk) + logits = outputs.logits + if DEBUG: + print(f" Logits shape: {logits.shape}") + + # Store logits (convert to CPU and numpy) + logits_numpy = logits.cpu().numpy() + all_logits.append(logits_numpy) + + # Store chunk information + chunk_info.append( + { + "chunk_id": chunk_count, + "begin_loc": begin_loc, + "end_loc": end_loc, + "shape": logits_numpy.shape, + } + ) + + except Exception as e: + print(f"[ERROR] Model forward pass failed: {e}") + break + + print(f"[INFO] Extracted logits from {len(all_logits)} chunks") + + return { + "logits": all_logits, + "chunk_info": chunk_info, + "model_path": model_path, + "provider": "huggingface_transformers", + "device": device, + "seq_len": seq_len, + "max_context_length": max_context_length, + "total_chunks": len(all_logits), + } + + +def main(): + """ + Command-line entry point for extracting logits from Hugging Face models. + + Extracts model logits on WikiText-2 dataset and saves them to a pickle file + for later KL divergence comparison with ONNX Runtime GenAI models. + + """ + parser = argparse.ArgumentParser(description="Extract logits from Hugging Face model") + parser.add_argument("--model_path", required=True, help="Path to Hugging Face model directory") + parser.add_argument("--output_file", required=True, help="Output pickle file path") + parser.add_argument( + "--device", default="cuda", choices=["cuda", "cpu"], help="Device to use (cuda or cpu)" + ) + parser.add_argument("--debug", action="store_true", help="Enable verbose debug output") + + args = parser.parse_args() + + # Set global debug flag + global DEBUG + DEBUG = args.debug + + try: + # Extract logits + start_time = time.time() + logits_data = extract_logits_from_hf_model(args.model_path, args.device) + end_time = time.time() + + # Save to file + print(f"\n[INFO] Saving logits to: {args.output_file}") + with open(args.output_file, "wb") as f: + pickle.dump(logits_data, f) + + print(f"[INFO] Extraction completed in {end_time - start_time:.2f} seconds") + print(f"[INFO] Total chunks processed: {logits_data['total_chunks']}") + print(f"[INFO] Model: {args.model_path}") + print(f"[INFO] Device: {args.device}") + print(f"[INFO] Output: {args.output_file}") + + return 0 + + except Exception as e: + print(f"[ERROR] Failed to extract logits: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt b/examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt new file mode 100644 index 000000000..81ae45432 --- /dev/null +++ b/examples/windows/accuracy_benchmark/kl_divergence_metrics/requirements.txt @@ -0,0 +1,18 @@ +--extra-index-url https://download.pytorch.org/whl/cu129 +accelerate +coloredlogs +datasets +flatbuffers +huggingface_hub[cli] +numpy +onnx +packaging +pandas +protobuf>=5.28.2 +pytest +sentencepiece +sympy +torch>=2.0.0 +torchaudio +torchvision +transformers diff --git a/examples/windows/accuracy_benchmark/perplexity_metrics/README.md b/examples/windows/accuracy_benchmark/perplexity_metrics/README.md new file mode 100644 index 000000000..aedb11c28 --- /dev/null +++ b/examples/windows/accuracy_benchmark/perplexity_metrics/README.md @@ -0,0 +1,230 @@ +# Perplexity Evaluation Tool + +## Overview + +This tool evaluates the perplexity of ONNX Runtime GenAI models and HuggingFace models using the [WikiText-2](https://huggingface.co/datasets/wikitext) dataset. Perplexity is a standard metric for language models: lower values indicate better predictive performance. + +## Attribution + +This script is originally based on [perplexity_metrics.py](https://github.com/microsoft/onnxruntime-genai/blob/main/tools/python/model_validation/perplexity_metrics.py) from the Microsoft ONNX Runtime GenAI repository. It has been modified to handle: + +- Multiple context lengths +- Configurable chunk sizes +- Enhanced prefill chunking handling +- HuggingFace model evaluation support + +## Scripts + +- **`perplexity_metrics.py`**: Core evaluation logic for computing perplexity. +- **`run_perplexity.py`**: Command-line utility for evaluating one or more models and saving results to CSV. + +## Requirements + +- Python 3.8+ +- CUDA 12.x (if using GPU acceleration) +- Install dependencies: + + **For CUDA 12.x (recommended for CUDA 12.1-12.9):** + + ```bash + pip install -r requirements.txt + ``` + +- [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) login is required to access the WikiText-2 dataset: + + ```bash + huggingface-cli login + ``` + +## Supported Models + +### ONNX Runtime GenAI Models + +- Any ONNX Runtime GenAI model exported with a compatible `genai_config.json` and tokenizer. +- Supported architectures include: Gemma, Llama, Mistral, Phi (language + vision), Qwen. +- Supported execution providers: CPU, DirectML, CUDA, NvTensorRtRtx. + +### HuggingFace Models + +- Any HuggingFace causal language model (e.g., `meta-llama/Llama-2-7b-hf`, `gpt2`, `mistralai/Mistral-7B-v0.1`). +- Models are automatically downloaded from the HuggingFace Hub if not cached locally. +- Supports custom data types (float16, bfloat16, float32) for efficient inference. + +## How to Run + +### Evaluate ONNX Models + +#### Single Model + +```bash +python run_perplexity.py --models /path/to/model +``` + +#### Multiple Models + +```bash +python run_perplexity.py --models /path/to/model1 /path/to/model2 +``` + +#### Custom Input Sequence Length(s) + +You can specify the input sequence length(s) to evaluate using the `--i` argument: + +```bash +python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 +``` + +#### Custom Prefill Chunk Size + +You can specify the prefill chunk size to evaluate using the `--chunk_size` argument: + +```bash +python run_perplexity.py --models /path/to/model --i 1024,2048,4096,8192,12288 --chunk_size=1024 +``` + +### Evaluate HuggingFace Models + +#### Basic HuggingFace Model Evaluation + +```bash +python run_perplexity.py --hf_model meta-llama/Llama-2-7b-hf --i 1024 +``` + +#### With Custom Data Type (Recommended for Performance) + +```bash +python run_perplexity.py --hf_model meta-llama/Llama-2-7b-hf --hf_dtype float16 --i 1024 +``` + +#### With Multiple Input Lengths + +```bash +python run_perplexity.py --hf_model meta-llama/Llama-2-7b-hf --hf_dtype float16 --i 1024,2048,4096 +``` + +#### On CPU (if no GPU available) + +```bash +python run_perplexity.py --hf_model gpt2 --hf_device cpu --i 1024 +``` + +### Evaluate Both ONNX and HuggingFace Models Together + +Compare ONNX and HuggingFace models side-by-side: + +```bash +python run_perplexity.py \ + --models /path/to/onnx_model \ + --hf_model meta-llama/Llama-2-7b-hf \ + --hf_dtype float16 \ + --i 1024 \ + --output comparison_results.csv +``` + +### HuggingFace Model Arguments + +- `--hf_model`: HuggingFace model name or local path (e.g., `meta-llama/Llama-2-7b-hf`) +- `--hf_device`: Device to run on (`cuda`, `cpu`, `cuda:0`, etc.) - default: `cuda` +- `--hf_dtype`: Data type for model weights - options: `float16`, `bfloat16`, `float32`, `fp16`, `bf16`, `fp32` - default: model default (usually float32) + +### Custom Output File + +```bash +python run_perplexity.py --models /path/to/model --output results.csv +``` + +## Expected Output + +Expected scores often fall between 2 and 1000; lower is better. See ranges below. + +### Perplexity Configuration Setting (for ONNX models) + +- If **kv_chunking** is enabled in the model configuration (i.e., `"chunk_size"` is present in the `"search"` section of `genai_config.json`), then: + - `max_input_seq_length` is set to **8192** + - `stride` is set to the value of `chunk_size` +- If **kv_chunking** is not enabled (default): + - `max_input_seq_length` is **1024** + - `stride` is **512** + +### For HuggingFace Models + +- Default `max_length` is **1024** +- Default `stride` is **512** (or `chunk_size` if specified) + +### Console Output + +```text +============================================================ +Evaluating HuggingFace model: meta-llama/Llama-2-7b-hf +============================================================ +[INFO] Loading Wikitext-2 'test' split ... +[TOKENIZER] Tokenizing ... + +[RESULT] Perplexity of meta-llama/Llama-2-7b-hf: 5.47 + +HuggingFace perplexity evaluation completed + +============================================================ +Evaluating perplexity for: /path/to/onnx_model +============================================================ +[INFO] Loading Wikitext-2 'test' split ... +[TOKENIZER] Tokenizing ... + +[RESULT] Perplexity of /path/to/onnx_model: 5.48 + +Perplexity evaluation completed successfully +``` + +### CSV Output + +Generated file contains: + +- Model Path (model directory or HuggingFace model name) +- Model Type (ONNX or HuggingFace) +- Input Length +- Perplexity score +- Status (Success/Failed) +- Error details (if any) + +## Debug Mode + +Set `DEBUG = True` in `perplexity_metrics.py` for detailed logs. + +## Typical Perplexity Ranges + +- Excellent: 2-20 +- Good: 20-40 +- OK: 40-80 +- Poor: 100+ + +## Common Use Cases + +### Compare ONNX vs. HuggingFace Model + +Verify that your ONNX exported model has similar perplexity to the original HuggingFace model: + +```bash +python run_perplexity.py \ + --models /path/to/exported_onnx_model \ + --hf_model meta-llama/Llama-2-7b-hf \ + --hf_dtype float16 \ + --i 1024 \ + --output validation_results.csv +``` + +### Evaluate Small Models (for quick testing) + +```bash +python run_perplexity.py --hf_model gpt2 --hf_dtype float16 --i 1024 +``` + +### Benchmark Multiple Quantization Variants + +```bash +python run_perplexity.py \ + --models /path/to/fp16_model /path/to/int8_model /path/to/int4_model \ + --hf_model original/model-name \ + --hf_dtype float16 \ + --i 2048 \ + --output quantization_comparison.csv +``` diff --git a/examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py b/examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py new file mode 100644 index 000000000..52d660276 --- /dev/null +++ b/examples/windows/accuracy_benchmark/perplexity_metrics/perplexity_metrics.py @@ -0,0 +1,554 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: MIT +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# This file is based on perplexity_metrics.py from the ONNX Runtime GenAI project: +# https://github.com/microsoft/onnxruntime-genai/blob/main/tools/python/model_validation/perplexity_metrics.py +# +# Modifications Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Modifications made: +# - Added support for multiple context lengths +# - Added configurable chunk sizes +# - Enhanced prefill chunking handling + +import json +import time + +import numpy as np +import onnxruntime_genai as og +import torch +from datasets import load_dataset + +# Global debug flag - set to True for verbose output +DEBUG = False + + +def calculate_perplexity_hf( + model_name_or_path, max_length=1024, stride=512, device="cuda", torch_dtype=None +): + """ + Evaluate perplexity of a HuggingFace model on the WikiText-2 dataset. + + This function computes perplexity using a sliding window approach similar to the + ONNX Runtime GenAI version, but using native HuggingFace transformers. + + Args: + model_name_or_path (str): HuggingFace model name (e.g., 'meta-llama/Llama-2-7b-hf') + or path to a local model directory. + max_length (int, optional): Maximum input sequence length for evaluation. + Defaults to 1024. + stride (int, optional): Stride for sliding window evaluation. + Defaults to 512. + device (str, optional): Device to run the model on ('cuda', 'cpu', etc.). + Defaults to 'cuda'. + torch_dtype: PyTorch dtype for the model. If None, uses default (float32). + Common options: torch.float16, torch.bfloat16, torch.float32. + + Returns: + float: Computed perplexity score. Lower values indicate better model performance. + + Raises: + ImportError: If transformers package is not installed. + """ + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError as e: + raise ImportError( + "The 'transformers' package is required for HuggingFace model evaluation. " + "Install it with: pip install transformers" + ) from e + + time_start = time.time() + print(f"\n[RUN] === BEGIN calculate_perplexity_hf('{model_name_or_path}') ===") + print(f"[RUN] Loading HuggingFace model from: {model_name_or_path}") + + # Load tokenizer + print("[TOKENIZER] Loading tokenizer ...") + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + + # Set pad_token if not already set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load model + print(f"[MODEL] Loading model on device: {device}") + model_kwargs = {"device_map": device} + if torch_dtype is not None: + model_kwargs["torch_dtype"] = torch_dtype + print(f"[MODEL] Using dtype: {torch_dtype}") + + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs) + model.eval() + + # Load and prepare the evaluation dataset + dataset = get_wikitext2() + print("[TOKENIZER] Tokenizing ...") + + # Tokenize the entire dataset + encodings = tokenizer(dataset, return_tensors="pt", add_special_tokens=True) + input_ids = encodings.input_ids + + if DEBUG: + print(f"[TOKENIZER] Input shape: {input_ids.shape}, dtype: {input_ids.dtype}") + + seq_len = input_ids.size(1) + print(f"[INFO] Full input length: {seq_len}") + print(f"[INFO] max_length: {max_length}, stride: {stride}") + + max_eval_length = seq_len + + # Initialize accumulators for log probabilities + total_log_probs = 0.0 + total_token_count = 0 + prev_end_loc = 0 + + # Slide a window over the input to compute perplexity in chunks + for chunk_idx, begin_loc in enumerate(range(0, max_eval_length, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc + + if DEBUG: + print( + f"\n[LOOP] chunk_idx={chunk_idx} [begin={begin_loc} end={end_loc}] trg_len={trg_len}" + ) + + # Extract the current chunk of input tokens (keep on CPU until needed) + input_ids_chunk = input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids_chunk.clone() + + # Mask context tokens: only predict for last trg_len tokens in chunk + mask = np.ones(target_ids.shape, dtype=bool) + mask[:, :-trg_len] = False + target_ids_masked = target_ids.clone() + target_ids_masked[~torch.from_numpy(mask)] = -100 # -100 is the ignore index + + if DEBUG: + print(f"[MASK] Mask shape: {mask.shape}") + print(f"[TARGET_IDS_MASKED] Target ids masked: {target_ids_masked}") + + # Run the model forward pass without gradient calculation + with torch.no_grad(): + if DEBUG: + print("[INFER] Running model forward pass ...") + + outputs = model(input_ids_chunk) + logits = outputs.logits + + if DEBUG: + print(f"[LOGITS] Shape: {logits.shape}, dtype: {logits.dtype}") + + # Compute log probabilities over vocabulary for each position + log_probs = torch.nn.functional.log_softmax(logits, dim=2).cpu().numpy() + chunk_seq_len = log_probs.shape[1] + + # Language models predict next token: logits[i] predicts token[i+1] + # So we need logits[:-1] to match with target_ids[1:] + if chunk_seq_len > 1: + # Get log probabilities for all positions except the last + pred_log_probs = log_probs[0, :-1, :] # predictions for positions 0 to max_length-2 + # Get the target token ids for positions 1 to max_length-1 + target_ids_shifted = ( + target_ids_masked[0, 1:].cpu().numpy() + ) # targets at positions 1 to max_length-1 + + if DEBUG: + print(f"[TARGET_IDS_SHIFTED] Target ids shifted shape: {target_ids_shifted.shape}") + print(f"[PRED_LOG_PROBS] Pred log probs shape: {pred_log_probs.shape}") + print(f"chunk_seq_len: {chunk_seq_len}") + + # Only include tokens with label != -100 (matching masking) + mask_flat = target_ids_shifted != -100 + valid_indices = np.arange(len(target_ids_shifted))[mask_flat] + valid_targets = target_ids_shifted[mask_flat] + + if DEBUG: + print(f"[VALID_INDICES] Valid indices shape: {valid_indices.shape}") + print(f"[VALID_TARGETS] Valid targets shape: {valid_targets.shape}") + + # Gather the log probabilities for the correct target tokens + valid_log_probs = pred_log_probs[valid_indices, valid_targets] + + if DEBUG: + print(f"[VALID_LOG_PROBS] Valid log probs shape: {valid_log_probs.shape}") + else: + valid_log_probs = np.array([]) + mask_flat = np.array([], dtype=bool) + + # Accumulate log probabilities and token count (same as ONNX) + total_log_probs += float(np.sum(valid_log_probs)) + total_token_count += int(valid_log_probs.size) + + if DEBUG: + print( + f"[LOOP] This chunk: valid tokens={valid_log_probs.size}, sum={np.sum(valid_log_probs)}" + ) + print(f"[TALLY] total_log_probs: {total_log_probs}") + print(f"[TALLY] total_token_count: {total_token_count}") + + # Clear GPU cache to prevent OOM + del ( + outputs, + logits, + log_probs, + pred_log_probs, + input_ids_chunk, + target_ids, + target_ids_masked, + ) + if device == "cuda": + torch.cuda.empty_cache() + + # Update for next chunk + prev_end_loc = end_loc + if end_loc >= max_eval_length: + if DEBUG: + print("[LOOP] Reached evaluation limit.") + break + + # Compute average log probability and perplexity (same as ONNX) + avg_log_prob = total_log_probs / total_token_count + perplexity = np.exp(-avg_log_prob) # Note the negative sign! + + if DEBUG: + print(f"[FINAL] avg_log_prob: {avg_log_prob}") + + print(f"\n[RESULT] Perplexity of {model_name_or_path}: {perplexity}") + print("[RUN] === END calculate_perplexity_hf ===\n") + time_end = time.time() + print(f"[RUN] Time taken: {time_end - time_start:.2f} seconds") + + # Cleanup: Unload model from GPU memory + print("[CLEANUP] Unloading model from GPU...") + del model, tokenizer + if device == "cuda": + torch.cuda.empty_cache() + print("[CLEANUP] Model unloaded") + + return perplexity + + +def get_wikitext2(): + """ + Load and concatenate the WikiText-2 test dataset. + + Returns: + str: Concatenated text from all samples in the WikiText-2 test split, + with samples separated by double newlines. + + Note: + Requires HuggingFace CLI authentication to access the dataset. + """ + # Load the Wikitext-2 test split using HuggingFace datasets + print("\n[INFO] Loading Wikitext-2 'test' split ...") + test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + if DEBUG: + print(f"[DATASET] Number of raw samples: {len(test)}") + for i in range(3): + print(f"[DATASET] Sample[{i}]: {repr(test[i]['text'])[:200]} ...") + # Concatenate all text samples into a single string, separated by double newlines + result = "\n\n".join(text for text in test["text"]) + if DEBUG: + print( + f"[DATASET] Concatenated text preview: {result[:512]!r} ... [total chars: {len(result)}]" + ) + return result + + +def perplexity_eval(model_dir, input_len=1024, chunk_size=None): + """ + Evaluate perplexity of an ONNX Runtime GenAI model on the WikiText-2 dataset. + + This function computes perplexity using a sliding window approach. It supports + both standard evaluation and prefill chunking for longer context lengths. + + Args: + model_dir (str): Path to the ONNX Runtime GenAI model directory. + Must contain genai_config.json and tokenizer files. + input_len (int, optional): Maximum input sequence length for evaluation. + Used as context length when KV chunking is enabled. + Defaults to 1024. + chunk_size (int, optional): Prefill chunk size for prefill chunking. + If provided, overrides the chunk_size in genai_config.json. + When set, enables evaluation with longer context lengths. + Defaults to None. + + Returns: + float: Computed perplexity score. Lower values indicate better model performance. + Typical ranges: 2-20 (excellent), 20-40 (good), 40-80 (ok), 100+ (poor). + + """ + time_start = time.time() + print(f"\n[RUN] === BEGIN perplexity_eval('{model_dir}') ===") + print(f"[RUN] Loading ONNX model from: {model_dir}") + chunking_failed = False + # Load the ONNX model + # Apply chunk_size overlay if provided + config = og.Config(model_dir) + if chunk_size is not None: + search_config = {"chunk_size": int(chunk_size)} + try: + print(f"[CONFIG] Applying chunk_size overlay: {chunk_size}") + config.overlay(json.dumps({"search": search_config})) + print(f"[CONFIG] Successfully applied chunk_size: {chunk_size}") + except Exception as e: + print(f"[WARNING] Failed to apply chunk_size overlay: {e}") + chunking_failed = True + model = og.Model(config) + + if DEBUG: + print("[RUN] Creating tokenizer ...") + # Create the tokenizer for the model + tokenizer = og.Tokenizer(model) + # Load model configuration from JSON file (optional) + model_cfg_json = None + try: + with open(f"{model_dir}/genai_config.json") as file: + model_cfg_json = json.load(file) + if DEBUG: + print( + f"[CONFIG] Model config loaded: {json.dumps(model_cfg_json.get('model', {}), indent=2)}" + ) + except Exception as e: + print(f"[WARNING] Could not read genai_config.json: {e}. Falling back to defaults.") + + max_context_length = 1024 + stride = 512 + kv_chunking_enabled = False + + # Check for chunk_size - prioritize parameter over config file + effective_chunk_size = None + if chunk_size is not None and not chunking_failed: + # Use the provided chunk_size parameter (overlaid) + effective_chunk_size = int(chunk_size) + kv_chunking_enabled = True + if DEBUG: + print(f"[CONFIG] Using provided chunk_size: {effective_chunk_size}") + elif model_cfg_json and "search" in model_cfg_json and "chunk_size" in model_cfg_json["search"]: + # Use chunk_size from existing config file + effective_chunk_size = model_cfg_json["search"]["chunk_size"] + kv_chunking_enabled = True + if DEBUG: + print(f"[CONFIG] Using config file chunk_size: {effective_chunk_size}") + + if DEBUG: + print( + f"[CONFIG] Effective chunk_size: {effective_chunk_size if kv_chunking_enabled else 'disabled'}" + ) + + if kv_chunking_enabled and effective_chunk_size: + if DEBUG: + print(f"[INFO] chunk size: {effective_chunk_size}") + print(f"[INFO] input length: {input_len}") + max_context_length = int(input_len) # Use input_len when chunking is enabled + stride = effective_chunk_size + if DEBUG: + print( + f"[CONFIG] KV chunking enabled with chunk_size: {effective_chunk_size}, input_len: {input_len}" + ) + elif DEBUG: + print(f"[CONFIG] KV chunking disabled, using default stride: {stride}") + + # Set chunk and stride lengths for evaluation + model_context_len = ( + int(model_cfg_json["model"]["context_length"]) + if model_cfg_json + and "model" in model_cfg_json + and "context_length" in model_cfg_json["model"] + else max_context_length + ) + max_length = min(max_context_length, model_context_len) + if DEBUG: + print(f"[INFO] max_length for chunk: {max_length}, stride for sliding window: {stride}") + + # Load and prepare the evaluation dataset + dataset = get_wikitext2() + print("[TOKENIZER] Tokenizing ...") + # Tokenize the entire dataset + input_ids = tokenizer.encode_batch([dataset]) + # Handle possible dict output from tokenizer + if isinstance(input_ids, dict) and "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + # Convert to numpy if needed + if hasattr(input_ids, "as_numpy"): + input_ids = input_ids.as_numpy() + if DEBUG: + print("[TOKENIZER] Used as_numpy()") + input_ids = np.array(input_ids) + if DEBUG: + print(f"[TOKENIZER] Numpy array shape: {input_ids.shape}, dtype: {input_ids.dtype}") + # Ensure input_ids is 2D (batch, seq_len) + if input_ids.ndim == 1: + input_ids = np.expand_dims(input_ids, 0) + if DEBUG: + print(f"[SHAPE] Expanded dims, now: {input_ids.shape}") + + # Convert input_ids to torch tensor + input_ids = torch.tensor(input_ids, dtype=torch.long) + if DEBUG: + print(f"[TENSOR] Torch tensor shape: {input_ids.shape}, dtype: {input_ids.dtype}") + + # Determine the sequence length to use + seq_len = int(input_ids.shape[1]) + if DEBUG: + print(f"[INFO] Full input length: {seq_len}") + + # Initialize accumulators for log probabilities and token count + total_log_probs = 0.0 + total_token_count = 0 + prev_end_loc = 0 + # Slide a window over the input to compute perplexity in chunks + for chunk_idx, begin_loc in enumerate(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc + if DEBUG: + print( + f"\n[LOOP] chunk_idx={chunk_idx} [begin={begin_loc} end={end_loc}] trg_len={trg_len}" + ) + + # Extract the current chunk of input tokens + input_ids_chunk = input_ids[:, begin_loc:end_loc].clone() + target_ids = input_ids_chunk.clone() + if DEBUG: + print(f"input_ids_chunk.shape: {input_ids_chunk.shape}") + # Mask context tokens: only predict for last trg_len tokens in chunk + mask = np.ones(target_ids.shape, dtype=bool) + mask[:, :-trg_len] = False + target_ids_masked = target_ids.clone() + target_ids_masked[~torch.from_numpy(mask)] = -100 # -100 is the ignore index + if DEBUG: + print(f"[MASK] Mask : {mask}") + print(f"[TARGET_IDS_MASKED] Target ids masked : {target_ids_masked}") + # Set up generator parameters for deterministic generation (no sampling) + params = og.GeneratorParams(model) + params.set_search_options( + max_length=int(input_ids_chunk.shape[1]), do_sample=False, early_stopping=False + ) + # Create generator and append input tokens + generator = og.Generator(model, params) + generator.append_tokens(input_ids_chunk.numpy()) + + # Run the model forward pass without gradient calculation + with torch.no_grad(): + if DEBUG: + print("[INFER] Running model forward pass ...") + try: + generator.generate_next_token() + except Exception as e: + print(f"[INFER] .generate_next_token() failed: {e}") + break # Fatal error + # Get logits output from the model + logits = generator.get_output("logits") + if hasattr(logits, "as_numpy"): + logits = logits.as_numpy() + if DEBUG: + print("[LOGITS] Used as_numpy()") + logits = torch.tensor(logits, dtype=torch.float32) + if DEBUG: + print(f"[LOGITS] Torch tensor shape: {logits.shape}, dtype: {logits.dtype}") + + # Compute log probabilities over vocabulary for each position + log_probs = torch.nn.functional.log_softmax(logits, dim=2).cpu().numpy() + chunk_seq_len = log_probs.shape[1] + # Language models predict next token: logits[i] predicts token[i+1] + # So we need logits[:-1] to match with target_ids[1:] + if chunk_seq_len > 1: + # Get log probabilities for all positions except the last + pred_log_probs = log_probs[0, :-1, :] # predictions for positions 0 to max_length-2 + # Get the target token ids for positions 1 to max_length-1 + target_ids_shifted = ( + target_ids_masked[0, 1:].cpu().numpy() + ) # targets at positions 1 to max_length-1 + if DEBUG: + print(f"[TARGET_IDS_SHIFTED] Target ids shifted shape: {target_ids_shifted.shape}") + print(f"[PRED_LOG_PROBS] Pred log probs shape: {pred_log_probs.shape}") + print(f"chunk_seq_len: {chunk_seq_len}") + + # Only include tokens with label != -100 (matching HF masking) + mask_flat = target_ids_shifted != -100 + if kv_chunking_enabled: + trg_len = min(trg_len, stride) + mask_flat = np.ones(trg_len, dtype=bool) + valid_indices = np.arange(0, trg_len - 1) + valid_targets = target_ids_shifted[-trg_len + 1 :] + else: + valid_indices = np.arange(len(target_ids_shifted))[mask_flat] + valid_targets = target_ids_shifted[mask_flat] + if DEBUG: + print(f"[VALID_INDICES] Valid indices shape: {valid_indices.shape}") + print(f"[VALID_TARGETS] Valid targets shape: {valid_targets.shape}") + # Gather the log probabilities for the correct target tokens + valid_log_probs = pred_log_probs[valid_indices, valid_targets] + if DEBUG: + print(f"[VALID_LOG_PROBS] Valid log probs shape: {valid_log_probs.shape}") + else: + valid_log_probs = np.array([]) + mask_flat = np.array([], dtype=bool) + + # Accumulate log probabilities and token count + total_log_probs += float(np.sum(valid_log_probs)) + total_token_count += int(valid_log_probs.size) + + if DEBUG: + print( + f"[LOOP] This chunk: valid tokens={valid_log_probs.size}, sum={np.sum(valid_log_probs)}" + ) + print(f"[TALLY] total_log_probs: {total_log_probs}") + print(f"[TALLY] total_token_count: {total_token_count}") + + # Update for next chunk + prev_end_loc = end_loc + if end_loc == seq_len: + if DEBUG: + print("[LOOP] Reached end of sequence.") + break + + # Compute average log probability and perplexity + avg_log_prob = total_log_probs / total_token_count + perplexity = np.exp(-avg_log_prob) + if DEBUG: + print(f"[FINAL] avg_log_prob: {avg_log_prob}") + print(f"\n[RESULT] Perplexity of {model_dir}: {perplexity}") + print("[RUN] === END perplexity_eval ===\n") + time_end = time.time() + print(f"[RUN] Time taken: {time_end - time_start:.2f} seconds") + return perplexity + + +# Example usage: +# perplexity_eval("/path/to/model_dir") +# +# To enable debug output, set DEBUG = True at the top of this file diff --git a/examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt b/examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt new file mode 100644 index 000000000..e3afaf9e7 --- /dev/null +++ b/examples/windows/accuracy_benchmark/perplexity_metrics/requirements.txt @@ -0,0 +1,23 @@ +# PyTorch with CUDA 12.x support (compatible with CUDA 12.1-12.9) +--extra-index-url https://download.pytorch.org/whl/cu129 +accelerate + +coloredlogs +datasets +flatbuffers +huggingface_hub[cli] +numpy +onnx +onnxruntime-genai +packaging +pandas +protobuf>=5.28.2 +pytest +sentencepiece +sympy +tokenizers>=0.14.1 +torch>=2.0.0 +torchaudio +torchvision +transformers>=4.36 + diff --git a/examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py b/examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py new file mode 100644 index 000000000..a66427b8a --- /dev/null +++ b/examples/windows/accuracy_benchmark/perplexity_metrics/run_perplexity.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys + +import pandas as pd + +# Ensure this directory is on sys.path for local imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if SCRIPT_DIR not in sys.path: + sys.path.insert(0, SCRIPT_DIR) +from perplexity_metrics import calculate_perplexity_hf, perplexity_eval # noqa: E402 + + +def run_perplexity_on_models( + model_dirs, + output_file="perplexity_results.csv", + i="1024", + chunk_size=None, + hf_model=None, + hf_device="cuda", + hf_dtype=None, +): + """ + Run perplexity evaluation on multiple ONNX Runtime GenAI models and/or a HuggingFace model. + + This function evaluates one or more models at different input sequence lengths, + saves results to a CSV file, and prints a summary report. Each model-length + combination is evaluated independently, with errors handled gracefully. + + Args: + model_dirs (list[str]): List of model directory paths to evaluate. + Each directory must contain a valid ONNX Runtime GenAI model. + output_file (str, optional): Path for the output CSV file containing results. + Defaults to "perplexity_results.csv". + i (str or list, optional): Input sequence lengths to evaluate. Can be: + - String: comma-separated values (e.g., "1024,2048,4096") + - List/tuple: sequence of integers + - Single int: one length to evaluate + Defaults to "1024". + chunk_size (int, optional): Prefill chunk size for KV cache chunking. + Required for input lengths > 1024. + Overrides chunk_size in model config if provided. + Defaults to None. + hf_model (str, optional): HuggingFace model name or path to evaluate. + If provided, will download and evaluate this model. + Defaults to None. + hf_device (str, optional): Device to run HuggingFace model on. + Defaults to "cuda". + hf_dtype (str, optional): Data type for HuggingFace model. + Options: "float16", "bfloat16", "float32". + Defaults to None (uses model default). + + Returns: + pd.DataFrame: DataFrame containing evaluation results with columns: + - Model Path: Full path to model directory + - Model Type: "ONNX" or "HuggingFace" + - Input Length: Sequence length used for evaluation + - Perplexity: Computed perplexity score (or "N/A" if failed) + - Status: "Success" or "Failed" + - Error: Error message if failed, "None" if successful + + """ + results = [] + + # Parse input lengths + if isinstance(i, str): + i_list = [int(x.strip()) for x in i.split(",") if x.strip()] + elif isinstance(i, (list, tuple)): + i_list = [int(x) for x in i] + else: + i_list = [int(i)] + + # Evaluate HuggingFace model if provided + if hf_model is not None: + print(f"\n{'=' * 60}") + print(f"Evaluating HuggingFace model: {hf_model}") + print(f"{'=' * 60}") + + # Convert dtype string to torch dtype + import torch + + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + torch_dtype = dtype_map.get(hf_dtype.lower()) if hf_dtype else torch.float16 + + for input_len in i_list: + try: + print(f" Evaluating with input length: {input_len}") + if torch_dtype: + print(f" Using dtype: {torch_dtype}") + + # Calculate stride (use chunk_size if provided, otherwise use half of input_len) + stride = chunk_size if chunk_size is not None else input_len // 2 + + perplexity = calculate_perplexity_hf( + model_name_or_path=hf_model, + max_length=input_len, + stride=stride, + device=hf_device, + torch_dtype=torch_dtype, + ) + + results.append( + { + "Model Path": hf_model, + "Model Type": "HuggingFace", + "Input Length": int(input_len), + "Perplexity": float(perplexity), + "Status": "Success", + "Error": "None", + } + ) + except Exception as e: # noqa: PERF203 + print(f" Error for input length {input_len}: {e!s}") + results.append( + { + "Model Path": hf_model, + "Model Type": "HuggingFace", + "Input Length": int(input_len), + "Perplexity": "N/A", + "Status": "Failed", + "Error": str(e), + } + ) + + print(" HuggingFace perplexity evaluation completed") + + # Unload HuggingFace model from GPU memory before ONNX evaluation + print("[CLEANUP] Unloading HuggingFace model from GPU memory...") + import gc + + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + print("[CLEANUP] GPU memory freed") + + # Evaluate ONNX models + for model_dir in model_dirs: + print(f"\n{'=' * 60}") + print(f"Evaluating perplexity for: {model_dir}") + print(f"{'=' * 60}") + + try: + # Check if model directory exists + if not os.path.exists(model_dir): + print(f"Error: Model directory does not exist: {model_dir}") + results.append( + { + "Model Path": model_dir, + "Perplexity": "N/A", + "Status": "Directory not found", + "Error": "Directory does not exist", + } + ) + continue + + # Check if genai_config.json exists + config_path = os.path.join(model_dir, "genai_config.json") + if not os.path.exists(config_path): + print(f"Error: genai_config.json not found in: {model_dir}") + results.append( + { + "Model Path": model_dir, + "Model Type": "ONNX", + "Perplexity": "N/A", + "Status": "Invalid model format", + "Error": "genai_config.json not found", + } + ) + continue + + # For each input length, run perplexity_eval and record results + for input_len in i_list: + try: + print(f" Evaluating with input length: {input_len}") + if chunk_size is None: + print( + " Note: input length is ignored unless chunk_size is set or " + "config.search.chunk_size is present." + ) + if chunk_size is not None: + print(f" Using chunk_size: {chunk_size}") + perplexity = perplexity_eval(model_dir, str(input_len), chunk_size) + else: + perplexity = perplexity_eval(model_dir, str(input_len)) + results.append( + { + "Model Path": model_dir, + "Model Type": "ONNX", + "Input Length": int(input_len), + "Perplexity": float(perplexity), + "Status": "Success", + "Error": "None", + } + ) + except Exception as e: # noqa: PERF203 + print(f" Error for input length {input_len}: {e!s}") + results.append( + { + "Model Path": model_dir, + "Model Type": "ONNX", + "Input Length": int(input_len), + "Perplexity": "N/A", + "Status": "Failed", + "Error": str(e), + } + ) + + print(" Perplexity evaluation completed successfully") + + except Exception as e: + print(f"Error during perplexity evaluation: {e!s}") + results.append( + { + "Model Path": model_dir, + "Model Type": "ONNX", + "Perplexity": "N/A", + "Status": "Failed", + "Error": str(e), + } + ) + + # Create results DataFrame and save to CSV + df = pd.DataFrame(results) + df.to_csv(output_file, index=False) + + print(f"\n{'=' * 60}") + print(f"Results saved to: {output_file}") + print(f"{'=' * 60}") + + # Print summary + successful = df[df["Status"] == "Success"] + failed = df[df["Status"] != "Success"] + + print("\nSummary:") + print(f" Successful evaluations: {len(successful)}") + print(f" Failed evaluations: {len(failed)}") + + if len(successful) > 0: + print("\nPerplexity Results:") + for _, row in successful.iterrows(): + print( + f" {os.path.basename(row['Model Path'])} [i={row.get('Input Length', '?')}]: " + f"{row['Perplexity']:.4f}" + if isinstance(row["Perplexity"], (int, float)) + else row["Perplexity"] + ) + + return df + + +def main(): + """ + Command-line entry point for perplexity evaluation. + + Parses command-line arguments and runs perplexity evaluation on specified + ONNX Runtime GenAI models and/or HuggingFace models. Results are saved to a CSV file. + + Command-line Arguments: + --models: One or more ONNX model directory paths (optional) + --hf_model: HuggingFace model name or path (optional) + --hf_device: Device for HuggingFace model (default: "cuda") + --hf_dtype: Data type for HuggingFace model (default: None) + --i: Comma-separated input sequence lengths (default: "1024") + --output: Output CSV file path (default: "perplexity_results.csv") + --chunk_size: Prefill chunk size for prefill chunking (optional) + + Examples: + # Evaluate ONNX models + $ python run_perplexity.py --models /path/to/model + $ python run_perplexity.py --models /path/to/model1 /path/to/model2 \\ + --i 1024,2048,4096 --chunk_size 1024 --output results.csv + + # Evaluate HuggingFace model + $ python run_perplexity.py --hf_model meta-llama/Llama-2-7b-hf --i 1024 + $ python run_perplexity.py --hf_model meta-llama/Llama-2-7b-hf \\ + --hf_dtype float16 --hf_device cuda --i 1024,2048 + + # Evaluate both ONNX and HuggingFace models + $ python run_perplexity.py --models /path/to/onnx_model \\ + --hf_model meta-llama/Llama-2-7b-hf --i 1024 + """ + parser = argparse.ArgumentParser( + description="Run perplexity evaluation on ONNX Runtime GenAI and/or HuggingFace models" + ) + parser.add_argument( + "--models", + nargs="+", + default=[], + help="List of ONNX model directory paths to evaluate (optional)", + ) + parser.add_argument( + "--i", + default="1024", + help="Comma-separated input seq lengths to be evaluated (e.g. 1024,2048) please enter number >= 1024", + ) + parser.add_argument( + "--output", + default="perplexity_results.csv", + help="Output CSV file name (default: perplexity_results.csv)", + ) + parser.add_argument( + "--chunk_size", + type=int, + default=None, + help="Chunk size for KV caching optimization (optional)", + ) + parser.add_argument( + "--hf_model", + type=str, + default=None, + help="HuggingFace model name or path to evaluate (e.g., 'meta-llama/Llama-2-7b-hf')", + ) + parser.add_argument( + "--hf_device", + type=str, + default="cuda", + help="Device to run HuggingFace model on (default: 'cuda')", + ) + parser.add_argument( + "--hf_dtype", + type=str, + default=None, + choices=["float16", "bfloat16", "float32", "fp16", "bf16", "fp32"], + help="Data type for HuggingFace model (default: None, uses model default)", + ) + + args = parser.parse_args() + + # Validate that at least one model source is provided + if not args.models and not args.hf_model: + print("Error: You must provide either --models or --hf_model (or both)") + parser.print_help() + return + + # Validate that all model directories exist + valid_models = [] + for model_dir in args.models: + if os.path.exists(model_dir): + valid_models.append(model_dir) + else: + print(f"Warning: Model directory does not exist: {model_dir}") + + # Count total models to evaluate + total_models = len(valid_models) + (1 if args.hf_model else 0) + + print(f"Running perplexity evaluation on {total_models} model(s)...") + if args.chunk_size is not None: + print(f"Using chunk_size: {args.chunk_size}") + + run_perplexity_on_models( + valid_models, + args.output, + args.i, + args.chunk_size, + args.hf_model, + args.hf_device, + args.hf_dtype, + ) + + +if __name__ == "__main__": + main() diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index ce252bc8f..31ea42764 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -866,7 +866,7 @@ def get_layer_info( layers_8bit = kwargs.get("layers_8bit") gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) - if enable_mixed_quant: + if enable_mixed_quant or layers_8bit: layer_info = get_layer_precision_mapping( onnx_model, layers_8bit,