From c140e976cbc0924356713bbb6c3808d71579bbbd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 9 Apr 2025 09:44:48 +0300 Subject: [PATCH 1/3] add fix --- .../train_dreambooth_lora_flux_advanced.py | 27 ++++++++++++++----- .../dreambooth/train_dreambooth_lora_flux.py | 27 ++++++++++++++----- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index f45e0a51d226..11a424f01eb4 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1915,17 +1915,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): free_memory() # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil( + len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, num_cycles=args.lr_num_cycles, power=args.lr_power, ) @@ -1961,7 +1967,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index debdafd04ba1..f1b6d4d2a3cf 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1524,17 +1524,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): free_memory() # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil( + len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, num_cycles=args.lr_num_cycles, power=args.lr_power, ) @@ -1561,7 +1567,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) From 757ac183f0a5328caf26c4b19fc55407e657774d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 9 Apr 2025 09:55:26 +0300 Subject: [PATCH 2/3] add fix --- .../train_dreambooth_lora_flux_advanced.py | 2 -- examples/dreambooth/train_dreambooth_flux.py | 26 ++++++++++++++----- .../dreambooth/train_dreambooth_lora_flux.py | 1 - 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 11a424f01eb4..b5c55bfabfb9 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1955,7 +1955,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): lr_scheduler, ) else: - print("I SHOULD BE HERE") transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler ) @@ -1975,7 +1974,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " f"This inconsistency may result in the learning rate scheduler not functioning properly." ) - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 6b5adb7a10a8..cd4cc9215df3 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1407,17 +1407,23 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil( + len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, num_cycles=args.lr_num_cycles, power=args.lr_power, ) @@ -1444,8 +1450,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index f1b6d4d2a3cf..87ec166f8381 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1575,7 +1575,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " f"This inconsistency may result in the learning rate scheduler not functioning properly." ) - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) From bfd3ddee18d65b690e66b0f91ab7426ba2a4c209 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 21 Apr 2025 07:04:16 +0000 Subject: [PATCH 3/3] Apply style fixes --- .../train_dreambooth_lora_flux_advanced.py | 5 ++--- examples/dreambooth/train_dreambooth_flux.py | 5 ++--- examples/dreambooth/train_dreambooth_lora_flux.py | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index a86372fcd6ec..b756fba7d659 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1919,10 +1919,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) - num_update_steps_per_epoch = math.ceil( - len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) num_training_steps_for_scheduler = ( - args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch ) else: num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 5fb1725eeac2..fcc1737e6588 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1411,10 +1411,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) - num_update_steps_per_epoch = math.ceil( - len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) num_training_steps_for_scheduler = ( - args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch ) else: num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 8db83006bdca..9de4973b6f64 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1528,10 +1528,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) - num_update_steps_per_epoch = math.ceil( - len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) num_training_steps_for_scheduler = ( - args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch ) else: num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes