Skip to content

Commit b9d52fc

Browse files
[train_lcm_distill_lora_sdxl.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#8446)
fix num_train_epochs Co-authored-by: Sayak Paul <[email protected]>
1 parent 2ada094 commit b9d52fc

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,11 +1111,16 @@ def compute_time_ids(original_size, crops_coords_top_left):
11111111

11121112
# 15. LR Scheduler creation
11131113
# Scheduler and math around the number of training steps.
1114-
overrode_max_train_steps = False
1115-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1114+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1115+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
11161116
if args.max_train_steps is None:
1117-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1118-
overrode_max_train_steps = True
1117+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1118+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1119+
num_training_steps_for_scheduler = (
1120+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1121+
)
1122+
else:
1123+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
11191124

11201125
if args.scale_lr:
11211126
args.learning_rate = (
@@ -1130,8 +1135,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
11301135
lr_scheduler = get_scheduler(
11311136
args.lr_scheduler,
11321137
optimizer=optimizer,
1133-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1134-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1138+
num_warmup_steps=num_warmup_steps_for_scheduler,
1139+
num_training_steps=num_training_steps_for_scheduler,
11351140
)
11361141

11371142
# 16. Prepare for training
@@ -1142,8 +1147,14 @@ def compute_time_ids(original_size, crops_coords_top_left):
11421147

11431148
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11441149
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1145-
if overrode_max_train_steps:
1150+
if args.max_train_steps is None:
11461151
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1152+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1153+
logger.warning(
1154+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1155+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1156+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1157+
)
11471158
# Afterwards we recalculate our number of training epochs
11481159
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
11491160

0 commit comments

Comments
 (0)