Skip to content

Commit fdb1759

Browse files
AnandK27sayakpaul
authored andcommitted
[train_custom_diffusion.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#9308)
* Update train_custom_diffusion.py to fix the LR schedulers for `num_train_epochs` * Fix saving text embeddings during safe serialization * Fixed formatting
1 parent 348578e commit fdb1759

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

examples/custom_diffusion/train_custom_diffusion.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,12 @@ def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_di
314314
for x, y in zip(modifier_token_id, args.modifier_token):
315315
learned_embeds_dict = {}
316316
learned_embeds_dict[y] = learned_embeds[x]
317-
filename = f"{output_dir}/{y}.bin"
318317

319318
if safe_serialization:
319+
filename = f"{output_dir}/{y}.safetensors"
320320
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
321321
else:
322+
filename = f"{output_dir}/{y}.bin"
322323
torch.save(learned_embeds_dict, filename)
323324

324325

@@ -1040,17 +1041,22 @@ def main(args):
10401041
)
10411042

10421043
# Scheduler and math around the number of training steps.
1043-
overrode_max_train_steps = False
1044-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1044+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1045+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
10451046
if args.max_train_steps is None:
1046-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1047-
overrode_max_train_steps = True
1047+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1048+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1049+
num_training_steps_for_scheduler = (
1050+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1051+
)
1052+
else:
1053+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
10481054

10491055
lr_scheduler = get_scheduler(
10501056
args.lr_scheduler,
10511057
optimizer=optimizer,
1052-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1053-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1058+
num_warmup_steps=num_warmup_steps_for_scheduler,
1059+
num_training_steps=num_training_steps_for_scheduler,
10541060
)
10551061

10561062
# Prepare everything with our `accelerator`.
@@ -1065,8 +1071,14 @@ def main(args):
10651071

10661072
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
10671073
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1068-
if overrode_max_train_steps:
1074+
if args.max_train_steps is None:
10691075
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1076+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1077+
logger.warning(
1078+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1079+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1080+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1081+
)
10701082
# Afterwards we recalculate our number of training epochs
10711083
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
10721084

0 commit comments

Comments
 (0)