Skip to content

Commit 29a38c6

Browse files
goldsboroughsoumith
authored andcommittedJan 15, 2019
Add cpp folder for C++ frontend examples (pytorch#492)
* Create C++ version of MNIST example * Create C++ version of DCGAN example * Update for Normalize transform
1 parent 5d27fdb commit 29a38c6

10 files changed

+678
-0
lines changed
 

‎.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ dcgan/data
22
data
33
*.pyc
44
OpenNMT/data
5+
cpp/mnist/build
6+
cpp/dcgan/build

‎cpp/.clang-format

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
---
2+
AccessModifierOffset: -1
3+
AlignAfterOpenBracket: AlwaysBreak
4+
AlignConsecutiveAssignments: false
5+
AlignConsecutiveDeclarations: false
6+
AlignEscapedNewlinesLeft: true
7+
AlignOperands: false
8+
AlignTrailingComments: false
9+
AllowAllParametersOfDeclarationOnNextLine: false
10+
AllowShortBlocksOnASingleLine: false
11+
AllowShortCaseLabelsOnASingleLine: false
12+
AllowShortFunctionsOnASingleLine: Empty
13+
AllowShortIfStatementsOnASingleLine: false
14+
AllowShortLoopsOnASingleLine: false
15+
AlwaysBreakAfterReturnType: None
16+
AlwaysBreakBeforeMultilineStrings: true
17+
AlwaysBreakTemplateDeclarations: true
18+
BinPackArguments: false
19+
BinPackParameters: false
20+
BraceWrapping:
21+
AfterClass: false
22+
AfterControlStatement: false
23+
AfterEnum: false
24+
AfterFunction: false
25+
AfterNamespace: false
26+
AfterObjCDeclaration: false
27+
AfterStruct: false
28+
AfterUnion: false
29+
BeforeCatch: false
30+
BeforeElse: false
31+
IndentBraces: false
32+
BreakBeforeBinaryOperators: None
33+
BreakBeforeBraces: Attach
34+
BreakBeforeTernaryOperators: true
35+
BreakConstructorInitializersBeforeComma: false
36+
BreakAfterJavaFieldAnnotations: false
37+
BreakStringLiterals: false
38+
ColumnLimit: 80
39+
CommentPragmas: '^ IWYU pragma:'
40+
CompactNamespaces: false
41+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
42+
ConstructorInitializerIndentWidth: 4
43+
ContinuationIndentWidth: 4
44+
Cpp11BracedListStyle: true
45+
DerivePointerAlignment: false
46+
DisableFormat: false
47+
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
48+
IncludeCategories:
49+
- Regex: '^<.*\.h(pp)?>'
50+
Priority: 1
51+
- Regex: '^<.*'
52+
Priority: 2
53+
- Regex: '.*'
54+
Priority: 3
55+
IndentCaseLabels: true
56+
IndentWidth: 2
57+
IndentWrappedFunctionNames: false
58+
KeepEmptyLinesAtTheStartOfBlocks: false
59+
MacroBlockBegin: ''
60+
MacroBlockEnd: ''
61+
MaxEmptyLinesToKeep: 1
62+
NamespaceIndentation: None
63+
ObjCBlockIndentWidth: 2
64+
ObjCSpaceAfterProperty: false
65+
ObjCSpaceBeforeProtocolList: false
66+
PenaltyBreakBeforeFirstCallParameter: 1
67+
PenaltyBreakComment: 300
68+
PenaltyBreakFirstLessLess: 120
69+
PenaltyBreakString: 1000
70+
PenaltyExcessCharacter: 1000000
71+
PenaltyReturnTypeOnItsOwnLine: 2000000
72+
PointerAlignment: Left
73+
ReflowComments: true
74+
SortIncludes: true
75+
SpaceAfterCStyleCast: false
76+
SpaceBeforeAssignmentOperators: true
77+
SpaceBeforeParens: ControlStatements
78+
SpaceInEmptyParentheses: false
79+
SpacesBeforeTrailingComments: 1
80+
SpacesInAngles: false
81+
SpacesInContainerLiterals: true
82+
SpacesInCStyleCastParentheses: false
83+
SpacesInParentheses: false
84+
SpacesInSquareBrackets: false
85+
Standard: Cpp11
86+
TabWidth: 8
87+
UseTab: Never
88+
...

‎cpp/dcgan/CMakeLists.txt

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
project(dcgan)
3+
4+
find_package(Torch REQUIRED)
5+
6+
option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
7+
if (DOWNLOAD_MNIST)
8+
message(STATUS "Downloading MNIST dataset")
9+
execute_process(
10+
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
11+
-d ${CMAKE_BINARY_DIR}/data
12+
ERROR_VARIABLE DOWNLOAD_ERROR)
13+
if (DOWNLOAD_ERROR)
14+
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
15+
endif()
16+
endif()
17+
18+
add_executable(dcgan dcgan.cpp)
19+
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
20+
set_property(TARGET dcgan PROPERTY CXX_STANDARD 11)

‎cpp/dcgan/README.md

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# DCGAN Example with the PyTorch C++ Frontend
2+
3+
This folder contains an example of training a DCGAN to generate MNIST digits
4+
with the PyTorch C++ frontend.
5+
6+
The entire training code is contained in `dcgan.cpp`.
7+
8+
To build the code, run the following commands from your terminal:
9+
10+
```shell
11+
$ cd dcgan
12+
$ mkdir build
13+
$ cd build
14+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
15+
$ make
16+
```
17+
18+
where `/path/to/libtorch` should be the path to the unzipped *LibTorch*
19+
distribution, which you can get from the [PyTorch
20+
homepage](https://pytorch.org/get-started/locally/).
21+
22+
Execute the compiled binary to train the model:
23+
24+
```shell
25+
$ ./dcgan
26+
[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
27+
-> checkpoint 1
28+
[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
29+
-> checkpoint 2
30+
[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
31+
-> checkpoint 3
32+
[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
33+
-> checkpoint 4
34+
[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
35+
-> checkpoint 5
36+
[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
37+
-> checkpoint 6
38+
[ 2/30][600/938] D_loss: 0.3815 | G_loss: 3.5696
39+
-> checkpoint 7
40+
[ 2/30][800/938] D_loss: 0.4039 | G_loss: 3.2759
41+
-> checkpoint 8
42+
[ 3/30][200/938] D_loss: 0.4236 | G_loss: 4.5132
43+
-> checkpoint 9
44+
[ 3/30][400/938] D_loss: 0.3645 | G_loss: 3.9759
45+
-> checkpoint 10
46+
...
47+
```
48+
49+
The training script periodically generates image samples. Use the
50+
`display_samples.py` script situated in this folder to generate a plot image.
51+
For example:
52+
53+
```shell
54+
$ python display_samples.py -i dcgan-sample-10.png
55+
Saved out.png
56+
```

‎cpp/dcgan/dcgan.cpp

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include <torch/torch.h>
2+
3+
#include <cmath>
4+
#include <cstdio>
5+
#include <iostream>
6+
7+
// The size of the noise vector fed to the generator.
8+
const int64_t kNoiseSize = 100;
9+
10+
// The batch size for training.
11+
const int64_t kBatchSize = 64;
12+
13+
// The number of epochs to train.
14+
const int64_t kNumberOfEpochs = 30;
15+
16+
// Where to find the MNIST dataset.
17+
const char* kDataFolder = "./data";
18+
19+
// After how many batches to create a new checkpoint periodically.
20+
const int64_t kCheckpointEvery = 200;
21+
22+
// How many images to sample at every checkpoint.
23+
const int64_t kNumberOfSamplesPerCheckpoint = 10;
24+
25+
// Set to `true` to restore models and optimizers from previously saved
26+
// checkpoints.
27+
const bool kRestoreFromCheckpoint = false;
28+
29+
// After how many batches to log a new update with the loss value.
30+
const int64_t kLogInterval = 10;
31+
32+
using namespace torch;
33+
34+
int main(int argc, const char* argv[]) {
35+
torch::manual_seed(1);
36+
37+
// Create the device we pass around based on whether CUDA is available.
38+
torch::Device device(torch::kCPU);
39+
if (torch::cuda::is_available()) {
40+
std::cout << "CUDA is available! Training on GPU." << std::endl;
41+
device = torch::Device(torch::kCUDA);
42+
}
43+
44+
nn::Sequential generator(
45+
// Layer 1
46+
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4)
47+
.with_bias(false)
48+
.transposed(true)),
49+
nn::BatchNorm(256),
50+
nn::Functional(torch::relu),
51+
// Layer 2
52+
nn::Conv2d(nn::Conv2dOptions(256, 128, 3)
53+
.stride(2)
54+
.padding(1)
55+
.with_bias(false)
56+
.transposed(true)),
57+
nn::BatchNorm(128),
58+
nn::Functional(torch::relu),
59+
// Layer 3
60+
nn::Conv2d(nn::Conv2dOptions(128, 64, 4)
61+
.stride(2)
62+
.padding(1)
63+
.with_bias(false)
64+
.transposed(true)),
65+
nn::BatchNorm(64),
66+
nn::Functional(torch::relu),
67+
// Layer 4
68+
nn::Conv2d(nn::Conv2dOptions(64, 1, 4)
69+
.stride(2)
70+
.padding(1)
71+
.with_bias(false)
72+
.transposed(true)),
73+
nn::Functional(torch::tanh));
74+
generator->to(device);
75+
76+
nn::Sequential discriminator(
77+
// Layer 1
78+
nn::Conv2d(
79+
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
80+
nn::Functional(torch::leaky_relu, 0.2),
81+
// Layer 2
82+
nn::Conv2d(
83+
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)),
84+
nn::BatchNorm(128),
85+
nn::Functional(torch::leaky_relu, 0.2),
86+
// Layer 3
87+
nn::Conv2d(
88+
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)),
89+
nn::BatchNorm(256),
90+
nn::Functional(torch::leaky_relu, 0.2),
91+
// Layer 4
92+
nn::Conv2d(
93+
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
94+
nn::Functional(torch::sigmoid));
95+
discriminator->to(device);
96+
97+
// Assume the MNIST dataset is available under `kDataFolder`;
98+
auto dataset = torch::data::datasets::MNIST(kDataFolder)
99+
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
100+
.map(torch::data::transforms::Stack<>());
101+
const int64_t batches_per_epoch =
102+
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
103+
104+
auto data_loader = torch::data::make_data_loader(
105+
std::move(dataset),
106+
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
107+
108+
torch::optim::Adam generator_optimizer(
109+
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
110+
torch::optim::Adam discriminator_optimizer(
111+
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
112+
113+
if (kRestoreFromCheckpoint) {
114+
torch::load(generator, "generator-checkpoint.pt");
115+
torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
116+
torch::load(discriminator, "discriminator-checkpoint.pt");
117+
torch::load(
118+
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
119+
}
120+
121+
int64_t checkpoint_counter = 1;
122+
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
123+
int64_t batch_index = 0;
124+
for (torch::data::Example<>& batch : *data_loader) {
125+
// Train discriminator with real images.
126+
discriminator->zero_grad();
127+
torch::Tensor real_images = batch.data.to(device);
128+
torch::Tensor real_labels =
129+
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
130+
torch::Tensor real_output = discriminator->forward(real_images);
131+
torch::Tensor d_loss_real =
132+
torch::binary_cross_entropy(real_output, real_labels);
133+
d_loss_real.backward();
134+
135+
// Train discriminator with fake images.
136+
torch::Tensor noise =
137+
torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
138+
torch::Tensor fake_images = generator->forward(noise);
139+
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
140+
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
141+
torch::Tensor d_loss_fake =
142+
torch::binary_cross_entropy(fake_output, fake_labels);
143+
d_loss_fake.backward();
144+
145+
torch::Tensor d_loss = d_loss_real + d_loss_fake;
146+
discriminator_optimizer.step();
147+
148+
// Train generator.
149+
generator->zero_grad();
150+
fake_labels.fill_(1);
151+
fake_output = discriminator->forward(fake_images);
152+
torch::Tensor g_loss =
153+
torch::binary_cross_entropy(fake_output, fake_labels);
154+
g_loss.backward();
155+
generator_optimizer.step();
156+
157+
if (batch_index % kLogInterval == 0) {
158+
std::printf(
159+
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
160+
epoch,
161+
kNumberOfEpochs,
162+
++batch_index,
163+
batches_per_epoch,
164+
d_loss.item<float>(),
165+
g_loss.item<float>());
166+
}
167+
168+
if (batch_index % kCheckpointEvery == 0) {
169+
// Checkpoint the model and optimizer state.
170+
torch::save(generator, "generator-checkpoint.pt");
171+
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
172+
torch::save(discriminator, "discriminator-checkpoint.pt");
173+
torch::save(
174+
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
175+
// Sample the generator and save the images.
176+
torch::Tensor samples = generator->forward(torch::randn(
177+
{kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device));
178+
torch::save(
179+
(samples + 1.0) / 2.0,
180+
torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
181+
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
182+
}
183+
}
184+
}
185+
186+
std::cout << "Training complete!" << std::endl;
187+
}

‎cpp/dcgan/display_samples.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import print_function
2+
from __future__ import unicode_literals
3+
4+
import argparse
5+
6+
import matplotlib.pyplot as plt
7+
import torch
8+
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("-i", "--sample-file", required=True)
12+
parser.add_argument("-o", "--out-file", default="out.png")
13+
parser.add_argument("-d", "--dimension", type=int, default=3)
14+
options = parser.parse_args()
15+
16+
module = torch.jit.load(options.sample_file)
17+
images = list(module.parameters())[0]
18+
19+
for index in range(options.dimension * options.dimension):
20+
image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
21+
array = image.numpy()
22+
axis = plt.subplot(options.dimension, options.dimension, 1 + index)
23+
plt.imshow(array, cmap="gray")
24+
axis.get_xaxis().set_visible(False)
25+
axis.get_yaxis().set_visible(False)
26+
27+
plt.savefig(options.out_file)
28+
print("Saved ", options.out_file)

‎cpp/mnist/CMakeLists.txt

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(mnist)
3+
4+
find_package(Torch REQUIRED)
5+
6+
option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
7+
if (DOWNLOAD_MNIST)
8+
message(STATUS "Downloading MNIST dataset")
9+
execute_process(
10+
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
11+
-d ${CMAKE_BINARY_DIR}/data
12+
ERROR_VARIABLE DOWNLOAD_ERROR)
13+
if (DOWNLOAD_ERROR)
14+
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
15+
endif()
16+
endif()
17+
18+
add_executable(mnist mnist.cpp)
19+
target_compile_features(mnist PUBLIC cxx_range_for)
20+
target_link_libraries(mnist ${TORCH_LIBRARIES})

‎cpp/mnist/README.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# MNIST Example with the PyTorch C++ Frontend
2+
3+
This folder contains an example of training a computer vision model to recognize
4+
digits in images from the MNIST dataset, using the PyTorch C++ frontend.
5+
6+
The entire training code is contained in `mnist.cpp`.
7+
8+
To build the code, run the following commands from your terminal:
9+
10+
```shell
11+
$ cd mnist
12+
$ mkdir build
13+
$ cd build
14+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
15+
$ make
16+
```
17+
18+
where `/path/to/libtorch` should be the path to the unzipped *LibTorch*
19+
distribution, which you can get from the [PyTorch
20+
homepage](https://pytorch.org/get-started/locally/).
21+
22+
Execute the compiled binary to train the model:
23+
24+
```shell
25+
$ ./mnist
26+
Train Epoch: 1 [59584/60000] Loss: 0.4232
27+
Test set: Average loss: 0.1989 | Accuracy: 0.940
28+
Train Epoch: 2 [59584/60000] Loss: 0.1926
29+
Test set: Average loss: 0.1338 | Accuracy: 0.959
30+
Train Epoch: 3 [59584/60000] Loss: 0.1390
31+
Test set: Average loss: 0.0997 | Accuracy: 0.969
32+
Train Epoch: 4 [59584/60000] Loss: 0.1239
33+
Test set: Average loss: 0.0875 | Accuracy: 0.972
34+
...
35+
```

‎cpp/mnist/mnist.cpp

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#include <torch/torch.h>
2+
3+
#include <cstddef>
4+
#include <cstdio>
5+
#include <iostream>
6+
#include <string>
7+
#include <vector>
8+
9+
// Where to find the MNIST dataset.
10+
const char* kDataRoot = "./data";
11+
12+
// The batch size for training.
13+
const int64_t kTrainBatchSize = 64;
14+
15+
// The batch size for testing.
16+
const int64_t kTestBatchSize = 1000;
17+
18+
// The number of epochs to train.
19+
const int64_t kNumberOfEpochs = 10;
20+
21+
// After how many batches to log a new update with the loss value.
22+
const int64_t kLogInterval = 10;
23+
24+
struct Net : torch::nn::Module {
25+
Net()
26+
: conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
27+
conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
28+
fc1(320, 50),
29+
fc2(50, 10) {
30+
register_module("conv1", conv1);
31+
register_module("conv2", conv2);
32+
register_module("conv2_drop", conv2_drop);
33+
register_module("fc1", fc1);
34+
register_module("fc2", fc2);
35+
}
36+
37+
torch::Tensor forward(torch::Tensor x) {
38+
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
39+
x = torch::relu(
40+
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
41+
x = x.view({-1, 320});
42+
x = torch::relu(fc1->forward(x));
43+
x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
44+
x = fc2->forward(x);
45+
return torch::log_softmax(x, /*dim=*/1);
46+
}
47+
48+
torch::nn::Conv2d conv1;
49+
torch::nn::Conv2d conv2;
50+
torch::nn::FeatureDropout conv2_drop;
51+
torch::nn::Linear fc1;
52+
torch::nn::Linear fc2;
53+
};
54+
55+
template <typename DataLoader>
56+
void train(
57+
int32_t epoch,
58+
Net& model,
59+
torch::Device device,
60+
DataLoader& data_loader,
61+
torch::optim::Optimizer& optimizer,
62+
size_t dataset_size) {
63+
model.train();
64+
size_t batch_idx = 0;
65+
for (auto& batch : data_loader) {
66+
auto data = batch.data.to(device), targets = batch.target.to(device);
67+
optimizer.zero_grad();
68+
auto output = model.forward(data);
69+
auto loss = torch::nll_loss(output, targets);
70+
AT_ASSERT(!std::isnan(loss.template item<float>()));
71+
loss.backward();
72+
optimizer.step();
73+
74+
if (batch_idx++ % kLogInterval == 0) {
75+
std::printf(
76+
"\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f",
77+
epoch,
78+
batch_idx * batch.data.size(0),
79+
dataset_size,
80+
loss.template item<float>());
81+
}
82+
}
83+
}
84+
85+
template <typename DataLoader>
86+
void test(
87+
Net& model,
88+
torch::Device device,
89+
DataLoader& data_loader,
90+
size_t dataset_size) {
91+
torch::NoGradGuard no_grad;
92+
model.eval();
93+
double test_loss = 0;
94+
int32_t correct = 0;
95+
for (const auto& batch : data_loader) {
96+
auto data = batch.data.to(device), targets = batch.target.to(device);
97+
auto output = model.forward(data);
98+
test_loss += torch::nll_loss(
99+
output,
100+
targets,
101+
/*weight=*/{},
102+
Reduction::Sum)
103+
.template item<float>();
104+
auto pred = output.argmax(1);
105+
correct += pred.eq(targets).sum().template item<int64_t>();
106+
}
107+
108+
test_loss /= dataset_size;
109+
std::printf(
110+
"\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
111+
test_loss,
112+
static_cast<double>(correct) / dataset_size);
113+
}
114+
115+
auto main() -> int {
116+
torch::manual_seed(1);
117+
118+
torch::DeviceType device_type;
119+
if (torch::cuda::is_available()) {
120+
std::cout << "CUDA available! Training on GPU." << std::endl;
121+
device_type = torch::kCUDA;
122+
} else {
123+
std::cout << "Training on CPU." << std::endl;
124+
device_type = torch::kCPU;
125+
}
126+
torch::Device device(device_type);
127+
128+
Net model;
129+
model.to(device);
130+
131+
auto train_dataset = torch::data::datasets::MNIST(kDataRoot)
132+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
133+
.map(torch::data::transforms::Stack<>());
134+
const size_t train_dataset_size = train_dataset.size().value();
135+
auto train_loader =
136+
torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
137+
std::move(train_dataset), kTrainBatchSize);
138+
139+
auto test_dataset = torch::data::datasets::MNIST(
140+
kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
141+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
142+
.map(torch::data::transforms::Stack<>());
143+
const size_t test_dataset_size = test_dataset.size().value();
144+
auto test_loader =
145+
torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize);
146+
147+
torch::optim::SGD optimizer(
148+
model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));
149+
150+
for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
151+
train(epoch, model, device, *train_loader, optimizer, train_dataset_size);
152+
test(model, device, *test_loader, test_dataset_size);
153+
}
154+
}

‎cpp/tools/download_mnist.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
4+
import argparse
5+
import gzip
6+
import os
7+
import sys
8+
import urllib
9+
10+
try:
11+
from urllib.error import URLError
12+
from urllib.request import urlretrieve
13+
except ImportError:
14+
from urllib2 import URLError
15+
from urllib import urlretrieve
16+
17+
RESOURCES = [
18+
'train-images-idx3-ubyte.gz',
19+
'train-labels-idx1-ubyte.gz',
20+
't10k-images-idx3-ubyte.gz',
21+
't10k-labels-idx1-ubyte.gz',
22+
]
23+
24+
25+
def report_download_progress(chunk_number, chunk_size, file_size):
26+
if file_size != -1:
27+
percent = min(1, (chunk_number * chunk_size) / file_size)
28+
bar = '#' * int(64 * percent)
29+
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
30+
31+
32+
def download(destination_path, url, quiet):
33+
if os.path.exists(destination_path):
34+
if not quiet:
35+
print('{} already exists, skipping ...'.format(destination_path))
36+
else:
37+
print('Downloading {} ...'.format(url))
38+
try:
39+
hook = None if quiet else report_download_progress
40+
urlretrieve(url, destination_path, reporthook=hook)
41+
except URLError:
42+
raise RuntimeError('Error downloading resource!')
43+
finally:
44+
if not quiet:
45+
# Just a newline.
46+
print()
47+
48+
49+
def unzip(zipped_path, quiet):
50+
unzipped_path = os.path.splitext(zipped_path)[0]
51+
if os.path.exists(unzipped_path):
52+
if not quiet:
53+
print('{} already exists, skipping ... '.format(unzipped_path))
54+
return
55+
with gzip.open(zipped_path, 'rb') as zipped_file:
56+
with open(unzipped_path, 'wb') as unzipped_file:
57+
unzipped_file.write(zipped_file.read())
58+
if not quiet:
59+
print('Unzipped {} ...'.format(zipped_path))
60+
61+
62+
def main():
63+
parser = argparse.ArgumentParser(
64+
description='Download the MNIST dataset from the internet')
65+
parser.add_argument(
66+
'-d', '--destination', default='.', help='Destination directory')
67+
parser.add_argument(
68+
'-q',
69+
'--quiet',
70+
action='store_true',
71+
help="Don't report about progress")
72+
options = parser.parse_args()
73+
74+
if not os.path.exists(options.destination):
75+
os.makedirs(options.destination)
76+
77+
try:
78+
for resource in RESOURCES:
79+
path = os.path.join(options.destination, resource)
80+
url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
81+
download(path, url, options.quiet)
82+
unzip(path, options.quiet)
83+
except KeyboardInterrupt:
84+
print('Interrupted')
85+
86+
87+
if __name__ == '__main__':
88+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.