Skip to content

Commit 715e1a4

Browse files
authored
Initial incorporation of a general training loop (tensorflow#586)
This is the initial incorporation of a general callback-based training loop, originally designed by @sgugger and proposed as the DifferentiableStep option here. As a first step, the following models have been converted to use this new training loop in place of the previous custom loop: LeNet-MNIST ResNet-CIFAR10 MobileNetV1-Imagenette MobileNetV2-Imagenette An initial set of callbacks have been provided that draw an animated progress bar on the console during training, and display the average loss and top-1 classification accuracy. These metric updates can either be continuous during training and validation, or can appear only at the end of an epoch (this is a performance option, because currently training will slow by up to 30% if continuous updates are enabled). Which metrics to display, if any, are also configurable. By default, X10 is used where available for training models, and this loop fully supports X10 or eager mode devices. As a next step, all but one or two classification examples will be reworked to use this loop, and timing functionality will be introduced to have this be the default loop within our benchmarks. This pull request is now ready for review.
1 parent 4cb8c9b commit 715e1a4

18 files changed

+721
-292
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ add_subdirectory(GAN)
131131
add_subdirectory(DCGAN)
132132
add_subdirectory(FastStyleTransfer)
133133
add_subdirectory(Examples)
134+
add_subdirectory(TrainingLoop)
134135

135136
if(BUILD_TESTING)
136137
add_subdirectory(Tests)

Examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ add_subdirectory(NeuMF-MovieLens)
1010
add_subdirectory(GPT2-Inference)
1111
add_subdirectory(WordSeg)
1212
add_subdirectory(Fractals)
13+
add_subdirectory(VGG-Imagewoof)

Examples/LeNet-MNIST/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
add_executable(LeNet-MNIST
22
main.swift)
33
target_link_libraries(LeNet-MNIST PRIVATE
4+
Datasets
45
ImageClassificationModels
5-
Datasets)
6+
TrainingLoop)
67

78

89
install(TARGETS LeNet-MNIST

Examples/LeNet-MNIST/main.swift

+13-80
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,23 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import TensorFlow
1615
import Datasets
16+
import TensorFlow
17+
import TrainingLoop
1718

1819
let epochCount = 12
1920
let batchSize = 128
2021

21-
// Until https://github.com/tensorflow/swift-models/issues/588 is fixed, default to the eager-mode
22+
// Until https://github.com/tensorflow/swift-apis/issues/993 is fixed, default to the eager-mode
2223
// device on macOS instead of X10.
2324
#if os(macOS)
2425
let device = Device.defaultTFEager
2526
#else
2627
let device = Device.defaultXLA
2728
#endif
2829

29-
let dataset = MNIST(batchSize: batchSize)
30+
let dataset = MNIST(batchSize: batchSize, on: device)
31+
3032
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
3133
var classifier = Sequential {
3234
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
@@ -38,84 +40,15 @@ var classifier = Sequential {
3840
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
3941
Dense<Float>(inputSize: 84, outputSize: 10)
4042
}
41-
classifier.move(to: device)
4243

4344
var optimizer = SGD(for: classifier, learningRate: 0.1)
44-
optimizer = SGD(copying: optimizer, to: device)
45-
46-
print("Beginning training...")
47-
48-
struct Statistics {
49-
var correctGuessCount = Tensor<Int32>(0, on: Device.default)
50-
var totalGuessCount = Tensor<Int32>(0, on: Device.default)
51-
var totalLoss = Tensor<Float>(0, on: Device.default)
52-
var batches: Int = 0
53-
54-
var accuracy: Float {
55-
Float(correctGuessCount.scalarized()) / Float(totalGuessCount.scalarized()) * 100
56-
}
5745

58-
var averageLoss: Float {
59-
totalLoss.scalarized() / Float(batches)
60-
}
46+
let trainingProgress = TrainingProgress()
47+
var trainingLoop = TrainingLoop(
48+
training: dataset.training,
49+
validation: dataset.validation,
50+
optimizer: optimizer,
51+
lossFunction: softmaxCrossEntropy,
52+
callbacks: [trainingProgress.update])
6153

62-
init(on device: Device = Device.default) {
63-
correctGuessCount = Tensor<Int32>(0, on: device)
64-
totalGuessCount = Tensor<Int32>(0, on: device)
65-
totalLoss = Tensor<Float>(0, on: device)
66-
}
67-
68-
mutating func update(logits: Tensor<Float>, labels: Tensor<Int32>, loss: Tensor<Float>) {
69-
let correct = logits.argmax(squeezingAxis: 1) .== labels
70-
correctGuessCount += Tensor<Int32>(correct).sum()
71-
totalGuessCount += Int32(labels.shape[0])
72-
totalLoss += loss
73-
batches += 1
74-
}
75-
}
76-
77-
// The training loop.
78-
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
79-
var trainStats = Statistics(on: device)
80-
var testStats = Statistics(on: device)
81-
82-
Context.local.learningPhase = .training
83-
for batch in epochBatches {
84-
let (eagerImages, eagerLabels) = (batch.data, batch.label)
85-
let images = Tensor(copying: eagerImages, to: device)
86-
let labels = Tensor(copying: eagerLabels, to: device)
87-
// Compute the gradient with respect to the model.
88-
let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor<Float> in
89-
let ŷ = classifier(images)
90-
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
91-
trainStats.update(logits: ŷ, labels: labels, loss: loss)
92-
return loss
93-
}
94-
// Update the model's differentiable variables along the gradient vector.
95-
optimizer.update(&classifier, along: 𝛁model)
96-
LazyTensorBarrier()
97-
}
98-
99-
Context.local.learningPhase = .inference
100-
for batch in dataset.validation {
101-
let (eagerImages, eagerLabels) = (batch.data, batch.label)
102-
let images = Tensor(copying: eagerImages, to: device)
103-
let labels = Tensor(copying: eagerLabels, to: device)
104-
// Compute loss on test set
105-
let ŷ = classifier(images)
106-
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
107-
LazyTensorBarrier()
108-
testStats.update(logits: ŷ, labels: labels, loss: loss)
109-
}
110-
111-
print(
112-
"""
113-
[Epoch \(epoch + 1)] \
114-
Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \
115-
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
116-
(\(String(format: "%.1f", trainStats.accuracy))%), \
117-
Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \
118-
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
119-
(\(String(format: "%.3f", testStats.accuracy))%)
120-
""")
121-
}
54+
try! trainingLoop.fit(&classifier, epochs: epochCount, on: device)

