diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index ea4527ddf..c7525d2db 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -103,6 +103,18 @@ def main(**kwargs): # print the datatype of the model parameters # print(get_parameter_dtypes(model)) + # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model. + # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to + # apply gradient checkpointing related hooks to the input embeddings. Without this we will get + # "No inf checks were recorded for this optimizer." error. + # Enable gradient checkpointing + if train_config.gradient_checkpointing: + # Note: below attribute and method is only available in HuggingFace Transformer models. + if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) + else: + raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") + if train_config.use_peft: # Load the pre-trained peft model checkpoint and setup its configuration if train_config.from_peft_checkpoint: diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index bc51b2a31..9010965a9 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -15,6 +15,7 @@ class train_config: batch_size_training: int = 1 context_length: int = None gradient_accumulation_steps: int = 4 + gradient_checkpointing: bool = False num_epochs: int = 1 max_train_step: int = 0 max_eval_step: int = 0 diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py index d03316e1e..ac086d272 100644 --- a/QEfficient/finetune/dataset/dataset_config.py +++ b/QEfficient/finetune/dataset/dataset_config.py @@ -21,6 +21,9 @@ from QEfficient.finetune.dataset.samsum_dataset import ( get_preprocessed_samsum as get_samsum_dataset, ) +from QEfficient.finetune.dataset.samsum_dataset import ( + get_samsum_collate_fn, +) DATASET_PREPROC = { "alpaca_dataset": partial(get_alpaca_dataset), @@ -29,4 +32,7 @@ "gsm8k_dataset": get_gsm8k_dataset, "custom_dataset": get_custom_dataset, } -DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator} +DATALOADER_COLLATE_FUNC = { + "custom_dataset": get_data_collator, + "samsum_dataset": get_samsum_collate_fn, +} diff --git a/QEfficient/finetune/dataset/samsum_dataset.py b/QEfficient/finetune/dataset/samsum_dataset.py index 71814599d..3bb552a39 100644 --- a/QEfficient/finetune/dataset/samsum_dataset.py +++ b/QEfficient/finetune/dataset/samsum_dataset.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- import datasets +import torch +from torch.nn.utils.rnn import pad_sequence def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None): @@ -46,3 +48,22 @@ def tokenize_add_label(sample): dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) return dataset + + +def collate_fn(batch): + eos_token = batch[0]["input_ids"][-1] + + input_ids = pad_sequence( + [torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token + ) + attn_mask = pad_sequence( + [torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0 + ) + labels = pad_sequence( + [torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token + ) + return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels} + + +def get_samsum_collate_fn(dataset_processer, dataset_config): + return collate_fn diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 085dcd9bd..073742739 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -178,10 +178,7 @@ def train( # adjust atol & rtol this as required atol=1e-1, use_ref_output_on_mismatch=True, - # report all mismatches - max_failures=None, - # generate unittest for each op once - repeat_same_op=True, + filter_config=qaic_debug.DispatchFilterConfig.default(device), dump_root_dir=train_config.dump_root_dir + str(step), ) as verifier: loss = model(**batch).loss # Forward call @@ -297,8 +294,6 @@ def train( eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( model, train_config, eval_dataloader, local_rank, tokenizer, device ) - dist.barrier() - dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM) if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)