Skip to content

Commit f3a6357

Browse files
committed
[NeoML] DnnDistributed -- remove code copy-paste
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 5f6d2c9 commit f3a6357

File tree

4 files changed

+184
-194
lines changed

4 files changed

+184
-194
lines changed

NeoML/Python/src/PyDnnDistributed.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
/* Copyright © 2017-2023 ABBYY
1+
/* Copyright © 2017-2024 ABBYY
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at

NeoML/Python/src/PyDnnDistributed.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
/* Copyright © 2017-2023 ABBYY
1+
/* Copyright © 2017-2024 ABBYY
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at
@@ -20,7 +21,7 @@ limitations under the License.
2021

2122
class CPyDistributedDataset : public IDistributedDataset {
2223
public:
23-
CPyDistributedDataset( const py::object& data ) : getData( data ) {};
24+
CPyDistributedDataset( const py::object& data ) : getData( data ) {}
2425
int SetInputBatch( CDnn& dnn, int thread ) override;
2526
private:
2627
py::object getData;
@@ -29,13 +30,14 @@ class CPyDistributedDataset : public IDistributedDataset {
2930
class CPyDistributedTraining : public CDistributedTraining {
3031
public:
3132
CPyDistributedTraining( CDnn& dnn, int count, TDistributedInitializer initializer, int seed )
32-
: CDistributedTraining( dnn, count, initializer, seed ) {};
33+
: CDistributedTraining( dnn, count, initializer, seed ) {}
3334
CPyDistributedTraining( CArchive& archive, int count, TDistributedInitializer initializer, int seed )
34-
: CDistributedTraining( archive, count, initializer, seed ) {};
35+
: CDistributedTraining( archive, count, initializer, seed ) {}
3536
CPyDistributedTraining( CDnn& dnn, const CArray<int>& cudaDevs, TDistributedInitializer initializer, int seed )
36-
: CDistributedTraining( dnn, cudaDevs, initializer, seed ) {};
37+
: CDistributedTraining( dnn, cudaDevs, initializer, seed ) {}
3738
CPyDistributedTraining( CArchive& archive, const CArray<int>& cudaDevs, TDistributedInitializer initializer, int seed )
38-
: CDistributedTraining( archive, cudaDevs, initializer, seed ) {};
39+
: CDistributedTraining( archive, cudaDevs, initializer, seed ) {}
40+
3941
void Run( const py::object& data );
4042
void RunAndBackward( const py::object& data );
4143
void Learn( const py::object& data );
@@ -46,4 +48,4 @@ class CPyDistributedTraining : public CDistributedTraining {
4648
void Save( const std::string& path );
4749
};
4850

49-
void InitializeDistributedTraining(py::module& m);
51+
void InitializeDistributedTraining( py::module& m );

NeoML/include/NeoML/Dnn/DnnDistributed.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
/* Copyright © 2017-2023 ABBYY
1+
/* Copyright © 2017-2024 ABBYY
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at
@@ -26,6 +27,7 @@ class CLoraSerializer;
2627
// Interface for setting input to a neural network
2728
class IDistributedDataset {
2829
public:
30+
virtual ~IDistributedDataset() {}
2931
// This method must set batches for all of the source layers in CDnn
3032
// Returns the current batch size (or 0, if there is no data for this thread on this run)
3133
// This batch size affects weights balance between different threads
@@ -57,7 +59,7 @@ class NEOML_API CDistributedTraining {
5759
CDistributedTraining( CArchive& archive, const CArray<int>& cudaDevs,
5860
TDistributedInitializer initializer = TDistributedInitializer::Xavier, int seed = 42 );
5961

60-
~CDistributedTraining();
62+
virtual ~CDistributedTraining();
6163

6264
// Gets the number of models in disitrbuted traning
6365
int GetModelCount() const { return cnns.Size(); }
@@ -67,6 +69,7 @@ class NEOML_API CDistributedTraining {
6769
void SetLearningRate( float rate );
6870
// Returns the current learning rate
6971
float GetLearningRate() const;
72+
7073
// Runs the networks without backward and training
7174
void RunOnce( IDistributedDataset& data );
7275
// Runs the networks and performs a backward pass
@@ -75,28 +78,31 @@ class NEOML_API CDistributedTraining {
7578
void RunAndLearnOnce( IDistributedDataset& data );
7679
// Updates the trainable weights of all models (after RunAndBackwardOnce)
7780
void Train();
81+
7882
// Returns last loss of `layerName` for all models
7983
// `layerName` should correspond to CLossLayer, CCtcLossLayer or CCrfLossLayer
80-
void GetLastLoss( const CString& layerName, CArray<float>& losses );
84+
void GetLastLoss( const CString& layerName, CArray<float>& losses ) const;
8185
// Returns last blobs of `layerName` for all models
8286
// `layerName` should correspond to CSinkLayer
83-
void GetLastBlob( const CString& layerName, CObjectArray<CDnnBlob>& blobs );
87+
void GetLastBlob( const CString& layerName, CObjectArray<CDnnBlob>& blobs ) const;
88+
8489
// Save trained net
8590
void Serialize( CArchive& archive );
8691
// Save the trained net with the given `index` with its solver state (optional)
8792
// An archive with solver state can later be passed to CDnn::SerializeCheckpoint to resume training
8893
void StoreDnn( CArchive& archive, int index, bool storeSolver );
8994

9095
private:
91-
const bool isCpu;
92-
IThreadPool* threadPool;
96+
enum class TRunType { Invalid, RunOnce, RunBackwardOnce, Train };
97+
class CParams;
98+
CParams* const params = nullptr;
99+
IThreadPool* const threadPool = nullptr;
93100
CArray<IMathEngine*> mathEngines;
94101
CArray<CRandom*> rands;
95102
CArray<CDnn*> cnns;
96103
CArray<int> batchSize;
97-
bool isFirstRun = true;
98-
CString errorMessage;
99104

105+
void runOnce( IDistributedDataset*, TRunType );
100106
void initialize( CArchive& archive, int count, TDistributedInitializer initializer, int seed );
101107

102108
friend class CLoraSerializer;

0 commit comments

Comments
 (0)