Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions sandbox/autoencoder/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/bin/python

import os
import glob
import numpy as np

import torch
from torch.utils import data
from imageio import imread

chars = [
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O',
'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd',
'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
't', 'u', 'v', 'w', 'x', 'y', 'z', 'zero', 'one', 'two', 'three', 'four',
'five', 'six', 'seven', 'eight', 'nine', 'exclam', 'numbersign', 'dollar',
'percent', 'ampersand', 'asterisk', 'question', 'at'
]

char_dict = dict(zip(chars, range(len(chars))))

class Dataset(data.Dataset):
def __init__(self, path, conditional=False):
count_file = [
f for f in os.listdir(path)
if os.path.isfile(os.path.join(path, f))
]
shards = [
f for f in os.listdir(path)
if os.path.isdir(os.path.join(path, f))
]

#assert len(count_file) == 1

self.path = path
self.conditional = conditional
self.num_examples = 1000000 #int(count_file[0])
self.num_shards = len(shards)

def __len__(self):
return self.num_examples

def __getitem__(self, index):
image_file = os.path.join(
self.path, "shard_{}/{}.png".format(index % self.num_shards,
index))
contour_file = os.path.join(
self.path, "shard_{}/{}.pts.npy".format(index % self.num_shards,
index))

if self.conditional:
cat_file = os.path.join(
self.path, "shard_{}/{}.cat".format(index % self.num_shards,
index))
with open(cat_file, 'r') as f:
category = char_dict[f.read()]
else:
category = []

return np.load(contour_file).transpose(0, 2, 1), (
np.array(imread(image_file)).transpose(2, 0, 1) / 256.).astype(
np.float32), category


