From 214e85c6371e875122edf5403cbdb457229f22c0 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Fri, 28 Feb 2025 13:43:43 +0530
Subject: [PATCH 01/16] Added finetuning support for BERT based models on IMDB
 dataset.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/cloud/finetune.py                  |  48 ++++---
 QEfficient/finetune/configs/dataset_config.py |   8 ++
 QEfficient/finetune/configs/training.py       |   1 +
 QEfficient/finetune/dataset/dataset_config.py |   4 +
 QEfficient/finetune/dataset/imdb_dataset.py   |  36 +++++
 QEfficient/finetune/utils/config_utils.py     |   7 +-
 QEfficient/finetune/utils/train_utils.py      | 131 ++++++++++++++++--
 7 files changed, 205 insertions(+), 30 deletions(-)
 create mode 100644 QEfficient/finetune/dataset/imdb_dataset.py

diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py
index c7525d2db..614c16180 100644
--- a/QEfficient/cloud/finetune.py
+++ b/QEfficient/cloud/finetune.py
@@ -38,7 +38,7 @@
     print(f"Warning: {e}. Moving ahead without these qaic modules.")
 
 
-from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
 
 # Suppress all warnings
 warnings.filterwarnings("ignore")
@@ -56,6 +56,7 @@ def main(**kwargs):
     # update the configuration for the training process
     train_config = TRAIN_CONFIG()
     update_config(train_config, **kwargs)
+    dataset_config = generate_dataset_config(train_config, kwargs)
     device = train_config.device
 
     # dist init
@@ -63,9 +64,9 @@ def main(**kwargs):
         # TODO: may have to init qccl backend, next try run with torchrun command
         torch_device = torch.device(device)
         assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
-        assert torch_device.index is None, (
-            f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
-        )
+        assert (
+            torch_device.index is None
+        ), f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
         dist.init_process_group(backend=train_config.dist_backend)
         # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
         getattr(torch, torch_device.type).set_device(dist.get_rank())
@@ -78,12 +79,26 @@ def main(**kwargs):
     # Load the pre-trained model and setup its configuration
     # config = AutoConfig.from_pretrained(train_config.model_name)
     pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
-    model = AutoModelForCausalLM.from_pretrained(
-        pretrained_model_path,
-        use_cache=False,
-        attn_implementation="sdpa",
-        torch_dtype=torch.float16,
-    )
+    if train_config.task_type == "seq_classification":
+        model = AutoModelForSequenceClassification.from_pretrained(
+            pretrained_model_path,
+            num_labels=dataset_config.num_labels,
+            attn_implementation="sdpa",
+            torch_dtype=torch.float16,
+        )
+        for param in getattr(model, model.base_model_prefix).parameters():
+            param.requires_grad = False
+
+        for param in model.parameters():
+            if param.requires_grad:
+                param.data = param.data.to(torch.float32)
+    else:
+        model = AutoModelForCausalLM.from_pretrained(
+            pretrained_model_path,
+            use_cache=False,
+            attn_implementation="sdpa",
+            torch_dtype=torch.float16,
+        )
 
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(
@@ -127,17 +142,16 @@ def main(**kwargs):
         model.print_trainable_parameters()
 
     # Get the dataset utils
-    dataset_config = generate_dataset_config(train_config, kwargs)
     dataset_processer = tokenizer
 
     # Load and preprocess the dataset for training and validation
-    dataset_train = get_preprocessed_dataset(
-        dataset_processer, dataset_config, split="train", context_length=train_config.context_length
-    )
+    ctx_len = train_config.context_length
+    if ctx_len is None and hasattr(model.config, "max_position_embeddings"):
+        ctx_len = model.config.max_position_embeddings
 
-    dataset_val = get_preprocessed_dataset(
-        dataset_processer, dataset_config, split="test", context_length=train_config.context_length
-    )
+    dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=ctx_len)
+
+    dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=ctx_len)
 
     # TODO: vbaddi, check if its necessary to do this?
     # dataset_train = ConcatDataset(
diff --git a/QEfficient/finetune/configs/dataset_config.py b/QEfficient/finetune/configs/dataset_config.py
index 2e7fb56fb..fd1081983 100644
--- a/QEfficient/finetune/configs/dataset_config.py
+++ b/QEfficient/finetune/configs/dataset_config.py
@@ -37,6 +37,14 @@ class gsm8k_dataset:
     test_split: str = "test"
 
 
+@dataclass
+class imdb_dataset:
+    dataset: str = "imdb_dataset"
+    train_split: str = "train"
+    test_split: str = "test"
+    num_labels: int = 2
+
+
 @dataclass
 class custom_dataset:
     dataset: str = "custom_dataset"
diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py
index 9010965a9..c50954c4c 100644
--- a/QEfficient/finetune/configs/training.py
+++ b/QEfficient/finetune/configs/training.py
@@ -29,6 +29,7 @@ class train_config:
     use_autocast: bool = True
     val_batch_size: int = 1
     dataset = "samsum_dataset"
+    task_type = "generation"  # "generation" / "seq_classification"
     peft_method: str = "lora"
     use_peft: bool = True  # use parameter efficient fine tuning
     from_peft_checkpoint: str = ""  # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py
index ac086d272..fd9ace0af 100644
--- a/QEfficient/finetune/dataset/dataset_config.py
+++ b/QEfficient/finetune/dataset/dataset_config.py
@@ -18,6 +18,9 @@
     get_dataset as get_grammar_dataset,
 )
 from QEfficient.finetune.dataset.gsm8k_dataset import get_gsm8k_dataset
+from QEfficient.finetune.dataset.imdb_dataset import (
+    get_preprocessed_imdb as get_imdb_dataset,
+)
 from QEfficient.finetune.dataset.samsum_dataset import (
     get_preprocessed_samsum as get_samsum_dataset,
 )
@@ -31,6 +34,7 @@
     "samsum_dataset": get_samsum_dataset,
     "gsm8k_dataset": get_gsm8k_dataset,
     "custom_dataset": get_custom_dataset,
