Skip to content

Commit

Permalink
Added conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Jul 17, 2024
1 parent 7a80531 commit 693299a
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 54 deletions.
5 changes: 4 additions & 1 deletion src/smalldiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,8 @@
)

from .model import (
TimeInputMLP, ModelMixin, get_sigma_embeds, IdealDenoiser, DiT
ModelMixin, Scaled, PredX0, PredV,
TimeInputMLP, IdealDenoiser, DiT,
get_sigma_embeds, SigmaEmbedderSinCos,
CondEmbedderLabel
)
48 changes: 28 additions & 20 deletions src/smalldiffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from types import SimpleNamespace
from typing import Optional
from typing import Optional, Union, Tuple

class Schedule:
'''Diffusion noise schedules parameterized by sigma'''
Expand Down Expand Up @@ -73,12 +73,15 @@ def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02,
# Given a batch of data x0, returns:
# eps : i.i.d. normal with same shape as x0
# sigma: uniformly sampled from schedule, with shape Bx1x..x1 for broadcasting
def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
def generate_train_sample(x0: Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]],
schedule: Schedule, conditional: bool=False):
cond = x0[1] if conditional else None
x0 = x0[0] if conditional else x0
sigma = schedule.sample_batch(x0)
while len(sigma.shape) < len(x0.shape):
sigma = sigma.unsqueeze(-1)
eps = torch.randn_like(x0)
return sigma, eps
return x0, sigma, eps, cond

# Model objects
# Always called with (x, sigma):
Expand All @@ -87,20 +90,22 @@ def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
# Otherwise, x[i] will be paired with sigma[i] when calling model
# Have a `rand_input` method for generating random xt during sampling

def training_loop(loader : DataLoader,
model : nn.Module,
schedule : Schedule,
accelerator: Optional[Accelerator] = None,
epochs : int = 10000,
lr : float = 1e-3):
def training_loop(loader : DataLoader,
model : nn.Module,
schedule : Schedule,
accelerator : Optional[Accelerator] = None,
epochs : int = 10000,
lr : float = 1e-3,
conditional : bool = False):
accelerator = accelerator or Accelerator()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
for _ in (pbar := tqdm(range(epochs))):
for x0 in loader:
model.train()
optimizer.zero_grad()
sigma, eps = generate_train_sample(x0, schedule)
loss = model.get_loss(x0, sigma, eps)
x0, sigma, eps, cond = generate_train_sample(x0, schedule, conditional)
loss = model.get_loss(x0, sigma, eps, cond=cond)
yield SimpleNamespace(**locals()) # For extracting training statistics
accelerator.backward(loss)
optimizer.step()
Expand All @@ -114,19 +119,22 @@ def samples(model : nn.Module,
sigmas : torch.FloatTensor, # Iterable with N+1 values for N sampling steps
gam : float = 1., # Suggested to use gam >= 1
mu : float = 0., # Requires mu in [0, 1)
cfg_scale : int = 0., # 0 means no classifier-free guidance
batchsize : int = 1,
xt : Optional[torch.FloatTensor] = None,
accelerator: Optional[Accelerator] = None,
batchsize : int = 1):
cond : Optional[torch.Tensor] = None,
accelerator: Optional[Accelerator] = None):
accelerator = accelerator or Accelerator()
if xt is None:
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0]
else:
batchsize = xt.shape[0]
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
if cfg_scale > 0:
assert cond is not None and cond.shape[0] == xt.shape[0], 'cond must have same shape as x!'
cond = cond.to(xt.device)
eps = None
for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
eps, eps_prev = model.predict_eps(xt, sig.to(xt)), eps
model.eval()
eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale)
eps_av = eps * gam + eps_prev * (1-gam) if i > 0 else eps
sig_p = (sig_prev/sig**mu)**(1/(1-mu)) # sig_prev == sig**mu sig_p**(1-mu)
eta = (sig_prev**2 - sig_p**2).sqrt()
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(batchsize).to(xt)
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)
yield xt
85 changes: 61 additions & 24 deletions src/smalldiffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ def rand_input(self, batchsize):
return torch.randn((batchsize,) + self.input_dims)

# Currently predicts eps, override following methods to predict, for example, x0
def get_loss(self, x0, sigma, eps):
return nn.MSELoss()(eps, self(x0 + sigma * eps, sigma))