if __name__ == "__main__":
ds = Dataset("../pil/renders/160_samples/")
dl = data.DataLoader(
ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

for e in dl:
print(e[0].shape)
print(e[1].shape)
exit()
153 changes: 153 additions & 0 deletions sandbox/autoencoder/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#!/bin/python

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

def AdaIN(x, gains, biases):
assert len(x.shape) == 3
eps = 1e-8
mean = x.mean(2, keepdim=True)
rstddev = torch.rsqrt(x.var(2, keepdim=True) + eps)

normed = (x - mean) * rstddev

return gains * normed + biases

class MiniBlock(nn.Module):
def __init__(self, num_channels, kernel_size):
super(MiniBlock, self).__init__()

self.conv1 = nn.Conv1d(
num_channels, num_channels, kernel_size, bias=False, padding=(kernel_size - 1) // 2)
self.conv2 = nn.Conv1d(
num_channels, num_channels, kernel_size, bias=False, padding=(kernel_size - 1) // 2)

self.bn1 = nn.BatchNorm1d(num_channels)
self.bn2 = nn.BatchNorm1d(num_channels)

def forward(self, x):
a = F.relu(self.bn1(x))
a = self.conv1(a)

a = F.relu(self.bn2(a))
a = self.conv2(a)

return a + x

class ConditionalMiniBlock(nn.Module):
def __init__(self, num_channels, kernel_size, num_classes):
super(ConditionalMiniBlock, self).__init__()
self.num_channels = num_channels

self.conv1 = nn.Conv1d(
num_channels, num_channels, kernel_size, bias=False, padding=(kernel_size - 1) // 2)
self.conv2 = nn.Conv1d(
num_channels, num_channels, kernel_size, bias=False, padding=(kernel_size - 1) // 2)

self.gains = nn.Embedding(num_classes, 2*num_channels)
self.biases = nn.Embedding(num_classes, 2*num_channels)

self.bn1 = nn.BatchNorm1d(num_channels)
self.bn2 = nn.BatchNorm1d(num_channels)

def forward(self, x, category):
gains = self.gains(category)
biases = self.biases(category)

a = F.relu(AdaIN(x,
gains[:,:self.num_channels].view(-1,self.num_channels,1),
biases[:,:self.num_channels].view(-1, self.num_channels,1)))
a = self.conv1(a)

a = F.relu(AdaIN(x, gains[:, self.num_channels:].view(-1,
self.num_channels,1),
biases[:,self.num_channels:].view(-1, self.num_channels,1)))
a = self.conv2(a)

return a + x


class Block(nn.Module):
def __init__(self, num_channels, kernel_size, **kwargs):
super(Block, self).__init__()

self.mini_block1 = MiniBlock(num_channels, kernel_size)
self.mini_block2 = MiniBlock(num_channels, kernel_size)

def forward(self, x, **kwargs):
a = self.mini_block1(x)
a = self.mini_block2(a)

return a

class ConditionalBlock(nn.Module):
def __init__(self, num_channels, kernel_size, num_classes):
super(ConditionalBlock, self).__init__()

self.mini_block1 = ConditionalMiniBlock(num_channels, kernel_size, num_classes)
self.mini_block2 = ConditionalMiniBlock(num_channels, kernel_size, num_classes)

def forward(self, x, category):
a = self.mini_block1(x, category)
a = self.mini_block2(a, category)

return a


class Decoder(nn.Module):
def __init__(self,
in_channels=1,
num_channels=[64, 64, 64, 64],
kernel_size=3,
class_conditional=False,
num_classes=None
):
super(Decoder, self).__init__()

if class_conditional:
assert num_classes is not None
block = ConditionalBlock
else:
block = Block

self.conv1 = nn.Conv1d(
in_channels, num_channels[0], kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

self.transitions = nn.ModuleDict(
{(str(i),
nn.Conv1d(num_channels[i], num_channels[i+1], kernel_size=1))
for i, a in enumerate(np.diff(num_channels)) if a > 0})

self.blocks = nn.ModuleList(
[block(c, kernel_size, num_classes=num_classes) for c in num_channels])

self.conv_out = nn.Conv1d(
num_channels[-1], 3*4, kernel_size=1)

self.called = False

def forward(self, x, category=None):
a = self.conv1(x)

for i, block in enumerate(self.blocks):
a = block(a, category=category)
if str(i) in self.transitions:
transition = self.transitions[str(i)]
a = transition(a)

if not self.called:
print(a.shape)
self.called = True

a = self.conv_out(a).reshape(a.shape[0], 3, 4, -1)

return torch.cat([torch.roll(a[:, :, 2:, :], 1, dims=-1), a], dim=2)


if __name__ == "__main__":
dec = Decoder()
z = dec(torch.randn(16, 2, 256))
print(z.shape)
118 changes: 118 additions & 0 deletions sandbox/autoencoder/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/bin/python

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from CoordConv import CoordConv


class MiniBlock(nn.Module):
def __init__(self, num_channels, kernel_size):
super(MiniBlock, self).__init__()

self.conv1 = nn.Conv2d(
num_channels, num_channels, kernel_size, bias=False, padding=(kernel_size - 1) // 2)
self.conv2 = nn.Conv2d(
num_channels, num_channels, kernel_size, bias=False, padding=(kernel_size - 1) // 2)

self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)

def forward(self, x):
a = F.relu(self.bn1(x))
a = self.conv1(a)

a = F.relu(self.bn2(a))
a = self.conv2(a)

return a + x


class Block(nn.Module):
def __init__(self, num_channels, kernel_size):
super(Block, self).__init__()

self.mini_block1 = MiniBlock(num_channels, kernel_size)
self.mini_block2 = MiniBlock(num_channels, kernel_size)

def forward(self, x):
a = self.mini_block1(x)
a = self.mini_block2(a)

return a


class Encoder(nn.Module):
def __init__(self,
in_channels=3,
num_channels=[16, 32, 32, 64, 64, 128, 256],
kernel_size=3,
code_size=16,
pooling=None):
super(Encoder, self).__init__()

self.conv1 = CoordConv(
in_channels, num_channels[0], kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

self.transitions = nn.ModuleDict(
{(str(i),
nn.Conv2d(num_channels[i], num_channels[i+1], kernel_size=1))
for i, a in enumerate(np.diff(num_channels)) if a > 0})
print(self.transitions)

self.blocks = nn.ModuleList(
[Block(c, kernel_size) for c in num_channels])

self.pooling = pooling

self.fc = nn.Linear(256, 16)

self.bn_out = nn.BatchNorm1d(16, affine=False)

self.called = False

def forward(self, x):
a = self.conv1(x)

for i, block in enumerate(self.blocks):
a = block(a)
if str(i) in self.transitions:
transition = self.transitions[str(i)]
a = transition(a)

if self.pooling is not None:
a = self.pooling(a)

if not self.called:
print(a.shape)
self.called = True

return self.bn_out(self.fc(a.reshape(a.shape[0], -1)))

class Discriminator(nn.Module):
def __init__(self, code_size=16):
super(Discriminator, self).__init__()

num_features = 128

self.fc1 = nn.Linear(code_size, num_features)
self.fc2 = nn.Linear(num_features, 1)

for m in self.modules():
if 'weight' in m._parameters:
nn.utils.spectral_norm(m)

def forward(self, x):
x = x.squeeze()
a = F.relu(self.fc1(x))

return self.fc2(a)


if __name__ == "__main__":
enc = Encoder(pooling=nn.MaxPool2d(2,2))
z = enc(torch.randn(32, 2, 256, 256))
print(z.shape)
Loading