1
- /* Copyright © 2017-2023 ABBYY
1
+ /* Copyright © 2017-2024 ABBYY
2
+
2
3
Licensed under the Apache License, Version 2.0 (the "License");
3
4
you may not use this file except in compliance with the License.
4
5
You may obtain a copy of the License at
@@ -26,6 +27,7 @@ class CLoraSerializer;
26
27
// Interface for setting input to a neural network
27
28
class IDistributedDataset {
28
29
public:
30
+ virtual ~IDistributedDataset () {}
29
31
// This method must set batches for all of the source layers in CDnn
30
32
// Returns the current batch size (or 0, if there is no data for this thread on this run)
31
33
// This batch size affects weights balance between different threads
@@ -57,7 +59,7 @@ class NEOML_API CDistributedTraining {
57
59
CDistributedTraining ( CArchive& archive, const CArray<int >& cudaDevs,
58
60
TDistributedInitializer initializer = TDistributedInitializer::Xavier, int seed = 42 );
59
61
60
- ~CDistributedTraining ();
62
+ virtual ~CDistributedTraining ();
61
63
62
64
// Gets the number of models in disitrbuted traning
63
65
int GetModelCount () const { return cnns.Size (); }
@@ -67,6 +69,7 @@ class NEOML_API CDistributedTraining {
67
69
void SetLearningRate ( float rate );
68
70
// Returns the current learning rate
69
71
float GetLearningRate () const ;
72
+
70
73
// Runs the networks without backward and training
71
74
void RunOnce ( IDistributedDataset& data );
72
75
// Runs the networks and performs a backward pass
@@ -75,28 +78,31 @@ class NEOML_API CDistributedTraining {
75
78
void RunAndLearnOnce ( IDistributedDataset& data );
76
79
// Updates the trainable weights of all models (after RunAndBackwardOnce)
77
80
void Train ();
81
+
78
82
// Returns last loss of `layerName` for all models
79
83
// `layerName` should correspond to CLossLayer, CCtcLossLayer or CCrfLossLayer
80
- void GetLastLoss ( const CString& layerName, CArray<float >& losses );
84
+ void GetLastLoss ( const CString& layerName, CArray<float >& losses ) const ;
81
85
// Returns last blobs of `layerName` for all models
82
86
// `layerName` should correspond to CSinkLayer
83
- void GetLastBlob ( const CString& layerName, CObjectArray<CDnnBlob>& blobs );
87
+ void GetLastBlob ( const CString& layerName, CObjectArray<CDnnBlob>& blobs ) const ;
88
+
84
89
// Save trained net
85
90
void Serialize ( CArchive& archive );
86
91
// Save the trained net with the given `index` with its solver state (optional)
87
92
// An archive with solver state can later be passed to CDnn::SerializeCheckpoint to resume training
88
93
void StoreDnn ( CArchive& archive, int index, bool storeSolver );
89
94
90
95
private:
91
- const bool isCpu;
92
- IThreadPool* threadPool;
96
+ enum class TRunType { Invalid, RunOnce, RunBackwardOnce, Train };
97
+ class CParams ;
98
+ CParams* const params = nullptr ;
99
+ IThreadPool* const threadPool = nullptr ;
93
100
CArray<IMathEngine*> mathEngines;
94
101
CArray<CRandom*> rands;
95
102
CArray<CDnn*> cnns;
96
103
CArray<int > batchSize;
97
- bool isFirstRun = true ;
98
- CString errorMessage;
99
104
105
+ void runOnce ( IDistributedDataset*, TRunType );
100
106
void initialize ( CArchive& archive, int count, TDistributedInitializer initializer, int seed );
101
107
102
108
friend class CLoraSerializer ;
0 commit comments