| 
 | 1 | +import argparse  | 
 | 2 | +import os  | 
 | 3 | + | 
 | 4 | +import numpy as np  | 
 | 5 | +import onnxruntime_genai as og  | 
 | 6 | +import torch  | 
 | 7 | +from datasets import load_dataset  | 
 | 8 | + | 
 | 9 | +DEBUG = False  | 
 | 10 | + | 
 | 11 | + | 
 | 12 | +def get_kl_divergence(log_probs_ref, log_probs_tar):  | 
 | 13 | +    kl_divergence = 0.0  | 
 | 14 | +    for i in range(log_probs_ref.shape[0]):  | 
 | 15 | +        log_probs_ref[i] = np.array(log_probs_ref[i])  | 
 | 16 | +        log_probs_tar[i] = np.array(log_probs_tar[i])  | 
 | 17 | +        prob_ref = np.exp(log_probs_ref[i])  | 
 | 18 | +        kl_divergence += np.sum(prob_ref * abs(log_probs_ref[i] - log_probs_tar[i]))  | 
 | 19 | +    kl_divergence = kl_divergence / log_probs_ref.shape[0]  | 
 | 20 | +    return kl_divergence  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +def get_wikitext2():  | 
 | 24 | +    # Load the Wikitext-2 test split using HuggingFace datasets  | 
 | 25 | +    print("\n[INFO] Loading Wikitext-2 'test' split ...")  | 
 | 26 | +    test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")  | 
 | 27 | +    if DEBUG:  | 
 | 28 | +        print(f"[DATASET] Number of raw samples: {len(test)}")  | 
 | 29 | +        for i in range(3):  | 
 | 30 | +            print(f"[DATASET] Sample[{i}]: {repr(test[i]['text'])[:200]} ...")  | 
 | 31 | +    # Concatenate all text samples into a single string, separated by double newlines  | 
 | 32 | +    result = "\n\n".join(text for text in test["text"])  | 
 | 33 | +    if DEBUG:  | 
 | 34 | +        print(  | 
 | 35 | +            f"[DATASET] Concatenated text preview: {result[:512]!r} ... [total chars: {len(result)}]"  | 
 | 36 | +        )  | 
 | 37 | +    return result  | 
 | 38 | + | 
 | 39 | + | 
 | 40 | +def run_kl_divergence_on_models(reference_model, target_model):  | 
 | 41 | +    ref_model = og.Model(reference_model)  | 
 | 42 | +    tar_model = og.Model(target_model)  | 
 | 43 | +    tokenizer_ref = og.Tokenizer(ref_model)  | 
 | 44 | +    tokenizer_tar = og.Tokenizer(tar_model)  | 
 | 45 | +    max_context_length = 1024  | 
 | 46 | +    dataset = get_wikitext2()  | 
 | 47 | + | 
 | 48 | +    input_ids_ref = tokenizer_ref.encode_batch([dataset])  | 
 | 49 | +    input_ids_tar = tokenizer_tar.encode_batch([dataset])  | 
 | 50 | +    # Handle possible dict output from tokenizer  | 
 | 51 | +    if isinstance(input_ids_ref, dict) and "input_ids" in input_ids_ref:  | 
 | 52 | +        input_ids_ref = input_ids_ref["input_ids"]  | 
 | 53 | +    # Convert to numpy if needed  | 
 | 54 | +    if hasattr(input_ids_ref, "as_numpy"):  | 
 | 55 | +        input_ids_ref = input_ids_ref.as_numpy()  | 
 | 56 | +        if DEBUG:  | 
 | 57 | +            print("[TOKENIZER] Used as_numpy()")  | 
 | 58 | +    if isinstance(input_ids_tar, dict) and "input_ids" in input_ids_tar:  | 
 | 59 | +        input_ids_tar = input_ids_tar["input_ids"]  | 
 | 60 | +    if hasattr(input_ids_tar, "as_numpy"):  | 
 | 61 | +        input_ids_tar = input_ids_tar.as_numpy()  | 
 | 62 | +        if DEBUG:  | 
 | 63 | +            print("[TOKENIZER] Used as_numpy()")  | 
 | 64 | +    input_ids_ref = np.array(input_ids_ref)  | 
 | 65 | +    input_ids_tar = np.array(input_ids_tar)  | 
 | 66 | + | 
 | 67 | +    # Ensure input_ids is 2D (batch, seq_len)  | 
 | 68 | +    if input_ids_ref.ndim == 1:  | 
 | 69 | +        input_ids_ref = np.expand_dims(input_ids_ref, 0)  | 
 | 70 | +        if DEBUG:  | 
 | 71 | +            print(f"[SHAPE] Expanded dims, now: {input_ids_ref.shape}")  | 
 | 72 | +    if input_ids_tar.ndim == 1:  | 
 | 73 | +        input_ids_tar = np.expand_dims(input_ids_tar, 0)  | 
 | 74 | +        if DEBUG:  | 
 | 75 | +            print(f"[SHAPE] Expanded dims, now: {input_ids_tar.shape}")  | 
 | 76 | +    # Convert input_ids to torch tensor  | 
 | 77 | +    input_ids_ref = torch.tensor(input_ids_ref, dtype=torch.long)  | 
 | 78 | +    input_ids_tar = torch.tensor(input_ids_tar, dtype=torch.long)  | 
 | 79 | +    seq_len_ref = int(input_ids_ref.shape[1])  | 
 | 80 | +    seq_len_tar = int(input_ids_tar.shape[1])  | 
 | 81 | +    if DEBUG:  | 
 | 82 | +        print(f"[INFO] Full input length: {seq_len_ref}")  | 
 | 83 | +        print(f"[INFO] Full input length: {seq_len_tar}")  | 
 | 84 | + | 
 | 85 | +    if seq_len_ref != seq_len_tar:  | 
 | 86 | +        print(  | 
 | 87 | +            f"Error: Input tokenizer lengths for reference and target models do not match: "  | 
 | 88 | +            f"{seq_len_ref} != {seq_len_tar}"  | 
 | 89 | +        )  | 
 | 90 | +        return  | 
 | 91 | +    if DEBUG:  | 
 | 92 | +        print(f"[INFO] Input lengths match: {seq_len_ref}")  | 
 | 93 | +    # Slide a window over the input to compute perplexity in chunks  | 
 | 94 | +    total_kl_divergence = 0.0  | 
 | 95 | +    total_batch = 0  | 
 | 96 | +    for begin_loc in range(0, seq_len_ref, max_context_length):  | 
 | 97 | +        end_loc = min(begin_loc + max_context_length, seq_len_ref)  | 
 | 98 | +        # Extract the current chunk of input tokens  | 
 | 99 | +        input_ids_chunk_ref = input_ids_ref[:, begin_loc:end_loc].clone()  | 
 | 100 | +        input_ids_chunk_tar = input_ids_tar[:, begin_loc:end_loc].clone()  | 
 | 101 | +        if DEBUG:  | 
 | 102 | +            print(f"input_ids_chunk_ref.shape: {input_ids_chunk_ref.shape}")  | 
 | 103 | +            print(f"input_ids_chunk_tar.shape: {input_ids_chunk_tar.shape}")  | 
 | 104 | +        # Set up generator parameters for deterministic generation (no sampling)  | 
 | 105 | +        params_ref = og.GeneratorParams(ref_model)  | 
 | 106 | +        params_tar = og.GeneratorParams(tar_model)  | 
 | 107 | +        params_ref.set_search_options(  | 
 | 108 | +            max_length=int(input_ids_chunk_ref.shape[1]), do_sample=False, early_stopping=False  | 
 | 109 | +        )  | 
 | 110 | +        params_tar.set_search_options(  | 
 | 111 | +            max_length=int(input_ids_chunk_tar.shape[1]), do_sample=False, early_stopping=False  | 
 | 112 | +        )  | 
 | 113 | +        # Create generator and append input tokens  | 
 | 114 | +        generator_ref = og.Generator(ref_model, params_ref)  | 
 | 115 | +        generator_ref.append_tokens(input_ids_chunk_ref.numpy())  | 
 | 116 | +        generator_tar = og.Generator(tar_model, params_tar)  | 
 | 117 | +        generator_tar.append_tokens(input_ids_chunk_tar.numpy())  | 
 | 118 | + | 
 | 119 | +        # Run the model forward pass without gradient calculation  | 
 | 120 | +        with torch.no_grad():  | 
 | 121 | +            if DEBUG:  | 
 | 122 | +                print("[INFER] Running model forward pass ...")  | 
 | 123 | +            try:  | 
 | 124 | +                generator_ref.generate_next_token()  | 
 | 125 | +                generator_tar.generate_next_token()  | 
 | 126 | +            except Exception as e:  | 
 | 127 | +                print(f"[INFER] .generate_next_token() failed: {e}")  | 
 | 128 | +                break  # Fatal error  | 
 | 129 | +            # Get logits output from the model  | 
 | 130 | +            logits_ref = generator_ref.get_output("logits")  | 
 | 131 | +            logits_tar = generator_tar.get_output("logits")  | 
 | 132 | +            if DEBUG:  | 
 | 133 | +                print(f"logits_ref.shape: {logits_ref.shape}")  | 
 | 134 | +                print(f"logits_tar.shape: {logits_tar.shape}")  | 
 | 135 | +            # Convert numpy arrays to torch tensors  | 
 | 136 | +            logits_ref = torch.tensor(logits_ref, dtype=torch.float32)  | 
 | 137 | +            logits_tar = torch.tensor(logits_tar, dtype=torch.float32)  | 
 | 138 | +        # Compute log probabilities over vocabulary for each position  | 
 | 139 | +        log_probs_ref = torch.nn.functional.log_softmax(logits_ref, dim=2).cpu().numpy()  | 
 | 140 | +        log_probs_tar = torch.nn.functional.log_softmax(logits_tar, dim=2).cpu().numpy()  | 
 | 141 | +        if DEBUG:  | 
 | 142 | +            print(f"log_probs_ref.shape: {log_probs_ref.shape}")  | 
 | 143 | +            print(f"log_probs_tar.shape: {log_probs_tar.shape}")  | 
 | 144 | +        # Compute KL divergence  | 
 | 145 | +        kl_divergence = 0.0  | 
 | 146 | +        # Reshape log_probs_ref and log_probs_tar from (1, 1024, 128256) to (1024, 128256)  | 
 | 147 | +        log_probs_ref = log_probs_ref.squeeze(0)  | 
 | 148 | +        log_probs_tar = log_probs_tar.squeeze(0)  | 
 | 149 | + | 
 | 150 | +        # log_probs_ref = torch.tensor(log_probs_ref, dtype=torch.float32)  | 
 | 151 | +        # log_probs_tar = torch.tensor(log_probs_tar, dtype=torch.float32)  | 
 | 152 | +        # kl_divergence = torch.nn.functional.kl_div(  | 
 | 153 | +        #     log_probs_ref, log_probs_tar, reduction='batchmean', log_target=True  | 
 | 154 | +        # )  | 
 | 155 | +        kl_divergence = get_kl_divergence(log_probs_ref, log_probs_tar)  | 
 | 156 | +        total_kl_divergence += kl_divergence  | 
 | 157 | +        total_batch += 1  | 
 | 158 | +        if DEBUG:  | 
 | 159 | +            print(f"KL divergence: {kl_divergence}")  | 
 | 160 | +    avg_kl_divergence = total_kl_divergence / total_batch  | 
 | 161 | +    if DEBUG:  | 
 | 162 | +        print(f"Average KL divergence: {avg_kl_divergence}")  | 
 | 163 | +    print(f"Total KL divergence: {total_kl_divergence}")  | 
 | 164 | +    print(f"Total batch: {total_batch}")  | 
 | 165 | +    print(f"Average KL divergence: {avg_kl_divergence}")  | 
 | 166 | + | 
 | 167 | + | 
 | 168 | +def main():  | 
 | 169 | +    parser = argparse.ArgumentParser(  | 
 | 170 | +        description="Run KL divergence evaluation on ONNX Runtime GenAI models"  | 
 | 171 | +    )  | 
 | 172 | +    parser.add_argument(  | 
 | 173 | +        "--reference_model", required=True, help="Path to reference model directory"  | 
 | 174 | +    )  | 
 | 175 | +    parser.add_argument("--target_model", required=True, help="Path to target model directory")  | 
 | 176 | +    args = parser.parse_args()  | 
 | 177 | + | 
 | 178 | +    # Validate that all model directories exist  | 
 | 179 | +    valid_models = []  | 
 | 180 | +    if os.path.exists(args.reference_model):  | 
 | 181 | +        valid_models.append(args.reference_model)  | 
 | 182 | +    else:  | 
 | 183 | +        print(f"Warning: Reference Model directory does not exist: {args.reference_model}")  | 
 | 184 | +    if os.path.exists(args.target_model):  | 
 | 185 | +        valid_models.append(args.target_model)  | 
 | 186 | +    else:  | 
 | 187 | +        print(f"Warning: Target Model directory does not exist: {args.target_model}")  | 
 | 188 | +    if len(valid_models) != 2:  | 
 | 189 | +        print("Error: No valid model directories provided")  | 
 | 190 | +        return  | 
 | 191 | + | 
 | 192 | +    print(  | 
 | 193 | +        f"Running KL divergence evaluation on reference model={valid_models[0]} and target model={valid_models[1]}"  | 
 | 194 | +    )  | 
 | 195 | +    run_kl_divergence_on_models(valid_models[0], valid_models[1])  | 
 | 196 | + | 
 | 197 | + | 
 | 198 | +if __name__ == "__main__":  | 
 | 199 | +    main()  | 
0 commit comments