Skip to content

Commit

Permalink
Add training script and pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
Cylumn committed Aug 2, 2022
1 parent 013cb8e commit 35f668f
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__
.idea
.ipynb_checkpoints
logs/
tester.ipynb
Binary file added daare_pretrained.pt
Binary file not shown.
4 changes: 3 additions & 1 deletion lib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler

from tqdm import tqdm

from data import simulate


Expand All @@ -25,7 +27,7 @@ def __init__(self, n_dataset, args):
self.observations = []
brushes = simulate.read_brushes(args)

for _ in range(n_dataset):
for _ in (tqdm(range(n_dataset), leave=False, position=0, bar_format=args.tqdm_format) if args.verbose else range(n_dataset)):
y = simulate.ground_truth(brushes, args)
x = simulate.noise(y, args)

Expand Down
7 changes: 3 additions & 4 deletions model/daare.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(self,

# Add the first component
self.components = nn.ModuleList()
self.add_cdae(residual=False, norm=norm)

def add_cdae(self,
residual: bool = True,
Expand Down Expand Up @@ -174,7 +173,7 @@ def forward(self,
z_inter = component(torch.cat([x, x_inter], axis=1))

if return_intermediate:
# Return difference of incremental observation and incremental noise
return (x_inter - z_inter).detach()
return x_inter, z_inter
else:
return x_inter, z_inter
# Return difference of incremental observation and incremental noise
return (x_inter - z_inter).detach()
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ matplotlib
numpy
pandas
scikit-learn
torch
tensorboard
torch
tqdm
230 changes: 224 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,221 @@
Date Created: 08/02/2022
'''

import torch
import argparse
import numpy as np
import pandas as pd
from torch import nn
from torch.utils.data import DataLoader

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from data import simulate
from lib.dataset import AKRDataset
from model.daare import DAARE


def init_dataset(args):
"""
Initializes the training and validation datasets.
:param args: Command line arguments.
:return: Returns a tuple (loader_train, loader_valid) of the training dataloader and the validation dataloader.
"""
# Training dataset
if args.verbose:
print(f'> Loading training dataset of size {args.n_train}')
data_train = AKRDataset(args.n_train, args)
# Validation dataset
if args.verbose:
print(f'> Loading validation dataset of size {args.n_valid}')
data_valid = AKRDataset(args.n_valid, args)
loader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True)
loader_valid = DataLoader(data_valid, batch_size=args.batch_size, shuffle=True)

return loader_train, loader_valid


def init_model(args):
"""
Initializes the DAARE model, devices, and torch environment parameters.
:param args: Command line arguments.
:return: Returns a tuple (daare, device) of the DataParallel container for DAARE and the available device.
"""
# GPU Speedup
torch.backends.cudnn.benchmark = True
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
daare = nn.DataParallel(DAARE(depth=args.depth,
hidden_channels=args.n_hidden,
kernel=args.kernel,
norm=True,
img_size=args.img_size),
device_ids=args.device_ids)
daare.to(device)
# Log model
if args.verbose:
print(f'> Device: {device}')
print(f'> DAARE Model:')
print(f'\tDepth: {args.depth}')
print(f'\tHidden Channels: {args.n_hidden}')
print(f'\tKernel: {args.kernel}')

return daare, device


def get_loss(criterion: nn.Module,
daare: nn.Module,
x: torch.Tensor,
y: torch.Tensor):
"""
Calculates loss between true noise and predicted noise.
:param criterion: The criterion to use to calculate difference.
:param daare: The DAARE model.
:param x: The input AKR observation.
:param y: The ground truth AKR.
:return: A Tensor containing a single loss with grad.
"""
# Calculate intermediate observation and noise predictions
x_inter, z_inter = daare(x, return_intermediate=True)
noise = x_inter - y
return criterion(z_inter, noise)


def train_daare(criterion: nn.Module,
daare: nn.Module,
opt: torch.optim.Optimizer,
x: torch.Tensor,
y: torch.Tensor):
"""
Back-propagates DAARE with the given optimizer.
:param criterion: The criterion to use to calculate difference.
:param daare: The DAARE model.
:param opt: The optimizer.
:param x: The input AKR observation.
:param y: The ground truth AKR.
:return: A float containing the loss value.
"""
# Zero the gradients
opt.zero_grad()
# Calculate loss and update
loss = get_loss(criterion, daare, x, y)
loss.backward()
opt.step()
return loss.item()


def run_epoch(n_loader: int,
loader: DataLoader,
is_train: bool,
daare: nn.Module,
criterion: nn.Module,
opt: torch.optim.Optimizer,
device: torch.device,
writer: SummaryWriter,
idx_component: int,
idx_epoch: int,
args):
"""
Runs a single epoch across a given dataloader.
:param n_loader: The number of samples in the dataloader
:param loader: The dataloader.
:param is_train: Whether this epoch should be run in train or validation mode.
:param daare: The DAARE model.
:param criterion: The loss criterion.
:param opt: The optimizer.
:param device: The device to train on.
:param writer: The logs writer.
:param idx_component: The index of the current component.
:param idx_epoch: The index of the current epoch.
:param args: Command line arguments.
:return: Total loss from the epoch.
"""
# Set DAARE to the right mode
if is_train:
daare.train()
else:
daare.eval()

# Run epoch
loss_total = 0
n_batches = int(n_loader / loader.batch_size)
for idx_batch, data in tqdm(enumerate(loader), total=n_batches,
position=0, leave=True, bar_format=args.tqdm_format):
# Load data
x, y = data[0].to(device), data[1].to(device)
# Back-propagate on DAARE
if is_train:
loss = train_daare(criterion, daare, opt, x, y) / n_batches
else:
loss = get_loss(criterion, daare, x, y).item() / n_batches

# Log loss
if not args.disable_logs:
writer.add_scalar(f'Component {idx_component} loss/{("train" if is_train else "valid")}',
loss,
(idx_epoch - 1) * n_batches + idx_batch)
loss_total += loss

return loss_total


def start_training(args):
# Initialization
loader_train, loader_valid = init_dataset(args)
daare, device = init_model(args)
mse_loss = nn.MSELoss()

# Set DAARE to train
daare.train()

# Logs
if args.verbose:
print(f'> Use logs: {not args.disable_logs}')
if not args.disable_logs:
writer = SummaryWriter(f'{args.path_to_logs}/{args.model_name}')
else:
writer = None

if args.verbose:
print(f'> Begin training for {args.n_cdae} components')
# Loop for each CDAE component
for idx_component in range(args.n_cdae):
# Add a new CDAE component
daare.module.add_cdae(residual=(idx_component > 0), norm=(idx_component < args.n_norm))
# Init optimizer
opt = torch.optim.Adam(daare.parameters(), lr=args.learning_rate)

# Training Loop
for idx_epoch in range(1, args.n_epochs_per_cdae + 1):
print(f"CDAE[{idx_component}]: Epoch {idx_epoch} of {args.n_epochs_per_cdae}")

# Train
loss_train = run_epoch(n_loader=args.n_train, loader=loader_train, is_train=True,
daare=daare, criterion=mse_loss, opt=opt, device=device,
writer=writer, idx_component=idx_component, idx_epoch=idx_epoch, args=args)
# Validation
loss_valid = run_epoch(n_loader=args.n_valid, loader=loader_valid, is_train=False,
daare=daare, criterion=mse_loss, opt=opt, device=device,
writer=writer, idx_component=idx_component, idx_epoch=idx_epoch, args=args)

# Flush logs
if not args.disable_logs:
writer.flush()

# Print
print(f"loss_train: {loss_train * 1e4:7.2f}", end=' | ')
print(f"loss_valid: {loss_valid * 1e4:7.2f}")

# Close logs
writer.close()

# Save model
state_dict = {
'state_dict': daare.state_dict(),
'args': args
}
torch.save(state_dict, f'{args.out_path}/{args.model_name}.pt')


def get_args():
Expand All @@ -15,10 +227,17 @@ def get_args():
# Paths
parser.add_argument('--path_to_data', default='data', type=str, help='Path to the data directory.')
parser.add_argument('--path_to_logs', default='logs', type=str, help='Path to the logs directory.')
parser.add_argument('--out_path', default='./', type=str, help='Path to the output directory.')

# Hardware
parser.add_argument('--device_ids', default=[0, 1], type=int, nargs=2,
help="Device ids of the GPUs, if GPUs are available.")

# Options
parser.add_argument('--model_name', default='daare_v1', type=str, help='Name of the model when logging and saving.')
parser.add_argument('--verbose', action='store_true', help='Trains with debugging outputs and print statements.')
parser.add_argument('--tqdm_format', default='{l_bar}{bar:20}{r_bar}{bar:-10b}', type=str,
help='Flag bar_format for the TQDM progress bar.')
parser.add_argument('--disable_logs', action='store_true', help='Disables logging to the output log directory.')
parser.add_argument('--refresh_brushes_file', action='store_true',
help='Rereads brush images and saves them to data/brushes.csv')
Expand Down Expand Up @@ -56,14 +275,16 @@ def get_args():
parser.add_argument('--n_hidden', default=8, type=int, help='Size of each hidden Conv2d layer.')
parser.add_argument('--kernel', default=[13, 5], type=int, nargs=2,
help='Kernel shape for the convolutional layers.')
parser.add_argument('--n_layernorm', default=3, type=int,
parser.add_argument('--n_norm', default=3, type=int,
help='The first n convolutional autoencoders to apply layernorm to.')

# Training parameters
parser.add_argument('--n_train', default=4096, type=int,
help='The number of training samples that are included in the training set.')
parser.add_argument('--n_valid', default=1024, type=int,
help='The number of validation samples that are included in the validation set.')
parser.add_argument('--batch_size', default=16, type=int,
help='Batch size of to use in training and validation.')
parser.add_argument('--n_epochs_per_cdae', default=10, type=int,
help='The number of epochs to train each convolutional denoising autoencoder.')
parser.add_argument('--learning_rate', default=1e-4, type=float,
Expand All @@ -75,16 +296,13 @@ def get_args():
args.kernel = tuple(args.kernel)

# Assertions
assert (args.n_cdae >= args.n_layernorm), 'Number of layernorms is larger than the number of CDAEs.'
assert (args.n_cdae >= args.n_norm), 'Number of layernorms is larger than the number of CDAEs.'

return args


if __name__ == '__main__':
# Get Arguments
args = get_args()
start_training(args)

brushes = simulate.read_brushes(args)
x = simulate.ground_truth(brushes, args)
x = simulate.noise(x, args)
print(x)

0 comments on commit 35f668f

Please sign in to comment.