forked from tensorflow/swift-models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainingLoop.swift
342 lines (313 loc) · 12 KB
/
TrainingLoop.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ModelSupport
import TensorFlow
// Workaround https://bugs.swift.org/browse/TF-1122 that prevents us from registering a
// loss function inside our TrainingLoop struct
public final class LossFunctionWrapper<Output: Differentiable, Target> {
public typealias F = @differentiable (Output, @noDerivative Target) -> Tensor<Float>
public var f: F
init(_ f: @escaping F) { self.f = f }
}
/// Types whose elements represent a training loop.
///
/// - Note: This protocol is mainly there to give us an easy type for a generic `TrainingLoop`
/// and unless you need to rewrite your own training loop entirely, you should use `TrainingLoop`.
public protocol TrainingLoopProtocol {
// Associatedtypes
/// The type of the sequence of epochs for the training data.
associatedtype Training
where
Training: Sequence, Training.Element: Collection,
Training.Element.Element == LabeledData<Opt.Model.Input, Target>
/// The type of the collection of batches for the validation data.
associatedtype Validation
where
Validation: Collection,
Validation.Element == LabeledData<Opt.Model.Input, Target>
/// The type of the target of our model.
associatedtype Target
/// The type of the optimizer used.
associatedtype Opt: Optimizer where Opt.Model: Module
// Typealiases
/// The type of the model.
typealias Model = Opt.Model
/// The type of the input of the model.
typealias Input = Opt.Model.Input
/// The type of the output of the model.
typealias Output = Opt.Model.Output
/// The type of a batch.
typealias Batch = LabeledData<Input, Target>
// In a wrapper for now because of TF-1122.
/// The type of the loss function.
typealias LossFunction = LossFunctionWrapper<Output, Target>
// Data
/// The training epochs.
var training: Training { get }
/// The validation batches.
var validation: Validation { get }
// Optimizer and loss function
/// The optimizer.
var optimizer: Opt { get set }
/// The loss function.
var lossFunction: LossFunction { get set }
// Callbacks
/// The callbacks used to customize the training loop.
var callbacks: [TrainingLoopCallback<Self>] { get set }
// Temporary data
/// The last input fed to the model.
var lastInput: Input? { get set }
/// The last target.
var lastTarget: Target? { get set }
/// The last predictions of the model.
var lastOutput: Output? { get set }
/// The last gradients computed.
var lastGradient: Model.TangentVector? { get set }
/// The last loss.
var lastLoss: Tensor<Float>? { get set }
/// The number of epochs we are currently fitting for.
var epochCount: Int? { get set }
/// The index of the current epoch.
var epochIndex: Int? { get set }
/// The number of batches in the current collection of batches.
var batchCount: Int? { get set }
/// The index of the current batch.
var batchIndex: Int? { get set }
}
/// The events that occur during a call to `fit` in the `TrainingLoop`
///
/// - Note: The method is called `fit` and not `train` because it trains the model and validates it.
/// Each epoch is composed of a *training* phase and a *validation* phase.
public enum TrainingLoopEvent {
/// The start of a fit.
case fitStart
/// The end of a fit.
case fitEnd
/// The start of one epoch (training + validation).
case epochStart
/// The start of one epoch (training + validation).
case epochEnd
/// The start of a training phase.
case trainingStart
/// The end of a training phase.
case trainingEnd
/// The start of a validation phase.
case validationStart
/// The end of a validation phase.
case validationEnd
/// The start of a training or inference step on a batch.
case batchStart
/// The end of a training or inference step on a batch.
case batchEnd
/// At the start of the optimizer update, just after the differentiable step.
case updateStart
/// Just after the model prediction at inference, before computing the loss.
case inferencePredictionEnd
}
/// Callbacks that can inject custom behavior in a training loop.
public typealias TrainingLoopCallback<L: TrainingLoopProtocol> = (
_ loop: inout L, _ event: TrainingLoopEvent
) throws -> Void
/// A generic training loop.
///
/// - Parameter `Training`: the type of the sequence of epochs for training data.
/// - Parameter `Validation`: the type of the collection of batches for validation.
/// - Parameter `Target`: the type of the target.
/// - Parameter `Opt`: the type of the optimizer used.
public struct TrainingLoop<
Training: Sequence, Validation: Collection, Target, Opt: Optimizer
>: TrainingLoopProtocol
where
Training.Element: Collection, Training.Element.Element == LabeledData<Opt.Model.Input, Target>,
Validation.Element == LabeledData<Opt.Model.Input, Target>, Opt.Model: Module
{
// Typealiases
/// The type of the model.
public typealias Model = Opt.Model
/// The type of the input of the model.
public typealias Input = Opt.Model.Input
/// The type of the output of the model.
public typealias Output = Opt.Model.Output
/// The type of a batch.
public typealias Batch = LabeledData<Input, Target>
// In a wrapper for now because of TF-1122.
/// The type of the loss function.
public typealias LossFunction = LossFunctionWrapper<Output, Target>
// Data
/// The training epochs.
public let training: Training
/// The validation batches.
public let validation: Validation
// Optimizer and loss function
/// The optimizer.
public var optimizer: Opt
/// The loss function
public var lossFunction: LossFunction
// Callbacks
/// The callbacks used to customize the training loop.
public var callbacks: [TrainingLoopCallback<Self>] = []
// Temporary data
/// The last input fed to the model.
public var lastInput: Input? = nil
/// The last target.
public var lastTarget: Target? = nil
/// The last predictions of the model.
public var lastOutput: Output? = nil
/// The last gradients computed.
public var lastGradient: Model.TangentVector? = nil
/// The last loss.
public var lastLoss: Tensor<Float>? = nil
/// The number of epochs we are currently fitting for.
public var epochCount: Int? = nil
/// The index of the current epoch.
public var epochIndex: Int? = nil
/// The number of batches in the current collection of batches.
public var batchCount: Int? = nil
/// The index of the current batch.
public var batchIndex: Int? = nil
/// Creates an instance from `training` and `validation` data, a `model`, an `optimizer` and a
/// `lossFunction`.
///
/// Parameter callbacks: Callbacks that the `TrainingLoop` will use in every call to fit.
public init(
training: Training, validation: Validation, optimizer: Opt,
lossFunction: @escaping LossFunction.F, callbacks: [TrainingLoopCallback<Self>] = []
) {
self.training = training
self.validation = validation
self.optimizer = optimizer
self.lossFunction = LossFunction(lossFunction)
self.callbacks = callbacks
}
}
extension TrainingLoop {
/// The default differentiable step.
public mutating func differentiableStep(model: Model) throws {
guard let data = lastInput else { return }
guard let target = lastTarget else { return }
(lastLoss, lastGradient) = valueWithGradient(at: model) { (model: Model) -> Tensor<Float> in
let predictions = model(data)
lastOutput = predictions
return lossFunction.f(predictions, target)
}
}
/// The step used for inference.
public mutating func inferenceStep(model: Model) throws {
guard let data = lastInput else { return }
lastOutput = model(data)
guard let target = lastTarget else { return }
try handleEvent(.inferencePredictionEnd)
lastLoss = lossFunction.f(lastOutput!, target)
}
/// The step used for training.
public mutating func trainingStep(
model: inout Model, differentiableStep: (Model, inout Self) throws -> Void
) throws {
try differentiableStep(model, &self)
try handleEvent(.updateStart)
optimizer.update(&model, along: lastGradient!)
}
}
/// Control flow of the training loop.
///
/// - Note: Each of the "end" event is called after its corresponding "cancel" action for cleanup.
public enum TrainingLoopAction: Error {
/// Abort actions in the current training/inference step and goes to the next batch.
case cancelBatch
/// Abort actions in the current training phase and goes to the validation phase.
case cancelTraining
/// Abort actions in the current validation phase and goes to the next epoch.
case cancelValidation
/// Abort actions in the current epoch and goes to the next epoch.
case cancelEpoch
/// Abort actions in the current fit and ends fitting.
case cancelFit
}
extension TrainingLoop {
/// Call `event` on all callbacks.
mutating private func handleEvent(_ event: TrainingLoopEvent) throws {
for callback in callbacks {
try callback(&self, event)
}
}
}
extension TrainingLoop {
/// Performs `step` on each of `batches`.
mutating private func multipleSteps<Batches: Collection>(
on batches: Batches, step: (inout Self) throws -> Void
) throws where Batches.Element == Batch {
batchCount = batches.count
for (i, batch) in batches.enumerated() {
batchIndex = i
(lastInput, lastTarget) = (batch.data, batch.label)
do {
try handleEvent(.batchStart)
try step(&self)
} catch TrainingLoopAction.cancelBatch {}
try handleEvent(.batchEnd)
LazyTensorBarrier()
}
}
}
extension TrainingLoop {
/// Fit the model for `epochs` using `callbacks` to customize the default training loop.
///
/// - Parameters:
/// - inferenceStep: The step used during the validation phase of each epoch. The default value
/// uses the `inferenceStep` method of `TrainingLoop`.
/// - trainingStep: The step used during the training phase of each epoch. The default value
/// uses the `trainingStep` method of `TrainingLoop`.
public mutating func fit(
_ model: inout Model, epochs: Int, callbacks: [TrainingLoopCallback<Self>] = [],
on device: Device = Device.default,
differentiableStep: (Model, inout Self) throws -> Void = { try $1.differentiableStep(model: $0) }
) throws {
let callbacksCount = self.callbacks.count
self.callbacks += callbacks
defer { self.callbacks = Array(self.callbacks.prefix(callbacksCount)) }
epochCount = epochs
model.move(to: device)
optimizer = Opt(copying: optimizer, to: device)
do {
try handleEvent(.fitStart)
LazyTensorBarrier()
for (i, batches) in training.prefix(epochs).enumerated() {
epochIndex = i
do {
try handleEvent(.epochStart)
// Training phase
do {
Context.local.learningPhase = .training
try handleEvent(.trainingStart)
try multipleSteps(
on: batches,
step: {
try $0.trainingStep(model: &model, differentiableStep: differentiableStep)
})
} catch TrainingLoopAction.cancelTraining {}
try handleEvent(.trainingEnd)
// Validation phase
do {
Context.local.learningPhase = .inference
try handleEvent(.validationStart)
try multipleSteps(on: validation, step: { try $0.inferenceStep(model: model) })
} catch TrainingLoopAction.cancelValidation {}
try handleEvent(.validationEnd)
} catch TrainingLoopAction.cancelEpoch {}
try handleEvent(.epochEnd)
}
} catch TrainingLoopAction.cancelFit {}
try handleEvent(.fitEnd)
}
}