Skip to content

Commit a55f3b8

Browse files
author
Swati Allabadi
committed
[QEff. Finetune] Adding the support to resume the fine tuning using pre computed (#233)
1) Adding the support to resume the fine tuning using checkpoints from a prev run which would have stopped in between. 2) Checkpoints, both intermediate and for complete epoch, will get saved for each epoch through these changes. 3) There's no necessity to pass tokenizer_name if a model_name is passed. It will take the same name as model_name by default. If a different tokenizer_name is required than the model_name, then it can be passed separately as an argument in the command. --------- Signed-off-by: Swati Allabadi <[email protected]> Co-authored-by: Swati Allabadi <[email protected]>
1 parent 5520757 commit a55f3b8

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

QEfficient/finetune/configs/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@dataclass
1111
class train_config:
1212
model_name: str = "meta-llama/Llama-3.2-1B"
13-
tokenizer_name: str = "meta-llama/Llama-3.2-1B"
13+
tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name
1414
run_validation: bool = True
1515
batch_size_training: int = 1
1616
context_length: int = None

QEfficient/finetune/utils/train_utils.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ def train(
112112
f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
113113
)
114114
break
115+
116+
if train_config.use_peft and train_config.from_peft_checkpoint:
117+
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
118+
if epoch < intermediate_epoch:
119+
print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
120+
# to bring the count of train_step in sync with where it left off
121+
total_train_steps += len(train_dataloader)
122+
continue
123+
115124
print(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
116125
print(f"train_config.max_train_step: {train_config.max_train_step}")
117126
# stop when the maximum number of training steps is reached
@@ -131,8 +140,23 @@ def train(
131140

132141
# enable profile for qaic
133142
qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None
143+
134144
for step, batch in enumerate(train_dataloader):
145+
# resume training from a particular checkpoint, assuming the dataset is not shuffled
146+
if train_config.use_peft and train_config.from_peft_checkpoint:
147+
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
148+
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
149+
# to bring the count of train_step in sync with where it left off
150+
if epoch == intermediate_epoch and step == 0:
151+
total_train_steps += intermediate_step
152+
print(
153+
f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them."
154+
)
155+
if epoch == intermediate_epoch and step < intermediate_step:
156+
total_train_steps += 1
157+
continue
135158
total_train_steps += 1
159+
136160
# stop when the maximum number of training steps is reached
137161
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
138162
max_steps_reached = True
@@ -206,9 +230,11 @@ def train(
206230
qaic_profile.stop_profiling(device) if train_config.use_profiler else None
207231
if train_config.enable_ddp:
208232
if dist.get_rank() == 0:
209-
model.module.save_pretrained(train_config.output_dir + f"/trained_weights/step_{step}")
233+
model.module.save_pretrained(
234+
train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}"
235+
)
210236
else:
211-
model.save_pretrained(train_config.output_dir + f"/trained_weights/step_{step}")
237+
model.save_pretrained(train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}")
212238

213239
pbar.set_description(
214240
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
@@ -243,17 +269,23 @@ def train(
243269
epoch_times.append(epoch_end_time)
244270

245271
if loss_0_counter.item() == train_config.convergence_counter:
246-
train_epoch_loss = total_loss / step
272+
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
273+
train_epoch_loss = total_loss / (step - intermediate_step)
274+
else:
275+
train_epoch_loss = total_loss / step
247276
else:
248-
train_epoch_loss = total_loss / len(train_dataloader)
277+
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
278+
train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step)
279+
else:
280+
train_epoch_loss = total_loss / len(train_dataloader)
281+
249282
train_perplexity = torch.exp(train_epoch_loss)
250283

251284
train_prep.append(float(train_perplexity))
252285
train_loss.append(float(train_epoch_loss))
253286

254287
# Update the learning rate as needed
255288
lr_scheduler.step()
256-
should_save_model = train_config.save_model
257289

258290
if train_config.run_validation:
259291
if train_config.enable_ddp:
@@ -275,14 +307,14 @@ def train(
275307
if train_config.save_metrics:
276308
val_step_loss.extend(temp_val_loss)
277309
val_step_perplexity.extend(temp_step_perplexity)
278-
should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
279310

280-
if should_save_model:
311+
# saving the adapters after completion of each epoch
312+
if train_config.save_model:
281313
if train_config.enable_ddp:
282314
if dist.get_rank() == 0:
283-
model.module.save_pretrained(train_config.output_dir)
315+
model.module.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")
284316
else:
285-
model.save_pretrained(train_config.output_dir)
317+
model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")
286318

287319
if train_config.run_validation:
288320
if eval_epoch_loss < best_val_loss:
@@ -307,7 +339,6 @@ def train(
307339
val_step_perplexity,
308340
val_prep,
309341
)
310-
311342
avg_epoch_time = sum(epoch_times) / len(epoch_times)
312343
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
313344
avg_train_prep = sum(train_prep) / len(train_prep)

0 commit comments

Comments
 (0)