Skip to content

[QEff Finetune]: Added support for gradient checkpointing in the finetuning script. #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 11, 2025
12 changes: 12 additions & 0 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def main(**kwargs):
# print the datatype of the model parameters
# print(get_parameter_dtypes(model))

# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
# "No inf checks were recorded for this optimizer." error.
# Enable gradient checkpointing
if train_config.gradient_checkpointing:
# Note: below attribute and method is only available in HuggingFace Transformer models.
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
else:
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")

if train_config.use_peft:
# Load the pre-trained peft model checkpoint and setup its configuration
if train_config.from_peft_checkpoint:
Expand Down
1 change: 1 addition & 0 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class train_config:
batch_size_training: int = 1
context_length: int = None
gradient_accumulation_steps: int = 4
gradient_checkpointing: bool = False
num_epochs: int = 1
max_train_step: int = 0
max_eval_step: int = 0
Expand Down
8 changes: 7 additions & 1 deletion QEfficient/finetune/dataset/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from QEfficient.finetune.dataset.samsum_dataset import (
get_preprocessed_samsum as get_samsum_dataset,
)
from QEfficient.finetune.dataset.samsum_dataset import (
get_samsum_collate_fn,
)

DATASET_PREPROC = {
"alpaca_dataset": partial(get_alpaca_dataset),
Expand All @@ -29,4 +32,7 @@
"gsm8k_dataset": get_gsm8k_dataset,
"custom_dataset": get_custom_dataset,
}
DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator}
DATALOADER_COLLATE_FUNC = {
"custom_dataset": get_data_collator,
"samsum_dataset": get_samsum_collate_fn,
}
21 changes: 21 additions & 0 deletions QEfficient/finetune/dataset/samsum_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# -----------------------------------------------------------------------------

import datasets
import torch
from torch.nn.utils.rnn import pad_sequence


def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
Expand Down Expand Up @@ -46,3 +48,22 @@ def tokenize_add_label(sample):
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

return dataset


def collate_fn(batch):
eos_token = batch[0]["input_ids"][-1]

input_ids = pad_sequence(
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
)
attn_mask = pad_sequence(
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
)
labels = pad_sequence(
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
)
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}


def get_samsum_collate_fn(dataset_processer, dataset_config):
return collate_fn
7 changes: 1 addition & 6 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,7 @@ def train(
# adjust atol & rtol this as required
atol=1e-1,
use_ref_output_on_mismatch=True,
# report all mismatches
max_failures=None,
# generate unittest for each op once
repeat_same_op=True,
filter_config=qaic_debug.DispatchFilterConfig.default(device),
dump_root_dir=train_config.dump_root_dir + str(step),
) as verifier:
loss = model(**batch).loss # Forward call
Expand Down Expand Up @@ -297,8 +294,6 @@ def train(
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
model, train_config, eval_dataloader, local_rank, tokenizer, device
)
dist.barrier()
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
if local_rank == 0:
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)

Expand Down