44against multiple ONNX Runtime GenAI models with different execution providers.
55
66Usage:
7- python compare_models_generic .py --hf_model "F:\shared\Llama-3.1-8B-Instruct"
7+ python compute_kl_divergence .py --hf_model "F:\shared\Llama-3.1-8B-Instruct"
88 --ep cuda --path "G:\models\cuda_model"
99 --ep directml --path "G:\models\directml_model"
1010 --output "comparison_results.json"
3636
3737
3838def debug_print (message ):
39- """Print debug message only if DEBUG is True"""
39+ """
40+ Print debug message only if DEBUG flag is enabled.
41+
42+ Args:
43+ message (str): Debug message to print.
44+ """
4045 if DEBUG :
4146 print (f"[DEBUG] { message } " )
4247
4348
4449def run_command (cmd , description = "" , capture_output = True ):
45- """Run a command and handle errors"""
50+ """
51+ Execute a subprocess command with error handling.
52+
53+ Args:
54+ cmd (list[str]): Command and arguments to execute.
55+ description (str, optional): Description of the command for logging. Defaults to "".
56+ capture_output (bool, optional): Whether to capture stdout/stderr or show in real-time.
57+ Defaults to True.
58+
59+ Returns:
60+ bool: True if command succeeded, False otherwise.
61+ """
4662 debug_print (f"[INFO] { description } " )
4763 debug_print (f"Running: { ' ' .join (cmd )} " )
4864
4965 try :
5066 if capture_output :
51- result = subprocess .run (cmd , check = True , capture_output = True , text = True , shell = True )
67+ result = subprocess .run (cmd , check = True , capture_output = True , text = True , shell = False )
5268 if result .stdout and DEBUG :
5369 print (f"[OUT] { result .stdout } " )
5470 else :
5571 # Real-time output - shows prints as they happen
56- result = subprocess .run (cmd , check = True , shell = True )
72+ result = subprocess .run (cmd , check = True , shell = False )
5773 return True
5874 except subprocess .CalledProcessError as e :
5975 print (f"[ERROR] Command failed: { e } " )
@@ -70,7 +86,12 @@ def get_python_executable():
7086
7187
7288def uninstall_onnxruntime_packages ():
73- """Uninstall all ONNX Runtime packages"""
89+ """
90+ Uninstall all ONNX Runtime and ONNX Runtime GenAI packages.
91+
92+ This ensures a clean environment before installing provider-specific packages
93+ to avoid version conflicts.
94+ """
7495 packages_to_remove = [
7596 "onnxruntime" ,
7697 "onnxruntime-genai" ,
@@ -88,7 +109,15 @@ def uninstall_onnxruntime_packages():
88109
89110
90111def install_package (package_name ):
91- """Install a specific package"""
112+ """
113+ Install a specific Python package using pip.
114+
115+ Args:
116+ package_name (str): Name of the package to install.
117+
118+ Returns:
119+ bool: True if installation succeeded, False otherwise.
120+ """
92121 debug_print (f"Installing package: { package_name } " )
93122 python_exe = get_python_executable ()
94123 debug_print (f"Python executable: { python_exe } " )
@@ -101,19 +130,33 @@ def install_package(package_name):
101130
102131
103132def extract_hf_logits_subprocess (model_path , device = "cuda" ):
104- """Extract logits from Hugging Face model using subprocess"""
133+ """
134+ Extract logits from a Hugging Face transformer model using a subprocess.
135+
136+ Runs extract_logits_hf.py in a separate process to avoid package conflicts.
137+ Uses temporary file for data transfer between processes.
138+
139+ Args:
140+ model_path (str): Path to the Hugging Face model directory.
141+ device (str, optional): Device for inference ('cuda' or 'cpu'). Defaults to "cuda".
142+
143+ """
105144 print ("[INFO] Extracting logits from Hugging Face baseline model..." )
106145 debug_print (f"Model path: { model_path } , Device: { device } " )
107146
108147 # Create temporary output file
109- output_file = f"temp_logits_hf_{ int (time .time ())} .pkl"
148+ import tempfile
149+
150+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
151+ with tempfile .NamedTemporaryFile (prefix = "temp_logits_hf_" , suffix = ".pkl" , delete = False ) as tmp :
152+ output_file = tmp .name
110153 debug_print (f"Temporary output file: { output_file } " )
111154
112155 try :
113156 python_exe = get_python_executable ()
114157 cmd = [
115158 python_exe ,
116- "extract_logits_hf.py" ,
159+ os . path . join ( script_dir , "extract_logits_hf.py" ) ,
117160 "--model_path" ,
118161 model_path ,
119162 "--output_file" ,
@@ -158,19 +201,33 @@ def extract_hf_logits_subprocess(model_path, device="cuda"):
158201
159202
160203def extract_onnx_logits_subprocess (model_path , provider ):
161- """Extract logits from ONNX Runtime GenAI model using subprocess"""
204+ """
205+ Extract logits from an ONNX Runtime GenAI model using a subprocess.
206+
207+ Runs extract_logits.py in a separate process with the appropriate ONNX Runtime
208+ package for the specified execution provider. Uses temporary file for data transfer.
209+
210+ Args:
211+ model_path (str): Path to the ONNX Runtime GenAI model directory.
212+ provider (str): Execution provider ('cuda', 'directml', or 'cpu').
213+
214+ """
162215 print (f"[INFO] Extracting logits from { provider .upper ()} model..." )
163216 debug_print (f"Model path: { model_path } , Provider: { provider } " )
164217
165218 # Create temporary output file
166- output_file = f"temp_logits_{ provider } _{ int (time .time ())} .pkl"
219+ import tempfile
220+
221+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
222+ with tempfile .NamedTemporaryFile (prefix = "temp_logits_" , suffix = ".pkl" , delete = False ) as tmp :
223+ output_file = tmp .name
167224 debug_print (f"Temporary output file: { output_file } " )
168225
169226 try :
170227 python_exe = get_python_executable ()
171228 cmd = [
172229 python_exe ,
173- "extract_logits.py" ,
230+ os . path . join ( script_dir , "extract_logits.py" ) ,
174231 "--model_path" ,
175232 model_path ,
176233 "--output_file" ,
@@ -220,8 +277,20 @@ def extract_onnx_logits_subprocess(model_path, provider):
220277
221278def compute_kl_divergence_from_logits (log_probs_ref , log_probs_tar ):
222279 """
223- Compute KL divergence between two log probability distributions.
224- Same logic as in compute_kl_divergence.py
280+ Compute Kullback-Leibler divergence between two log probability distributions.
281+
282+ KL divergence measures how one probability distribution diverges from a reference
283+ distribution. Lower values indicate more similar distributions.
284+
285+ Args:
286+ log_probs_ref (np.ndarray): Reference log probabilities with shape (seq_len, vocab_size).
287+ log_probs_tar (np.ndarray): Target log probabilities with shape (seq_len, vocab_size).
288+
289+ Returns:
290+ float: Average KL divergence across all positions.
291+
292+ Note:
293+ Formula: KL(P||Q) = sum(P(x) * |log(P(x)) - log(Q(x))|) averaged over sequence length
225294 """
226295 debug_print (
227296 f"Computing KL divergence - log_probs shapes: ref={ log_probs_ref .shape } , tar={ log_probs_tar .shape } "
@@ -239,7 +308,13 @@ def compute_kl_divergence_from_logits(log_probs_ref, log_probs_tar):
239308
240309def to_serializable (obj ):
241310 """
242- Recursively convert numpy types and torch types to native Python types for JSON serialization.
311+ Recursively convert numpy and torch types to native Python types for JSON serialization.
312+
313+ Args:
314+ obj: Object to convert (dict, list, tuple, np.ndarray, torch.Tensor, etc.).
315+
316+ Returns:
317+ Converted object with native Python types (int, float, list, dict, tuple).
243318 """
244319 if isinstance (obj , dict ):
245320 return {k : to_serializable (v ) for k , v in obj .items ()}
@@ -261,8 +336,23 @@ def to_serializable(obj):
261336
262337def compute_unified_comparison (model_logits_list , output_file ):
263338 """
264- Compute KL divergence comparison between all models in a unified way
265- model_logits_list: List of tuples (model_name, model_data)
339+ Compute pairwise KL divergence between all models and save results to JSON.
340+
341+ This function performs an all-vs-all comparison of the provided models by computing
342+ KL divergence for each chunk and averaging across all chunks. Results are saved
343+ in a structured JSON format.
344+
345+ Args:
346+ model_logits_list (list): List of tuples (model_name, model_data) where:
347+ - model_name (str): Identifier for the model (e.g., "hf_baseline", "cuda_1")
348+ - model_data (dict): Dictionary containing:
349+ - 'logits': List of numpy arrays (one per chunk)
350+ - 'total_chunks': Number of chunks
351+ - 'seq_len': Sequence length
352+ - 'model_path': Path to model
353+ - 'chunk_info': Chunk position info
354+ output_file (str): Path to save the JSON results file.
355+
266356 """
267357 print ("\n [INFO] Computing unified KL divergence comparison..." )
268358 debug_print (f"Number of models to compare: { len (model_logits_list )} " )
@@ -325,7 +415,7 @@ def compute_unified_comparison(model_logits_list, output_file):
325415 # Find minimum sequence length for this chunk
326416 min_seq_len = min (getattr (logits , "shape" , [None , 0 ])[1 ] for _ , logits in chunk_logits )
327417 # Assume all have same vocab size
328- vocab_size = getattr (chunk_logits [ 0 ][ 1 ] , "shape" , [None , None , 0 ])[2 ]
418+ vocab_size = min ( getattr (logits , "shape" , [None , None , 0 ])[2 ] for _ , logits in chunk_logits )
329419 debug_print (f" Min seq len: { min_seq_len } , Vocab size: { vocab_size } " )
330420
331421 # Trim all logits to matching dimensions
@@ -396,7 +486,16 @@ def compute_unified_comparison(model_logits_list, output_file):
396486
397487
398488def validate_inputs (hf_model , ep_path_pairs ):
399- """Validate that all input paths exist and EPs are supported"""
489+ """
490+ Validate that all model paths exist and execution providers are supported.
491+
492+ Args:
493+ hf_model (str or None): Path to Hugging Face model (optional).
494+ ep_path_pairs (list): List of (execution_provider, model_path) tuples.
495+
496+ Returns:
497+ bool: True if all inputs are valid, False otherwise.
498+ """
400499 # Check HF model path (only if provided)
401500 if hf_model and not os .path .exists (hf_model ):
402501 print (f"[ERROR] Hugging Face model path does not exist: { hf_model } " )
@@ -556,7 +655,7 @@ def main():
556655 ]
557656
558657 print ("\n [INFO] Running KL_divergence_metrics_same_ep.py..." )
559- result = subprocess .run (cmd , shell = True )
658+ result = subprocess .run (cmd , shell = False )
560659
561660 if result .returncode == 0 :
562661 print ("\n [SUCCESS] KL divergence computation completed successfully" )
0 commit comments