def predict_eps(self, x, sigma):
return self(x, sigma)
def get_loss(self, x0, sigma, eps, cond=None):
return nn.MSELoss()(eps, self(x0 + sigma * eps, sigma, cond=cond))

def predict_eps(self, x, sigma, cond=None):
return self(x, sigma, cond=cond)

def predict_eps_cfg(self, x, sigma, cond, cfg_scale):
if cond is None:
return self.predict_eps(x, sigma)
uncond = torch.full_like(cond, self.cond_embed.null_cond) # (B,)
if cfg_scale == 0:
return self.predict_eps(x, sigma, cond=uncond)
assert sigma.shape == tuple(), 'CFG sampling only supports singleton sigma!'
eps_cond, eps_uncond = self.predict_eps( # (B,), (B,)
torch.cat([x, x]), sigma, torch.cat([cond, uncond]) # (2B,)
).chunk(2)
return eps_cond + cfg_scale * (eps_cond - eps_uncond)

## Modifiers for models, such as including scaling or changing model predictions

Expand All @@ -39,28 +51,28 @@ def alpha(sigma):

# Scale model input so that its norm stays constant for all sigma
def Scaled(cls: ModelMixin):
def forward(self, x, sigma):
return cls.forward(self, x * alpha(sigma).sqrt(), sigma)
def forward(self, x, sigma, cond=None):
return cls.forward(self, x * alpha(sigma).sqrt(), sigma, cond=cond)
return type(cls.__name__ + 'Scaled', (cls,), dict(forward=forward))

# Train model to predict x0 instead of eps
def PredX0(cls: ModelMixin):
def get_loss(self, x0, sigma, eps):
return nn.MSELoss()(x0, self(x0 + sigma * eps, sigma))
def predict_eps(self, x, sigma):
x0_hat = self(x, sigma)
def get_loss(self, x0, sigma, eps, cond=None):
return nn.MSELoss()(x0, self(x0 + sigma * eps, sigma, cond=cond))
def predict_eps(self, x, sigma, cond=None):
x0_hat = self(x, sigma, cond=cond)
return (x - x0_hat)/sigma
return type(cls.__name__ + 'PredX0', (cls,),
dict(get_loss=get_loss, predict_eps=predict_eps))

# Train model to predict v (https://arxiv.org/pdf/2202.00512.pdf) instead of eps
def PredV(cls: ModelMixin):
def get_loss(self, x0, sigma, eps):
def get_loss(self, x0, sigma, eps, cond=None):
xt = x0 + sigma * eps
v = alpha(sigma).sqrt() * eps - (1-alpha(sigma)).sqrt() * x0
return nn.MSELoss()(v, self(xt, sigma))
def predict_eps(self, x, sigma):
v_hat = self(x, sigma)
return nn.MSELoss()(v, self(xt, sigma, cond=cond))
def predict_eps(self, x, sigma, cond=None):
v_hat = self(x, sigma, cond=cond)
return alpha(sigma).sqrt() * (v_hat + (1-alpha(sigma)).sqrt() * x)
return type(cls.__name__ + 'PredV', (cls,),
dict(get_loss=get_loss, predict_eps=predict_eps))
Expand All @@ -79,7 +91,7 @@ def __init__(self, dim=2, hidden_dims=(16,128,256,128,16)):
self.net = nn.Sequential(*layers)
self.input_dims = (dim,)

def forward(self, x, sigma):
def forward(self, x, sigma, cond=None):
# x shape: b x dim
# sigma shape: b x 1 or scalar
sigma_embeds = get_sigma_embeds(x.shape[0], sigma.squeeze()) # shape: b x 2
Expand Down Expand Up @@ -204,8 +216,9 @@ def get_pos_embed(in_dim, patch_size, dim, N=10000):

class DiT(nn.Module, ModelMixin):
def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12,
head_dim=64, num_heads=6, mlp_ratio=4.0, sig_embed_factor=0.5,
sig_embed_class=None):
head_dim=64, num_heads=6, mlp_ratio=4.0,
sig_embed_class=None, sig_embed_factor=0.5,
cond_embed_class=None, cond_dropout_prob=0.1, cond_num_classes=None):
super().__init__()
self.in_dim = in_dim
self.channels = channels
Expand All @@ -219,6 +232,11 @@ def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12,
self.sig_embed = (sig_embed_class or SigmaEmbedderSinCos)(
dim, scaling_factor=sig_embed_factor
)
self.conditional = cond_embed_class is not None
if self.conditional:
self.cond_embed = cond_embed_class(
dim, num_classes=cond_num_classes, dropout_prob=cond_dropout_prob
)

