Skip to content

Commit

Permalink
changing params to go fast on H200s (#64)
Browse files Browse the repository at this point in the history
* changing params to go fast on H200s
  • Loading branch information
daanelson authored Jan 23, 2025
1 parent 0bf7bb1 commit 9e68dd2
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def train(
description="Regular expression to match specific layers to optimize. Optimizing fewer layers results in shorter training times, but can also result in a weaker LoRA. For example, To target layers 7, 12, 16, 20 which seems to create good likeness with faster training (as discovered by lux in the Ostris discord, inspired by The Last Ben), use `transformer.single_transformer_blocks.(7|12|16|20).proj_out`.",
default=None,
),
gradient_checkpointing: bool = Input(
description="Turn on gradient checkpointing; saves memory at the cost of training speed. Automatically enabled for batch sizes > 1.",
default=False,
),
hf_repo_id: str = Input(
description="Hugging Face repository ID, if you'd like to upload the trained LoRA to Hugging Face. For example, lucataco/flux-dev-lora. If the given repo does not exist, a new public repo will be created.",
default=None,
Expand Down Expand Up @@ -223,11 +227,32 @@ def train(
f"The regex '{layers_to_optimize_regex}' didn't match any layers. These layers can be optimized:\n"
+ "\n".join(available_layers_to_optimize)
)
quantize = False
resolutions = [int(res) for res in resolution.split(",")]

sample_prompts = []
if wandb_sample_prompts:
sample_prompts = [p.strip() for p in wandb_sample_prompts.split("\n")]

if not gradient_checkpointing:
if (
torch.cuda.get_device_properties(0).total_memory
< 1024 * 1024 * 1024 * 100 # memory < 100 GB?
):
print(
"Turning gradient checkpointing on and quantizing base model, GPU has less than 100 GB of memory"
)
gradient_checkpointing = True
quantize = True
elif batch_size > 1:
print("Turning gradient checkpointing on automatically for batch size > 1")
gradient_checkpointing = True
elif max(resolutions) > 1024:
print(
"Turning gradient checkpointing on; training resolution greater than 1024x1024"
)
gradient_checkpointing = True

train_config = OrderedDict(
{
"job": "custom_job",
Expand Down Expand Up @@ -260,9 +285,7 @@ def train(
# TODO: Do we need to cache to disk? It's faster not to.
"cache_latents_to_disk": cache_latents_to_disk,
"cache_latents": True,
"resolution": [
int(res) for res in resolution.split(",")
],
"resolution": resolutions,
}
],
"train": {
Expand All @@ -272,7 +295,7 @@ def train(
"train_unet": True,
"train_text_encoder": False,
"content_or_style": "balanced",
"gradient_checkpointing": True,
"gradient_checkpointing": gradient_checkpointing,
"noise_scheduler": "flowmatch",
"optimizer": optimizer,
"lr": learning_rate,
Expand All @@ -282,7 +305,7 @@ def train(
"model": {
"name_or_path": str(WEIGHTS_PATH),
"is_flux": True,
"quantize": True,
"quantize": quantize,
},
"sample": {
"sampler": "flowmatch",
Expand Down

0 comments on commit 9e68dd2

Please sign in to comment.