+    "imdb_dataset": get_imdb_dataset,
 }
 DATALOADER_COLLATE_FUNC = {
     "custom_dataset": get_data_collator,
diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py
new file mode 100644
index 000000000..71f67e46e
--- /dev/null
+++ b/QEfficient/finetune/dataset/imdb_dataset.py
@@ -0,0 +1,36 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import datasets
+
+
+def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None):
+    dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True)
+
+    # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data.
+    dataset = dataset.shuffle(seed=42)
+
+    if split == "test":
+        # Test set contains 15000 samples. Not all are required.
+        dataset = dataset.select(range(0, 1000))
+
+    if tokenizer.pad_token is None:
+        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+
+    def tokenize_add_label(sample):
+        data = tokenizer(
+            sample["text"],
+            add_special_tokens=True,
+            max_length=context_length,
+            pad_to_max_length=True,
+        )
+
+        data["labels"] = sample["label"]
+        return data
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+    return dataset
diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py
index 58344b190..c2896dcda 100644
--- a/QEfficient/finetune/utils/config_utils.py
+++ b/QEfficient/finetune/utils/config_utils.py
@@ -16,7 +16,7 @@
     PrefixTuningConfig,
 )
 from transformers import default_data_collator
-from transformers.data import DataCollatorForSeq2Seq
+from transformers.data import DataCollatorForSeq2Seq, DefaultDataCollator
 
 import QEfficient.finetune.configs.dataset_config as datasets
 from QEfficient.finetune.configs.peft_config import lora_config, prefix_config
@@ -88,7 +88,10 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
                     num_replicas=dist.get_world_size(),
                     shuffle=False,
                 )
-                kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
+                if train_config.task_type == "seq_classification":
+                    kwargs["collate_fn"] = DefaultDataCollator(dataset_processer)
+                else:
+                    kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
         else:
             kwargs["sampler"] = data_utils.DistributedSampler(
                 dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 073742739..736df949f 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -14,6 +14,7 @@
 
 import torch
 import torch.distributed as dist
+import torchmetrics
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
@@ -103,6 +104,14 @@ def train(
     if train_config.enable_ddp:
         dist.broadcast(loss_0_counter, src=0)
 
+    acc_helper = None
+    if train_config.task_type == "seq_classification":
+        if local_rank is None:
+            num_classes = model.classifier.out_features
+        else:
+            num_classes = model.module.classifier.out_features
+        acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device)
+
     # Start the training loop
     for epoch in range(train_config.num_epochs):
         if loss_0_counter.item() == train_config.convergence_counter:
@@ -181,10 +190,20 @@ def train(
                         filter_config=qaic_debug.DispatchFilterConfig.default(device),
                         dump_root_dir=train_config.dump_root_dir + str(step),
                     ) as verifier:
-                        loss = model(**batch).loss  # Forward call
+                        model_outputs = model(**batch)
+                        loss = model_outputs.loss  # Forward call
+                        if train_config.task_type == "seq_classification":
+                            logits = model_outputs.logits
+                            labels = batch["labels"]
+                            acc_helper.forward(logits, labels)
                     print("Mismatches detected:", verifier.get_perop_mismatch_count())
                 else:
-                    loss = model(**batch).loss  # Forward call
+                    model_outputs = model(**batch)
+                    loss = model_outputs.loss  # Forward call
+                    if train_config.task_type == "seq_classification":
+                        logits = model_outputs.logits
+                        labels = batch["labels"]
+                        acc_helper.forward(logits, labels)
 
             total_loss += loss.detach().float()
             # Accumalate graidents
@@ -280,7 +299,10 @@ def train(
             else:
                 train_epoch_loss = total_loss / len(train_dataloader)
 
-        train_perplexity = torch.exp(train_epoch_loss)
+        if train_config.task_type == "seq_classification":
+            train_perplexity = acc_helper.compute()
+        else:
+            train_perplexity = torch.exp(train_epoch_loss)
 
         train_prep.append(float(train_perplexity))
         train_loss.append(float(train_epoch_loss))
@@ -291,14 +313,14 @@ def train(
         if train_config.run_validation:
             if train_config.enable_ddp:
                 dist.barrier()
-                eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
+                eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper(
                     model, train_config, eval_dataloader, local_rank, tokenizer, device
                 )
                 if local_rank == 0:
                     tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
 
             else:
-                eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
+                eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper(
                     model, train_config, eval_dataloader, local_rank, tokenizer, device
                 )
                 tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
@@ -321,9 +343,14 @@ def train(
                 print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
             val_loss.append(float(eval_epoch_loss))
             val_prep.append(float(eval_ppl))
-        print(
-            f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
-        )
+        if train_config.task_type == "seq_classification":
+            print(
+                f"Epoch {epoch + 1}: train_acc={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+            )
+        else:
+            print(
+                f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+            )
 
         # Saving the results every epoch to plot later
         if train_config.save_metrics:
@@ -346,10 +373,16 @@ def train(
         avg_eval_prep = sum(val_prep) / len(val_prep)
         avg_eval_loss = sum(val_loss) / len(val_loss)
 
-    results["avg_train_prep"] = avg_train_prep
+    if train_config.task_type == "seq_classification":
+        results["avg_train_acc"] = avg_train_prep
+    else:
+        results["avg_train_prep"] = avg_train_prep
     results["avg_train_loss"] = avg_train_loss
     if train_config.run_validation:
-        results["avg_eval_prep"] = avg_eval_prep
+        if train_config.task_type == "seq_classification":
+            results["avg_eval_acc"] = avg_eval_prep
+        else:
+            results["avg_eval_prep"] = avg_eval_prep
         results["avg_eval_loss"] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
@@ -359,7 +392,7 @@ def train(
     return results
 
 
-def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, device):
+def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device):
     """
     Evaluates the model on the given dataloader
 
@@ -420,6 +453,82 @@ def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, devi
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 
+def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device):
+    """
+    Evaluates the model on the given dataloader
+
+    Args:
+        model: The model to evaluate
+        eval_dataloader: The dataloader containing the evaluation data
+        local_rank: The rank of the current node in a distributed setting
+        tokenizer: The tokenizer used to decode predictions
+
+    Returns: eval_acc, eval_epoch_loss
+    """
+    model.eval()
+    if local_rank is None:
+        num_classes = model.classifier.out_features
+    else:
+        num_classes = model.module.classifier.out_features
+
+    acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device)
+
+    # special handling for qaic device and dtype
+    # model.to(device)
+
+    # eval_preds = []
+    val_step_loss = []
+    val_step_acc = []
+
+    eval_loss = 0.0  # Initialize evaluation loss
+    total_eval_steps = 0
+    # max_steps_reached = False  # Flag to indicate max eval steps reached
+
+    for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+        total_eval_steps += 1
+        #  stop when the maximum number of eval steps is reached
+        if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
+            # max_steps_reached = True
+            break
+        for key in batch.keys():
+            batch[key] = batch[key].to(device)
+        # Ensure no gradients are computed for this scope to save memory
+        with torch.no_grad():
+            # Forward pass and compute loss
+            with (
+                torch.autocast(device_type=device, dtype=torch.float16) if train_config.use_autocast else nullcontext()
+            ):
+                outputs = model(**batch)
+            loss = outputs.loss
+            logits = outputs.logits
+            labels = batch["labels"]
+            if train_config.save_metrics:
+                val_step_loss.append(loss.detach().float().item())
+                val_acc = acc_helper.forward(logits, labels)
+                val_step_acc.append(val_acc.detach().float().item())
+
+            eval_loss += loss.detach().float()
+        # Decode predictions and add to evaluation predictions list
+        # preds = torch.argmax(outputs.logits, -1)
+        # eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True))
+
+    # Compute average loss and perplexity
+    eval_epoch_loss = eval_loss / len(eval_dataloader)
+    eval_acc = acc_helper.compute()
+
+    # Print evaluation metrics
+    print(f" {eval_acc.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
+
+    return eval_acc, eval_epoch_loss, val_step_loss, val_step_acc
+
+
+def evaluation_helper(model, train_config, eval_dataloader, local_rank, tokenizer, device):
+    if train_config.task_type == "seq_classification":
+        return evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device)
+    else:
+        return evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device)
+
+
 def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
     # find out the minimum max_seq_length required during fine-tuning (saves memory!)
     lengths = [len(d["input_ids"]) for d in data]

From 390bce817c9fe40c1ba2db4b2c72252f9381e80c Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Fri, 28 Feb 2025 13:43:43 +0530
Subject: [PATCH 02/16] Added finetuning support for BERT based models on IMDB
 dataset.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/dataset/imdb_dataset.py | 11 +++++++----
 QEfficient/finetune/utils/config_utils.py   |  2 +-
 QEfficient/finetune/utils/train_utils.py    | 13 ++++++++-----
 3 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py
index 71f67e46e..6564a067d 100644
--- a/QEfficient/finetune/dataset/imdb_dataset.py
+++ b/QEfficient/finetune/dataset/imdb_dataset.py
@@ -5,18 +5,21 @@
 #
 # -----------------------------------------------------------------------------
 
+
 import datasets
+from itertools import chain
 
 
 def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None):
     dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True)
 
-    # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data.
-    dataset = dataset.shuffle(seed=42)
-
     if split == "test":
         # Test set contains 15000 samples. Not all are required.
-        dataset = dataset.select(range(0, 1000))
+        # 0-12499 are 0 labeled samples, 12500-24999 are 1 labeled samples.
+        dataset = dataset.select(chain(range(0, 500), range(12500, 13000)))
+
+    # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data.
+    dataset = dataset.shuffle(seed=42)
 
     if tokenizer.pad_token is None:
         tokenizer.add_special_tokens({"pad_token": "[PAD]"})
diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py
index c2896dcda..93c0efbe8 100644
--- a/QEfficient/finetune/utils/config_utils.py
+++ b/QEfficient/finetune/utils/config_utils.py
@@ -89,7 +89,7 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
                     shuffle=False,
                 )
                 if train_config.task_type == "seq_classification":
-                    kwargs["collate_fn"] = DefaultDataCollator(dataset_processer)
+                    kwargs["collate_fn"] = default_data_collator
                 else:
                     kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
         else:
diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 736df949f..a5de1780e 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -110,7 +110,7 @@ def train(
             num_classes = model.classifier.out_features
         else:
             num_classes = model.module.classifier.out_features
-        acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device)
+        acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
 
     # Start the training loop
     for epoch in range(train_config.num_epochs):
@@ -195,7 +195,8 @@ def train(
                         if train_config.task_type == "seq_classification":
                             logits = model_outputs.logits
                             labels = batch["labels"]
-                            acc_helper.forward(logits, labels)
+                            preds = torch.nn.functional.softmax(logits, dim=-1)
+                            acc_helper.forward(preds, labels)
                     print("Mismatches detected:", verifier.get_perop_mismatch_count())
                 else:
                     model_outputs = model(**batch)
@@ -203,7 +204,8 @@ def train(
                     if train_config.task_type == "seq_classification":
                         logits = model_outputs.logits
                         labels = batch["labels"]
-                        acc_helper.forward(logits, labels)
+                        preds = torch.nn.functional.softmax(logits, dim=-1)
+                        acc_helper.forward(preds, labels)
 
             total_loss += loss.detach().float()
             # Accumalate graidents
@@ -471,7 +473,7 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
     else:
         num_classes = model.module.classifier.out_features
 
-    acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device)
+    acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
 
     # special handling for qaic device and dtype
     # model.to(device)
@@ -504,7 +506,8 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
             labels = batch["labels"]
             if train_config.save_metrics:
                 val_step_loss.append(loss.detach().float().item())
-                val_acc = acc_helper.forward(logits, labels)
+                preds = torch.nn.functional.softmax(logits, dim=-1)
+                val_acc = acc_helper.forward(preds, labels)
                 val_step_acc.append(val_acc.detach().float().item())
 
             eval_loss += loss.detach().float()

From a9637b14bca7fc9e8f48f4bef0db8aa32e8ff627 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Fri, 28 Feb 2025 13:43:43 +0530
Subject: [PATCH 03/16] Added finetuning support for BERT based models on IMDB
 dataset.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index a5de1780e..5183f051e 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -185,7 +185,7 @@ def train(
                         ref_device="cpu",
                         ref_dtype=torch.float32,
                         # adjust atol & rtol this as required
-                        atol=1e-1,
+                        atol=1,
                         use_ref_output_on_mismatch=True,
                         filter_config=qaic_debug.DispatchFilterConfig.default(device),
                         dump_root_dir=train_config.dump_root_dir + str(step),

From 2a09de271ca5b298af8b584ba20f7477d1719382 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Mon, 7 Apr 2025 09:44:29 +0000
Subject: [PATCH 04/16] Added torchmetrics as dependency and fixed loss
 computation for ddp case.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 12 +++++++++++-
 pyproject.toml                           |  5 +++--
 2 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 5183f051e..712ceeda9 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -301,8 +301,18 @@ def train(
             else:
                 train_epoch_loss = total_loss / len(train_dataloader)
 
+        if train_config.enable_ddp:
+            # Get the correct train loss from all the nodes.
+            dist.barrier()
+            dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
+            train_epoch_loss /= dist.get_world_size()
+            
         if train_config.task_type == "seq_classification":
-            train_perplexity = acc_helper.compute()
+            accuracy = acc_helper.compute()
+            if train_config.enable_ddp:
+                dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
+                accuracy /= dist.get_world_size()
+            train_perplexity = accuracy
         else:
             train_perplexity = torch.exp(train_epoch_loss)
 
diff --git a/pyproject.toml b/pyproject.toml
index a218ec9ee..648d2ce4e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,7 @@ license = { file = "LICENSE" }
 authors = [{ name = "Qualcomm Cloud AI ML Team" }]
 keywords = ["transformers", "Cloud AI 100", "Inference"]
 classifiers = [
-    "Programming Language :: Python :: 3", 
+    "Programming Language :: Python :: 3",
     "Development Status :: 5 - Development/Unstable",
     "Intended Audience :: Developers",
     "Intended Audience :: Education",
@@ -38,6 +38,7 @@ dependencies = [
     "tensorboard",
     "fire",
     "py7zr",
+    "torchmetrics==1.7.0",
     "torch==2.4.1; platform_machine=='aarch64'",
     # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
     "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
@@ -60,7 +61,7 @@ namespaces = false
 
 [tool.setuptools.dynamic.version]
 attr = "QEfficient.__version__"
- 
+
 [tool.ruff]
 line-length = 120
 # Enable the isort rules.

From 41766dd3e369e64e5b37e94cbff6e929e430c79e Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Mon, 7 Apr 2025 09:47:01 +0000
Subject: [PATCH 05/16] Fixed atol value.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 712ceeda9..e2c661759 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -185,7 +185,7 @@ def train(
                         ref_device="cpu",
                         ref_dtype=torch.float32,
                         # adjust atol & rtol this as required
-                        atol=1,
+                        atol=1e-1,
                         use_ref_output_on_mismatch=True,
                         filter_config=qaic_debug.DispatchFilterConfig.default(device),
                         dump_root_dir=train_config.dump_root_dir + str(step),

From 74ede951c18af541435da213a3a650ebd98b87d5 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Tue, 8 Apr 2025 10:47:33 +0000
Subject: [PATCH 06/16] Fixed collate fn for bs>1. It will work fine for bs>1
 for llama as well on single device.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/cloud/finetune.py                | 8 ++------
 QEfficient/finetune/dataset/imdb_dataset.py | 5 ++---
 QEfficient/finetune/utils/config_utils.py   | 7 +------
 QEfficient/finetune/utils/train_utils.py    | 8 ++++----
 4 files changed, 9 insertions(+), 19 deletions(-)

diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py
index 614c16180..90c9d2ec8 100644
--- a/QEfficient/cloud/finetune.py
+++ b/QEfficient/cloud/finetune.py
@@ -145,13 +145,9 @@ def main(**kwargs):
     dataset_processer = tokenizer
 
     # Load and preprocess the dataset for training and validation
-    ctx_len = train_config.context_length
-    if ctx_len is None and hasattr(model.config, "max_position_embeddings"):
-        ctx_len = model.config.max_position_embeddings
+    dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=train_config.context_length)
 
-    dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=ctx_len)
-
-    dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=ctx_len)
+    dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=train_config.context_length)
 
     # TODO: vbaddi, check if its necessary to do this?
     # dataset_train = ConcatDataset(
diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py
index 6564a067d..edfb5e523 100644
--- a/QEfficient/finetune/dataset/imdb_dataset.py
+++ b/QEfficient/finetune/dataset/imdb_dataset.py
@@ -28,11 +28,10 @@ def tokenize_add_label(sample):
         data = tokenizer(
             sample["text"],
             add_special_tokens=True,
-            max_length=context_length,
-            pad_to_max_length=True,
+            max_length=tokenizer.model_max_length,
         )
 
-        data["labels"] = sample["label"]
+        data["labels"] = [sample["label"]]
         return data
 
     dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py
index 93c0efbe8..b9b18e494 100644
--- a/QEfficient/finetune/utils/config_utils.py
+++ b/QEfficient/finetune/utils/config_utils.py
@@ -88,19 +88,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
                     num_replicas=dist.get_world_size(),
                     shuffle=False,
                 )
-                if train_config.task_type == "seq_classification":
-                    kwargs["collate_fn"] = default_data_collator
-                else:
-                    kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
         else:
             kwargs["sampler"] = data_utils.DistributedSampler(
                 dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
             )
             kwargs["batch_size"] = batch_size
             kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
     else:
         kwargs["batch_size"] = batch_size
         kwargs["drop_last"] = True
-        kwargs["collate_fn"] = default_data_collator
+    kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
     return kwargs
diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index e2c661759..13bcc0226 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -194,7 +194,7 @@ def train(
                         loss = model_outputs.loss  # Forward call
                         if train_config.task_type == "seq_classification":
                             logits = model_outputs.logits
-                            labels = batch["labels"]
+                            labels = batch["labels"][:, 0]
                             preds = torch.nn.functional.softmax(logits, dim=-1)
                             acc_helper.forward(preds, labels)
                     print("Mismatches detected:", verifier.get_perop_mismatch_count())
@@ -203,7 +203,7 @@ def train(
                     loss = model_outputs.loss  # Forward call
                     if train_config.task_type == "seq_classification":
                         logits = model_outputs.logits
-                        labels = batch["labels"]
+                        labels = batch["labels"][:, 0]
                         preds = torch.nn.functional.softmax(logits, dim=-1)
                         acc_helper.forward(preds, labels)
 
@@ -306,7 +306,7 @@ def train(
             dist.barrier()
             dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
             train_epoch_loss /= dist.get_world_size()
-            
+
         if train_config.task_type == "seq_classification":
             accuracy = acc_helper.compute()
             if train_config.enable_ddp:
@@ -513,7 +513,7 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
                 outputs = model(**batch)
             loss = outputs.loss
             logits = outputs.logits
-            labels = batch["labels"]
+            labels = batch["labels"][:, 0]
             if train_config.save_metrics:
                 val_step_loss.append(loss.detach().float().item())
                 preds = torch.nn.functional.softmax(logits, dim=-1)

From 1047357435ba854f2e38d676b1d10fa852566dda Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Wed, 9 Apr 2025 16:02:16 +0530
Subject: [PATCH 07/16] Fixed train accuracy computation

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 13bcc0226..b2eae9b9b 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -309,6 +309,7 @@ def train(
 
         if train_config.task_type == "seq_classification":
             accuracy = acc_helper.compute()
+            acc_helper.reset()
             if train_config.enable_ddp:
                 dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
                 accuracy /= dist.get_world_size()

From 4e21d3b4ac25fed9eb041dc38fd1d14d519159bc Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Mon, 14 Apr 2025 11:12:53 +0530
Subject: [PATCH 08/16] Fixed formatting issues.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/cloud/finetune.py                | 14 +++++++++-----
 QEfficient/finetune/dataset/imdb_dataset.py |  3 ++-
 QEfficient/finetune/utils/config_utils.py   |  3 +--
 3 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py
index 90c9d2ec8..e23a1e656 100644
--- a/QEfficient/cloud/finetune.py
+++ b/QEfficient/cloud/finetune.py
@@ -64,9 +64,9 @@ def main(**kwargs):
         # TODO: may have to init qccl backend, next try run with torchrun command
         torch_device = torch.device(device)
         assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
-        assert (
-            torch_device.index is None
-        ), f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
+        assert torch_device.index is None, (
+            f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
+        )
         dist.init_process_group(backend=train_config.dist_backend)
         # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
         getattr(torch, torch_device.type).set_device(dist.get_rank())
@@ -145,9 +145,13 @@ def main(**kwargs):
     dataset_processer = tokenizer
 
     # Load and preprocess the dataset for training and validation
-    dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=train_config.context_length)
+    dataset_train = get_preprocessed_dataset(
+        dataset_processer, dataset_config, split="train", context_length=train_config.context_length
+    )
 
-    dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=train_config.context_length)
+    dataset_val = get_preprocessed_dataset(
+        dataset_processer, dataset_config, split="test", context_length=train_config.context_length
+    )
 
     # TODO: vbaddi, check if its necessary to do this?
     # dataset_train = ConcatDataset(
diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py
index edfb5e523..9630f77f2 100644
--- a/QEfficient/finetune/dataset/imdb_dataset.py
+++ b/QEfficient/finetune/dataset/imdb_dataset.py
@@ -6,9 +6,10 @@
 # -----------------------------------------------------------------------------
 
 
-import datasets
 from itertools import chain
 
+import datasets
+
 
 def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None):
     dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True)
diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py
index b9b18e494..e979961d6 100644
--- a/QEfficient/finetune/utils/config_utils.py
+++ b/QEfficient/finetune/utils/config_utils.py
@@ -15,8 +15,7 @@
     LoraConfig,
     PrefixTuningConfig,
 )
-from transformers import default_data_collator
-from transformers.data import DataCollatorForSeq2Seq, DefaultDataCollator
+from transformers.data import DataCollatorForSeq2Seq
 
 import QEfficient.finetune.configs.dataset_config as datasets
 from QEfficient.finetune.configs.peft_config import lora_config, prefix_config

From 5c154aa4c29f8ef57f016198532b01db351db7d6 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Tue, 15 Apr 2025 16:55:08 +0530
Subject: [PATCH 09/16] Fixed few comments. Need to rebase first and then test

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index b2eae9b9b..1037995ba 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -301,12 +301,6 @@ def train(
             else:
                 train_epoch_loss = total_loss / len(train_dataloader)
 
-        if train_config.enable_ddp:
-            # Get the correct train loss from all the nodes.
-            dist.barrier()
-            dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
-            train_epoch_loss /= dist.get_world_size()
-
         if train_config.task_type == "seq_classification":
             accuracy = acc_helper.compute()
             acc_helper.reset()
@@ -479,10 +473,10 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer,
     Returns: eval_acc, eval_epoch_loss
     """
     model.eval()
-    if local_rank is None:
-        num_classes = model.classifier.out_features
-    else:
+    if train_config.enable_ddp:
         num_classes = model.module.classifier.out_features
+    else:
+        num_classes = model.classifier.out_features
 
     acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
 

From e97fed885e2bd481a27d1fee08f235e20b923b49 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Tue, 15 Apr 2025 18:01:28 +0530
Subject: [PATCH 10/16] Removed collate_fn for samsum dataset as
 DataCollatorForSeq2Seq works for all the use cases supported.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/dataset/dataset_config.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py
index fd9ace0af..7ae596404 100644
--- a/QEfficient/finetune/dataset/dataset_config.py
+++ b/QEfficient/finetune/dataset/dataset_config.py
@@ -38,5 +38,4 @@
 }
 DATALOADER_COLLATE_FUNC = {
     "custom_dataset": get_data_collator,
-    "samsum_dataset": get_samsum_collate_fn,
 }

From b3c707ecc5a9d7b11469793cf3627f48a24b10e4 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Tue, 15 Apr 2025 18:03:44 +0530
Subject: [PATCH 11/16] Ruff formatted code.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/dataset/dataset_config.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py
index 7ae596404..6613ad56e 100644
--- a/QEfficient/finetune/dataset/dataset_config.py
+++ b/QEfficient/finetune/dataset/dataset_config.py
@@ -24,9 +24,6 @@
 from QEfficient.finetune.dataset.samsum_dataset import (
     get_preprocessed_samsum as get_samsum_dataset,
 )
-from QEfficient.finetune.dataset.samsum_dataset import (
-    get_samsum_collate_fn,
-)
 
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset),

From 8d983f28a26df0bdc0193e76434dd9af2a1734df Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Tue, 15 Apr 2025 18:33:50 +0530
Subject: [PATCH 12/16] Removed samsum collate_fn as it is dead code.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/dataset/samsum_dataset.py | 21 -------------------
 1 file changed, 21 deletions(-)

diff --git a/QEfficient/finetune/dataset/samsum_dataset.py b/QEfficient/finetune/dataset/samsum_dataset.py
index 3bb552a39..71814599d 100644
--- a/QEfficient/finetune/dataset/samsum_dataset.py
+++ b/QEfficient/finetune/dataset/samsum_dataset.py
@@ -6,8 +6,6 @@
 # -----------------------------------------------------------------------------
 
 import datasets
-import torch
-from torch.nn.utils.rnn import pad_sequence
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -48,22 +46,3 @@ def tokenize_add_label(sample):
     dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
 
     return dataset
-
-
-def collate_fn(batch):
-    eos_token = batch[0]["input_ids"][-1]
-
-    input_ids = pad_sequence(
-        [torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
-    )
-    attn_mask = pad_sequence(
-        [torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
-    )
-    labels = pad_sequence(
-        [torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
-    )
-    return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
-
-
-def get_samsum_collate_fn(dataset_processer, dataset_config):
-    return collate_fn

From bf0f083668d0c3dc42267c00e321be2f609439f1 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Wed, 16 Apr 2025 11:47:10 +0530
Subject: [PATCH 13/16] Refactored evaluation fun and renamed variable names to
 generic names.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 211 ++++++++---------------
 1 file changed, 71 insertions(+), 140 deletions(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 1037995ba..b078e629f 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -62,9 +62,9 @@ def train(
 
     Returns: results dictionary containing average training and validation perplexity and loss
     """
-    train_prep = []
+    train_metric = []
     train_loss = []
-    val_prep = []
+    val_metric = []
     val_loss = []
 
     if train_config.save_metrics:
@@ -73,10 +73,10 @@ def train(
         metrics_filename = (
             f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
         )
-        train_step_perplexity = []
+        train_step_metric = []
         train_step_loss = []
         val_step_loss = []
-        val_step_perplexity = []
+        val_step_metric = []
 
     epoch_times = []
     checkpoint_times = []
@@ -106,10 +106,10 @@ def train(
 
     acc_helper = None
     if train_config.task_type == "seq_classification":
-        if local_rank is None:
-            num_classes = model.classifier.out_features
-        else:
+        if train_config.enable_ddp:
             num_classes = model.module.classifier.out_features
+        else:
+            num_classes = model.classifier.out_features
         acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
 
     # Start the training loop
@@ -231,7 +231,11 @@ def train(
 
             if train_config.save_metrics:
                 train_step_loss.append(loss.detach().float().item())
-                train_step_perplexity.append(float(torch.exp(loss.detach().float())))
+                if train_config.task_type == "seq_classification":
+                    step_metric_val = acc_helper.compute()
+                else:
+                    step_metric_val = float(torch.exp(loss.detach().float()))
+                train_step_metric.append(step_metric_val)
 
             if train_config.grad_scaler:
                 scaler.scale(loss).backward()  # backward pass
@@ -266,12 +270,12 @@ def train(
                     metrics_filename,
                     train_step_loss,
                     train_loss,
-                    train_step_perplexity,
-                    train_prep,
+                    train_step_metric,
+                    train_metric,
                     val_step_loss,
                     val_loss,
-                    val_step_perplexity,
-                    val_prep,
+                    val_step_metric,
+                    val_metric,
                 )
             if train_config.enable_ddp:
                 if loss_0_counter.item() == train_config.convergence_counter:
@@ -307,11 +311,11 @@ def train(
             if train_config.enable_ddp:
                 dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
                 accuracy /= dist.get_world_size()
-            train_perplexity = accuracy
+            train_metric = accuracy
         else:
-            train_perplexity = torch.exp(train_epoch_loss)
+            train_metric = torch.exp(train_epoch_loss)
 
-        train_prep.append(float(train_perplexity))
+        train_metric.append(float(train_metric))
         train_loss.append(float(train_epoch_loss))
 
         # Update the learning rate as needed
@@ -320,21 +324,21 @@ def train(
         if train_config.run_validation:
             if train_config.enable_ddp:
                 dist.barrier()
-                eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper(
-                    model, train_config, eval_dataloader, local_rank, tokenizer, device
+                eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
+                    model, train_config, eval_dataloader, device
                 )
                 if local_rank == 0:
                     tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
 
             else:
-                eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper(
-                    model, train_config, eval_dataloader, local_rank, tokenizer, device
+                eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
+                    model, train_config, eval_dataloader, device
                 )
                 tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
 
             if train_config.save_metrics:
                 val_step_loss.extend(temp_val_loss)
-                val_step_perplexity.extend(temp_step_perplexity)
+                val_step_metric.extend(temp_step_metric)
 
         # saving the adapters after completion of each epoch
         if train_config.save_model:
@@ -349,14 +353,14 @@ def train(
                 best_val_loss = eval_epoch_loss
                 print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
             val_loss.append(float(eval_epoch_loss))
-            val_prep.append(float(eval_ppl))
+            val_metric.append(float(eval_metric))
         if train_config.task_type == "seq_classification":
             print(
-                f"Epoch {epoch + 1}: train_acc={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+                f"Epoch {epoch + 1}: train_acc={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
             )
         else:
             print(
-                f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+                f"Epoch {epoch + 1}: train_metric={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
             )
 
         # Saving the results every epoch to plot later
@@ -365,31 +369,25 @@ def train(
                 metrics_filename,
                 train_step_loss,
                 train_loss,
-                train_step_perplexity,
-                train_prep,
+                train_step_metric,
+                train_metric,
                 val_step_loss,
                 val_loss,
-                val_step_perplexity,
-                val_prep,
+                val_step_metric,
+                val_metric,
             )
     avg_epoch_time = sum(epoch_times) / len(epoch_times)
     avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
-    avg_train_prep = sum(train_prep) / len(train_prep)
+    avg_train_metric = sum(train_metric) / len(train_metric)
     avg_train_loss = sum(train_loss) / len(train_loss)
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep) / len(val_prep)
+        avg_eval_metric = sum(val_metric) / len(val_metric)
         avg_eval_loss = sum(val_loss) / len(val_loss)
 
-    if train_config.task_type == "seq_classification":
-        results["avg_train_acc"] = avg_train_prep
-    else:
-        results["avg_train_prep"] = avg_train_prep
+    results["avg_train_metric"] = avg_train_metric
     results["avg_train_loss"] = avg_train_loss
     if train_config.run_validation:
-        if train_config.task_type == "seq_classification":
-            results["avg_eval_acc"] = avg_eval_prep
-        else:
-            results["avg_eval_prep"] = avg_eval_prep
+        results["avg_eval_metric"] = avg_eval_metric
         results["avg_eval_loss"] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
@@ -399,39 +397,40 @@ def train(
     return results
 
 
-def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device):
+def evaluation_helper(model, train_config, eval_dataloader, device):
     """
     Evaluates the model on the given dataloader
 
     Args:
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
-        local_rank: The rank of the current node in a distributed setting
-        tokenizer: The tokenizer used to decode predictions
 
-    Returns: eval_ppl, eval_epoch_loss
+    Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
     """
     model.eval()
 
+    if train_config.task_type == "seq_classification":
+        if train_config.enable_ddp:
+            num_classes = model.module.classifier.out_features
+        else:
+            num_classes = model.classifier.out_features
+        acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
+
     # special handling for qaic device and dtype
     # model.to(device)
 
-    eval_preds = []
     val_step_loss = []
-    val_step_perplexity = []
+    val_step_metric = []
 
     eval_loss = 0.0  # Initialize evaluation loss
-    total_eval_steps = 0
-    # max_steps_reached = False  # Flag to indicate max eval steps reached
 
     for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
-        total_eval_steps += 1
         #  stop when the maximum number of eval steps is reached
-        if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
-            # max_steps_reached = True
+        if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
             break
         for key in batch.keys():
             batch[key] = batch[key].to(device)
+
         # Ensure no gradients are computed for this scope to save memory
         with torch.no_grad():
             # Forward pass and compute loss
@@ -441,100 +440,32 @@ def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer,
                 outputs = model(**batch)
             loss = outputs.loss
 
-            if train_config.save_metrics:
-                val_step_loss.append(loss.detach().float().item())
-                val_step_perplexity.append(float(torch.exp(loss.detach().float())))
-
-            eval_loss += loss.detach().float()
-        # Decode predictions and add to evaluation predictions list
-        preds = torch.argmax(outputs.logits, -1)
-        eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True))
-
-    # Compute average loss and perplexity
-    eval_epoch_loss = eval_loss / len(eval_dataloader)
-    eval_ppl = torch.exp(eval_epoch_loss)
-
-    # Print evaluation metrics
-    print(f" {eval_ppl.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
-
-    return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
-
-
-def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device):
-    """
-    Evaluates the model on the given dataloader
-
-    Args:
-        model: The model to evaluate
-        eval_dataloader: The dataloader containing the evaluation data
-        local_rank: The rank of the current node in a distributed setting
-        tokenizer: The tokenizer used to decode predictions
-
-    Returns: eval_acc, eval_epoch_loss
-    """
-    model.eval()
-    if train_config.enable_ddp:
-        num_classes = model.module.classifier.out_features
-    else:
-        num_classes = model.classifier.out_features
-
-    acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
-
-    # special handling for qaic device and dtype
-    # model.to(device)
-
-    # eval_preds = []
-    val_step_loss = []
-    val_step_acc = []
-
-    eval_loss = 0.0  # Initialize evaluation loss
-    total_eval_steps = 0
-    # max_steps_reached = False  # Flag to indicate max eval steps reached
+            if train_config.task_type == "seq_classification":
+                logits = outputs.logits
+                labels = batch["labels"][:, 0]
+                preds = torch.nn.functional.softmax(logits, dim=-1)
+                val_acc = acc_helper.forward(preds, labels)
+                metric_val = val_acc.detach().float().item()
+            else:
+                metric_val = float(torch.exp(loss.detach().float()))
 
-    for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
-        total_eval_steps += 1
-        #  stop when the maximum number of eval steps is reached
-        if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
-            # max_steps_reached = True
-            break
-        for key in batch.keys():
-            batch[key] = batch[key].to(device)
-        # Ensure no gradients are computed for this scope to save memory
-        with torch.no_grad():
-            # Forward pass and compute loss
-            with (
-                torch.autocast(device_type=device, dtype=torch.float16) if train_config.use_autocast else nullcontext()
-            ):
-                outputs = model(**batch)
-            loss = outputs.loss
-            logits = outputs.logits
-            labels = batch["labels"][:, 0]
             if train_config.save_metrics:
                 val_step_loss.append(loss.detach().float().item())
-                preds = torch.nn.functional.softmax(logits, dim=-1)
-                val_acc = acc_helper.forward(preds, labels)
-                val_step_acc.append(val_acc.detach().float().item())
+                val_step_metric.append(metric_val)
 
             eval_loss += loss.detach().float()
-        # Decode predictions and add to evaluation predictions list
-        # preds = torch.argmax(outputs.logits, -1)
-        # eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True))
 
-    # Compute average loss and perplexity
+    # Compute average loss and metric
     eval_epoch_loss = eval_loss / len(eval_dataloader)
-    eval_acc = acc_helper.compute()
+    if train_config.task_type == "seq_classification":
+        eval_metric = acc_helper.compute()
+    else:
+        eval_metric = torch.exp(eval_epoch_loss)
 
     # Print evaluation metrics
-    print(f" {eval_acc.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
+    print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
 
-    return eval_acc, eval_epoch_loss, val_step_loss, val_step_acc
-
-
-def evaluation_helper(model, train_config, eval_dataloader, local_rank, tokenizer, device):
-    if train_config.task_type == "seq_classification":
-        return evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device)
-    else:
-        return evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device)
+    return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric
 
 
 def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
@@ -571,22 +502,22 @@ def save_to_json(
     output_filename,
     train_step_loss,
     train_epoch_loss,
-    train_step_ppl,
-    train_epoch_ppl,
+    train_step_metric,
+    train_epoch_metric,
     val_step_loss,
     val_epoch_loss,
-    val_step_ppl,
-    val_epoch_ppl,
+    val_step_metric,
+    val_epoch_metric,
 ):
     metrics_data = {
         "train_step_loss": train_step_loss,
         "train_epoch_loss": train_epoch_loss,
-        "train_step_perplexity": train_step_ppl,
-        "train_epoch_perplexity": train_epoch_ppl,
+        "train_step_metric": train_step_metric,
+        "train_epoch_metric": train_epoch_metric,
         "val_step_loss": val_step_loss,
         "val_epoch_loss": val_epoch_loss,
-        "val_step_perplexity": val_step_ppl,
-        "val_epoch_perplexity": val_epoch_ppl,
+        "val_step_metric": val_step_metric,
+        "val_epoch_metric": val_epoch_metric,
     }
     with open(output_filename, "w") as f:
         json.dump(metrics_data, f)

From 9ef972fc750d29e050e60efd59719397126bab60 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Wed, 16 Apr 2025 12:13:31 +0530
Subject: [PATCH 14/16] Fixed metric value for json logging

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index b078e629f..3f8fb4df7 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -232,7 +232,7 @@ def train(
             if train_config.save_metrics:
                 train_step_loss.append(loss.detach().float().item())
                 if train_config.task_type == "seq_classification":
-                    step_metric_val = acc_helper.compute()
+                    step_metric_val = float(acc_helper.compute())
                 else:
                     step_metric_val = float(torch.exp(loss.detach().float()))
                 train_step_metric.append(step_metric_val)

From 2713fdc8158b0d08d2b420d1592a5f278fe65cf1 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Thu, 17 Apr 2025 10:20:40 +0530
Subject: [PATCH 15/16] Minor fixes to variable names

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/finetune/utils/train_utils.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 3f8fb4df7..84379967a 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -311,11 +311,11 @@ def train(
             if train_config.enable_ddp:
                 dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
                 accuracy /= dist.get_world_size()
-            train_metric = accuracy
+            metric_val = accuracy
         else:
-            train_metric = torch.exp(train_epoch_loss)
+            metric_val = torch.exp(train_epoch_loss)
 
-        train_metric.append(float(train_metric))
+        train_metric.append(float(metric_val))
         train_loss.append(float(train_epoch_loss))
 
         # Update the learning rate as needed
@@ -356,11 +356,11 @@ def train(
             val_metric.append(float(eval_metric))
         if train_config.task_type == "seq_classification":
             print(
-                f"Epoch {epoch + 1}: train_acc={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+                f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
             )
         else:
             print(
-                f"Epoch {epoch + 1}: train_metric={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+                f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
             )
 
         # Saving the results every epoch to plot later

From b6d0ce9ddcf6136b1bd67746c58673bbd3a50381 Mon Sep 17 00:00:00 2001
From: Meet Patel <quic_meetkuma@quicinc.com>
Date: Thu, 17 Apr 2025 11:20:10 +0530
Subject: [PATCH 16/16] Addressed few offline comments.

Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com>
---
 QEfficient/cloud/finetune.py             | 4 ++++
 QEfficient/finetune/utils/train_utils.py | 6 +-----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py
index e23a1e656..f312d00cb 100644
--- a/QEfficient/cloud/finetune.py
+++ b/QEfficient/cloud/finetune.py
@@ -86,6 +86,10 @@ def main(**kwargs):
             attn_implementation="sdpa",
             torch_dtype=torch.float16,
         )
+
+        if not hasattr(model, "base_model_prefix"):
+            raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
+
         for param in getattr(model, model.base_model_prefix).parameters():
             param.requires_grad = False
 
diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py
index 84379967a..2bc701008 100644
--- a/QEfficient/finetune/utils/train_utils.py
+++ b/QEfficient/finetune/utils/train_utils.py
@@ -306,12 +306,8 @@ def train(
                 train_epoch_loss = total_loss / len(train_dataloader)
 
         if train_config.task_type == "seq_classification":
-            accuracy = acc_helper.compute()
+            metric_val = acc_helper.compute()
             acc_helper.reset()
-            if train_config.enable_ddp:
-                dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
-                accuracy /= dist.get_world_size()
-            metric_val = accuracy
         else:
             metric_val = torch.exp(train_epoch_loss)