self.blocks = nn.ModuleList([
DiTBlock(head_dim, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
Expand Down Expand Up @@ -252,16 +270,35 @@ def unpatchify(self, x):
ph=patches, pw=patches,
psh=self.patch_size, psw=self.patch_size)

def forward(self, x, sigma):
# (B, C, H, W), Union[(B, 1, 1, 1), ()] -> (B, C, H, W)
def forward(self, x, sigma, cond=None):
# x: (B, C, H, W), sigma: Union[(B, 1, 1, 1), ()], cond: (B, *)
# returns: (B, C, H, W)
# N = num_patches, D = dim = head_dim * num_heads
x = self.x_embed(x) + self.pos_embed # (B, N, D)
y = self.sig_embed(x.shape[0], sigma.squeeze()) # (B, D)
x = self.x_embed(x) + self.pos_embed # (B, N, D)
y = self.sig_embed(x.shape[0], sigma.squeeze()) # (B, D)
if self.conditional:
assert x.shape[0] == cond.shape[0], 'Conditioning must have same batches as x!'
y += self.cond_embed(cond) # (B, D)
for block in self.blocks:
x = block(x, y) # (B, N, D)
x = self.final_linear(self.final_norm(x, y)) # (B, N, patchsize**2 * channels)
x = block(x, y) # (B, N, D)
x = self.final_linear(self.final_norm(x, y)) # (B, N, patchsize**2 * channels)
return self.unpatchify(x)

# Embedding table for conditioning on labels assumed to be in [0, num_classes),
# unconditional label encoded as: num_classes
class CondEmbedderLabel(nn.Module):
def __init__(self, hidden_size, num_classes, dropout_prob):
super().__init__()
self.embeddings = nn.Embedding(num_classes + 1, hidden_size)
self.null_cond = num_classes
self.dropout_prob = dropout_prob

def forward(self, labels): # (B,) -> (B, D)
if self.training:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
labels = torch.where(drop_ids, self.null_cond, labels)
return self.embeddings(labels)

# A simple embedding that works just as well as usual sinusoidal embedding
class SigmaEmbedderSinCos(nn.Module):
def __init__(self, hidden_size, scaling_factor=0.5):
Expand Down
38 changes: 29 additions & 9 deletions tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
def get_hf_sigmas(scheduler):
return (1/scheduler.alphas_cumprod - 1).sqrt()

class DummyModel(ModelMixin):
class DummyModel(torch.nn.Module, ModelMixin):
def __init__(self, dims):
super().__init__()
self.input_dims = dims

def __call__(self, x, sigma):
def __call__(self, x, sigma, cond=None):
gen = torch.Generator().manual_seed(int(sigma * 100000))
return torch.randn((x.shape[0],) + self.input_dims, generator=gen)

Expand Down Expand Up @@ -184,11 +185,30 @@ def test_swissroll(self):
accelerator=accelerator)
self.assertEqual(sample.shape, (B//2, 2))

# Just testing that model creation and forward pass works
class TestDiT(unittest.TestCase):
def test_basic_setup(self):
# Just testing that model creation and forward pass works
model = DiT(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6)
x = torch.randn(10, 3, 16, 16)
sigma = torch.tensor(1)
y = model(x, sigma)
self.assertEqual(y.shape, x.shape)
def setUp(self):
self.modifiers = [
Scaled, PredX0, PredV,
lambda x: x,
lambda x: Scaled(PredX0(x)),
lambda x: Scaled(PredV(x))
]

def test_uncond(self):
for modifier in self.modifiers:
model = modifier(DiT)(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6)
x = torch.randn(10, 3, 16, 16)
sigma = torch.tensor(1)
y = model.predict_eps(x, sigma)
self.assertEqual(y.shape, x.shape)

def test_cond(self):
for modifier in self.modifiers:
model = modifier(DiT)(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6,
cond_embed_class=CondEmbedderLabel, cond_num_classes=10)
x = torch.randn(10, 3, 16, 16)
sigma = torch.tensor(1)
labels = torch.tensor([1,2,3,4,5] + [10]*5)
y = model.predict_eps_cfg(x, sigma, cond=labels, cfg_scale=4.0)
self.assertEqual(y.shape, x.shape)

0 comments on commit 693299a

Please sign in to comment.