Skip to content

Commit

Permalink
Merge pull request #4 from gcervantes8/hf-accelerate
Browse files Browse the repository at this point in the history
Added Accelerate support from HuggingFace - MultiGPU & distributed support!
  • Loading branch information
gcervantes8 authored Dec 20, 2023
2 parents 5756b6e + 68325fd commit 7831e70
Show file tree
Hide file tree
Showing 19 changed files with 464 additions and 381 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
![Fast Image Gans with a picture of a fig to the left of it](logo/FigsName.png)
[![Python](https://img.shields.io/badge/Python-3.7--3.11-blue)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-GPL--3.0-yellow)](https://github.com/gcervantes8/Game-Image-Generator/blob/master/LICENSE) [![Python package](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml/badge.svg)](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml)

[![Python](https://img.shields.io/badge/Python-3.8--3.11-blue)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-GPL--3.0-yellow)](https://github.com/gcervantes8/Game-Image-Generator/blob/master/LICENSE) [![Python package](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml/badge.svg)](https://github.com/gcervantes8/Game-Image-Generator/actions/workflows/python-package.yml)


With this project, you can train Generative Adversarial Networks (GANs). While you can train with any type of image,
this repository focuses on generating images from games.

## Features

- PyTorch 2.0 Compile
- Mixed Precision training
- PyTorch 2 Compile
- Mixed Precision training (fp16 or bf16)
- Gradient Accumulation
- Inception Score and FID evaluation
- HF🤗 Accelerate - Adds Multi-GPU, and TPU support
- Easy to start training
- Testing

Expand All @@ -26,8 +29,8 @@ Provided in the code is a sample of the coil-100 dataset, which is used for test

## Requirements
The following are the Python packages needed.
- [Pytorch](https://pytorch.org/get-started/locally/), 1.9+
- [torchvision](https://pypi.org/project/torchvision/) 0.9+
- [Pytorch](https://pytorch.org/get-started/locally/), 2.0+
- [torchvision](https://pypi.org/project/torchvision/) 1.5+
- [SciPy](https://scipy.org/install/) 1.7+
- [TorchMetrics](https://torchmetrics.readthedocs.io/en/stable/)
- [torchinfo](https://github.com/TylerYep/torchinfo)
Expand Down
8 changes: 5 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
torch >= 1.9
torchvision >= 0.9
torch >= 2.0
torchvision >= 0.15
torchinfo
torch-ema
Pillow
torchmetrics
scipy >= 1.7
tqdm
tensorboard
six
six
bitsandbytes
accelerate
28 changes: 17 additions & 11 deletions src/data_load.py → src/data/data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
import PIL
import torchvision.datasets as torch_data_set
import torchvision.transforms as transforms
from src import os_helper
import torchvision.transforms.v2 as transforms
from src.utils import os_helper


def normalize(images, norm_mean=torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32),
Expand All @@ -38,13 +38,13 @@ def color_transform(images, brightness=0.1, contrast=0.05, saturation=0.1, hue=0
return train_transform_augment(images)


def data_loader_from_config(data_config, data_dtype=torch.float32, using_gpu=False):
def data_loader_from_config(data_config, using_gpu=False):
data_dir = data_config['train_dir']
os_helper.is_valid_dir(data_dir, 'Invalid training data directory\nPath is an invalid directory: ' + data_dir)
image_height, image_width = get_image_height_and_width(data_config)
batch_size = int(data_config['batch_size'])
n_workers = int(data_config['workers'])
return create_data_loader(data_dir, image_height, image_width, dtype=data_dtype, using_gpu=using_gpu,
return create_data_loader(data_dir, image_height, image_width, using_gpu=using_gpu,
batch_size=batch_size, n_workers=n_workers)


Expand All @@ -69,28 +69,34 @@ def get_num_classes(data_config):
data_loader = data_loader_from_config(data_config)
return len(data_loader.dataset.classes)

def to_int16(label):
return torch.tensor(label, dtype=torch.int16)

def create_data_loader(data_dir: str, image_height: int, image_width: int, dtype=torch.float32, using_gpu=False,
def create_data_loader(data_dir: str, image_height: int, image_width: int, image_dtype=torch.float16, using_gpu=False,
batch_size=1, n_workers=1):

data_transform = transforms.Compose([transforms.Resize((image_height, image_width)),
transforms.ToTensor(),
transforms.ConvertImageDtype(dtype)
transforms.ToImage(),
transforms.ToDtype(image_dtype, scale=True), # Float16 is tiny bit faster, and bit more VRAM. Strange.
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
label_transform = to_int16
try:
data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform)
data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform, target_transform=label_transform)
except FileNotFoundError:
raise FileNotFoundError('Data directory provided should contain directories that have images in them, '
'directory provided: ' + data_dir)

# Create the data-loader
torch_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size,
shuffle=True, num_workers=n_workers, pin_memory=using_gpu)
torch_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True,
num_workers=n_workers, pin_memory=using_gpu, drop_last=True)
return torch_loader


# Returns images of size: (batch_size, num_channels, height, width)
def get_data_batch(data_loader, device):
def get_data_batch(data_loader, device, unnormalize_batch=False):
if unnormalize_batch:
return unnormalize(next(iter(data_loader))[0]).to(device)
return next(iter(data_loader))[0].to(device)


Expand Down
8 changes: 5 additions & 3 deletions src/gan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
from torchvision import transforms
import math
import PIL
from src import saver_and_loader, os_helper
from utils import saver_and_loader
from src.configs import ini_parser
# from src.metrics import score_metrics
from src.data_load import unnormalize, get_num_classes, create_latent_vector
from src import create_model
from data.data_load import unnormalize, get_num_classes, create_latent_vector
from models import create_model

import os
import logging
import argparse

from utils import os_helper


def generate_batch_image(ini_config, gan_model, num_images: int):
model_arch_config, data_config = ini_config['MODEL ARCHITECTURE'], ini_config['DATA']
Expand Down
222 changes: 0 additions & 222 deletions src/gan_model.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/losses/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def supported_losses():
# Given a loss function returns a 3-tuple
# The loss function, the fake label, and the real label
# Returns 3-tuple of None if the loss function is not supported
def supported_loss_functions(loss_name: str, device=None):
def supported_loss_functions(loss_name: str):
loss_functions = _losses()
if loss_name in loss_functions:
loss_fn, fake_label, real_label = loss_functions[loss_name.lower()]
Expand Down
Loading

0 comments on commit 7831e70

Please sign in to comment.