-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
30 lines (21 loc) · 779 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import yaml
from inpainting.train import GAN
from inpainting.data import ImageDataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchvision.utils import make_grid
import numpy as np
import imageio
config = yaml.load(open('model_config.yaml'))
PATH = 'data/test_256'
print(PATH)
image_loader = DataLoader(
ImageDataset(PATH),
batch_size=12, shuffle=True, num_workers=16, drop_last=True
)
imageio.imwrite('color_debug_2.png', (np.moveaxis(ImageDataset(PATH)[0], 0, -1) + 1) / 2)
imageio.imwrite('color_debug_1.png',
(np.moveaxis(make_grid(next(iter(image_loader))).cpu().detach().numpy(), 0, -1) + 1) / 2)
gan = GAN(config['Model'], config['OptParams'])
trainer = pl.Trainer(gpus=1)
trainer.fit(gan, image_loader)