diff --git a/huggan/pytorch/cyclegan/train.py b/huggan/pytorch/cyclegan/train.py index 3faa50ef..a4f9dfca 100644 --- a/huggan/pytorch/cyclegan/train.py +++ b/huggan/pytorch/cyclegan/train.py @@ -1,4 +1,5 @@ import argparse +import logging import os import numpy as np import itertools @@ -7,7 +8,8 @@ import time import sys -from PIL import Image +from PIL import Image, ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip from torchvision.utils import save_image, make_grid @@ -25,6 +27,9 @@ import torch.nn as nn import torch + +logger = logging.getLogger(__name__) + def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") @@ -39,7 +44,6 @@ def parse_args(args=None): parser.add_argument("--image_size", type=int, default=256, help="Size of images for training") parser.add_argument("--channels", type=int, default=3, help="Number of image channels") parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs") - parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints") parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator") parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight") parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight") @@ -54,17 +58,19 @@ def parse_args(args=None): "and an Nvidia Ampere GPU.", ) parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.") + parser.add_argument("--output_dir", type=Path, default=Path("./output-cyclegan"), help="Name of the directory to dump generated images during training.") + parser.add_argument("--wandb", action="store_true", help="If passed, will log to Weights and Biases.") + parser.add_argument( + "--logging_steps", + type=int, + default=50, + help="Number of steps between each logging", + ) parser.add_argument( "--push_to_hub", action="store_true", help="Whether to push the model to the HuggingFace hub after training.", ) - parser.add_argument( - "--pytorch_dump_folder_path", - required="--push_to_hub" in sys.argv, - type=Path, - help="Path to save the model. Will be created if it doesn't exist already.", - ) parser.add_argument( "--model_name", required="--push_to_hub" in sys.argv, @@ -78,7 +84,16 @@ def parse_args(args=None): type=str, help="Organization name to push to, in case args.push_to_hub is specified.", ) - return parser.parse_args(args=args) + args = parser.parse_args() + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + assert args.model_name is not None, "Need a `model_name` to create a repo when `--push_to_hub` is passed." + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + return args def weights_init_normal(m): @@ -93,11 +108,18 @@ def weights_init_normal(m): def training_function(config, args): + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu, mixed_precision=args.mixed_precision) - - # Create sample and checkpoint directories - os.makedirs("images/%s" % args.dataset_name, exist_ok=True) - os.makedirs("saved_models/%s" % args.dataset_name, exist_ok=True) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + # set up Weights and Biases if requested + if args.wandb: + import wandb + + wandb.init(project=str(args.output_dir).split("/")[-1]) # Losses criterion_GAN = torch.nn.MSELoss() @@ -159,11 +181,8 @@ def training_function(config, args): ]) def transforms(examples): - examples["A"] = [transform(image.convert("RGB")) for image in examples["imageA"]] - examples["B"] = [transform(image.convert("RGB")) for image in examples["imageB"]] - - del examples["imageA"] - del examples["imageB"] + examples["imageA"] = [transform(image.convert("RGB")) for image in examples["imageA"]] + examples["imageB"] = [transform(image.convert("RGB")) for image in examples["imageB"]] return examples @@ -177,14 +196,14 @@ def transforms(examples): dataloader = DataLoader(train_ds, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers) val_dataloader = DataLoader(val_ds, batch_size=5, shuffle=True, num_workers=1) - def sample_images(batches_done): + def sample_images(): """Saves a generated sample from the test set""" batch = next(iter(val_dataloader)) G_AB.eval() G_BA.eval() - real_A = batch["A"] + real_A = batch["imageA"] fake_B = G_AB(real_A) - real_B = batch["B"] + real_B = batch["imageB"] fake_A = G_BA(real_B) # Arange images along x-axis real_A = make_grid(real_A, nrow=5, normalize=True) @@ -193,7 +212,8 @@ def sample_images(batches_done): fake_B = make_grid(fake_B, nrow=5, normalize=True) # Arange images along y-axis image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) - save_image(image_grid, "images/%s/%s.png" % (args.dataset_name, batches_done), normalize=False) + + return image_grid G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D_A, optimizer_D_B, dataloader, val_dataloader = accelerator.prepare(G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D_A, optimizer_D_B, dataloader, val_dataloader) @@ -203,11 +223,11 @@ def sample_images(batches_done): prev_time = time.time() for epoch in range(args.epoch, args.num_epochs): - for i, batch in enumerate(dataloader): + for step, batch in enumerate(dataloader): # Set model input - real_A = batch["A"] - real_B = batch["B"] + real_A = batch["imageA"] + real_B = batch["imageB"] # Adversarial ground truths valid = torch.ones((real_A.size(0), *output_shape), device=accelerator.device) @@ -291,52 +311,62 @@ def sample_images(batches_done): # -------------- # Determine approximate time left - batches_done = epoch * len(dataloader) + i + batches_done = epoch * len(dataloader) + step batches_left = args.num_epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() - # Print log - sys.stdout.write( - "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" - % ( - epoch, - args.num_epochs, - i, - len(dataloader), - loss_D.item(), - loss_G.item(), - loss_GAN.item(), - loss_cycle.item(), - loss_identity.item(), - time_left, - ) - ) + # Log results + if (step + 1) % args.logging_steps == 0: + loss_D.detach() + loss_G.detach() + loss_GAN.detach() + loss_cycle.detach() + loss_identity.detach() + + if accelerator.state.num_processes > 1: + loss_D = accelerator.gather(loss_D).sum() / accelerator.state.num_processes + loss_G = accelerator.gather(loss_G).sum() / accelerator.state.num_processes + loss_GAN = accelerator.gather(loss_GAN).sum() / accelerator.state.num_processes + loss_cycle = accelerator.gather(loss_cycle).sum() / accelerator.state.num_processes + loss_identity = accelerator.gather(loss_identity).sum() / accelerator.state.num_processes + + train_logs = { + "epoch": epoch, + "discriminator_loss": loss_D, + "generator_loss": loss_G, + "GAN_loss": loss_GAN, + "cycle_loss": loss_cycle, + "identity_loss": loss_identity, + # "time_left": time_left, + } + log_str = "" + for k, v in train_logs.items(): + log_str += "| {}: {:.3e}".format(k, v) + + if accelerator.is_local_main_process: + logger.info(log_str) + if args.wandb: + wandb.log(train_logs) # If at sample interval save image if batches_done % args.sample_interval == 0: - sample_images(batches_done) + image_grid = sample_images() + file_name = args.output_dir / f"{batches_done}.png" + save_image(image_grid, file_name, normalize=False) + if accelerator.is_local_main_process and args.wandb: + wandb.log({'generated_examples': wandb.Image(str(file_name)) }) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() - if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0: - # Save model checkpoints - torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (args.dataset_name, epoch)) - torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (args.dataset_name, epoch)) - torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (args.dataset_name, epoch)) - torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (args.dataset_name, epoch)) - # Optionally push to hub - if args.push_to_hub: - save_directory = args.pytorch_dump_folder_path - if not save_directory.exists(): - save_directory.mkdir(parents=True) - - G_AB.push_to_hub( - repo_path_or_name=save_directory / args.model_name, + if accelerator.is_main_process and args.push_to_hub: + model = accelerator.unwrap_model(G_AB) + model.module.push_to_hub( + repo_path_or_name=args.output_dir / args.model_name, organization=args.organization_name, ) @@ -344,9 +374,6 @@ def main(): args = parse_args() print(args) - # Make directory for saving generated images - os.makedirs("images", exist_ok=True) - training_function({}, args)