From c5bcd95d20343452be828d9c82c6abdc93070066 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Mon, 10 Sep 2018 20:45:03 -0700 Subject: [PATCH 1/3] Create C++ version of MNIST example --- .gitignore | 1 + cpp/.clang-format | 88 +++++++++++++++++++ cpp/mnist/CMakeLists.txt | 20 +++++ cpp/mnist/download_mnist.py | 88 +++++++++++++++++++ cpp/mnist/mnist.cpp | 169 ++++++++++++++++++++++++++++++++++++ 5 files changed, 366 insertions(+) create mode 100644 cpp/.clang-format create mode 100644 cpp/mnist/CMakeLists.txt create mode 100644 cpp/mnist/download_mnist.py create mode 100644 cpp/mnist/mnist.cpp diff --git a/.gitignore b/.gitignore index bf65fb8af5..12fdd1bfe1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ dcgan/data data *.pyc OpenNMT/data +cpp/mnist/build diff --git a/cpp/.clang-format b/cpp/.clang-format new file mode 100644 index 0000000000..dd7771d727 --- /dev/null +++ b/cpp/.clang-format @@ -0,0 +1,88 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 2000000 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/cpp/mnist/CMakeLists.txt b/cpp/mnist/CMakeLists.txt new file mode 100644 index 0000000000..b38293972e --- /dev/null +++ b/cpp/mnist/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(mnist) + +find_package(Torch REQUIRED) + +option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) +if (DOWNLOAD_MNIST) + message(STATUS "Downloading MNIST dataset") + execute_process( + COMMAND python ${CMAKE_CURRENT_LIST_DIR}/download_mnist.py + -d ${CMAKE_BINARY_DIR}/data + ERROR_VARIABLE DOWNLOAD_ERROR) + if (DOWNLOAD_ERROR) + message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}") + endif() +endif() + +add_executable(mnist mnist.cpp) +target_compile_features(mnist PUBLIC cxx_range_for) +target_link_libraries(mnist ${TORCH_LIBRARIES}) diff --git a/cpp/mnist/download_mnist.py b/cpp/mnist/download_mnist.py new file mode 100644 index 0000000000..2a5068ffb8 --- /dev/null +++ b/cpp/mnist/download_mnist.py @@ -0,0 +1,88 @@ +from __future__ import division +from __future__ import print_function + +import argparse +import gzip +import os +import sys +import urllib + +try: + from urllib.error import URLError + from urllib.request import urlretrieve +except ImportError: + from urllib2 import URLError + from urllib import urlretrieve + +RESOURCES = [ + 'train-images-idx3-ubyte.gz', + 'train-labels-idx1-ubyte.gz', + 't10k-images-idx3-ubyte.gz', + 't10k-labels-idx1-ubyte.gz', +] + + +def report_download_progress(chunk_number, chunk_size, file_size): + if file_size != -1: + percent = min(1, (chunk_number * chunk_size) / file_size) + bar = '#' * int(64 * percent) + sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100))) + + +def download(destination_path, url, quiet): + if os.path.exists(destination_path): + if not quiet: + print('{} already exists, skipping ...'.format(destination_path)) + else: + print('Downloading {} ...'.format(url)) + try: + hook = None if quiet else report_download_progress + urlretrieve(url, destination_path, reporthook=hook) + except URLError: + raise RuntimeError('Error downloading resource!') + finally: + if not quiet: + # Just a newline. + print() + + +def unzip(zipped_path, quiet): + unzipped_path = os.path.splitext(zipped_path)[0] + if os.path.exists(unzipped_path): + if not quiet: + print('{} already exists, skipping ... '.format(unzipped_path)) + return + with gzip.open(zipped_path, 'rb') as zipped_file: + with open(unzipped_path, 'wb') as unzipped_file: + unzipped_file.write(zipped_file.read()) + if not quiet: + print('Unzipped {} ...'.format(zipped_path)) + + +def main(): + parser = argparse.ArgumentParser( + description='Download the MNIST dataset from the internet') + parser.add_argument( + '-d', '--destination', default='.', help='Destination directory') + parser.add_argument( + '-q', + '--quiet', + action='store_true', + help="Don't report about progress") + options = parser.parse_args() + + if not os.path.exists(options.destination): + os.makedirs(options.destination) + + try: + for resource in RESOURCES: + path = os.path.join(options.destination, resource) + url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource) + download(path, url, options.quiet) + unzip(path, options.quiet) + except KeyboardInterrupt: + print('Interrupted') + + +if __name__ == '__main__': + main() diff --git a/cpp/mnist/mnist.cpp b/cpp/mnist/mnist.cpp new file mode 100644 index 0000000000..419c3792e1 --- /dev/null +++ b/cpp/mnist/mnist.cpp @@ -0,0 +1,169 @@ +#include + +#include +#include +#include +#include + +struct Net : torch::nn::Module { + Net() + : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)), + conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)), + fc1(320, 50), + fc2(50, 10) { + register_module("conv1", conv1); + register_module("conv2", conv2); + register_module("conv2_drop", conv2_drop); + register_module("fc1", fc1); + register_module("fc2", fc2); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); + x = torch::relu( + torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); + x = x.view({-1, 320}); + x = torch::relu(fc1->forward(x)); + x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training()); + x = fc2->forward(x); + return torch::log_softmax(x, /*dim=*/1); + } + + torch::nn::Conv2d conv1; + torch::nn::Conv2d conv2; + torch::nn::FeatureDropout conv2_drop; + torch::nn::Linear fc1; + torch::nn::Linear fc2; +}; + +struct Options { + std::string data_root{"data"}; + int32_t batch_size{64}; + int32_t epochs{10}; + double lr{0.01}; + double momentum{0.5}; + bool no_cuda{false}; + int32_t seed{1}; + int32_t test_batch_size{1000}; + int32_t log_interval{10}; +}; + +template +void train( + int32_t epoch, + const Options& options, + Net& model, + torch::Device device, + DataLoader& data_loader, + torch::optim::Optimizer& optimizer, + size_t dataset_size) { + model.train(); + size_t batch_idx = 0; + for (auto& batch : data_loader) { + auto data = batch.data.to(device), targets = batch.target.to(device); + optimizer.zero_grad(); + auto output = model.forward(data); + auto loss = torch::nll_loss(output, targets); + AT_ASSERT(!std::isnan(loss.template item())); + loss.backward(); + optimizer.step(); + + if (batch_idx++ % options.log_interval == 0) { + std::cout << "Train Epoch: " << epoch << " [" + << batch_idx * batch.data.size(0) << "/" << dataset_size + << "]\tLoss: " << loss.template item() << std::endl; + } + } +} + +template +void test( + Net& model, + torch::Device device, + DataLoader& data_loader, + size_t dataset_size) { + torch::NoGradGuard no_grad; + model.eval(); + double test_loss = 0; + int32_t correct = 0; + for (const auto& batch : data_loader) { + auto data = batch.data.to(device), targets = batch.target.to(device); + auto output = model.forward(data); + test_loss += torch::nll_loss( + output, + targets, + /*weight=*/{}, + Reduction::Sum) + .template item(); + auto pred = output.argmax(1); + correct += pred.eq(targets).sum().template item(); + } + + test_loss /= dataset_size; + std::cout << "Test set: Average loss: " << test_loss + << ", Accuracy: " << static_cast(correct) / dataset_size + << std::endl; +} + +struct Normalize : public torch::data::transforms::TensorTransform<> { + Normalize(float mean, float stddev) + : mean_(torch::tensor(mean)), stddev_(torch::tensor(stddev)) {} + torch::Tensor operator()(torch::Tensor input) { + return input.sub(mean_).div(stddev_); + } + torch::Tensor mean_, stddev_; +}; + +auto main() -> int { + Options options; + + torch::manual_seed(options.seed); + + torch::DeviceType device_type; + if (torch::cuda::is_available() && !options.no_cuda) { + std::cout << "CUDA available! Training on GPU" << std::endl; + device_type = torch::kCUDA; + } else { + std::cout << "Training on CPU" << std::endl; + device_type = torch::kCPU; + } + torch::Device device(device_type); + + Net model; + model.to(device); + + auto train_dataset = + torch::data::datasets::MNIST( + options.data_root, torch::data::datasets::MNIST::Mode::kTrain) + .map(Normalize(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + const size_t train_dataset_size = train_dataset.size().value(); + auto train_loader = + torch::data::make_data_loader( + std::move(train_dataset), options.batch_size); + + auto test_dataset = + torch::data::datasets::MNIST( + options.data_root, torch::data::datasets::MNIST::Mode::kTest) + .map(Normalize(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + const size_t test_dataset_size = test_dataset.size().value(); + auto test_loader = torch::data::make_data_loader( + std::move(test_dataset), options.batch_size); + + torch::optim::SGD optimizer( + model.parameters(), + torch::optim::SGDOptions(options.lr).momentum(options.momentum)); + + for (size_t epoch = 1; epoch <= options.epochs; ++epoch) { + train( + epoch, + options, + model, + device, + *train_loader, + optimizer, + train_dataset_size); + test(model, device, *test_loader, test_dataset_size); + } +} From f9ddd9f380b6092631be083d6d8473fc6e50254f Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Wed, 9 Jan 2019 15:46:24 -0800 Subject: [PATCH 2/3] Create C++ version of DCGAN example --- .gitignore | 1 + cpp/.clang-format | 2 +- cpp/dcgan/CMakeLists.txt | 20 +++ cpp/dcgan/README.md | 56 ++++++++ cpp/dcgan/dcgan.cpp | 187 +++++++++++++++++++++++++ cpp/dcgan/display_samples.py | 28 ++++ cpp/mnist/CMakeLists.txt | 2 +- cpp/mnist/README.md | 35 +++++ cpp/mnist/mnist.cpp | 96 ++++++------- cpp/{mnist => tools}/download_mnist.py | 0 10 files changed, 374 insertions(+), 53 deletions(-) create mode 100644 cpp/dcgan/CMakeLists.txt create mode 100644 cpp/dcgan/README.md create mode 100644 cpp/dcgan/dcgan.cpp create mode 100644 cpp/dcgan/display_samples.py create mode 100644 cpp/mnist/README.md rename cpp/{mnist => tools}/download_mnist.py (100%) diff --git a/.gitignore b/.gitignore index 12fdd1bfe1..5c944a7acc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ data *.pyc OpenNMT/data cpp/mnist/build +cpp/dcgan/build diff --git a/cpp/.clang-format b/cpp/.clang-format index dd7771d727..73304266bd 100644 --- a/cpp/.clang-format +++ b/cpp/.clang-format @@ -10,7 +10,7 @@ AllowAllParametersOfDeclarationOnNextLine: false AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: Empty -AllowShortIfStatementsOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false AllowShortLoopsOnASingleLine: false AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: true diff --git a/cpp/dcgan/CMakeLists.txt b/cpp/dcgan/CMakeLists.txt new file mode 100644 index 0000000000..568ebbd2a3 --- /dev/null +++ b/cpp/dcgan/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.0 FATAL_ERROR) +project(dcgan) + +find_package(Torch REQUIRED) + +option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) +if (DOWNLOAD_MNIST) + message(STATUS "Downloading MNIST dataset") + execute_process( + COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py + -d ${CMAKE_BINARY_DIR}/data + ERROR_VARIABLE DOWNLOAD_ERROR) + if (DOWNLOAD_ERROR) + message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}") + endif() +endif() + +add_executable(dcgan dcgan.cpp) +target_link_libraries(dcgan "${TORCH_LIBRARIES}") +set_property(TARGET dcgan PROPERTY CXX_STANDARD 11) diff --git a/cpp/dcgan/README.md b/cpp/dcgan/README.md new file mode 100644 index 0000000000..c992a6c532 --- /dev/null +++ b/cpp/dcgan/README.md @@ -0,0 +1,56 @@ +# DCGAN Example with the PyTorch C++ Frontend + +This folder contains an example of training a DCGAN to generate MNIST digits +with the PyTorch C++ frontend. + +The entire training code is contained in `dcgan.cpp`. + +To build the code, run the following commands from your terminal: + +```shell +$ cd dcgan +$ mkdir build +$ cd build +$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +$ make +``` + +where `/path/to/libtorch` should be the path to the unzipped *LibTorch* +distribution, which you can get from the [PyTorch +homepage](https://pytorch.org/get-started/locally/). + +Execute the compiled binary to train the model: + +```shell +$ ./dcgan +[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195 +-> checkpoint 1 +[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148 +-> checkpoint 2 +[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760 +-> checkpoint 3 +[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250 +-> checkpoint 4 +[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790 +-> checkpoint 5 +[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315 +-> checkpoint 6 +[ 2/30][600/938] D_loss: 0.3815 | G_loss: 3.5696 +-> checkpoint 7 +[ 2/30][800/938] D_loss: 0.4039 | G_loss: 3.2759 +-> checkpoint 8 +[ 3/30][200/938] D_loss: 0.4236 | G_loss: 4.5132 +-> checkpoint 9 +[ 3/30][400/938] D_loss: 0.3645 | G_loss: 3.9759 +-> checkpoint 10 +... +``` + +The training script periodically generates image samples. Use the +`display_samples.py` script situated in this folder to generate a plot image. +For example: + +```shell +$ python display_samples.py -i dcgan-sample-10.png +Saved out.png +``` diff --git a/cpp/dcgan/dcgan.cpp b/cpp/dcgan/dcgan.cpp new file mode 100644 index 0000000000..f747fce54a --- /dev/null +++ b/cpp/dcgan/dcgan.cpp @@ -0,0 +1,187 @@ +#include + +#include +#include +#include + +// The size of the noise vector fed to the generator. +const int64_t kNoiseSize = 100; + +// The batch size for training. +const int64_t kBatchSize = 64; + +// The number of epochs to train. +const int64_t kNumberOfEpochs = 30; + +// Where to find the MNIST dataset. +const char* kDataFolder = "./data"; + +// After how many batches to create a new checkpoint periodically. +const int64_t kCheckpointEvery = 200; + +// How many images to sample at every checkpoint. +const int64_t kNumberOfSamplesPerCheckpoint = 10; + +// Set to `true` to restore models and optimizers from previously saved +// checkpoints. +const bool kRestoreFromCheckpoint = false; + +// After how many batches to log a new update with the loss value. +const int64_t kLogInterval = 10; + +using namespace torch; + +int main(int argc, const char* argv[]) { + torch::manual_seed(1); + + // Create the device we pass around based on whether CUDA is available. + torch::Device device(torch::kCPU); + if (torch::cuda::is_available()) { + std::cout << "CUDA is available! Training on GPU." << std::endl; + device = torch::Device(torch::kCUDA); + } + + nn::Sequential generator( + // Layer 1 + nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4) + .with_bias(false) + .transposed(true)), + nn::BatchNorm(256), + nn::Functional(torch::relu), + // Layer 2 + nn::Conv2d(nn::Conv2dOptions(256, 128, 3) + .stride(2) + .padding(1) + .with_bias(false) + .transposed(true)), + nn::BatchNorm(128), + nn::Functional(torch::relu), + // Layer 3 + nn::Conv2d(nn::Conv2dOptions(128, 64, 4) + .stride(2) + .padding(1) + .with_bias(false) + .transposed(true)), + nn::BatchNorm(64), + nn::Functional(torch::relu), + // Layer 4 + nn::Conv2d(nn::Conv2dOptions(64, 1, 4) + .stride(2) + .padding(1) + .with_bias(false) + .transposed(true)), + nn::Functional(torch::tanh)); + generator->to(device); + + nn::Sequential discriminator( + // Layer 1 + nn::Conv2d( + nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)), + nn::Functional(torch::leaky_relu, 0.2), + // Layer 2 + nn::Conv2d( + nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)), + nn::BatchNorm(128), + nn::Functional(torch::leaky_relu, 0.2), + // Layer 3 + nn::Conv2d( + nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)), + nn::BatchNorm(256), + nn::Functional(torch::leaky_relu, 0.2), + // Layer 4 + nn::Conv2d( + nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)), + nn::Functional(torch::sigmoid)); + discriminator->to(device); + + // Assume the MNIST dataset is available under `kDataFolder`; + auto dataset = torch::data::datasets::MNIST(kDataFolder) + .map(torch::data::transforms::Normalize(0.5, 0.5)) + .map(torch::data::transforms::Stack<>()); + const int64_t batches_per_epoch = + std::ceil(dataset.size().value() / static_cast(kBatchSize)); + + auto data_loader = torch::data::make_data_loader( + std::move(dataset), + torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2)); + + torch::optim::Adam generator_optimizer( + generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); + torch::optim::Adam discriminator_optimizer( + discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); + + if (kRestoreFromCheckpoint) { + torch::load(generator, "generator-checkpoint.pt"); + torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt"); + torch::load(discriminator, "discriminator-checkpoint.pt"); + torch::load( + discriminator_optimizer, "discriminator-optimizer-checkpoint.pt"); + } + + int64_t checkpoint_counter = 1; + for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { + int64_t batch_index = 0; + for (torch::data::Example<>& batch : *data_loader) { + // Train discriminator with real images. + discriminator->zero_grad(); + torch::Tensor real_images = batch.data.to(device); + torch::Tensor real_labels = + torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0); + torch::Tensor real_output = discriminator->forward(real_images); + torch::Tensor d_loss_real = + torch::binary_cross_entropy(real_output, real_labels); + d_loss_real.backward(); + + // Train discriminator with fake images. + torch::Tensor noise = + torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device); + torch::Tensor fake_images = generator->forward(noise); + torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device); + torch::Tensor fake_output = discriminator->forward(fake_images.detach()); + torch::Tensor d_loss_fake = + torch::binary_cross_entropy(fake_output, fake_labels); + d_loss_fake.backward(); + + torch::Tensor d_loss = d_loss_real + d_loss_fake; + discriminator_optimizer.step(); + + // Train generator. + generator->zero_grad(); + fake_labels.fill_(1); + fake_output = discriminator->forward(fake_images); + torch::Tensor g_loss = + torch::binary_cross_entropy(fake_output, fake_labels); + g_loss.backward(); + generator_optimizer.step(); + + if (batch_index % kLogInterval == 0) { + std::printf( + "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f", + epoch, + kNumberOfEpochs, + ++batch_index, + batches_per_epoch, + d_loss.item(), + g_loss.item()); + } + + if (batch_index % kCheckpointEvery == 0) { + // Checkpoint the model and optimizer state. + torch::save(generator, "generator-checkpoint.pt"); + torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt"); + torch::save(discriminator, "discriminator-checkpoint.pt"); + torch::save( + discriminator_optimizer, "discriminator-optimizer-checkpoint.pt"); + // Sample the generator and save the images. + torch::Tensor samples = generator->forward(torch::randn( + {kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device)); + torch::save( + (samples + 1.0) / 2.0, + torch::str("dcgan-sample-", checkpoint_counter, ".pt")); + std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n'; + } + } + } + + std::cout << "Training complete!" << std::endl; +} diff --git a/cpp/dcgan/display_samples.py b/cpp/dcgan/display_samples.py new file mode 100644 index 0000000000..c1f9e9c3d9 --- /dev/null +++ b/cpp/dcgan/display_samples.py @@ -0,0 +1,28 @@ +from __future__ import print_function +from __future__ import unicode_literals + +import argparse + +import matplotlib.pyplot as plt +import torch + + +parser = argparse.ArgumentParser() +parser.add_argument("-i", "--sample-file", required=True) +parser.add_argument("-o", "--out-file", default="out.png") +parser.add_argument("-d", "--dimension", type=int, default=3) +options = parser.parse_args() + +module = torch.jit.load(options.sample_file) +images = list(module.parameters())[0] + +for index in range(options.dimension * options.dimension): + image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8) + array = image.numpy() + axis = plt.subplot(options.dimension, options.dimension, 1 + index) + plt.imshow(array, cmap="gray") + axis.get_xaxis().set_visible(False) + axis.get_yaxis().set_visible(False) + +plt.savefig(options.out_file) +print("Saved ", options.out_file) diff --git a/cpp/mnist/CMakeLists.txt b/cpp/mnist/CMakeLists.txt index b38293972e..ad1f8ad567 100644 --- a/cpp/mnist/CMakeLists.txt +++ b/cpp/mnist/CMakeLists.txt @@ -7,7 +7,7 @@ option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) if (DOWNLOAD_MNIST) message(STATUS "Downloading MNIST dataset") execute_process( - COMMAND python ${CMAKE_CURRENT_LIST_DIR}/download_mnist.py + COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py -d ${CMAKE_BINARY_DIR}/data ERROR_VARIABLE DOWNLOAD_ERROR) if (DOWNLOAD_ERROR) diff --git a/cpp/mnist/README.md b/cpp/mnist/README.md new file mode 100644 index 0000000000..5a773a4368 --- /dev/null +++ b/cpp/mnist/README.md @@ -0,0 +1,35 @@ +# MNIST Example with the PyTorch C++ Frontend + +This folder contains an example of training a computer vision model to recognize +digits in images from the MNIST dataset, using the PyTorch C++ frontend. + +The entire training code is contained in `mnist.cpp`. + +To build the code, run the following commands from your terminal: + +```shell +$ cd mnist +$ mkdir build +$ cd build +$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +$ make +``` + +where `/path/to/libtorch` should be the path to the unzipped *LibTorch* +distribution, which you can get from the [PyTorch +homepage](https://pytorch.org/get-started/locally/). + +Execute the compiled binary to train the model: + +```shell +$ ./mnist +Train Epoch: 1 [59584/60000] Loss: 0.4232 +Test set: Average loss: 0.1989 | Accuracy: 0.940 +Train Epoch: 2 [59584/60000] Loss: 0.1926 +Test set: Average loss: 0.1338 | Accuracy: 0.959 +Train Epoch: 3 [59584/60000] Loss: 0.1390 +Test set: Average loss: 0.0997 | Accuracy: 0.969 +Train Epoch: 4 [59584/60000] Loss: 0.1239 +Test set: Average loss: 0.0875 | Accuracy: 0.972 +... +``` diff --git a/cpp/mnist/mnist.cpp b/cpp/mnist/mnist.cpp index 419c3792e1..880abcb82d 100644 --- a/cpp/mnist/mnist.cpp +++ b/cpp/mnist/mnist.cpp @@ -1,10 +1,26 @@ #include #include +#include #include #include #include +// Where to find the MNIST dataset. +const char* kDataRoot = "./data"; + +// The batch size for training. +const int64_t kTrainBatchSize = 64; + +// The batch size for testing. +const int64_t kTestBatchSize = 1000; + +// The number of epochs to train. +const int64_t kNumberOfEpochs = 10; + +// After how many batches to log a new update with the loss value. +const int64_t kLogInterval = 10; + struct Net : torch::nn::Module { Net() : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)), @@ -36,22 +52,9 @@ struct Net : torch::nn::Module { torch::nn::Linear fc2; }; -struct Options { - std::string data_root{"data"}; - int32_t batch_size{64}; - int32_t epochs{10}; - double lr{0.01}; - double momentum{0.5}; - bool no_cuda{false}; - int32_t seed{1}; - int32_t test_batch_size{1000}; - int32_t log_interval{10}; -}; - template void train( int32_t epoch, - const Options& options, Net& model, torch::Device device, DataLoader& data_loader, @@ -68,10 +71,13 @@ void train( loss.backward(); optimizer.step(); - if (batch_idx++ % options.log_interval == 0) { - std::cout << "Train Epoch: " << epoch << " [" - << batch_idx * batch.data.size(0) << "/" << dataset_size - << "]\tLoss: " << loss.template item() << std::endl; + if (batch_idx++ % kLogInterval == 0) { + std::printf( + "\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f", + epoch, + batch_idx * batch.data.size(0), + dataset_size, + loss.template item()); } } } @@ -100,9 +106,10 @@ void test( } test_loss /= dataset_size; - std::cout << "Test set: Average loss: " << test_loss - << ", Accuracy: " << static_cast(correct) / dataset_size - << std::endl; + std::printf( + "\nTest set: Average loss: %.4f | Accuracy: %.3f\n", + test_loss, + static_cast(correct) / dataset_size); } struct Normalize : public torch::data::transforms::TensorTransform<> { @@ -115,16 +122,14 @@ struct Normalize : public torch::data::transforms::TensorTransform<> { }; auto main() -> int { - Options options; - - torch::manual_seed(options.seed); + torch::manual_seed(1); torch::DeviceType device_type; - if (torch::cuda::is_available() && !options.no_cuda) { - std::cout << "CUDA available! Training on GPU" << std::endl; + if (torch::cuda::is_available()) { + std::cout << "CUDA available! Training on GPU." << std::endl; device_type = torch::kCUDA; } else { - std::cout << "Training on CPU" << std::endl; + std::cout << "Training on CPU." << std::endl; device_type = torch::kCPU; } torch::Device device(device_type); @@ -132,38 +137,27 @@ auto main() -> int { Net model; model.to(device); - auto train_dataset = - torch::data::datasets::MNIST( - options.data_root, torch::data::datasets::MNIST::Mode::kTrain) - .map(Normalize(0.1307, 0.3081)) - .map(torch::data::transforms::Stack<>()); + auto train_dataset = torch::data::datasets::MNIST(kDataRoot) + .map(Normalize(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); const size_t train_dataset_size = train_dataset.size().value(); auto train_loader = torch::data::make_data_loader( - std::move(train_dataset), options.batch_size); + std::move(train_dataset), kTrainBatchSize); - auto test_dataset = - torch::data::datasets::MNIST( - options.data_root, torch::data::datasets::MNIST::Mode::kTest) - .map(Normalize(0.1307, 0.3081)) - .map(torch::data::transforms::Stack<>()); + auto test_dataset = torch::data::datasets::MNIST( + kDataRoot, torch::data::datasets::MNIST::Mode::kTest) + .map(Normalize(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); const size_t test_dataset_size = test_dataset.size().value(); - auto test_loader = torch::data::make_data_loader( - std::move(test_dataset), options.batch_size); + auto test_loader = + torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize); torch::optim::SGD optimizer( - model.parameters(), - torch::optim::SGDOptions(options.lr).momentum(options.momentum)); - - for (size_t epoch = 1; epoch <= options.epochs; ++epoch) { - train( - epoch, - options, - model, - device, - *train_loader, - optimizer, - train_dataset_size); + model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5)); + + for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { + train(epoch, model, device, *train_loader, optimizer, train_dataset_size); test(model, device, *test_loader, test_dataset_size); } } diff --git a/cpp/mnist/download_mnist.py b/cpp/tools/download_mnist.py similarity index 100% rename from cpp/mnist/download_mnist.py rename to cpp/tools/download_mnist.py From 76f4b5c22e3af846b6c3612559fac82f1e0f79fa Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Mon, 14 Jan 2019 08:42:54 -0800 Subject: [PATCH 3/3] Update for Normalize transform --- cpp/dcgan/dcgan.cpp | 2 +- cpp/mnist/mnist.cpp | 13 ++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/cpp/dcgan/dcgan.cpp b/cpp/dcgan/dcgan.cpp index f747fce54a..2928955194 100644 --- a/cpp/dcgan/dcgan.cpp +++ b/cpp/dcgan/dcgan.cpp @@ -96,7 +96,7 @@ int main(int argc, const char* argv[]) { // Assume the MNIST dataset is available under `kDataFolder`; auto dataset = torch::data::datasets::MNIST(kDataFolder) - .map(torch::data::transforms::Normalize(0.5, 0.5)) + .map(torch::data::transforms::Normalize<>(0.5, 0.5)) .map(torch::data::transforms::Stack<>()); const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast(kBatchSize)); diff --git a/cpp/mnist/mnist.cpp b/cpp/mnist/mnist.cpp index 880abcb82d..329fdc3a96 100644 --- a/cpp/mnist/mnist.cpp +++ b/cpp/mnist/mnist.cpp @@ -112,15 +112,6 @@ void test( static_cast(correct) / dataset_size); } -struct Normalize : public torch::data::transforms::TensorTransform<> { - Normalize(float mean, float stddev) - : mean_(torch::tensor(mean)), stddev_(torch::tensor(stddev)) {} - torch::Tensor operator()(torch::Tensor input) { - return input.sub(mean_).div(stddev_); - } - torch::Tensor mean_, stddev_; -}; - auto main() -> int { torch::manual_seed(1); @@ -138,7 +129,7 @@ auto main() -> int { model.to(device); auto train_dataset = torch::data::datasets::MNIST(kDataRoot) - .map(Normalize(0.1307, 0.3081)) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) .map(torch::data::transforms::Stack<>()); const size_t train_dataset_size = train_dataset.size().value(); auto train_loader = @@ -147,7 +138,7 @@ auto main() -> int { auto test_dataset = torch::data::datasets::MNIST( kDataRoot, torch::data::datasets::MNIST::Mode::kTest) - .map(Normalize(0.1307, 0.3081)) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) .map(torch::data::transforms::Stack<>()); const size_t test_dataset_size = test_dataset.size().value(); auto test_loader =