Examples/MobileNetV1-Imagenette/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
add_executable(MobileNetV1-Imagenette
22
main.swift)
33
target_link_libraries(MobileNetV1-Imagenette PRIVATE
4+
Datasets
45
ImageClassificationModels
5-
Datasets)
6+
TrainingLoop)
67

78

89
install(TARGETS MobileNetV1-Imagenette

Examples/MobileNetV1-Imagenette/main.swift

+17-52
Original file line numberDiff line numberDiff line change
@@ -15,61 +15,26 @@
1515
import Datasets
1616
import ImageClassificationModels
1717
import TensorFlow
18+
import TrainingLoop
1819

19-
let epochCount = 10
20-
let batchSize = 64
21-
22-
let dataset = Imagenette(
23-
batchSize: batchSize,
24-
inputSize: .resized320,
25-
outputSize: 224
26-
)
20+
// Until https://github.com/tensorflow/swift-apis/issues/993 is fixed, default to the eager-mode
21+
// device on macOS instead of X10.
22+
#if os(macOS)
23+
let device = Device.defaultTFEager
24+
#else
25+
let device = Device.defaultXLA
26+
#endif
2727

28+
let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224, on: device)
2829
var model = MobileNetV1(classCount: 10)
29-
3030
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9)
3131

32-
print("Starting training...")
33-
34-
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
35-
Context.local.learningPhase = .training
36-
var trainingLossSum: Float = 0
37-
var trainingBatchCount = 0
38-
for batch in epochBatches {
39-
let (images, labels) = (batch.data, batch.label)
40-
let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
41-
let logits = model(images)
42-
return softmaxCrossEntropy(logits: logits, labels: labels)
43-
}
44-
trainingLossSum += loss.scalarized()
45-
trainingBatchCount += 1
46-
optimizer.update(&model, along: gradients)
47-
}
48-
49-
Context.local.learningPhase = .inference
50-
var testLossSum: Float = 0
51-
var testBatchCount = 0
52-
var correctGuessCount = 0
53-
var totalGuessCount = 0
54-
for batch in dataset.validation {
55-
let (images, labels) = (batch.data, batch.label)
56-
let logits = model(images)
57-
testLossSum += softmaxCrossEntropy(logits: logits, labels: labels).scalarized()
58-
testBatchCount += 1
59-
60-
let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
61-
correctGuessCount = correctGuessCount
62-
+ Int(
63-
Tensor<Int32>(correctPredictions).sum().scalarized())
64-
totalGuessCount = totalGuessCount + batch.data.shape[0]
65-
}
32+
let trainingProgress = TrainingProgress()
33+
var trainingLoop = TrainingLoop(
34+
training: dataset.training,
35+
validation: dataset.validation,
36+
optimizer: optimizer,
37+
lossFunction: softmaxCrossEntropy,
38+
callbacks: [trainingProgress.update])
6639

