@@ -1111,11 +1111,16 @@ def compute_time_ids(original_size, crops_coords_top_left):
1111
1111
1112
1112
# 15. LR Scheduler creation
1113
1113
# Scheduler and math around the number of training steps.
1114
- overrode_max_train_steps = False
1115
- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
1114
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1115
+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
1116
1116
if args .max_train_steps is None :
1117
- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1118
- overrode_max_train_steps = True
1117
+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
1118
+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
1119
+ num_training_steps_for_scheduler = (
1120
+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
1121
+ )
1122
+ else :
1123
+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
1119
1124
1120
1125
if args .scale_lr :
1121
1126
args .learning_rate = (
@@ -1130,8 +1135,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
1130
1135
lr_scheduler = get_scheduler (
1131
1136
args .lr_scheduler ,
1132
1137
optimizer = optimizer ,
1133
- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
1134
- num_training_steps = args . max_train_steps * accelerator . num_processes ,
1138
+ num_warmup_steps = num_warmup_steps_for_scheduler ,
1139
+ num_training_steps = num_training_steps_for_scheduler ,
1135
1140
)
1136
1141
1137
1142
# 16. Prepare for training
@@ -1142,8 +1147,14 @@ def compute_time_ids(original_size, crops_coords_top_left):
1142
1147
1143
1148
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
1144
1149
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
1145
- if overrode_max_train_steps :
1150
+ if args . max_train_steps is None :
1146
1151
args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1152
+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
1153
+ logger .warning (
1154
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
1155
+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
1156
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1157
+ )
1147
1158
# Afterwards we recalculate our number of training epochs
1148
1159
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
1149
1160
0 commit comments