From 57cd336d5add6c9b0a2dbdb346e2d6bd6d2b524a Mon Sep 17 00:00:00 2001 From: Chenyang Yuan Date: Tue, 16 Apr 2024 23:29:10 -0400 Subject: [PATCH] Added DiT tests and bumped version --- pyproject.toml | 2 +- tests/test_diffusion.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 561fb81..d525a39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smalldiffusion" -version = "0.2" +version = "0.3" description = "A minimal but functional implementation of diffusion model training and sampling" readme = "README.md" requires-python = ">=3.10" diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index 5aa00d9..a79f1b3 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -175,7 +175,7 @@ def test_swissroll(self): trainer = training_loop(loader, model, schedule, epochs=epochs, lr=1e-3, accelerator=accelerator) - # Mainly to test that model trains without erroe + # Mainly to test that model trains without error losses = [ns.loss.item() for ns in trainer] self.assertEqual(len(losses), epochs) @@ -183,3 +183,12 @@ def test_swissroll(self): *_, sample = samples(model, schedule.sample_sigmas(sample_steps), gam=1, batchsize=B//2, accelerator=accelerator) self.assertEqual(sample.shape, (B//2, 2)) + +class TestDiT(unittest.TestCase): + def test_basic_setup(self): + # Just testing that model creation and forward pass works + model = DiT(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6) + x = torch.randn(10, 3, 16, 16) + sigma = torch.tensor(1) + y = model(x, sigma) + self.assertEqual(y.shape, x.shape)