67-
let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
68-
print(
69-
"""
70-
[Epoch \(epoch)] \
71-
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
72-
Loss: \(testLossSum / Float(testBatchCount))
73-
"""
74-
)
75-
}
40+
try! trainingLoop.fit(&model, epochs: 10, on: device)

Examples/MobileNetV2-Imagenette/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
add_executable(MobileNetV2-Imagenette
22
main.swift)
33
target_link_libraries(MobileNetV2-Imagenette PRIVATE
4+
Datasets
45
ImageClassificationModels
5-
Datasets)
6+
TrainingLoop)
67

78

89
install(TARGETS MobileNetV2-Imagenette

Examples/MobileNetV2-Imagenette/main.swift

+17-52
Original file line numberDiff line numberDiff line change
@@ -15,61 +15,26 @@
1515
import Datasets
1616
import ImageClassificationModels
1717
import TensorFlow
18+
import TrainingLoop
1819

19-
let epochCount = 10
20-
let batchSize = 64
21-
22-
let dataset = Imagenette(
23-
batchSize: batchSize,
24-
inputSize: .resized320,
25-
outputSize: 224
26-
)
20+
// Until https://github.com/tensorflow/swift-apis/issues/993 is fixed, default to the eager-mode
21+
// device on macOS instead of X10.
22+
#if os(macOS)
23+
let device = Device.defaultTFEager
24+
#else
25+
let device = Device.defaultXLA
26+
#endif
2727

28+
let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224, on: device)
2829
var model = MobileNetV2(classCount: 10)
29-
3030
let optimizer = SGD(for: model, learningRate: 0.002, momentum: 0.9)
3131

32-
print("Starting training...")
33-
34-
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
35-
Context.local.learningPhase = .training
36-
var trainingLossSum: Float = 0
37-
var trainingBatchCount = 0
38-
for batch in epochBatches {
39-
let (images, labels) = (batch.data, batch.label)
40-
let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
41-
let logits = model(images)
42-
return softmaxCrossEntropy(logits: logits, labels: labels)
43-
}
44-
trainingLossSum += loss.scalarized()
45-
trainingBatchCount += 1
46-
optimizer.update(&model, along: gradients)
47-
}
48-
49-
Context.local.learningPhase = .inference
50-
var testLossSum: Float = 0
51-
var testBatchCount = 0
52-
var correctGuessCount = 0
53-
var totalGuessCount = 0
54-
for batch in dataset.validation {
55-
let (images, labels) = (batch.data, batch.label)
56-
let logits = model(images)
57-
testLossSum += softmaxCrossEntropy(logits: logits, labels: labels).scalarized()
58-
testBatchCount += 1
59-
60-
let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
61-
correctGuessCount = correctGuessCount
62-
+ Int(
63-
Tensor<Int32>(correctPredictions).sum().scalarized())
64-
totalGuessCount = totalGuessCount + batch.data.shape[0]
65-
}
32+
let trainingProgress = TrainingProgress()
33+
var trainingLoop = TrainingLoop(
34+
training: dataset.training,
35+
validation: dataset.validation,
36+
optimizer: optimizer,
37+
lossFunction: softmaxCrossEntropy,
38+
callbacks: [trainingProgress.update])
6639

67-
let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
68-
print(
69-
"""
70-
[Epoch \(epoch)] \
71-
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
72-
Loss: \(testLossSum / Float(testBatchCount))
73-
"""
74-
)
75-
}
40+
try! trainingLoop.fit(&model, epochs: 10, on: device)

Examples/ResNet-CIFAR10/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ add_executable(ResNet-CIFAR10
22
main.swift)
33
target_link_libraries(ResNet-CIFAR10 PRIVATE
44
ImageClassificationModels
5-
Datasets)
5+
Datasets
6+
TrainingLoop)
67

78

89
install(TARGETS ResNet-CIFAR10

0 commit comments

Comments
 (0)