From 214e85c6371e875122edf5403cbdb457229f22c0 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Fri, 28 Feb 2025 13:43:43 +0530 Subject: [PATCH 01/16] Added finetuning support for BERT based models on IMDB dataset. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/cloud/finetune.py | 48 ++++--- QEfficient/finetune/configs/dataset_config.py | 8 ++ QEfficient/finetune/configs/training.py | 1 + QEfficient/finetune/dataset/dataset_config.py | 4 + QEfficient/finetune/dataset/imdb_dataset.py | 36 +++++ QEfficient/finetune/utils/config_utils.py | 7 +- QEfficient/finetune/utils/train_utils.py | 131 ++++++++++++++++-- 7 files changed, 205 insertions(+), 30 deletions(-) create mode 100644 QEfficient/finetune/dataset/imdb_dataset.py diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index c7525d2db..614c16180 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -38,7 +38,7 @@ print(f"Warning: {e}. Moving ahead without these qaic modules.") -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer # Suppress all warnings warnings.filterwarnings("ignore") @@ -56,6 +56,7 @@ def main(**kwargs): # update the configuration for the training process train_config = TRAIN_CONFIG() update_config(train_config, **kwargs) + dataset_config = generate_dataset_config(train_config, kwargs) device = train_config.device # dist init @@ -63,9 +64,9 @@ def main(**kwargs): # TODO: may have to init qccl backend, next try run with torchrun command torch_device = torch.device(device) assert torch_device.type != "cpu", "Host doesn't support single-node DDP" - assert torch_device.index is None, ( - f"DDP requires specification of device type only, however provided device index as well: {torch_device}" - ) + assert ( + torch_device.index is None + ), f"DDP requires specification of device type only, however provided device index as well: {torch_device}" dist.init_process_group(backend=train_config.dist_backend) # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank getattr(torch, torch_device.type).set_device(dist.get_rank()) @@ -78,12 +79,26 @@ def main(**kwargs): # Load the pre-trained model and setup its configuration # config = AutoConfig.from_pretrained(train_config.model_name) pretrained_model_path = login_and_download_hf_lm(train_config.model_name) - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_path, - use_cache=False, - attn_implementation="sdpa", - torch_dtype=torch.float16, - ) + if train_config.task_type == "seq_classification": + model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_path, + num_labels=dataset_config.num_labels, + attn_implementation="sdpa", + torch_dtype=torch.float16, + ) + for param in getattr(model, model.base_model_prefix).parameters(): + param.requires_grad = False + + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.float32) + else: + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_path, + use_cache=False, + attn_implementation="sdpa", + torch_dtype=torch.float16, + ) # Load the tokenizer and add special tokens tokenizer = AutoTokenizer.from_pretrained( @@ -127,17 +142,16 @@ def main(**kwargs): model.print_trainable_parameters() # Get the dataset utils - dataset_config = generate_dataset_config(train_config, kwargs) dataset_processer = tokenizer # Load and preprocess the dataset for training and validation - dataset_train = get_preprocessed_dataset( - dataset_processer, dataset_config, split="train", context_length=train_config.context_length - ) + ctx_len = train_config.context_length + if ctx_len is None and hasattr(model.config, "max_position_embeddings"): + ctx_len = model.config.max_position_embeddings - dataset_val = get_preprocessed_dataset( - dataset_processer, dataset_config, split="test", context_length=train_config.context_length - ) + dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=ctx_len) + + dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=ctx_len) # TODO: vbaddi, check if its necessary to do this? # dataset_train = ConcatDataset( diff --git a/QEfficient/finetune/configs/dataset_config.py b/QEfficient/finetune/configs/dataset_config.py index 2e7fb56fb..fd1081983 100644 --- a/QEfficient/finetune/configs/dataset_config.py +++ b/QEfficient/finetune/configs/dataset_config.py @@ -37,6 +37,14 @@ class gsm8k_dataset: test_split: str = "test" +@dataclass +class imdb_dataset: + dataset: str = "imdb_dataset" + train_split: str = "train" + test_split: str = "test" + num_labels: int = 2 + + @dataclass class custom_dataset: dataset: str = "custom_dataset" diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index 9010965a9..c50954c4c 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -29,6 +29,7 @@ class train_config: use_autocast: bool = True val_batch_size: int = 1 dataset = "samsum_dataset" + task_type = "generation" # "generation" / "seq_classification" peft_method: str = "lora" use_peft: bool = True # use parameter efficient fine tuning from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py index ac086d272..fd9ace0af 100644 --- a/QEfficient/finetune/dataset/dataset_config.py +++ b/QEfficient/finetune/dataset/dataset_config.py @@ -18,6 +18,9 @@ get_dataset as get_grammar_dataset, ) from QEfficient.finetune.dataset.gsm8k_dataset import get_gsm8k_dataset +from QEfficient.finetune.dataset.imdb_dataset import ( + get_preprocessed_imdb as get_imdb_dataset, +) from QEfficient.finetune.dataset.samsum_dataset import ( get_preprocessed_samsum as get_samsum_dataset, ) @@ -31,6 +34,7 @@ "samsum_dataset": get_samsum_dataset, "gsm8k_dataset": get_gsm8k_dataset, "custom_dataset": get_custom_dataset, + "imdb_dataset": get_imdb_dataset, } DATALOADER_COLLATE_FUNC = { "custom_dataset": get_data_collator, diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py new file mode 100644 index 000000000..71f67e46e --- /dev/null +++ b/QEfficient/finetune/dataset/imdb_dataset.py @@ -0,0 +1,36 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import datasets + + +def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None): + dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True) + + # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data. + dataset = dataset.shuffle(seed=42) + + if split == "test": + # Test set contains 15000 samples. Not all are required. + dataset = dataset.select(range(0, 1000)) + + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + def tokenize_add_label(sample): + data = tokenizer( + sample["text"], + add_special_tokens=True, + max_length=context_length, + pad_to_max_length=True, + ) + + data["labels"] = sample["label"] + return data + + dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) + return dataset diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index 58344b190..c2896dcda 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -16,7 +16,7 @@ PrefixTuningConfig, ) from transformers import default_data_collator -from transformers.data import DataCollatorForSeq2Seq +from transformers.data import DataCollatorForSeq2Seq, DefaultDataCollator import QEfficient.finetune.configs.dataset_config as datasets from QEfficient.finetune.configs.peft_config import lora_config, prefix_config @@ -88,7 +88,10 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): num_replicas=dist.get_world_size(), shuffle=False, ) - kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) + if train_config.task_type == "seq_classification": + kwargs["collate_fn"] = DefaultDataCollator(dataset_processer) + else: + kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) else: kwargs["sampler"] = data_utils.DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 073742739..736df949f 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -14,6 +14,7 @@ import torch import torch.distributed as dist +import torchmetrics from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm @@ -103,6 +104,14 @@ def train( if train_config.enable_ddp: dist.broadcast(loss_0_counter, src=0) + acc_helper = None + if train_config.task_type == "seq_classification": + if local_rank is None: + num_classes = model.classifier.out_features + else: + num_classes = model.module.classifier.out_features + acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device) + # Start the training loop for epoch in range(train_config.num_epochs): if loss_0_counter.item() == train_config.convergence_counter: @@ -181,10 +190,20 @@ def train( filter_config=qaic_debug.DispatchFilterConfig.default(device), dump_root_dir=train_config.dump_root_dir + str(step), ) as verifier: - loss = model(**batch).loss # Forward call + model_outputs = model(**batch) + loss = model_outputs.loss # Forward call + if train_config.task_type == "seq_classification": + logits = model_outputs.logits + labels = batch["labels"] + acc_helper.forward(logits, labels) print("Mismatches detected:", verifier.get_perop_mismatch_count()) else: - loss = model(**batch).loss # Forward call + model_outputs = model(**batch) + loss = model_outputs.loss # Forward call + if train_config.task_type == "seq_classification": + logits = model_outputs.logits + labels = batch["labels"] + acc_helper.forward(logits, labels) total_loss += loss.detach().float() # Accumalate graidents @@ -280,7 +299,10 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) - train_perplexity = torch.exp(train_epoch_loss) + if train_config.task_type == "seq_classification": + train_perplexity = acc_helper.compute() + else: + train_perplexity = torch.exp(train_epoch_loss) train_prep.append(float(train_perplexity)) train_loss.append(float(train_epoch_loss)) @@ -291,14 +313,14 @@ def train( if train_config.run_validation: if train_config.enable_ddp: dist.barrier() - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( + eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper( model, train_config, eval_dataloader, local_rank, tokenizer, device ) if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) else: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( + eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper( model, train_config, eval_dataloader, local_rank, tokenizer, device ) tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) @@ -321,9 +343,14 @@ def train( print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}") val_loss.append(float(eval_epoch_loss)) val_prep.append(float(eval_ppl)) - print( - f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" - ) + if train_config.task_type == "seq_classification": + print( + f"Epoch {epoch + 1}: train_acc={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) + else: + print( + f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) # Saving the results every epoch to plot later if train_config.save_metrics: @@ -346,10 +373,16 @@ def train( avg_eval_prep = sum(val_prep) / len(val_prep) avg_eval_loss = sum(val_loss) / len(val_loss) - results["avg_train_prep"] = avg_train_prep + if train_config.task_type == "seq_classification": + results["avg_train_acc"] = avg_train_prep + else: + results["avg_train_prep"] = avg_train_prep results["avg_train_loss"] = avg_train_loss if train_config.run_validation: - results["avg_eval_prep"] = avg_eval_prep + if train_config.task_type == "seq_classification": + results["avg_eval_acc"] = avg_eval_prep + else: + results["avg_eval_prep"] = avg_eval_prep results["avg_eval_loss"] = avg_eval_loss results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time @@ -359,7 +392,7 @@ def train( return results -def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, device): +def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device): """ Evaluates the model on the given dataloader @@ -420,6 +453,82 @@ def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, devi return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity +def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device): + """ + Evaluates the model on the given dataloader + + Args: + model: The model to evaluate + eval_dataloader: The dataloader containing the evaluation data + local_rank: The rank of the current node in a distributed setting + tokenizer: The tokenizer used to decode predictions + + Returns: eval_acc, eval_epoch_loss + """ + model.eval() + if local_rank is None: + num_classes = model.classifier.out_features + else: + num_classes = model.module.classifier.out_features + + acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device) + + # special handling for qaic device and dtype + # model.to(device) + + # eval_preds = [] + val_step_loss = [] + val_step_acc = [] + + eval_loss = 0.0 # Initialize evaluation loss + total_eval_steps = 0 + # max_steps_reached = False # Flag to indicate max eval steps reached + + for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + total_eval_steps += 1 + # stop when the maximum number of eval steps is reached + if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: + # max_steps_reached = True + break + for key in batch.keys(): + batch[key] = batch[key].to(device) + # Ensure no gradients are computed for this scope to save memory + with torch.no_grad(): + # Forward pass and compute loss + with ( + torch.autocast(device_type=device, dtype=torch.float16) if train_config.use_autocast else nullcontext() + ): + outputs = model(**batch) + loss = outputs.loss + logits = outputs.logits + labels = batch["labels"] + if train_config.save_metrics: + val_step_loss.append(loss.detach().float().item()) + val_acc = acc_helper.forward(logits, labels) + val_step_acc.append(val_acc.detach().float().item()) + + eval_loss += loss.detach().float() + # Decode predictions and add to evaluation predictions list + # preds = torch.argmax(outputs.logits, -1) + # eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)) + + # Compute average loss and perplexity + eval_epoch_loss = eval_loss / len(eval_dataloader) + eval_acc = acc_helper.compute() + + # Print evaluation metrics + print(f" {eval_acc.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") + + return eval_acc, eval_epoch_loss, val_step_loss, val_step_acc + + +def evaluation_helper(model, train_config, eval_dataloader, local_rank, tokenizer, device): + if train_config.task_type == "seq_classification": + return evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device) + else: + return evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device) + + def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] From 390bce817c9fe40c1ba2db4b2c72252f9381e80c Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Fri, 28 Feb 2025 13:43:43 +0530 Subject: [PATCH 02/16] Added finetuning support for BERT based models on IMDB dataset. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/dataset/imdb_dataset.py | 11 +++++++---- QEfficient/finetune/utils/config_utils.py | 2 +- QEfficient/finetune/utils/train_utils.py | 13 ++++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py index 71f67e46e..6564a067d 100644 --- a/QEfficient/finetune/dataset/imdb_dataset.py +++ b/QEfficient/finetune/dataset/imdb_dataset.py @@ -5,18 +5,21 @@ # # ----------------------------------------------------------------------------- + import datasets +from itertools import chain def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None): dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True) - # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data. - dataset = dataset.shuffle(seed=42) - if split == "test": # Test set contains 15000 samples. Not all are required. - dataset = dataset.select(range(0, 1000)) + # 0-12499 are 0 labeled samples, 12500-24999 are 1 labeled samples. + dataset = dataset.select(chain(range(0, 500), range(12500, 13000))) + + # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data. + dataset = dataset.shuffle(seed=42) if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index c2896dcda..93c0efbe8 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -89,7 +89,7 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): shuffle=False, ) if train_config.task_type == "seq_classification": - kwargs["collate_fn"] = DefaultDataCollator(dataset_processer) + kwargs["collate_fn"] = default_data_collator else: kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) else: diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 736df949f..a5de1780e 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -110,7 +110,7 @@ def train( num_classes = model.classifier.out_features else: num_classes = model.module.classifier.out_features - acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device) + acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) # Start the training loop for epoch in range(train_config.num_epochs): @@ -195,7 +195,8 @@ def train( if train_config.task_type == "seq_classification": logits = model_outputs.logits labels = batch["labels"] - acc_helper.forward(logits, labels) + preds = torch.nn.functional.softmax(logits, dim=-1) + acc_helper.forward(preds, labels) print("Mismatches detected:", verifier.get_perop_mismatch_count()) else: model_outputs = model(**batch) @@ -203,7 +204,8 @@ def train( if train_config.task_type == "seq_classification": logits = model_outputs.logits labels = batch["labels"] - acc_helper.forward(logits, labels) + preds = torch.nn.functional.softmax(logits, dim=-1) + acc_helper.forward(preds, labels) total_loss += loss.detach().float() # Accumalate graidents @@ -471,7 +473,7 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, else: num_classes = model.module.classifier.out_features - acc_helper = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device) + acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) # special handling for qaic device and dtype # model.to(device) @@ -504,7 +506,8 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, labels = batch["labels"] if train_config.save_metrics: val_step_loss.append(loss.detach().float().item()) - val_acc = acc_helper.forward(logits, labels) + preds = torch.nn.functional.softmax(logits, dim=-1) + val_acc = acc_helper.forward(preds, labels) val_step_acc.append(val_acc.detach().float().item()) eval_loss += loss.detach().float() From a9637b14bca7fc9e8f48f4bef0db8aa32e8ff627 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Fri, 28 Feb 2025 13:43:43 +0530 Subject: [PATCH 03/16] Added finetuning support for BERT based models on IMDB dataset. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index a5de1780e..5183f051e 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -185,7 +185,7 @@ def train( ref_device="cpu", ref_dtype=torch.float32, # adjust atol & rtol this as required - atol=1e-1, + atol=1, use_ref_output_on_mismatch=True, filter_config=qaic_debug.DispatchFilterConfig.default(device), dump_root_dir=train_config.dump_root_dir + str(step), From 2a09de271ca5b298af8b584ba20f7477d1719382 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Mon, 7 Apr 2025 09:44:29 +0000 Subject: [PATCH 04/16] Added torchmetrics as dependency and fixed loss computation for ddp case. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 12 +++++++++++- pyproject.toml | 5 +++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 5183f051e..712ceeda9 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -301,8 +301,18 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) + if train_config.enable_ddp: + # Get the correct train loss from all the nodes. + dist.barrier() + dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) + train_epoch_loss /= dist.get_world_size() + if train_config.task_type == "seq_classification": - train_perplexity = acc_helper.compute() + accuracy = acc_helper.compute() + if train_config.enable_ddp: + dist.all_reduce(accuracy, op=dist.ReduceOp.SUM) + accuracy /= dist.get_world_size() + train_perplexity = accuracy else: train_perplexity = torch.exp(train_epoch_loss) diff --git a/pyproject.toml b/pyproject.toml index a218ec9ee..648d2ce4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = { file = "LICENSE" } authors = [{ name = "Qualcomm Cloud AI ML Team" }] keywords = ["transformers", "Cloud AI 100", "Inference"] classifiers = [ - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3", "Development Status :: 5 - Development/Unstable", "Intended Audience :: Developers", "Intended Audience :: Education", @@ -38,6 +38,7 @@ dependencies = [ "tensorboard", "fire", "py7zr", + "torchmetrics==1.7.0", "torch==2.4.1; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", @@ -60,7 +61,7 @@ namespaces = false [tool.setuptools.dynamic.version] attr = "QEfficient.__version__" - + [tool.ruff] line-length = 120 # Enable the isort rules. From 41766dd3e369e64e5b37e94cbff6e929e430c79e Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Mon, 7 Apr 2025 09:47:01 +0000 Subject: [PATCH 05/16] Fixed atol value. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 712ceeda9..e2c661759 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -185,7 +185,7 @@ def train( ref_device="cpu", ref_dtype=torch.float32, # adjust atol & rtol this as required - atol=1, + atol=1e-1, use_ref_output_on_mismatch=True, filter_config=qaic_debug.DispatchFilterConfig.default(device), dump_root_dir=train_config.dump_root_dir + str(step), From 74ede951c18af541435da213a3a650ebd98b87d5 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Tue, 8 Apr 2025 10:47:33 +0000 Subject: [PATCH 06/16] Fixed collate fn for bs>1. It will work fine for bs>1 for llama as well on single device. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/cloud/finetune.py | 8 ++------ QEfficient/finetune/dataset/imdb_dataset.py | 5 ++--- QEfficient/finetune/utils/config_utils.py | 7 +------ QEfficient/finetune/utils/train_utils.py | 8 ++++---- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 614c16180..90c9d2ec8 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -145,13 +145,9 @@ def main(**kwargs): dataset_processer = tokenizer # Load and preprocess the dataset for training and validation - ctx_len = train_config.context_length - if ctx_len is None and hasattr(model.config, "max_position_embeddings"): - ctx_len = model.config.max_position_embeddings + dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=train_config.context_length) - dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=ctx_len) - - dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=ctx_len) + dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=train_config.context_length) # TODO: vbaddi, check if its necessary to do this? # dataset_train = ConcatDataset( diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py index 6564a067d..edfb5e523 100644 --- a/QEfficient/finetune/dataset/imdb_dataset.py +++ b/QEfficient/finetune/dataset/imdb_dataset.py @@ -28,11 +28,10 @@ def tokenize_add_label(sample): data = tokenizer( sample["text"], add_special_tokens=True, - max_length=context_length, - pad_to_max_length=True, + max_length=tokenizer.model_max_length, ) - data["labels"] = sample["label"] + data["labels"] = [sample["label"]] return data dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index 93c0efbe8..b9b18e494 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -88,19 +88,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): num_replicas=dist.get_world_size(), shuffle=False, ) - if train_config.task_type == "seq_classification": - kwargs["collate_fn"] = default_data_collator - else: - kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) else: kwargs["sampler"] = data_utils.DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True ) kwargs["batch_size"] = batch_size kwargs["drop_last"] = True - kwargs["collate_fn"] = default_data_collator else: kwargs["batch_size"] = batch_size kwargs["drop_last"] = True - kwargs["collate_fn"] = default_data_collator + kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) return kwargs diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index e2c661759..13bcc0226 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -194,7 +194,7 @@ def train( loss = model_outputs.loss # Forward call if train_config.task_type == "seq_classification": logits = model_outputs.logits - labels = batch["labels"] + labels = batch["labels"][:, 0] preds = torch.nn.functional.softmax(logits, dim=-1) acc_helper.forward(preds, labels) print("Mismatches detected:", verifier.get_perop_mismatch_count()) @@ -203,7 +203,7 @@ def train( loss = model_outputs.loss # Forward call if train_config.task_type == "seq_classification": logits = model_outputs.logits - labels = batch["labels"] + labels = batch["labels"][:, 0] preds = torch.nn.functional.softmax(logits, dim=-1) acc_helper.forward(preds, labels) @@ -306,7 +306,7 @@ def train( dist.barrier() dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) train_epoch_loss /= dist.get_world_size() - + if train_config.task_type == "seq_classification": accuracy = acc_helper.compute() if train_config.enable_ddp: @@ -513,7 +513,7 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, outputs = model(**batch) loss = outputs.loss logits = outputs.logits - labels = batch["labels"] + labels = batch["labels"][:, 0] if train_config.save_metrics: val_step_loss.append(loss.detach().float().item()) preds = torch.nn.functional.softmax(logits, dim=-1) From 1047357435ba854f2e38d676b1d10fa852566dda Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Wed, 9 Apr 2025 16:02:16 +0530 Subject: [PATCH 07/16] Fixed train accuracy computation Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 13bcc0226..b2eae9b9b 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -309,6 +309,7 @@ def train( if train_config.task_type == "seq_classification": accuracy = acc_helper.compute() + acc_helper.reset() if train_config.enable_ddp: dist.all_reduce(accuracy, op=dist.ReduceOp.SUM) accuracy /= dist.get_world_size() From 4e21d3b4ac25fed9eb041dc38fd1d14d519159bc Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Mon, 14 Apr 2025 11:12:53 +0530 Subject: [PATCH 08/16] Fixed formatting issues. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/cloud/finetune.py | 14 +++++++++----- QEfficient/finetune/dataset/imdb_dataset.py | 3 ++- QEfficient/finetune/utils/config_utils.py | 3 +-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 90c9d2ec8..e23a1e656 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -64,9 +64,9 @@ def main(**kwargs): # TODO: may have to init qccl backend, next try run with torchrun command torch_device = torch.device(device) assert torch_device.type != "cpu", "Host doesn't support single-node DDP" - assert ( - torch_device.index is None - ), f"DDP requires specification of device type only, however provided device index as well: {torch_device}" + assert torch_device.index is None, ( + f"DDP requires specification of device type only, however provided device index as well: {torch_device}" + ) dist.init_process_group(backend=train_config.dist_backend) # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank getattr(torch, torch_device.type).set_device(dist.get_rank()) @@ -145,9 +145,13 @@ def main(**kwargs): dataset_processer = tokenizer # Load and preprocess the dataset for training and validation - dataset_train = get_preprocessed_dataset(dataset_processer, dataset_config, split="train", context_length=train_config.context_length) + dataset_train = get_preprocessed_dataset( + dataset_processer, dataset_config, split="train", context_length=train_config.context_length + ) - dataset_val = get_preprocessed_dataset(dataset_processer, dataset_config, split="test", context_length=train_config.context_length) + dataset_val = get_preprocessed_dataset( + dataset_processer, dataset_config, split="test", context_length=train_config.context_length + ) # TODO: vbaddi, check if its necessary to do this? # dataset_train = ConcatDataset( diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py index edfb5e523..9630f77f2 100644 --- a/QEfficient/finetune/dataset/imdb_dataset.py +++ b/QEfficient/finetune/dataset/imdb_dataset.py @@ -6,9 +6,10 @@ # ----------------------------------------------------------------------------- -import datasets from itertools import chain +import datasets + def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None): dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True) diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index b9b18e494..e979961d6 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -15,8 +15,7 @@ LoraConfig, PrefixTuningConfig, ) -from transformers import default_data_collator -from transformers.data import DataCollatorForSeq2Seq, DefaultDataCollator +from transformers.data import DataCollatorForSeq2Seq import QEfficient.finetune.configs.dataset_config as datasets from QEfficient.finetune.configs.peft_config import lora_config, prefix_config From 5c154aa4c29f8ef57f016198532b01db351db7d6 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Tue, 15 Apr 2025 16:55:08 +0530 Subject: [PATCH 09/16] Fixed few comments. Need to rebase first and then test Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index b2eae9b9b..1037995ba 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -301,12 +301,6 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) - if train_config.enable_ddp: - # Get the correct train loss from all the nodes. - dist.barrier() - dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) - train_epoch_loss /= dist.get_world_size() - if train_config.task_type == "seq_classification": accuracy = acc_helper.compute() acc_helper.reset() @@ -479,10 +473,10 @@ def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, Returns: eval_acc, eval_epoch_loss """ model.eval() - if local_rank is None: - num_classes = model.classifier.out_features - else: + if train_config.enable_ddp: num_classes = model.module.classifier.out_features + else: + num_classes = model.classifier.out_features acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) From e97fed885e2bd481a27d1fee08f235e20b923b49 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Tue, 15 Apr 2025 18:01:28 +0530 Subject: [PATCH 10/16] Removed collate_fn for samsum dataset as DataCollatorForSeq2Seq works for all the use cases supported. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/dataset/dataset_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py index fd9ace0af..7ae596404 100644 --- a/QEfficient/finetune/dataset/dataset_config.py +++ b/QEfficient/finetune/dataset/dataset_config.py @@ -38,5 +38,4 @@ } DATALOADER_COLLATE_FUNC = { "custom_dataset": get_data_collator, - "samsum_dataset": get_samsum_collate_fn, } From b3c707ecc5a9d7b11469793cf3627f48a24b10e4 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Tue, 15 Apr 2025 18:03:44 +0530 Subject: [PATCH 11/16] Ruff formatted code. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/dataset/dataset_config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py index 7ae596404..6613ad56e 100644 --- a/QEfficient/finetune/dataset/dataset_config.py +++ b/QEfficient/finetune/dataset/dataset_config.py @@ -24,9 +24,6 @@ from QEfficient.finetune.dataset.samsum_dataset import ( get_preprocessed_samsum as get_samsum_dataset, ) -from QEfficient.finetune.dataset.samsum_dataset import ( - get_samsum_collate_fn, -) DATASET_PREPROC = { "alpaca_dataset": partial(get_alpaca_dataset), From 8d983f28a26df0bdc0193e76434dd9af2a1734df Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Tue, 15 Apr 2025 18:33:50 +0530 Subject: [PATCH 12/16] Removed samsum collate_fn as it is dead code. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/dataset/samsum_dataset.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/QEfficient/finetune/dataset/samsum_dataset.py b/QEfficient/finetune/dataset/samsum_dataset.py index 3bb552a39..71814599d 100644 --- a/QEfficient/finetune/dataset/samsum_dataset.py +++ b/QEfficient/finetune/dataset/samsum_dataset.py @@ -6,8 +6,6 @@ # ----------------------------------------------------------------------------- import datasets -import torch -from torch.nn.utils.rnn import pad_sequence def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None): @@ -48,22 +46,3 @@ def tokenize_add_label(sample): dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) return dataset - - -def collate_fn(batch): - eos_token = batch[0]["input_ids"][-1] - - input_ids = pad_sequence( - [torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token - ) - attn_mask = pad_sequence( - [torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0 - ) - labels = pad_sequence( - [torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token - ) - return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels} - - -def get_samsum_collate_fn(dataset_processer, dataset_config): - return collate_fn From bf0f083668d0c3dc42267c00e321be2f609439f1 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Wed, 16 Apr 2025 11:47:10 +0530 Subject: [PATCH 13/16] Refactored evaluation fun and renamed variable names to generic names. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 211 ++++++++--------------- 1 file changed, 71 insertions(+), 140 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 1037995ba..b078e629f 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -62,9 +62,9 @@ def train( Returns: results dictionary containing average training and validation perplexity and loss """ - train_prep = [] + train_metric = [] train_loss = [] - val_prep = [] + val_metric = [] val_loss = [] if train_config.save_metrics: @@ -73,10 +73,10 @@ def train( metrics_filename = ( f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" ) - train_step_perplexity = [] + train_step_metric = [] train_step_loss = [] val_step_loss = [] - val_step_perplexity = [] + val_step_metric = [] epoch_times = [] checkpoint_times = [] @@ -106,10 +106,10 @@ def train( acc_helper = None if train_config.task_type == "seq_classification": - if local_rank is None: - num_classes = model.classifier.out_features - else: + if train_config.enable_ddp: num_classes = model.module.classifier.out_features + else: + num_classes = model.classifier.out_features acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) # Start the training loop @@ -231,7 +231,11 @@ def train( if train_config.save_metrics: train_step_loss.append(loss.detach().float().item()) - train_step_perplexity.append(float(torch.exp(loss.detach().float()))) + if train_config.task_type == "seq_classification": + step_metric_val = acc_helper.compute() + else: + step_metric_val = float(torch.exp(loss.detach().float())) + train_step_metric.append(step_metric_val) if train_config.grad_scaler: scaler.scale(loss).backward() # backward pass @@ -266,12 +270,12 @@ def train( metrics_filename, train_step_loss, train_loss, - train_step_perplexity, - train_prep, + train_step_metric, + train_metric, val_step_loss, val_loss, - val_step_perplexity, - val_prep, + val_step_metric, + val_metric, ) if train_config.enable_ddp: if loss_0_counter.item() == train_config.convergence_counter: @@ -307,11 +311,11 @@ def train( if train_config.enable_ddp: dist.all_reduce(accuracy, op=dist.ReduceOp.SUM) accuracy /= dist.get_world_size() - train_perplexity = accuracy + train_metric = accuracy else: - train_perplexity = torch.exp(train_epoch_loss) + train_metric = torch.exp(train_epoch_loss) - train_prep.append(float(train_perplexity)) + train_metric.append(float(train_metric)) train_loss.append(float(train_epoch_loss)) # Update the learning rate as needed @@ -320,21 +324,21 @@ def train( if train_config.run_validation: if train_config.enable_ddp: dist.barrier() - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper( - model, train_config, eval_dataloader, local_rank, tokenizer, device + eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper( + model, train_config, eval_dataloader, device ) if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) else: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper( - model, train_config, eval_dataloader, local_rank, tokenizer, device + eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper( + model, train_config, eval_dataloader, device ) tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) if train_config.save_metrics: val_step_loss.extend(temp_val_loss) - val_step_perplexity.extend(temp_step_perplexity) + val_step_metric.extend(temp_step_metric) # saving the adapters after completion of each epoch if train_config.save_model: @@ -349,14 +353,14 @@ def train( best_val_loss = eval_epoch_loss print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}") val_loss.append(float(eval_epoch_loss)) - val_prep.append(float(eval_ppl)) + val_metric.append(float(eval_metric)) if train_config.task_type == "seq_classification": print( - f"Epoch {epoch + 1}: train_acc={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + f"Epoch {epoch + 1}: train_acc={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" ) else: print( - f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + f"Epoch {epoch + 1}: train_metric={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" ) # Saving the results every epoch to plot later @@ -365,31 +369,25 @@ def train( metrics_filename, train_step_loss, train_loss, - train_step_perplexity, - train_prep, + train_step_metric, + train_metric, val_step_loss, val_loss, - val_step_perplexity, - val_prep, + val_step_metric, + val_metric, ) avg_epoch_time = sum(epoch_times) / len(epoch_times) avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0 - avg_train_prep = sum(train_prep) / len(train_prep) + avg_train_metric = sum(train_metric) / len(train_metric) avg_train_loss = sum(train_loss) / len(train_loss) if train_config.run_validation: - avg_eval_prep = sum(val_prep) / len(val_prep) + avg_eval_metric = sum(val_metric) / len(val_metric) avg_eval_loss = sum(val_loss) / len(val_loss) - if train_config.task_type == "seq_classification": - results["avg_train_acc"] = avg_train_prep - else: - results["avg_train_prep"] = avg_train_prep + results["avg_train_metric"] = avg_train_metric results["avg_train_loss"] = avg_train_loss if train_config.run_validation: - if train_config.task_type == "seq_classification": - results["avg_eval_acc"] = avg_eval_prep - else: - results["avg_eval_prep"] = avg_eval_prep + results["avg_eval_metric"] = avg_eval_metric results["avg_eval_loss"] = avg_eval_loss results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time @@ -399,39 +397,40 @@ def train( return results -def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device): +def evaluation_helper(model, train_config, eval_dataloader, device): """ Evaluates the model on the given dataloader Args: model: The model to evaluate eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - Returns: eval_ppl, eval_epoch_loss + Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric """ model.eval() + if train_config.task_type == "seq_classification": + if train_config.enable_ddp: + num_classes = model.module.classifier.out_features + else: + num_classes = model.classifier.out_features + acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) + # special handling for qaic device and dtype # model.to(device) - eval_preds = [] val_step_loss = [] - val_step_perplexity = [] + val_step_metric = [] eval_loss = 0.0 # Initialize evaluation loss - total_eval_steps = 0 - # max_steps_reached = False # Flag to indicate max eval steps reached for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): - total_eval_steps += 1 # stop when the maximum number of eval steps is reached - if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: - # max_steps_reached = True + if train_config.max_eval_step > 0 and step > train_config.max_eval_step: break for key in batch.keys(): batch[key] = batch[key].to(device) + # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss @@ -441,100 +440,32 @@ def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, outputs = model(**batch) loss = outputs.loss - if train_config.save_metrics: - val_step_loss.append(loss.detach().float().item()) - val_step_perplexity.append(float(torch.exp(loss.detach().float()))) - - eval_loss += loss.detach().float() - # Decode predictions and add to evaluation predictions list - preds = torch.argmax(outputs.logits, -1) - eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)) - - # Compute average loss and perplexity - eval_epoch_loss = eval_loss / len(eval_dataloader) - eval_ppl = torch.exp(eval_epoch_loss) - - # Print evaluation metrics - print(f" {eval_ppl.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") - - return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity - - -def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device): - """ - Evaluates the model on the given dataloader - - Args: - model: The model to evaluate - eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - - Returns: eval_acc, eval_epoch_loss - """ - model.eval() - if train_config.enable_ddp: - num_classes = model.module.classifier.out_features - else: - num_classes = model.classifier.out_features - - acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) - - # special handling for qaic device and dtype - # model.to(device) - - # eval_preds = [] - val_step_loss = [] - val_step_acc = [] - - eval_loss = 0.0 # Initialize evaluation loss - total_eval_steps = 0 - # max_steps_reached = False # Flag to indicate max eval steps reached + if train_config.task_type == "seq_classification": + logits = outputs.logits + labels = batch["labels"][:, 0] + preds = torch.nn.functional.softmax(logits, dim=-1) + val_acc = acc_helper.forward(preds, labels) + metric_val = val_acc.detach().float().item() + else: + metric_val = float(torch.exp(loss.detach().float())) - for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): - total_eval_steps += 1 - # stop when the maximum number of eval steps is reached - if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: - # max_steps_reached = True - break - for key in batch.keys(): - batch[key] = batch[key].to(device) - # Ensure no gradients are computed for this scope to save memory - with torch.no_grad(): - # Forward pass and compute loss - with ( - torch.autocast(device_type=device, dtype=torch.float16) if train_config.use_autocast else nullcontext() - ): - outputs = model(**batch) - loss = outputs.loss - logits = outputs.logits - labels = batch["labels"][:, 0] if train_config.save_metrics: val_step_loss.append(loss.detach().float().item()) - preds = torch.nn.functional.softmax(logits, dim=-1) - val_acc = acc_helper.forward(preds, labels) - val_step_acc.append(val_acc.detach().float().item()) + val_step_metric.append(metric_val) eval_loss += loss.detach().float() - # Decode predictions and add to evaluation predictions list - # preds = torch.argmax(outputs.logits, -1) - # eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)) - # Compute average loss and perplexity + # Compute average loss and metric eval_epoch_loss = eval_loss / len(eval_dataloader) - eval_acc = acc_helper.compute() + if train_config.task_type == "seq_classification": + eval_metric = acc_helper.compute() + else: + eval_metric = torch.exp(eval_epoch_loss) # Print evaluation metrics - print(f" {eval_acc.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") + print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") - return eval_acc, eval_epoch_loss, val_step_loss, val_step_acc - - -def evaluation_helper(model, train_config, eval_dataloader, local_rank, tokenizer, device): - if train_config.task_type == "seq_classification": - return evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device) - else: - return evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device) + return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: @@ -571,22 +502,22 @@ def save_to_json( output_filename, train_step_loss, train_epoch_loss, - train_step_ppl, - train_epoch_ppl, + train_step_metric, + train_epoch_metric, val_step_loss, val_epoch_loss, - val_step_ppl, - val_epoch_ppl, + val_step_metric, + val_epoch_metric, ): metrics_data = { "train_step_loss": train_step_loss, "train_epoch_loss": train_epoch_loss, - "train_step_perplexity": train_step_ppl, - "train_epoch_perplexity": train_epoch_ppl, + "train_step_metric": train_step_metric, + "train_epoch_metric": train_epoch_metric, "val_step_loss": val_step_loss, "val_epoch_loss": val_epoch_loss, - "val_step_perplexity": val_step_ppl, - "val_epoch_perplexity": val_epoch_ppl, + "val_step_metric": val_step_metric, + "val_epoch_metric": val_epoch_metric, } with open(output_filename, "w") as f: json.dump(metrics_data, f) From 9ef972fc750d29e050e60efd59719397126bab60 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Wed, 16 Apr 2025 12:13:31 +0530 Subject: [PATCH 14/16] Fixed metric value for json logging Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index b078e629f..3f8fb4df7 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -232,7 +232,7 @@ def train( if train_config.save_metrics: train_step_loss.append(loss.detach().float().item()) if train_config.task_type == "seq_classification": - step_metric_val = acc_helper.compute() + step_metric_val = float(acc_helper.compute()) else: step_metric_val = float(torch.exp(loss.detach().float())) train_step_metric.append(step_metric_val) From 2713fdc8158b0d08d2b420d1592a5f278fe65cf1 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Thu, 17 Apr 2025 10:20:40 +0530 Subject: [PATCH 15/16] Minor fixes to variable names Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/finetune/utils/train_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 3f8fb4df7..84379967a 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -311,11 +311,11 @@ def train( if train_config.enable_ddp: dist.all_reduce(accuracy, op=dist.ReduceOp.SUM) accuracy /= dist.get_world_size() - train_metric = accuracy + metric_val = accuracy else: - train_metric = torch.exp(train_epoch_loss) + metric_val = torch.exp(train_epoch_loss) - train_metric.append(float(train_metric)) + train_metric.append(float(metric_val)) train_loss.append(float(train_epoch_loss)) # Update the learning rate as needed @@ -356,11 +356,11 @@ def train( val_metric.append(float(eval_metric)) if train_config.task_type == "seq_classification": print( - f"Epoch {epoch + 1}: train_acc={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" ) else: print( - f"Epoch {epoch + 1}: train_metric={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" ) # Saving the results every epoch to plot later From b6d0ce9ddcf6136b1bd67746c58673bbd3a50381 Mon Sep 17 00:00:00 2001 From: Meet Patel <quic_meetkuma@quicinc.com> Date: Thu, 17 Apr 2025 11:20:10 +0530 Subject: [PATCH 16/16] Addressed few offline comments. Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> --- QEfficient/cloud/finetune.py | 4 ++++ QEfficient/finetune/utils/train_utils.py | 6 +----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index e23a1e656..f312d00cb 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -86,6 +86,10 @@ def main(**kwargs): attn_implementation="sdpa", torch_dtype=torch.float16, ) + + if not hasattr(model, "base_model_prefix"): + raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.") + for param in getattr(model, model.base_model_prefix).parameters(): param.requires_grad = False diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 84379967a..2bc701008 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -306,12 +306,8 @@ def train( train_epoch_loss = total_loss / len(train_dataloader) if train_config.task_type == "seq_classification": - accuracy = acc_helper.compute() + metric_val = acc_helper.compute() acc_helper.reset() - if train_config.enable_ddp: - dist.all_reduce(accuracy, op=dist.ReduceOp.SUM) - accuracy /= dist.get_world_size() - metric_val = accuracy else: metric_val = torch.exp(train_epoch_loss)