|
| 1 | +#include <torch/torch.h> |
| 2 | + |
| 3 | +#include <cmath> |
| 4 | +#include <cstdio> |
| 5 | +#include <iostream> |
| 6 | + |
| 7 | +// The size of the noise vector fed to the generator. |
| 8 | +const int64_t kNoiseSize = 100; |
| 9 | + |
| 10 | +// The batch size for training. |
| 11 | +const int64_t kBatchSize = 64; |
| 12 | + |
| 13 | +// The number of epochs to train. |
| 14 | +const int64_t kNumberOfEpochs = 30; |
| 15 | + |
| 16 | +// Where to find the MNIST dataset. |
| 17 | +const char* kDataFolder = "./data"; |
| 18 | + |
| 19 | +// After how many batches to create a new checkpoint periodically. |
| 20 | +const int64_t kCheckpointEvery = 200; |
| 21 | + |
| 22 | +// How many images to sample at every checkpoint. |
| 23 | +const int64_t kNumberOfSamplesPerCheckpoint = 10; |
| 24 | + |
| 25 | +// Set to `true` to restore models and optimizers from previously saved |
| 26 | +// checkpoints. |
| 27 | +const bool kRestoreFromCheckpoint = false; |
| 28 | + |
| 29 | +// After how many batches to log a new update with the loss value. |
| 30 | +const int64_t kLogInterval = 10; |
| 31 | + |
| 32 | +using namespace torch; |
| 33 | + |
| 34 | +int main(int argc, const char* argv[]) { |
| 35 | + torch::manual_seed(1); |
| 36 | + |
| 37 | + // Create the device we pass around based on whether CUDA is available. |
| 38 | + torch::Device device(torch::kCPU); |
| 39 | + if (torch::cuda::is_available()) { |
| 40 | + std::cout << "CUDA is available! Training on GPU." << std::endl; |
| 41 | + device = torch::Device(torch::kCUDA); |
| 42 | + } |
| 43 | + |
| 44 | + nn::Sequential generator( |
| 45 | + // Layer 1 |
| 46 | + nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4) |
| 47 | + .with_bias(false) |
| 48 | + .transposed(true)), |
| 49 | + nn::BatchNorm(256), |
| 50 | + nn::Functional(torch::relu), |
| 51 | + // Layer 2 |
| 52 | + nn::Conv2d(nn::Conv2dOptions(256, 128, 3) |
| 53 | + .stride(2) |
| 54 | + .padding(1) |
| 55 | + .with_bias(false) |
| 56 | + .transposed(true)), |
| 57 | + nn::BatchNorm(128), |
| 58 | + nn::Functional(torch::relu), |
| 59 | + // Layer 3 |
| 60 | + nn::Conv2d(nn::Conv2dOptions(128, 64, 4) |
| 61 | + .stride(2) |
| 62 | + .padding(1) |
| 63 | + .with_bias(false) |
| 64 | + .transposed(true)), |
| 65 | + nn::BatchNorm(64), |
| 66 | + nn::Functional(torch::relu), |
| 67 | + // Layer 4 |
| 68 | + nn::Conv2d(nn::Conv2dOptions(64, 1, 4) |
| 69 | + .stride(2) |
| 70 | + .padding(1) |
| 71 | + .with_bias(false) |
| 72 | + .transposed(true)), |
| 73 | + nn::Functional(torch::tanh)); |
| 74 | + generator->to(device); |
| 75 | + |
| 76 | + nn::Sequential discriminator( |
| 77 | + // Layer 1 |
| 78 | + nn::Conv2d( |
| 79 | + nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)), |
| 80 | + nn::Functional(torch::leaky_relu, 0.2), |
| 81 | + // Layer 2 |
| 82 | + nn::Conv2d( |
| 83 | + nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)), |
| 84 | + nn::BatchNorm(128), |
| 85 | + nn::Functional(torch::leaky_relu, 0.2), |
| 86 | + // Layer 3 |
| 87 | + nn::Conv2d( |
| 88 | + nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)), |
| 89 | + nn::BatchNorm(256), |
| 90 | + nn::Functional(torch::leaky_relu, 0.2), |
| 91 | + // Layer 4 |
| 92 | + nn::Conv2d( |
| 93 | + nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)), |
| 94 | + nn::Functional(torch::sigmoid)); |
| 95 | + discriminator->to(device); |
| 96 | + |
| 97 | + // Assume the MNIST dataset is available under `kDataFolder`; |
| 98 | + auto dataset = torch::data::datasets::MNIST(kDataFolder) |
| 99 | + .map(torch::data::transforms::Normalize<>(0.5, 0.5)) |
| 100 | + .map(torch::data::transforms::Stack<>()); |
| 101 | + const int64_t batches_per_epoch = |
| 102 | + std::ceil(dataset.size().value() / static_cast<double>(kBatchSize)); |
| 103 | + |
| 104 | + auto data_loader = torch::data::make_data_loader( |
| 105 | + std::move(dataset), |
| 106 | + torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2)); |
| 107 | + |
| 108 | + torch::optim::Adam generator_optimizer( |
| 109 | + generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); |
| 110 | + torch::optim::Adam discriminator_optimizer( |
| 111 | + discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); |
| 112 | + |
| 113 | + if (kRestoreFromCheckpoint) { |
| 114 | + torch::load(generator, "generator-checkpoint.pt"); |
| 115 | + torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt"); |
| 116 | + torch::load(discriminator, "discriminator-checkpoint.pt"); |
| 117 | + torch::load( |
| 118 | + discriminator_optimizer, "discriminator-optimizer-checkpoint.pt"); |
| 119 | + } |
| 120 | + |
| 121 | + int64_t checkpoint_counter = 1; |
| 122 | + for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { |
| 123 | + int64_t batch_index = 0; |
| 124 | + for (torch::data::Example<>& batch : *data_loader) { |
| 125 | + // Train discriminator with real images. |
| 126 | + discriminator->zero_grad(); |
| 127 | + torch::Tensor real_images = batch.data.to(device); |
| 128 | + torch::Tensor real_labels = |
| 129 | + torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0); |
| 130 | + torch::Tensor real_output = discriminator->forward(real_images); |
| 131 | + torch::Tensor d_loss_real = |
| 132 | + torch::binary_cross_entropy(real_output, real_labels); |
| 133 | + d_loss_real.backward(); |
| 134 | + |
| 135 | + // Train discriminator with fake images. |
| 136 | + torch::Tensor noise = |
| 137 | + torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device); |
| 138 | + torch::Tensor fake_images = generator->forward(noise); |
| 139 | + torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device); |
| 140 | + torch::Tensor fake_output = discriminator->forward(fake_images.detach()); |
| 141 | + torch::Tensor d_loss_fake = |
| 142 | + torch::binary_cross_entropy(fake_output, fake_labels); |
| 143 | + d_loss_fake.backward(); |
| 144 | + |
| 145 | + torch::Tensor d_loss = d_loss_real + d_loss_fake; |
| 146 | + discriminator_optimizer.step(); |
| 147 | + |
| 148 | + // Train generator. |
| 149 | + generator->zero_grad(); |
| 150 | + fake_labels.fill_(1); |
| 151 | + fake_output = discriminator->forward(fake_images); |
| 152 | + torch::Tensor g_loss = |
| 153 | + torch::binary_cross_entropy(fake_output, fake_labels); |
| 154 | + g_loss.backward(); |
| 155 | + generator_optimizer.step(); |
| 156 | + |
| 157 | + if (batch_index % kLogInterval == 0) { |
| 158 | + std::printf( |
| 159 | + "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f", |
| 160 | + epoch, |
| 161 | + kNumberOfEpochs, |
| 162 | + ++batch_index, |
| 163 | + batches_per_epoch, |
| 164 | + d_loss.item<float>(), |
| 165 | + g_loss.item<float>()); |
| 166 | + } |
| 167 | + |
| 168 | + if (batch_index % kCheckpointEvery == 0) { |
| 169 | + // Checkpoint the model and optimizer state. |
| 170 | + torch::save(generator, "generator-checkpoint.pt"); |
| 171 | + torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt"); |
| 172 | + torch::save(discriminator, "discriminator-checkpoint.pt"); |
| 173 | + torch::save( |
| 174 | + discriminator_optimizer, "discriminator-optimizer-checkpoint.pt"); |
| 175 | + // Sample the generator and save the images. |
| 176 | + torch::Tensor samples = generator->forward(torch::randn( |
| 177 | + {kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device)); |
| 178 | + torch::save( |
| 179 | + (samples + 1.0) / 2.0, |
| 180 | + torch::str("dcgan-sample-", checkpoint_counter, ".pt")); |
| 181 | + std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n'; |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + std::cout << "Training complete!" << std::endl; |
| 187 | +} |
0 commit comments