From 750ba881cab51d994b41d0cf91ae3b87cb5eb910 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 27 Mar 2025 18:43:35 +0530 Subject: [PATCH 01/11] Added support for gradient checkpointing in the finetuning script. Signed-off-by: Meet Patel --- QEfficient/cloud/finetune.py | 14 ++++++++++--- QEfficient/finetune/configs/training.py | 1 + QEfficient/finetune/dataset/dataset_config.py | 8 ++++++- QEfficient/finetune/dataset/samsum_dataset.py | 21 +++++++++++++++++++ QEfficient/finetune/utils/train_utils.py | 15 +++++++++---- 5 files changed, 51 insertions(+), 8 deletions(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index ea4527ddf..2f5e689c9 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -63,9 +63,9 @@ def main(**kwargs): # TODO: may have to init qccl backend, next try run with torchrun command torch_device = torch.device(device) assert torch_device.type != "cpu", "Host doesn't support single-node DDP" - assert torch_device.index is None, ( - f"DDP requires specification of device type only, however provided device index as well: {torch_device}" - ) + assert ( + torch_device.index is None + ), f"DDP requires specification of device type only, however provided device index as well: {torch_device}" dist.init_process_group(backend=train_config.dist_backend) # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank getattr(torch, torch_device.type).set_device(dist.get_rank()) @@ -114,6 +114,14 @@ def main(**kwargs): model = get_peft_model(model, peft_config) model.print_trainable_parameters() + # Enable gradient checkpointing + if train_config.gradient_checkpointing: + # Note: below attribute and method is only available in HuggingFace Transformer models. + if model.supports_gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) + else: + raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") + # Get the dataset utils dataset_config = generate_dataset_config(train_config, kwargs) dataset_processer = tokenizer diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index bc51b2a31..9010965a9 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -15,6 +15,7 @@ class train_config: batch_size_training: int = 1 context_length: int = None gradient_accumulation_steps: int = 4 + gradient_checkpointing: bool = False num_epochs: int = 1 max_train_step: int = 0 max_eval_step: int = 0 diff --git a/QEfficient/finetune/dataset/dataset_config.py b/QEfficient/finetune/dataset/dataset_config.py index d03316e1e..ac086d272 100644 --- a/QEfficient/finetune/dataset/dataset_config.py +++ b/QEfficient/finetune/dataset/dataset_config.py @@ -21,6 +21,9 @@ from QEfficient.finetune.dataset.samsum_dataset import ( get_preprocessed_samsum as get_samsum_dataset, ) +from QEfficient.finetune.dataset.samsum_dataset import ( + get_samsum_collate_fn, +) DATASET_PREPROC = { "alpaca_dataset": partial(get_alpaca_dataset), @@ -29,4 +32,7 @@ "gsm8k_dataset": get_gsm8k_dataset, "custom_dataset": get_custom_dataset, } -DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator} +DATALOADER_COLLATE_FUNC = { + "custom_dataset": get_data_collator, + "samsum_dataset": get_samsum_collate_fn, +} diff --git a/QEfficient/finetune/dataset/samsum_dataset.py b/QEfficient/finetune/dataset/samsum_dataset.py index 71814599d..3bb552a39 100644 --- a/QEfficient/finetune/dataset/samsum_dataset.py +++ b/QEfficient/finetune/dataset/samsum_dataset.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- import datasets +import torch +from torch.nn.utils.rnn import pad_sequence def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None): @@ -46,3 +48,22 @@ def tokenize_add_label(sample): dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) return dataset + + +def collate_fn(batch): + eos_token = batch[0]["input_ids"][-1] + + input_ids = pad_sequence( + [torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token + ) + attn_mask = pad_sequence( + [torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0 + ) + labels = pad_sequence( + [torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token + ) + return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels} + + +def get_samsum_collate_fn(dataset_processer, dataset_config): + return collate_fn diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 085dcd9bd..fa9d1fa1d 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -178,10 +178,7 @@ def train( # adjust atol & rtol this as required atol=1e-1, use_ref_output_on_mismatch=True, - # report all mismatches - max_failures=None, - # generate unittest for each op once - repeat_same_op=True, + filter_config=qaic_debug.DispatchFilterConfig.default(device, opwise_limit=10), dump_root_dir=train_config.dump_root_dir + str(step), ) as verifier: loss = model(**batch).loss # Forward call @@ -215,6 +212,10 @@ def train( train_step_loss.append(loss.detach().float().item()) train_step_perplexity.append(float(torch.exp(loss.detach().float()))) + if train_config.gradient_checkpointing: + # Enforce that the loss retains its gradient tracking. + loss.requires_grad = True + if train_config.grad_scaler: scaler.scale(loss).backward() # backward pass else: @@ -283,6 +284,11 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) + # Get the correct train loss from all the nodes. + dist.barrier() + dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) + train_epoch_loss /= dist.get_world_size() + train_perplexity = torch.exp(train_epoch_loss) train_prep.append(float(train_perplexity)) @@ -299,6 +305,7 @@ def train( ) dist.barrier() dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM) + eval_epoch_loss /= dist.get_world_size() if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) From 2b705965704843e7a84187c3ed218d78adf15567 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 27 Mar 2025 18:43:35 +0530 Subject: [PATCH 02/11] Added support for gradient checkpointing in the finetuning script. Signed-off-by: Meet Patel --- QEfficient/cloud/finetune.py | 19 ++++++++++++------- QEfficient/finetune/utils/train_utils.py | 3 --- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 2f5e689c9..0b240f1a6 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -103,6 +103,18 @@ def main(**kwargs): # print the datatype of the model parameters # print(get_parameter_dtypes(model)) + # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model. + # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to + # apply gradient checkpointing related hooks to the input embeddings. Without this we will get + # "No inf checks were recorded for this optimizer." error. + # Enable gradient checkpointing + if train_config.gradient_checkpointing: + # Note: below attribute and method is only available in HuggingFace Transformer models. + if model.supports_gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) + else: + raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") + if train_config.use_peft: # Load the pre-trained peft model checkpoint and setup its configuration if train_config.from_peft_checkpoint: @@ -114,13 +126,6 @@ def main(**kwargs): model = get_peft_model(model, peft_config) model.print_trainable_parameters() - # Enable gradient checkpointing - if train_config.gradient_checkpointing: - # Note: below attribute and method is only available in HuggingFace Transformer models. - if model.supports_gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) - else: - raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") # Get the dataset utils dataset_config = generate_dataset_config(train_config, kwargs) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index fa9d1fa1d..20aa29fd9 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -212,9 +212,6 @@ def train( train_step_loss.append(loss.detach().float().item()) train_step_perplexity.append(float(torch.exp(loss.detach().float()))) - if train_config.gradient_checkpointing: - # Enforce that the loss retains its gradient tracking. - loss.requires_grad = True if train_config.grad_scaler: scaler.scale(loss).backward() # backward pass From 11fc42ec6b8eca027da1caba96d1d5cdd6e1888e Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 27 Mar 2025 18:43:35 +0530 Subject: [PATCH 03/11] Added support for gradient checkpointing in the finetuning script. Signed-off-by: Meet Patel --- QEfficient/finetune/utils/train_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 20aa29fd9..15d27d9a6 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -281,10 +281,11 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) - # Get the correct train loss from all the nodes. - dist.barrier() - dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) - train_epoch_loss /= dist.get_world_size() + if train_config.enable_ddp: + # Get the correct train loss from all the nodes. + dist.barrier() + dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) + train_epoch_loss /= dist.get_world_size() train_perplexity = torch.exp(train_epoch_loss) From d74f7af4187da1e78f2cc90f6e44ca766dd6e24b Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Mon, 7 Apr 2025 15:57:15 +0530 Subject: [PATCH 04/11] Changed preserve_rng_state=True Signed-off-by: Meet Patel --- QEfficient/cloud/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 0b240f1a6..94b6a76dc 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -111,7 +111,7 @@ def main(**kwargs): if train_config.gradient_checkpointing: # Note: below attribute and method is only available in HuggingFace Transformer models. if model.supports_gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": True}) else: raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") From 1f5f249cc074c61fe99a7ce2cb481f7285a9fff6 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Mon, 7 Apr 2025 17:04:21 +0530 Subject: [PATCH 05/11] Changed preserve_rng_state=False as we dont support it yet. Signed-off-by: Meet Patel --- QEfficient/cloud/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 94b6a76dc..0b240f1a6 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -111,7 +111,7 @@ def main(**kwargs): if train_config.gradient_checkpointing: # Note: below attribute and method is only available in HuggingFace Transformer models. if model.supports_gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": True}) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) else: raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") From fa2bc03251e32c33ccc5f964a1ed54197e78d356 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 10 Apr 2025 14:51:10 +0530 Subject: [PATCH 06/11] Fixed comments Signed-off-by: Meet Patel --- QEfficient/finetune/utils/train_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 15d27d9a6..90def153c 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -281,12 +281,6 @@ def train( else: train_epoch_loss = total_loss / len(train_dataloader) - if train_config.enable_ddp: - # Get the correct train loss from all the nodes. - dist.barrier() - dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) - train_epoch_loss /= dist.get_world_size() - train_perplexity = torch.exp(train_epoch_loss) train_prep.append(float(train_perplexity)) @@ -302,8 +296,6 @@ def train( model, train_config, eval_dataloader, local_rank, tokenizer, device ) dist.barrier() - dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM) - eval_epoch_loss /= dist.get_world_size() if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) From 19c16d7c7a32252232dffa0c98e89fcb3e0802c7 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 10 Apr 2025 14:53:48 +0530 Subject: [PATCH 07/11] Ran ruff to format the files Signed-off-by: Meet Patel --- QEfficient/cloud/finetune.py | 7 +++---- QEfficient/finetune/utils/train_utils.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 0b240f1a6..398644443 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -63,9 +63,9 @@ def main(**kwargs): # TODO: may have to init qccl backend, next try run with torchrun command torch_device = torch.device(device) assert torch_device.type != "cpu", "Host doesn't support single-node DDP" - assert ( - torch_device.index is None - ), f"DDP requires specification of device type only, however provided device index as well: {torch_device}" + assert torch_device.index is None, ( + f"DDP requires specification of device type only, however provided device index as well: {torch_device}" + ) dist.init_process_group(backend=train_config.dist_backend) # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank getattr(torch, torch_device.type).set_device(dist.get_rank()) @@ -126,7 +126,6 @@ def main(**kwargs): model = get_peft_model(model, peft_config) model.print_trainable_parameters() - # Get the dataset utils dataset_config = generate_dataset_config(train_config, kwargs) dataset_processer = tokenizer diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 90def153c..d075791a7 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -212,7 +212,6 @@ def train( train_step_loss.append(loss.detach().float().item()) train_step_perplexity.append(float(torch.exp(loss.detach().float()))) - if train_config.grad_scaler: scaler.scale(loss).backward() # backward pass else: From ba2615a35fd0e6ef6f3ee625aae8ed64d916e194 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 10 Apr 2025 15:05:37 +0530 Subject: [PATCH 08/11] Reverted a change in opbyopverifier Signed-off-by: Meet Patel --- QEfficient/finetune/utils/train_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index d075791a7..80b1984ec 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -178,7 +178,6 @@ def train( # adjust atol & rtol this as required atol=1e-1, use_ref_output_on_mismatch=True, - filter_config=qaic_debug.DispatchFilterConfig.default(device, opwise_limit=10), dump_root_dir=train_config.dump_root_dir + str(step), ) as verifier: loss = model(**batch).loss # Forward call From bb43de35715cdb2f190075d735e3a46400d1f9ee Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 10 Apr 2025 16:17:26 +0530 Subject: [PATCH 09/11] Fixed recent comment Signed-off-by: Meet Patel --- QEfficient/cloud/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index 398644443..c7525d2db 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -110,7 +110,7 @@ def main(**kwargs): # Enable gradient checkpointing if train_config.gradient_checkpointing: # Note: below attribute and method is only available in HuggingFace Transformer models. - if model.supports_gradient_checkpointing: + if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False}) else: raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") From 3bdf63cde18e48bf91b6d4a0baf184bcff8b53b6 Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 10 Apr 2025 16:18:01 +0530 Subject: [PATCH 10/11] Removed extra dist barrier Signed-off-by: Meet Patel --- QEfficient/finetune/utils/train_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 80b1984ec..9e6b24ef4 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -293,7 +293,6 @@ def train( eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation( model, train_config, eval_dataloader, local_rank, tokenizer, device ) - dist.barrier() if local_rank == 0: tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) From 446ad8e6fbfad5dc8ae60cf6f4b3a93a9ea8fd1c Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Thu, 10 Apr 2025 16:22:31 +0530 Subject: [PATCH 11/11] Need to bring back filter_config due to change in arguments Signed-off-by: Meet Patel --- QEfficient/finetune/utils/train_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 9e6b24ef4..073742739 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -178,6 +178,7 @@ def train( # adjust atol & rtol this as required atol=1e-1, use_ref_output_on_mismatch=True, + filter_config=qaic_debug.DispatchFilterConfig.default(device), dump_root_dir=train_config.dump_root_dir + str(step), ) as verifier: loss = model(**batch).loss # Forward call