-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathvqgan.py
363 lines (314 loc) · 14.5 KB
/
vqgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from taming import instantiate_from_config
from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.image_key = image_key
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
return x.float()
def training_step(self, batch, batch_idx, optimizer_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
rec_loss = log_dict_ae["val/rec_loss"]
self.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQSegmentationModel(VQModel):
def __init__(self, n_labels, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
return opt_ae
def training_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
total_loss = log_dict_ae["val/total_loss"]
self.log("val/total_loss", total_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
return aeloss
@torch.no_grad()
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
# convert logits to indices
xrec = torch.argmax(xrec, dim=1, keepdim=True)
xrec = F.one_hot(xrec, num_classes=x.shape[1])
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
return log
class VQNoDiscModel(VQModel):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None
):
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
colorize_nlabels=colorize_nlabels)
def training_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
output = pl.TrainResult(minimize=aeloss)
output.log("train/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True)
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return output
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
rec_loss = log_dict_ae["val/rec_loss"]
output = pl.EvalResult(checkpoint_on=rec_loss)
output.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True)
output.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True)
output.log_dict(log_dict_ae)
return output
def configure_optimizers(self):
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=self.learning_rate, betas=(0.5, 0.9))
return optimizer
class GumbelVQ(VQModel):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
temperature_scheduler_config,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
kl_weight=1e-8,
remap=None,
):
z_channels = ddconfig["z_channels"]
super().__init__(ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=ignore_keys,
image_key=image_key,
colorize_nlabels=colorize_nlabels,
monitor=monitor,
)
self.loss.n_classes = n_embed
self.vocab_size = n_embed
self.quantize = GumbelQuantize(z_channels, embed_dim,
n_embed=n_embed,
kl_weight=kl_weight, temp_init=1.0,
remap=remap)
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def temperature_scheduling(self):
self.quantize.temperature = self.temperature_scheduler(self.global_step)
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode_code(self, code_b):
raise NotImplementedError
def training_step(self, batch, batch_idx, optimizer_idx):
self.temperature_scheduling()
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
rec_loss = log_dict_ae["val/rec_loss"]
self.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
# encode
h = self.encoder(x)
h = self.quant_conv(h)
quant, _, _ = self.quantize(h)
# decode
x_rec = self.decode(quant)
log["inputs"] = x
log["reconstructions"] = x_rec
return log