Skip to content

Commit

Permalink
move stuff to ddpm
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Capelle committed Mar 15, 2023
1 parent 0352354 commit 49386a9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 40 deletions.
38 changes: 38 additions & 0 deletions cloud_diffusion/ddpm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from pathlib import Path
from functools import partial

import torch, wandb
from torch.nn import init
from torch.utils.data.dataloader import default_collate

import fastcore.all as fc
from fastprogress import progress_bar

from diffusers.schedulers import DDIMScheduler

from diffusers import UNet2DModel


Expand Down Expand Up @@ -49,6 +56,37 @@ def get_unet_params(model_name="unet_small", num_frames=4):
else:
raise(f"Model name not found: {model_name}, choose between 'unet_small' or 'unet_big'")

def init_ddpm(model):
"From Jeremy's bag of tricks on fastai V2 2023"
for o in model.down_blocks:
for p in o.resnets:
p.conv2.weight.data.zero_()
for p in fc.L(o.downsamplers): init.orthogonal_(p.conv.weight)

for o in model.up_blocks:
for p in o.resnets: p.conv2.weight.data.zero_()

model.conv_out.weight.data.zero_()

@torch.no_grad()
def diffusers_sampler(model, past_frames, sched, **kwargs):
"Using Diffusers built-in samplers"
model.eval()
device = next(model.parameters()).device
new_frame = torch.randn_like(past_frames[:,-1:], dtype=past_frames.dtype, device=device)
preds = []
for t in progress_bar(sched.timesteps, leave=False):
noise = model(torch.cat([past_frames, new_frame], dim=1), t)
new_frame = sched.step(noise, t, new_frame, **kwargs).prev_sample
preds.append(new_frame.float().cpu())
return preds[-1]

def ddim_sampler(steps=350, eta=1.):
"DDIM sampler, faster and a bit better than the built-in sampler"
ddim_sched = DDIMScheduler()
ddim_sched.set_timesteps(steps)
return partial(diffusers_sampler, sched=ddim_sched, eta=eta)

class UNet2D(UNet2DModel):
def forward(self, *x, **kwargs):
return super().forward(*x, **kwargs).sample ## Diffusers's UNet2DOutput class
Expand Down
36 changes: 0 additions & 36 deletions cloud_diffusion/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import random, argparse
from pathlib import Path
from functools import partial
import fastcore.all as fc

import wandb
import numpy as np
import torch
from torch import nn
from torch.nn import init

from fastprogress import progress_bar

from diffusers.schedulers import DDIMScheduler

from cloud_diffusion.wandb import log_images, save_model


Expand Down Expand Up @@ -101,37 +96,6 @@ def ls(path: Path):
"Return files on Path, sorted"
return sorted(list(path.iterdir()))

def init_ddpm(model):
"From Jeremy's bag of tricks on fastai V2 2023"
for o in model.down_blocks:
for p in o.resnets:
p.conv2.weight.data.zero_()
for p in fc.L(o.downsamplers): init.orthogonal_(p.conv.weight)

for o in model.up_blocks:
for p in o.resnets: p.conv2.weight.data.zero_()

model.conv_out.weight.data.zero_()

@torch.no_grad()
def diffusers_sampler(model, past_frames, sched, **kwargs):
"Using Diffusers built-in samplers"
model.eval()
device = next(model.parameters()).device
new_frame = torch.randn_like(past_frames[:,-1:], dtype=past_frames.dtype, device=device)
preds = []
for t in progress_bar(sched.timesteps, leave=False):
noise = model(torch.cat([past_frames, new_frame], dim=1), t)
new_frame = sched.step(noise, t, new_frame, **kwargs).prev_sample
preds.append(new_frame.float().cpu())
return preds[-1]

def ddim_sampler(steps=350, eta=1.):
"DDIM sampler, faster and a bit better than the built-in sampler"
ddim_sched = DDIMScheduler()
ddim_sched.set_timesteps(steps)
return partial(diffusers_sampler, sched=ddim_sched, eta=eta)


def parse_args(config):
"A brute force way to parse arguments, it is probably not a good idea to use it"
Expand Down
6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@


from cloud_diffusion.dataset import download_dataset, CloudDataset
from cloud_diffusion.utils import (
MiniTrainer, init_ddpm,
ddim_sampler, set_seed, parse_args)
from cloud_diffusion.ddpm import collate_ddpm, get_unet_params, UNet2D
from cloud_diffusion.utils import MiniTrainer, set_seed, parse_args
from cloud_diffusion.ddpm import collate_ddpm, get_unet_params, UNet2D, init_ddpm, ddim_sampler


PROJECT_NAME = "ddpm_clouds"
Expand Down
1 change: 1 addition & 0 deletions train_uvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def train_func(config):

trainer = MiniTrainer(train_dataloader, valid_dataloader, model, optimizer, scheduler,
sampler, device, loss_func)
wandb.config.update(config)
trainer.fit(config)

if __name__=="__main__":
Expand Down

0 comments on commit 49386a9

Please sign in to comment.