Skip to content

Commit 29a38c6

Browse files
goldsboroughsoumith
authored andcommitted
Add cpp folder for C++ frontend examples (#492)
* Create C++ version of MNIST example * Create C++ version of DCGAN example * Update for Normalize transform
1 parent 5d27fdb commit 29a38c6

10 files changed

+678
-0
lines changed

Diff for: .gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ dcgan/data
22
data
33
*.pyc
44
OpenNMT/data
5+
cpp/mnist/build
6+
cpp/dcgan/build

Diff for: cpp/.clang-format

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
---
2+
AccessModifierOffset: -1
3+
AlignAfterOpenBracket: AlwaysBreak
4+
AlignConsecutiveAssignments: false
5+
AlignConsecutiveDeclarations: false
6+
AlignEscapedNewlinesLeft: true
7+
AlignOperands: false
8+
AlignTrailingComments: false
9+
AllowAllParametersOfDeclarationOnNextLine: false
10+
AllowShortBlocksOnASingleLine: false
11+
AllowShortCaseLabelsOnASingleLine: false
12+
AllowShortFunctionsOnASingleLine: Empty
13+
AllowShortIfStatementsOnASingleLine: false
14+
AllowShortLoopsOnASingleLine: false
15+
AlwaysBreakAfterReturnType: None
16+
AlwaysBreakBeforeMultilineStrings: true
17+
AlwaysBreakTemplateDeclarations: true
18+
BinPackArguments: false
19+
BinPackParameters: false
20+
BraceWrapping:
21+
AfterClass: false
22+
AfterControlStatement: false
23+
AfterEnum: false
24+
AfterFunction: false
25+
AfterNamespace: false
26+
AfterObjCDeclaration: false
27+
AfterStruct: false
28+
AfterUnion: false
29+
BeforeCatch: false
30+
BeforeElse: false
31+
IndentBraces: false
32+
BreakBeforeBinaryOperators: None
33+
BreakBeforeBraces: Attach
34+
BreakBeforeTernaryOperators: true
35+
BreakConstructorInitializersBeforeComma: false
36+
BreakAfterJavaFieldAnnotations: false
37+
BreakStringLiterals: false
38+
ColumnLimit: 80
39+
CommentPragmas: '^ IWYU pragma:'
40+
CompactNamespaces: false
41+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
42+
ConstructorInitializerIndentWidth: 4
43+
ContinuationIndentWidth: 4
44+
Cpp11BracedListStyle: true
45+
DerivePointerAlignment: false
46+
DisableFormat: false
47+
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
48+
IncludeCategories:
49+
- Regex: '^<.*\.h(pp)?>'
50+
Priority: 1
51+
- Regex: '^<.*'
52+
Priority: 2
53+
- Regex: '.*'
54+
Priority: 3
55+
IndentCaseLabels: true
56+
IndentWidth: 2
57+
IndentWrappedFunctionNames: false
58+
KeepEmptyLinesAtTheStartOfBlocks: false
59+
MacroBlockBegin: ''
60+
MacroBlockEnd: ''
61+
MaxEmptyLinesToKeep: 1
62+
NamespaceIndentation: None
63+
ObjCBlockIndentWidth: 2
64+
ObjCSpaceAfterProperty: false
65+
ObjCSpaceBeforeProtocolList: false
66+
PenaltyBreakBeforeFirstCallParameter: 1
67+
PenaltyBreakComment: 300
68+
PenaltyBreakFirstLessLess: 120
69+
PenaltyBreakString: 1000
70+
PenaltyExcessCharacter: 1000000
71+
PenaltyReturnTypeOnItsOwnLine: 2000000
72+
PointerAlignment: Left
73+
ReflowComments: true
74+
SortIncludes: true
75+
SpaceAfterCStyleCast: false
76+
SpaceBeforeAssignmentOperators: true
77+
SpaceBeforeParens: ControlStatements
78+
SpaceInEmptyParentheses: false
79+
SpacesBeforeTrailingComments: 1
80+
SpacesInAngles: false
81+
SpacesInContainerLiterals: true
82+
SpacesInCStyleCastParentheses: false
83+
SpacesInParentheses: false
84+
SpacesInSquareBrackets: false
85+
Standard: Cpp11
86+
TabWidth: 8
87+
UseTab: Never
88+
...

Diff for: cpp/dcgan/CMakeLists.txt

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
project(dcgan)
3+
4+
find_package(Torch REQUIRED)
5+
6+
option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
7+
if (DOWNLOAD_MNIST)
8+
message(STATUS "Downloading MNIST dataset")
9+
execute_process(
10+
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
11+
-d ${CMAKE_BINARY_DIR}/data
12+
ERROR_VARIABLE DOWNLOAD_ERROR)
13+
if (DOWNLOAD_ERROR)
14+
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
15+
endif()
16+
endif()
17+
18+
add_executable(dcgan dcgan.cpp)
19+
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
20+
set_property(TARGET dcgan PROPERTY CXX_STANDARD 11)

Diff for: cpp/dcgan/README.md

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# DCGAN Example with the PyTorch C++ Frontend
2+
3+
This folder contains an example of training a DCGAN to generate MNIST digits
4+
with the PyTorch C++ frontend.
5+
6+
The entire training code is contained in `dcgan.cpp`.
7+
8+
To build the code, run the following commands from your terminal:
9+
10+
```shell
11+
$ cd dcgan
12+
$ mkdir build
13+
$ cd build
14+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
15+
$ make
16+
```
17+
18+
where `/path/to/libtorch` should be the path to the unzipped *LibTorch*
19+
distribution, which you can get from the [PyTorch
20+
homepage](https://pytorch.org/get-started/locally/).
21+
22+
Execute the compiled binary to train the model:
23+
24+
```shell
25+
$ ./dcgan
26+
[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
27+
-> checkpoint 1
28+
[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
29+
-> checkpoint 2
30+
[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
31+
-> checkpoint 3
32+
[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
33+
-> checkpoint 4
34+
[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
35+
-> checkpoint 5
36+
[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
37+
-> checkpoint 6
38+
[ 2/30][600/938] D_loss: 0.3815 | G_loss: 3.5696
39+
-> checkpoint 7
40+
[ 2/30][800/938] D_loss: 0.4039 | G_loss: 3.2759
41+
-> checkpoint 8
42+
[ 3/30][200/938] D_loss: 0.4236 | G_loss: 4.5132
43+
-> checkpoint 9
44+
[ 3/30][400/938] D_loss: 0.3645 | G_loss: 3.9759
45+
-> checkpoint 10
46+
...
47+
```
48+
49+
The training script periodically generates image samples. Use the
50+
`display_samples.py` script situated in this folder to generate a plot image.
51+
For example:
52+
53+
```shell
54+
$ python display_samples.py -i dcgan-sample-10.png
55+
Saved out.png
56+
```

Diff for: cpp/dcgan/dcgan.cpp

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include <torch/torch.h>
2+
3+
#include <cmath>
4+
#include <cstdio>
5+
#include <iostream>
6+
7+
// The size of the noise vector fed to the generator.
8+
const int64_t kNoiseSize = 100;
9+
10+
// The batch size for training.
11+
const int64_t kBatchSize = 64;
12+
13+
// The number of epochs to train.
14+
const int64_t kNumberOfEpochs = 30;
15+
16+
// Where to find the MNIST dataset.
17+
const char* kDataFolder = "./data";
18+
19+
// After how many batches to create a new checkpoint periodically.
20+
const int64_t kCheckpointEvery = 200;
21+
22+
// How many images to sample at every checkpoint.
23+
const int64_t kNumberOfSamplesPerCheckpoint = 10;
24+
25+
// Set to `true` to restore models and optimizers from previously saved
26+
// checkpoints.
27+
const bool kRestoreFromCheckpoint = false;
28+
29+
// After how many batches to log a new update with the loss value.
30+
const int64_t kLogInterval = 10;
31+
32+
using namespace torch;
33+
34+
int main(int argc, const char* argv[]) {
35+
torch::manual_seed(1);
36+
37+
// Create the device we pass around based on whether CUDA is available.
38+
torch::Device device(torch::kCPU);
39+
if (torch::cuda::is_available()) {
40+
std::cout << "CUDA is available! Training on GPU." << std::endl;
41+
device = torch::Device(torch::kCUDA);
42+
}
43+
44+
nn::Sequential generator(
45+
// Layer 1
46+
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4)
47+
.with_bias(false)
48+
.transposed(true)),
49+
nn::BatchNorm(256),
50+
nn::Functional(torch::relu),
51+
// Layer 2
52+
nn::Conv2d(nn::Conv2dOptions(256, 128, 3)
53+
.stride(2)
54+
.padding(1)
55+
.with_bias(false)
56+
.transposed(true)),
57+
nn::BatchNorm(128),
58+
nn::Functional(torch::relu),
59+
// Layer 3
60+
nn::Conv2d(nn::Conv2dOptions(128, 64, 4)
61+
.stride(2)
62+
.padding(1)
63+
.with_bias(false)
64+
.transposed(true)),
65+
nn::BatchNorm(64),
66+
nn::Functional(torch::relu),
67+
// Layer 4
68+
nn::Conv2d(nn::Conv2dOptions(64, 1, 4)
69+
.stride(2)
70+
.padding(1)
71+
.with_bias(false)
72+
.transposed(true)),
73+
nn::Functional(torch::tanh));
74+
generator->to(device);
75+
76+
nn::Sequential discriminator(
77+
// Layer 1
78+
nn::Conv2d(
79+
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
80+
nn::Functional(torch::leaky_relu, 0.2),
81+
// Layer 2
82+
nn::Conv2d(
83+
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)),
84+
nn::BatchNorm(128),
85+
nn::Functional(torch::leaky_relu, 0.2),
86+
// Layer 3
87+
nn::Conv2d(
88+
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)),
89+
nn::BatchNorm(256),
90+
nn::Functional(torch::leaky_relu, 0.2),
91+
// Layer 4
92+
nn::Conv2d(
93+
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
94+
nn::Functional(torch::sigmoid));
95+
discriminator->to(device);
96+
97+
// Assume the MNIST dataset is available under `kDataFolder`;
98+
auto dataset = torch::data::datasets::MNIST(kDataFolder)
99+
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
100+
.map(torch::data::transforms::Stack<>());
101+
const int64_t batches_per_epoch =
102+
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
103+
104+
auto data_loader = torch::data::make_data_loader(
105+
std::move(dataset),
106+
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
107+
108+
torch::optim::Adam generator_optimizer(
109+
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
110+
torch::optim::Adam discriminator_optimizer(
111+
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
112+
113+
if (kRestoreFromCheckpoint) {
114+
torch::load(generator, "generator-checkpoint.pt");
115+
torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
116+
torch::load(discriminator, "discriminator-checkpoint.pt");
117+
torch::load(
118+
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
119+
}
120+
121+
int64_t checkpoint_counter = 1;
122+
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
123+
int64_t batch_index = 0;
124+
for (torch::data::Example<>& batch : *data_loader) {
125+
// Train discriminator with real images.
126+
discriminator->zero_grad();
127+
torch::Tensor real_images = batch.data.to(device);
128+
torch::Tensor real_labels =
129+
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
130+
torch::Tensor real_output = discriminator->forward(real_images);
131+
torch::Tensor d_loss_real =
132+
torch::binary_cross_entropy(real_output, real_labels);
133+
d_loss_real.backward();
134+
135+
// Train discriminator with fake images.
136+
torch::Tensor noise =
137+
torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
138+
torch::Tensor fake_images = generator->forward(noise);
139+
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
140+
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
141+
torch::Tensor d_loss_fake =
142+
torch::binary_cross_entropy(fake_output, fake_labels);
143+
d_loss_fake.backward();
144+
145+
torch::Tensor d_loss = d_loss_real + d_loss_fake;
146+
discriminator_optimizer.step();
147+
148+
// Train generator.
149+
generator->zero_grad();
150+
fake_labels.fill_(1);
151+
fake_output = discriminator->forward(fake_images);
152+
torch::Tensor g_loss =
153+
torch::binary_cross_entropy(fake_output, fake_labels);
154+
g_loss.backward();
155+
generator_optimizer.step();
156+
157+
if (batch_index % kLogInterval == 0) {
158+
std::printf(
159+
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
160+
epoch,
161+
kNumberOfEpochs,
162+
++batch_index,
163+
batches_per_epoch,
164+
d_loss.item<float>(),
165+
g_loss.item<float>());
166+
}
167+
168+
if (batch_index % kCheckpointEvery == 0) {
169+
// Checkpoint the model and optimizer state.
170+
torch::save(generator, "generator-checkpoint.pt");
171+
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
172+
torch::save(discriminator, "discriminator-checkpoint.pt");
173+
torch::save(
174+
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
175+
// Sample the generator and save the images.
176+
torch::Tensor samples = generator->forward(torch::randn(
177+
{kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device));
178+
torch::save(
179+
(samples + 1.0) / 2.0,
180+
torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
181+
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
182+
}
183+
}
184+
}
185+
186+
std::cout << "Training complete!" << std::endl;
187+
}

Diff for: cpp/dcgan/display_samples.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import print_function
2+
from __future__ import unicode_literals
3+
4+
import argparse
5+
6+
import matplotlib.pyplot as plt
7+
import torch
8+
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("-i", "--sample-file", required=True)
12+
parser.add_argument("-o", "--out-file", default="out.png")
13+
parser.add_argument("-d", "--dimension", type=int, default=3)
14+
options = parser.parse_args()
15+
16+
module = torch.jit.load(options.sample_file)
17+
images = list(module.parameters())[0]
18+
19+
for index in range(options.dimension * options.dimension):
20+
image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
21+
array = image.numpy()
22+
axis = plt.subplot(options.dimension, options.dimension, 1 + index)
23+
plt.imshow(array, cmap="gray")
24+
axis.get_xaxis().set_visible(False)
25+
axis.get_yaxis().set_visible(False)
26+
27+
plt.savefig(options.out_file)
28+
print("Saved ", options.out_file)

0 commit comments

Comments
 (0)