diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index c7525d2db..f312d00cb 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -38,7 +38,7 @@ print(f"Warning: {e}. Moving ahead without these qaic modules.") -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer # Suppress all warnings warnings.filterwarnings("ignore") @@ -56,6 +56,7 @@ def main(**kwargs): # update the configuration for the training process train_config = TRAIN_CONFIG() update_config(train_config, **kwargs) + dataset_config = generate_dataset_config(train_config, kwargs) device = train_config.device # dist init @@ -78,12 +79,30 @@ def main(**kwargs): # Load the pre-trained model and setup its configuration # config = AutoConfig.from_pretrained(train_config.model_name) pretrained_model_path = login_and_download_hf_lm(train_config.model_name) - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_path, - use_cache=False, - attn_implementation="sdpa", - torch_dtype=torch.float16, - ) + if train_config.task_type == "seq_classification": + model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_path, + num_labels=dataset_config.num_labels, + attn_implementation="sdpa", + torch_dtype=torch.float16, + ) + + if not hasattr(model, "base_model_prefix"): + raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.") + + for param in getattr(model, model.base_model_prefix).parameters(): + param.requires_grad = False + + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.float32) + else: + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_path, + use_cache=False, + attn_implementation="sdpa", + torch_dtype=torch.float16, + ) # Load the tokenizer and add special tokens tokenizer = AutoTokenizer.from_pretrained( @@ -127,7 +146,6 @@ def main(**kwargs): model.print_trainable_parameters() # Get the dataset utils - dataset_config = generate_dataset_config(train_config, kwargs) dataset_processer = tokenizer # Load and preprocess the dataset for training and validation diff --git a/QEfficient/finetune/configs/dataset_config.py b/QEfficient/finetune/configs/dataset_config.py index 2e7fb56fb..fd1081983 100644 --- a/QEfficient/finetune/configs/dataset_config.py +++ b/QEfficient/finetune/configs/dataset_config.py @@ -37,6 +37,14 @@ class gsm8k_dataset: test_split: str = "test" +@dataclass +class imdb_dataset: + dataset: str = "imdb_dataset" + train_split: str = "train" + test_split: str = "test" + num_labels: int = 2 + + @dataclass class custom_dataset: dataset: str = "custom_dataset" diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index 9010965a9..c50954c4c 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -29,6 +29,7 @@ class train_config: use_autocast: bool = True val_batch_size: int = 1 dataset = "samsum_dataset" + task_type = "generation" # "generation" / "seq_classification" peft_method: str = "lora" use_peft: bool = True # use parameter efficient fine tuning from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py index ac086d272..6613ad56e 100644 --- a/QEfficient/finetune/dataset/dataset_config.py +++ b/QEfficient/finetune/dataset/dataset_config.py @@ -18,11 +18,11 @@ get_dataset as get_grammar_dataset, ) from QEfficient.finetune.dataset.gsm8k_dataset import get_gsm8k_dataset -from QEfficient.finetune.dataset.samsum_dataset import ( - get_preprocessed_samsum as get_samsum_dataset, +from QEfficient.finetune.dataset.imdb_dataset import ( + get_preprocessed_imdb as get_imdb_dataset, ) from QEfficient.finetune.dataset.samsum_dataset import ( - get_samsum_collate_fn, + get_preprocessed_samsum as get_samsum_dataset, ) DATASET_PREPROC = { @@ -31,8 +31,8 @@ "samsum_dataset": get_samsum_dataset, "gsm8k_dataset": get_gsm8k_dataset, "custom_dataset": get_custom_dataset, + "imdb_dataset": get_imdb_dataset, } DATALOADER_COLLATE_FUNC = { "custom_dataset": get_data_collator, - "samsum_dataset": get_samsum_collate_fn, } diff --git a/QEfficient/finetune/dataset/imdb_dataset.py b/QEfficient/finetune/dataset/imdb_dataset.py new file mode 100644 index 000000000..9630f77f2 --- /dev/null +++ b/QEfficient/finetune/dataset/imdb_dataset.py @@ -0,0 +1,39 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +from itertools import chain + +import datasets + + +def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None): + dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True) + + if split == "test": + # Test set contains 15000 samples. Not all are required. + # 0-12499 are 0 labeled samples, 12500-24999 are 1 labeled samples. + dataset = dataset.select(chain(range(0, 500), range(12500, 13000))) + + # Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data. + dataset = dataset.shuffle(seed=42) + + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + def tokenize_add_label(sample): + data = tokenizer( + sample["text"], + add_special_tokens=True, + max_length=tokenizer.model_max_length, + ) + + data["labels"] = [sample["label"]] + return data + + dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) + return dataset diff --git a/QEfficient/finetune/dataset/samsum_dataset.py b/QEfficient/finetune/dataset/samsum_dataset.py index 3bb552a39..71814599d 100644 --- a/QEfficient/finetune/dataset/samsum_dataset.py +++ b/QEfficient/finetune/dataset/samsum_dataset.py @@ -6,8 +6,6 @@ # ----------------------------------------------------------------------------- import datasets -import torch -from torch.nn.utils.rnn import pad_sequence def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None): @@ -48,22 +46,3 @@ def tokenize_add_label(sample): dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) return dataset - - -def collate_fn(batch): - eos_token = batch[0]["input_ids"][-1] - - input_ids = pad_sequence( - [torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token - ) - attn_mask = pad_sequence( - [torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0 - ) - labels = pad_sequence( - [torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token - ) - return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels} - - -def get_samsum_collate_fn(dataset_processer, dataset_config): - return collate_fn diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index 58344b190..e979961d6 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -15,7 +15,6 @@ LoraConfig, PrefixTuningConfig, ) -from transformers import default_data_collator from transformers.data import DataCollatorForSeq2Seq import QEfficient.finetune.configs.dataset_config as datasets @@ -88,16 +87,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): num_replicas=dist.get_world_size(), shuffle=False, ) - kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) else: kwargs["sampler"] = data_utils.DistributedSampler( dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True ) kwargs["batch_size"] = batch_size kwargs["drop_last"] = True - kwargs["collate_fn"] = default_data_collator else: kwargs["batch_size"] = batch_size kwargs["drop_last"] = True - kwargs["collate_fn"] = default_data_collator + kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) return kwargs diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 073742739..2bc701008 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -14,6 +14,7 @@ import torch import torch.distributed as dist +import torchmetrics from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm @@ -61,9 +62,9 @@ def train( Returns: results dictionary containing average training and validation perplexity and loss """ - train_prep = [] + train_metric = [] train_loss = [] - val_prep = [] + val_metric = [] val_loss = [] if train_config.save_metrics: @@ -72,10 +73,10 @@ def train( metrics_filename = ( f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" ) - train_step_perplexity = [] + train_step_metric = [] train_step_loss = [] val_step_loss = [] - val_step_perplexity = [] + val_step_metric = [] epoch_times = [] checkpoint_times = [] @@ -103,6 +104,14 @@ def train( if train_config.enable_ddp: dist.broadcast(loss_0_counter, src=0) + acc_helper = None + if train_config.task_type == "seq_classification": + if train_config.enable_ddp: + num_classes = model.module.classifier.out_features + else: + num_classes = model.classifier.out_features + acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) + # Start the training loop for epoch in range(train_config.num_epochs): if loss_0_counter.item() == train_config.convergence_counter: @@ -181,10 +190,22 @@ def train( filter_config=qaic_debug.DispatchFilterConfig.default(device), dump_root_dir=train_config.dump_root_dir + str(step), ) as verifier: - loss = model(**batch).loss # Forward call + model_outputs = model(**batch) + loss = model_outputs.loss # Forward call + if train_config.task_type == "seq_classification": + logits = model_outputs.logits + labels = batch["labels"][:, 0] + preds = torch.nn.functional.softmax(logits, dim=-1) + acc_helper.forward(preds, labels) print("Mismatches detected:", verifier.get_perop_mismatch_count()) else: - loss = model(**batch).loss # Forward call + model_outputs = model(**batch) + loss = model_outputs.loss # Forward call + if train_config.task_type == "seq_classification": + logits = model_outputs.logits + labels = batch["labels"][:, 0] + preds = torch.nn.functional.softmax(logits, dim=-1) + acc_helper.forward(preds, labels) total_loss += loss.detach().float() # Accumalate graidents @@ -210,7 +231,11 @@ def train( if train_config.save_metrics: train_step_loss.append(loss.detach().float().item()) - train_step_perplexity.append(float(torch.exp(loss.detach().float()))) + if train_config.task_type == "seq_classification": + step_metric_val = float(acc_helper.compute()) + else: + step_metric_val = float(torch.exp(loss.detach().float())) + train_step_metric.append(step_metric_val) if train_config.grad_scaler: scaler.scale(loss).backward() # backward pass @@ -245,12 +270,12 @@ def train( metrics_filename, train_step_loss, train_loss, - train_step_perplexity, - train_prep, + train_step_metric, + train_metric, val_step_loss, val_loss, - val_step_perplexity, - val_prep, + val_step_metric, + val_metric, ) if train_config.enable_ddp: if loss_0_counter.item() == train_config.convergence_counter: @@ -280,9 +305,13 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) - train_perplexity = torch.exp(train_epoch_loss) + if train_config.task_type == "seq_classification": + metric_val = acc_helper.compute() + acc_helper.reset() + else: + metric_val = torch.exp(train_epoch_loss) - train_prep.append(float(train_perplexity)) + train_metric.append(float(metric_val)) train_loss.append(float(train_epoch_loss)) # Update the learning rate as needed @@ -291,21 +320,21 @@ def train( if train_config.run_validation: if train_config.enable_ddp: dist.barrier() - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( - model, train_config, eval_dataloader, local_rank, tokenizer, device + eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper( + model, train_config, eval_dataloader, device ) if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) else: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( - model, train_config, eval_dataloader, local_rank, tokenizer, device + eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper( + model, train_config, eval_dataloader, device ) tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) if train_config.save_metrics: val_step_loss.extend(temp_val_loss) - val_step_perplexity.extend(temp_step_perplexity) + val_step_metric.extend(temp_step_metric) # saving the adapters after completion of each epoch if train_config.save_model: @@ -320,10 +349,15 @@ def train( best_val_loss = eval_epoch_loss print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}") val_loss.append(float(eval_epoch_loss)) - val_prep.append(float(eval_ppl)) - print( - f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" - ) + val_metric.append(float(eval_metric)) + if train_config.task_type == "seq_classification": + print( + f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) + else: + print( + f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) # Saving the results every epoch to plot later if train_config.save_metrics: @@ -331,25 +365,25 @@ def train( metrics_filename, train_step_loss, train_loss, - train_step_perplexity, - train_prep, + train_step_metric, + train_metric, val_step_loss, val_loss, - val_step_perplexity, - val_prep, + val_step_metric, + val_metric, ) avg_epoch_time = sum(epoch_times) / len(epoch_times) avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0 - avg_train_prep = sum(train_prep) / len(train_prep) + avg_train_metric = sum(train_metric) / len(train_metric) avg_train_loss = sum(train_loss) / len(train_loss) if train_config.run_validation: - avg_eval_prep = sum(val_prep) / len(val_prep) + avg_eval_metric = sum(val_metric) / len(val_metric) avg_eval_loss = sum(val_loss) / len(val_loss) - results["avg_train_prep"] = avg_train_prep + results["avg_train_metric"] = avg_train_metric results["avg_train_loss"] = avg_train_loss if train_config.run_validation: - results["avg_eval_prep"] = avg_eval_prep + results["avg_eval_metric"] = avg_eval_metric results["avg_eval_loss"] = avg_eval_loss results["avg_epoch_time"] = avg_epoch_time results["avg_checkpoint_time"] = avg_checkpoint_time @@ -359,39 +393,40 @@ def train( return results -def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, device): +def evaluation_helper(model, train_config, eval_dataloader, device): """ Evaluates the model on the given dataloader Args: model: The model to evaluate eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - Returns: eval_ppl, eval_epoch_loss + Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric """ model.eval() + if train_config.task_type == "seq_classification": + if train_config.enable_ddp: + num_classes = model.module.classifier.out_features + else: + num_classes = model.classifier.out_features + acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device) + # special handling for qaic device and dtype # model.to(device) - eval_preds = [] val_step_loss = [] - val_step_perplexity = [] + val_step_metric = [] eval_loss = 0.0 # Initialize evaluation loss - total_eval_steps = 0 - # max_steps_reached = False # Flag to indicate max eval steps reached for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): - total_eval_steps += 1 # stop when the maximum number of eval steps is reached - if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: - # max_steps_reached = True + if train_config.max_eval_step > 0 and step > train_config.max_eval_step: break for key in batch.keys(): batch[key] = batch[key].to(device) + # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss @@ -401,23 +436,32 @@ def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, devi outputs = model(**batch) loss = outputs.loss + if train_config.task_type == "seq_classification": + logits = outputs.logits + labels = batch["labels"][:, 0] + preds = torch.nn.functional.softmax(logits, dim=-1) + val_acc = acc_helper.forward(preds, labels) + metric_val = val_acc.detach().float().item() + else: + metric_val = float(torch.exp(loss.detach().float())) + if train_config.save_metrics: val_step_loss.append(loss.detach().float().item()) - val_step_perplexity.append(float(torch.exp(loss.detach().float()))) + val_step_metric.append(metric_val) eval_loss += loss.detach().float() - # Decode predictions and add to evaluation predictions list - preds = torch.argmax(outputs.logits, -1) - eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)) - # Compute average loss and perplexity + # Compute average loss and metric eval_epoch_loss = eval_loss / len(eval_dataloader) - eval_ppl = torch.exp(eval_epoch_loss) + if train_config.task_type == "seq_classification": + eval_metric = acc_helper.compute() + else: + eval_metric = torch.exp(eval_epoch_loss) # Print evaluation metrics - print(f" {eval_ppl.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") + print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") - return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity + return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: @@ -454,22 +498,22 @@ def save_to_json( output_filename, train_step_loss, train_epoch_loss, - train_step_ppl, - train_epoch_ppl, + train_step_metric, + train_epoch_metric, val_step_loss, val_epoch_loss, - val_step_ppl, - val_epoch_ppl, + val_step_metric, + val_epoch_metric, ): metrics_data = { "train_step_loss": train_step_loss, "train_epoch_loss": train_epoch_loss, - "train_step_perplexity": train_step_ppl, - "train_epoch_perplexity": train_epoch_ppl, + "train_step_metric": train_step_metric, + "train_epoch_metric": train_epoch_metric, "val_step_loss": val_step_loss, "val_epoch_loss": val_epoch_loss, - "val_step_perplexity": val_step_ppl, - "val_epoch_perplexity": val_epoch_ppl, + "val_step_metric": val_step_metric, + "val_epoch_metric": val_epoch_metric, } with open(output_filename, "w") as f: json.dump(metrics_data, f) diff --git a/pyproject.toml b/pyproject.toml index a218ec9ee..648d2ce4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = { file = "LICENSE" } authors = [{ name = "Qualcomm Cloud AI ML Team" }] keywords = ["transformers", "Cloud AI 100", "Inference"] classifiers = [ - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3", "Development Status :: 5 - Development/Unstable", "Intended Audience :: Developers", "Intended Audience :: Education", @@ -38,6 +38,7 @@ dependencies = [ "tensorboard", "fire", "py7zr", + "torchmetrics==1.7.0", "torch==2.4.1; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", @@ -60,7 +61,7 @@ namespaces = false [tool.setuptools.dynamic.version] attr = "QEfficient.__version__" - + [tool.ruff] line-length = 120 # Enable the isort rules.