@@ -112,6 +112,15 @@ def train(
112112 f"Not proceeding with epoch { epoch + 1 } since loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps."
113113 )
114114 break
115+
116+ if train_config .use_peft and train_config .from_peft_checkpoint :
117+ intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
118+ if epoch < intermediate_epoch :
119+ print (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
120+ # to bring the count of train_step in sync with where it left off
121+ total_train_steps += len (train_dataloader )
122+ continue
123+
115124 print (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
116125 print (f"train_config.max_train_step: { train_config .max_train_step } " )
117126 # stop when the maximum number of training steps is reached
@@ -131,8 +140,23 @@ def train(
131140
132141 # enable profile for qaic
133142 qaic_profile .start_profiling (device , 1 ) if train_config .use_profiler else None
143+
134144 for step , batch in enumerate (train_dataloader ):
145+ # resume training from a particular checkpoint, assuming the dataset is not shuffled
146+ if train_config .use_peft and train_config .from_peft_checkpoint :
147+ intermediate_step = int (train_config .from_peft_checkpoint .split ("/" )[- 1 ].split ("_" )[- 1 ])
148+ intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
149+ # to bring the count of train_step in sync with where it left off
150+ if epoch == intermediate_epoch and step == 0 :
151+ total_train_steps += intermediate_step
152+ print (
153+ f"skipping first { intermediate_step } steps for epoch { epoch + 1 } , since fine tuning has already completed for them."
154+ )
155+ if epoch == intermediate_epoch and step < intermediate_step :
156+ total_train_steps += 1
157+ continue
135158 total_train_steps += 1
159+
136160 # stop when the maximum number of training steps is reached
137161 if train_config .max_train_step > 0 and total_train_steps > train_config .max_train_step :
138162 max_steps_reached = True
@@ -206,9 +230,11 @@ def train(
206230 qaic_profile .stop_profiling (device ) if train_config .use_profiler else None
207231 if train_config .enable_ddp :
208232 if dist .get_rank () == 0 :
209- model .module .save_pretrained (train_config .output_dir + f"/trained_weights/step_{ step } " )
233+ model .module .save_pretrained (
234+ train_config .output_dir + f"/trained_weights/epoch_{ epoch + 1 } /step_{ step } "
235+ )
210236 else :
211- model .save_pretrained (train_config .output_dir + f"/trained_weights/step_{ step } " )
237+ model .save_pretrained (train_config .output_dir + f"/trained_weights/epoch_ { epoch + 1 } / step_{ step } " )
212238
213239 pbar .set_description (
214240 f"Training Epoch: { epoch + 1 } /{ train_config .num_epochs } , step { step + 1 } /{ len (train_dataloader )} completed (loss: { loss .detach ().float ()} )"
@@ -243,17 +269,23 @@ def train(
243269 epoch_times .append (epoch_end_time )
244270
245271 if loss_0_counter .item () == train_config .convergence_counter :
246- train_epoch_loss = total_loss / step
272+ if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
273+ train_epoch_loss = total_loss / (step - intermediate_step )
274+ else :
275+ train_epoch_loss = total_loss / step
247276 else :
248- train_epoch_loss = total_loss / len (train_dataloader )
277+ if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
278+ train_epoch_loss = total_loss / (len (train_dataloader ) - intermediate_step )
279+ else :
280+ train_epoch_loss = total_loss / len (train_dataloader )
281+
249282 train_perplexity = torch .exp (train_epoch_loss )
250283
251284 train_prep .append (float (train_perplexity ))
252285 train_loss .append (float (train_epoch_loss ))
253286
254287 # Update the learning rate as needed
255288 lr_scheduler .step ()
256- should_save_model = train_config .save_model
257289
258290 if train_config .run_validation :
259291 if train_config .enable_ddp :
@@ -275,14 +307,14 @@ def train(
275307 if train_config .save_metrics :
276308 val_step_loss .extend (temp_val_loss )
277309 val_step_perplexity .extend (temp_step_perplexity )
278- should_save_model = train_config .save_model and eval_epoch_loss < best_val_loss
279310
280- if should_save_model :
311+ # saving the adapters after completion of each epoch
312+ if train_config .save_model :
281313 if train_config .enable_ddp :
282314 if dist .get_rank () == 0 :
283- model .module .save_pretrained (train_config .output_dir )
315+ model .module .save_pretrained (train_config .output_dir + f"/complete_epoch_ { epoch + 1 } " )
284316 else :
285- model .save_pretrained (train_config .output_dir )
317+ model .save_pretrained (train_config .output_dir + f"/complete_epoch_ { epoch + 1 } " )
286318
287319 if train_config .run_validation :
288320 if eval_epoch_loss < best_val_loss :
@@ -307,7 +339,6 @@ def train(
307339 val_step_perplexity ,
308340 val_prep ,
309341 )
310-
311342 avg_epoch_time = sum (epoch_times ) / len (epoch_times )
312343 avg_checkpoint_time = sum (checkpoint_times ) / len (checkpoint_times ) if len (checkpoint_times ) > 0 else 0
313344 avg_train_prep = sum (train_prep ) / len (train_prep )
0 commit comments