Skip to content

Commit

Permalink
Added DiT tests and bumped version
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Apr 17, 2024
1 parent 4c0e5fe commit 57cd336
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smalldiffusion"
version = "0.2"
version = "0.3"
description = "A minimal but functional implementation of diffusion model training and sampling"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
11 changes: 10 additions & 1 deletion tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,20 @@ def test_swissroll(self):
trainer = training_loop(loader, model, schedule, epochs=epochs, lr=1e-3,
accelerator=accelerator)

# Mainly to test that model trains without erroe
# Mainly to test that model trains without error
losses = [ns.loss.item() for ns in trainer]
self.assertEqual(len(losses), epochs)

# Test sampling
*_, sample = samples(model, schedule.sample_sigmas(sample_steps), gam=1, batchsize=B//2,
accelerator=accelerator)
self.assertEqual(sample.shape, (B//2, 2))

class TestDiT(unittest.TestCase):
def test_basic_setup(self):
# Just testing that model creation and forward pass works
model = DiT(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6)
x = torch.randn(10, 3, 16, 16)
sigma = torch.tensor(1)
y = model(x, sigma)
self.assertEqual(y.shape, x.shape)

0 comments on commit 57cd336

Please sign in to comment.