Skip to content

Added finetuning support for BERT based models on IMDB dataset. #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
print(f"Warning: {e}. Moving ahead without these qaic modules.")


from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

# Suppress all warnings
warnings.filterwarnings("ignore")
Expand All @@ -56,6 +56,7 @@ def main(**kwargs):
# update the configuration for the training process
train_config = TRAIN_CONFIG()
update_config(train_config, **kwargs)
dataset_config = generate_dataset_config(train_config, kwargs)
device = train_config.device

# dist init
Expand All @@ -78,12 +79,30 @@ def main(**kwargs):
# Load the pre-trained model and setup its configuration
# config = AutoConfig.from_pretrained(train_config.model_name)
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_path,
use_cache=False,
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
if train_config.task_type == "seq_classification":
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_path,
num_labels=dataset_config.num_labels,
attn_implementation="sdpa",
torch_dtype=torch.float16,
)

if not hasattr(model, "base_model_prefix"):
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")

for param in getattr(model, model.base_model_prefix).parameters():
param.requires_grad = False

for param in model.parameters():
if param.requires_grad:
param.data = param.data.to(torch.float32)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_path,
use_cache=False,
attn_implementation="sdpa",
torch_dtype=torch.float16,
)

# Load the tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -127,7 +146,6 @@ def main(**kwargs):
model.print_trainable_parameters()

# Get the dataset utils
dataset_config = generate_dataset_config(train_config, kwargs)
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation
Expand Down
8 changes: 8 additions & 0 deletions QEfficient/finetune/configs/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ class gsm8k_dataset:
test_split: str = "test"


@dataclass
class imdb_dataset:
dataset: str = "imdb_dataset"
train_split: str = "train"
test_split: str = "test"
num_labels: int = 2


@dataclass
class custom_dataset:
dataset: str = "custom_dataset"
Expand Down
1 change: 1 addition & 0 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class train_config:
use_autocast: bool = True
val_batch_size: int = 1
dataset = "samsum_dataset"
task_type = "generation" # "generation" / "seq_classification"
peft_method: str = "lora"
use_peft: bool = True # use parameter efficient fine tuning
from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
Expand Down
8 changes: 4 additions & 4 deletions QEfficient/finetune/dataset/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
get_dataset as get_grammar_dataset,
)
from QEfficient.finetune.dataset.gsm8k_dataset import get_gsm8k_dataset
from QEfficient.finetune.dataset.samsum_dataset import (
get_preprocessed_samsum as get_samsum_dataset,
from QEfficient.finetune.dataset.imdb_dataset import (
get_preprocessed_imdb as get_imdb_dataset,
)
from QEfficient.finetune.dataset.samsum_dataset import (
get_samsum_collate_fn,
get_preprocessed_samsum as get_samsum_dataset,
)

DATASET_PREPROC = {
Expand All @@ -31,8 +31,8 @@
"samsum_dataset": get_samsum_dataset,
"gsm8k_dataset": get_gsm8k_dataset,
"custom_dataset": get_custom_dataset,
"imdb_dataset": get_imdb_dataset,
}
DATALOADER_COLLATE_FUNC = {
"custom_dataset": get_data_collator,
"samsum_dataset": get_samsum_collate_fn,
}
39 changes: 39 additions & 0 deletions QEfficient/finetune/dataset/imdb_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------


from itertools import chain

import datasets


def get_preprocessed_imdb(dataset_config, tokenizer, split, context_length=None):
dataset = datasets.load_dataset("stanfordnlp/imdb", split=split, trust_remote_code=True)

if split == "test":
# Test set contains 15000 samples. Not all are required.
# 0-12499 are 0 labeled samples, 12500-24999 are 1 labeled samples.
dataset = dataset.select(chain(range(0, 500), range(12500, 13000)))

# Need to shuffle dataset as all the 0 labeled data is organized first and then all the 1 labeled data.
dataset = dataset.shuffle(seed=42)

if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

def tokenize_add_label(sample):
data = tokenizer(
sample["text"],
add_special_tokens=True,
max_length=tokenizer.model_max_length,
)

data["labels"] = [sample["label"]]
return data

dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
return dataset
21 changes: 0 additions & 21 deletions QEfficient/finetune/dataset/samsum_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# -----------------------------------------------------------------------------

import datasets
import torch
from torch.nn.utils.rnn import pad_sequence


def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
Expand Down Expand Up @@ -48,22 +46,3 @@ def tokenize_add_label(sample):
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

return dataset


def collate_fn(batch):
eos_token = batch[0]["input_ids"][-1]

input_ids = pad_sequence(
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
)
attn_mask = pad_sequence(
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
)
labels = pad_sequence(
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
)
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}


def get_samsum_collate_fn(dataset_processer, dataset_config):
return collate_fn
5 changes: 1 addition & 4 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
LoraConfig,
PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq

import QEfficient.finetune.configs.dataset_config as datasets
Expand Down Expand Up @@ -88,16 +87,14 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
num_replicas=dist.get_world_size(),
shuffle=False,
)
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
else:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
return kwargs
Loading
Loading