Skip to content

Commit 1f9fd71

Browse files
committed
added codes
1 parent 33dfc8e commit 1f9fd71

File tree

2 files changed

+199
-11
lines changed

2 files changed

+199
-11
lines changed

doc/pub/week15/ipynb/week15.ipynb

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,10 @@
691691
"id": "53f1e714",
692692
"metadata": {
693693
"collapsed": false,
694-
"editable": true
694+
"editable": true,
695+
"jupyter": {
696+
"outputs_hidden": false
697+
}
695698
},
696699
"outputs": [],
697700
"source": [
@@ -722,7 +725,10 @@
722725
"id": "e1c56c04",
723726
"metadata": {
724727
"collapsed": false,
725-
"editable": true
728+
"editable": true,
729+
"jupyter": {
730+
"outputs_hidden": false
731+
}
726732
},
727733
"outputs": [],
728734
"source": [
@@ -759,7 +765,10 @@
759765
"id": "21777496",
760766
"metadata": {
761767
"collapsed": false,
762-
"editable": true
768+
"editable": true,
769+
"jupyter": {
770+
"outputs_hidden": false
771+
}
763772
},
764773
"outputs": [],
765774
"source": [
@@ -790,7 +799,10 @@
790799
"id": "769872cf",
791800
"metadata": {
792801
"collapsed": false,
793-
"editable": true
802+
"editable": true,
803+
"jupyter": {
804+
"outputs_hidden": false
805+
}
794806
},
795807
"outputs": [],
796808
"source": [
@@ -838,7 +850,10 @@
838850
"id": "34b144b0",
839851
"metadata": {
840852
"collapsed": false,
841-
"editable": true
853+
"editable": true,
854+
"jupyter": {
855+
"outputs_hidden": false
856+
}
842857
},
843858
"outputs": [],
844859
"source": [
@@ -867,7 +882,10 @@
867882
"id": "468b3848",
868883
"metadata": {
869884
"collapsed": false,
870-
"editable": true
885+
"editable": true,
886+
"jupyter": {
887+
"outputs_hidden": false
888+
}
871889
},
872890
"outputs": [],
873891
"source": [
@@ -897,9 +915,28 @@
897915
"id": "2b837425",
898916
"metadata": {
899917
"collapsed": false,
900-
"editable": true
918+
"editable": true,
919+
"jupyter": {
920+
"outputs_hidden": false
921+
}
901922
},
902-
"outputs": [],
923+
"outputs": [
924+
{
925+
"ename": "RuntimeError",
926+
"evalue": "The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 1",
927+
"output_type": "error",
928+
"traceback": [
929+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
930+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
931+
"Cell \u001b[0;32mIn[7], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m x, _ \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[1;32m 7\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 8\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mdiffusion_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m opt\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 10\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
932+
"Cell \u001b[0;32mIn[6], line 7\u001b[0m, in \u001b[0;36mdiffusion_loss\u001b[0;34m(model, x0)\u001b[0m\n\u001b[1;32m 5\u001b[0m noise \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn_like(x0)\n\u001b[1;32m 6\u001b[0m x_noisy \u001b[38;5;241m=\u001b[39m q_sample(x0, t, noise)\n\u001b[0;32m----> 7\u001b[0m pred_noise \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43mT\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mmse_loss(pred_noise, noise)\n",
933+
"File \u001b[0;32m~/miniforge3/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
934+
"File \u001b[0;32m~/miniforge3/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
935+
"Cell \u001b[0;32mIn[4], line 24\u001b[0m, in \u001b[0;36mSimpleUNet.forward\u001b[0;34m(self, x, t)\u001b[0m\n\u001b[1;32m 22\u001b[0m temb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_mlp(t) \u001b[38;5;66;03m# [oai_citation:8‡GitHub](https://github.com/tonyduan/diffusion?utm_source=chatgpt.com)\u001b[39;00m\n\u001b[1;32m 23\u001b[0m temb \u001b[38;5;241m=\u001b[39m temb\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 24\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[43mh\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtemb\u001b[49m\n\u001b[1;32m 25\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdec1(h))\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdec2(h)\n",
936+
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 1"
937+
]
938+
}
939+
],
903940
"source": [
904941
"model = SimpleUNet(channels).to(device)\n",
905942
"opt = torch.optim.Adam(model.parameters(), lr=lr)\n",
@@ -932,9 +969,29 @@
932969
"id": "62e04251",
933970
"metadata": {
934971
"collapsed": false,
935-
"editable": true
972+
"editable": true,
973+
"jupyter": {
974+
"outputs_hidden": false
975+
}
936976
},
937-
"outputs": [],
977+
"outputs": [
978+
{
979+
"ename": "RuntimeError",
980+
"evalue": "The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 1",
981+
"output_type": "error",
982+
"traceback": [
983+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
984+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
985+
"Cell \u001b[0;32mIn[8], line 20\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# Generate samples\u001b[39;00m\n\u001b[0;32m---> 20\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43mp_sample_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimg_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimg_size\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m samples \u001b[38;5;241m=\u001b[39m samples\u001b[38;5;241m.\u001b[39mclamp(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcpu()\n\u001b[1;32m 22\u001b[0m grid \u001b[38;5;241m=\u001b[39m torchvision\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mmake_grid(samples, nrow\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
986+
"File \u001b[0;32m~/miniforge3/envs/myenv/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
987+
"Cell \u001b[0;32mIn[8], line 6\u001b[0m, in \u001b[0;36mp_sample_loop\u001b[0;34m(model, shape)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mreversed\u001b[39m(\u001b[38;5;28mrange\u001b[39m(T)):\n\u001b[1;32m 5\u001b[0m t \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfull((shape[\u001b[38;5;241m0\u001b[39m],), i, device\u001b[38;5;241m=\u001b[39mdevice)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m/\u001b[39mT\n\u001b[0;32m----> 6\u001b[0m eps_pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m beta_t \u001b[38;5;241m=\u001b[39m betas[i]\n\u001b[1;32m 8\u001b[0m alpha_t \u001b[38;5;241m=\u001b[39m alphas[i]\n",
988+
"File \u001b[0;32m~/miniforge3/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
989+
"File \u001b[0;32m~/miniforge3/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
990+
"Cell \u001b[0;32mIn[4], line 24\u001b[0m, in \u001b[0;36mSimpleUNet.forward\u001b[0;34m(self, x, t)\u001b[0m\n\u001b[1;32m 22\u001b[0m temb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_mlp(t) \u001b[38;5;66;03m# [oai_citation:8‡GitHub](https://github.com/tonyduan/diffusion?utm_source=chatgpt.com)\u001b[39;00m\n\u001b[1;32m 23\u001b[0m temb \u001b[38;5;241m=\u001b[39m temb\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 24\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[43mh\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtemb\u001b[49m\n\u001b[1;32m 25\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdec1(h))\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdec2(h)\n",
991+
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 1"
992+
]
993+
}
994+
],
938995
"source": [
939996
"@torch.no_grad()\n",
940997
"def p_sample_loop(model, shape):\n",
@@ -978,7 +1035,25 @@
9781035
]
9791036
}
9801037
],
981-
"metadata": {},
1038+
"metadata": {
1039+
"kernelspec": {
1040+
"display_name": "Python 3 (ipykernel)",
1041+
"language": "python",
1042+
"name": "python3"
1043+
},
1044+
"language_info": {
1045+
"codemirror_mode": {
1046+
"name": "ipython",
1047+
"version": 3
1048+
},
1049+
"file_extension": ".py",
1050+
"mimetype": "text/x-python",
1051+
"name": "python",
1052+
"nbconvert_exporter": "python",
1053+
"pygments_lexer": "ipython3",
1054+
"version": "3.9.15"
1055+
}
1056+
},
9821057
"nbformat": 4,
9831058
"nbformat_minor": 5
9841059
}

doc/src/week15/programs/diff.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Data Loading: Uses torchvision’s MNIST loader with basic normalization.
3+
Noise Scheduler: Linearly increases noise over time steps.
4+
Model: A small convolutional network inspired by a U-Net (without skip connections).
5+
Forward Diffusion: Adds noise based on a given timestep.
6+
Training Loop: Learns to predict the noise added at each step.
7+
Sampling: Generates new images by reversing the diffusion process.
8+
"""
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.nn.functional as F
13+
from torchvision import datasets, transforms
14+
from torch.utils.data import DataLoader
15+
import numpy as np
16+
import matplotlib.pyplot as plt
17+
18+
# Configurations
19+
image_size = 28
20+
batch_size = 64
21+
num_steps = 1000
22+
device = "cpu"
23+
epochs = 1 # For demonstration
24+
25+
# Dataset
26+
transform = transforms.Compose([
27+
transforms.ToTensor(),
28+
transforms.Lambda(lambda x: x * 2 - 1) # Scale to [-1, 1]
29+
])
30+
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
31+
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
32+
33+
# Linear noise schedule
34+
beta_start, beta_end = 1e-4, 0.02
35+
betas = torch.linspace(beta_start, beta_end, num_steps).to(device)
36+
alphas = 1. - betas
37+
alphas_cumprod = torch.cumprod(alphas, axis=0)
38+
39+
# Simple convolutional model (mini U-Net)
40+
class SimpleUNet(nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
self.enc1 = nn.Conv2d(1, 32, 3, padding=1)
44+
self.enc2 = nn.Conv2d(32, 64, 3, padding=1)
45+
self.dec1 = nn.ConvTranspose2d(64, 32, 3, padding=1)
46+
self.out = nn.ConvTranspose2d(32, 1, 3, padding=1)
47+
48+
def forward(self, x, t):
49+
t_embed = t[:, None, None, None].float() / num_steps
50+
t_embed = t_embed.expand(x.shape)
51+
x = torch.cat([x, t_embed], dim=1)
52+
x1 = F.relu(self.enc1(x[:, :1])) # Only image through conv
53+
x2 = F.relu(self.enc2(x1))
54+
x3 = F.relu(self.dec1(x2))
55+
return self.out(x3)
56+
57+
model = SimpleUNet().to(device)
58+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
59+
60+
# Forward diffusion process
61+
def q_sample(x0, t, noise=None):
62+
if noise is None:
63+
noise = torch.randn_like(x0)
64+
sqrt_alpha_bar = torch.sqrt(alphas_cumprod[t])[:, None, None, None]
65+
sqrt_one_minus = torch.sqrt(1 - alphas_cumprod[t])[:, None, None, None]
66+
return sqrt_alpha_bar * x0 + sqrt_one_minus * noise
67+
68+
# Training
69+
def train():
70+
model.train()
71+
for epoch in range(epochs):
72+
for batch, (x, _) in enumerate(train_loader):
73+
x = x.to(device)
74+
t = torch.randint(0, num_steps, (x.shape[0],), device=device)
75+
noise = torch.randn_like(x)
76+
x_noisy = q_sample(x, t, noise)
77+
predicted_noise = model(x_noisy, t)
78+
loss = F.mse_loss(predicted_noise, noise)
79+
80+
optimizer.zero_grad()
81+
loss.backward()
82+
optimizer.step()
83+
84+
if batch % 100 == 0:
85+
print(f"Epoch {epoch}, Batch {batch}, Loss: {loss.item():.4f}")
86+
87+
# Reverse sampling
88+
@torch.no_grad()
89+
def sample(model, n=8):
90+
model.eval()
91+
img = torch.randn(n, 1, image_size, image_size).to(device)
92+
for i in reversed(range(num_steps)):
93+
t = torch.full((n,), i, device=device, dtype=torch.long)
94+
predicted_noise = model(img, t)
95+
beta = betas[i]
96+
alpha = alphas[i]
97+
alpha_bar = alphas_cumprod[i]
98+
if i > 0:
99+
noise = torch.randn_like(img)
100+
else:
101+
noise = 0
102+
img = (1 / torch.sqrt(alpha)) * (img - beta / torch.sqrt(1 - alpha_bar) * predicted_noise) + torch.sqrt(beta) * noise
103+
return img
104+
105+
# Run
106+
if __name__ == "__main__":
107+
train()
108+
samples = sample(model, n=8).cpu()
109+
samples = (samples + 1) / 2 # Convert back to [0,1]
110+
grid = torch.cat([s.squeeze(0) for s in samples], dim=1)
111+
plt.imshow(grid, cmap='gray')
112+
plt.axis('off')
113+
plt.show()

0 commit comments

Comments
 (0)