diff --git a/NeoML/Python/neoml/Dnn/Split.py b/NeoML/Python/neoml/Dnn/Split.py index b8a06acbf..de9004600 100644 --- a/NeoML/Python/neoml/Dnn/Split.py +++ b/NeoML/Python/neoml/Dnn/Split.py @@ -20,7 +20,48 @@ import numpy -class SplitChannels(Layer): +class SplitLayer(Layer): + """The base (abstract) class for a split layer. + """ + def __init__(self, classname, input_layer, sizes, name): + assert hasattr(PythonWrapper, classname), 'Incorrect split layer specified: ' + classname + + if type(input_layer) is getattr(PythonWrapper, classname): + super().__init__(input_layer) + return + + layers, outputs = check_input_layers(input_layer, 1) + + internal = getattr(PythonWrapper, classname)(str(name), layers[0], int(outputs[0]), self.__sizes_to_array(sizes)) + super().__init__(internal) + + @property + def output_sizes(self): + """ + """ + return self._internal.get_output_counts() + + @output_sizes.setter + def output_sizes(self, value): + """ + """ + self._internal.set_output_counts(self.__sizes_to_array(value)) + + @staticmethod + def __sizes_to_array(sizes) -> numpy.ndarray: + sizes = numpy.array(sizes, dtype=numpy.int32) + if sizes.ndim != 1 or sizes.size > 3: + raise ValueError('The `sizes` must be a one-dimentional sequence containing not more than 3 elements.') + + if numpy.any(sizes < 0): + raise ValueError('The `sizes` must contain only positive values.') + + return sizes + +# ---------------------------------------------------------------------------------------------------------------------- + + +class SplitChannels(SplitLayer): """The layer that splits an input blob along the Channels dimension. :param input_layer: The input layer and the number of its output. If no number @@ -49,27 +90,12 @@ class SplitChannels(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitChannels: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitChannels(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitChannels", input_layer, sizes, name) # ---------------------------------------------------------------------------------------------------------------------- -class SplitDepth(Layer): +class SplitDepth(SplitLayer): """The layer that splits an input blob along the Depth dimension. :param input_layer: The input layer and the number of its output. If no number @@ -98,27 +124,12 @@ class SplitDepth(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitDepth: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitDepth(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitDepth", input_layer, sizes, name) # ---------------------------------------------------------------------------------------------------------------------- -class SplitWidth(Layer): +class SplitWidth(SplitLayer): """The layer that splits an input blob along the Width dimension. :param input_layer: The input layer and the number of its output. If no number @@ -147,27 +158,12 @@ class SplitWidth(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitWidth: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitWidth(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitWidth", input_layer, sizes, name) # ---------------------------------------------------------------------------------------------------------------------- -class SplitHeight(Layer): +class SplitHeight(SplitLayer): """The layer that splits an input blob along the Height dimension. :param input_layer: The input layer and the number of its output. If no number @@ -196,27 +192,12 @@ class SplitHeight(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitHeight: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitHeight(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitHeight", input_layer, sizes, name) # ---------------------------------------------------------------------------------------------------------------------- -class SplitListSize(Layer): +class SplitListSize(SplitLayer): """The layer that splits an input blob along the ListSize dimension. :param input_layer: The input layer and the number of its output. If no number @@ -245,27 +226,12 @@ class SplitListSize(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitListSize: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitListSize(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitListSize", input_layer, sizes, name) # ---------------------------------------------------------------------------------------------------------------------- -class SplitBatchWidth(Layer): +class SplitBatchWidth(SplitLayer): """The layer that splits an input blob along the BatchWidth dimension. :param input_layer: The input layer and the number of its output. If no number @@ -294,27 +260,12 @@ class SplitBatchWidth(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitBatchWidth: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitBatchWidth(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitBatchWidth", input_layer, sizes, name) # ---------------------------------------------------------------------------------------------------------------------- -class SplitBatchLength(Layer): +class SplitBatchLength(SplitLayer): """The layer that splits an input blob along the BatchLength dimension. :param input_layer: The input layer and the number of its output. If no number @@ -343,19 +294,4 @@ class SplitBatchLength(Layer): - all other dimensions are the same as for the input """ def __init__(self, input_layer, sizes, name=None): - if type(input_layer) is PythonWrapper.SplitBatchLength: - super().__init__(input_layer) - return - - layers, outputs = check_input_layers(input_layer, 1) - - s = numpy.array(sizes, dtype=numpy.int32, copy=False) - - if s.size > 3: - raise ValueError('The `sizes` must contain not more than 3 elements.') - - if numpy.any(s < 0): - raise ValueError('The `sizes` must contain only positive values.') - - internal = PythonWrapper.SplitBatchLength(str(name), layers[0], int(outputs[0]), s) - super().__init__(internal) + super().__init__("SplitBatchLength", input_layer, sizes, name) diff --git a/NeoML/Python/neoml/Utils.py b/NeoML/Python/neoml/Utils.py index 57bf48f94..99154598c 100644 --- a/NeoML/Python/neoml/Utils.py +++ b/NeoML/Python/neoml/Utils.py @@ -62,7 +62,7 @@ def check_input_layers(input_layers, layer_count): layers.append(i._internal) outputs.append(0) elif isinstance(i, (list, tuple)) and len(i) == 2 and isinstance(i[0], Layer) and isinstance(i[1], int): - if int(i[1]) < 0 or int(i[1]) >= i[0].output_count(): + if int(i[1]) < 0: raise ValueError('Invalid value `input_layers`.' ' It must be a list of layers or a list of (layer, output).') layers.append(i[0]._internal) diff --git a/NeoML/Python/src/PySplitLayer.cpp b/NeoML/Python/src/PySplitLayer.cpp index de2d4475c..0b659ac3c 100644 --- a/NeoML/Python/src/PySplitLayer.cpp +++ b/NeoML/Python/src/PySplitLayer.cpp @@ -18,11 +18,36 @@ limitations under the License. #include "PySplitLayer.h" -class CPySplitChannelsLayer : public CPyLayer { +py::array CPyBaseSplitLayer::GetOutputCounts() const { + const auto& fineCounts = Layer()->GetOutputCounts(); + + py::array_t counts( fineCounts.Size() ); + auto countsData = counts.mutable_unchecked<>(); + + for( int i = 0; i < fineCounts.Size(); ++i ) { + countsData( i ) = fineCounts[i]; + } + + return counts; +} + +void CPyBaseSplitLayer::SetOutputCounts( py::array counts ) { + NeoAssert( counts.ndim() == 1 ); + NeoAssert( counts.dtype().is( py::dtype::of() ) ); + + CArray fineCounts; + fineCounts.SetSize( static_cast(counts.size()) ); + for( int i = 0; i < fineCounts.Size(); i++ ) { + fineCounts[i] = static_cast(counts.data())[i]; + } + Layer()->SetOutputCounts( fineCounts ); +} + +class CPySplitChannelsLayer : public CPyBaseSplitLayer { public: - explicit CPySplitChannelsLayer( CSplitChannelsLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitChannelsLayer( CSplitChannelsLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitChannels" ); @@ -30,11 +55,11 @@ class CPySplitChannelsLayer : public CPyLayer { } }; -class CPySplitDepthLayer : public CPyLayer { +class CPySplitDepthLayer : public CPyBaseSplitLayer { public: - explicit CPySplitDepthLayer( CSplitDepthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitDepthLayer( CSplitDepthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitDepth" ); @@ -42,11 +67,11 @@ class CPySplitDepthLayer : public CPyLayer { } }; -class CPySplitWidthLayer : public CPyLayer { +class CPySplitWidthLayer : public CPyBaseSplitLayer { public: - explicit CPySplitWidthLayer( CSplitWidthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitWidthLayer( CSplitWidthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitWidth" ); @@ -54,11 +79,11 @@ class CPySplitWidthLayer : public CPyLayer { } }; -class CPySplitHeightLayer : public CPyLayer { +class CPySplitHeightLayer : public CPyBaseSplitLayer { public: - explicit CPySplitHeightLayer( CSplitHeightLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitHeightLayer( CSplitHeightLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitHeight" ); @@ -66,11 +91,11 @@ class CPySplitHeightLayer : public CPyLayer { } }; -class CPySplitListSizeLayer : public CPyLayer { +class CPySplitListSizeLayer : public CPyBaseSplitLayer { public: - explicit CPySplitListSizeLayer( CSplitListSizeLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitListSizeLayer( CSplitListSizeLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitListSize" ); @@ -78,11 +103,11 @@ class CPySplitListSizeLayer : public CPyLayer { } }; -class CPySplitBatchWidthLayer : public CPyLayer { +class CPySplitBatchWidthLayer : public CPyBaseSplitLayer { public: - explicit CPySplitBatchWidthLayer( CSplitBatchWidthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitBatchWidthLayer( CSplitBatchWidthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitBatchWidth" ); @@ -90,11 +115,11 @@ class CPySplitBatchWidthLayer : public CPyLayer { } }; -class CPySplitBatchLengthLayer : public CPyLayer { +class CPySplitBatchLengthLayer : public CPyBaseSplitLayer { public: - explicit CPySplitBatchLengthLayer( CSplitBatchLengthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + explicit CPySplitBatchLengthLayer( CSplitBatchLengthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyBaseSplitLayer( layer, mathEngineOwner ) {} - py::object CreatePythonObject() const + py::object CreatePythonObject() const override { py::object pyModule = py::module::import( "neoml.Dnn" ); py::object pyConstructor = pyModule.attr( "SplitBatchLength" ); @@ -102,167 +127,97 @@ class CPySplitBatchLengthLayer : public CPyLayer { } }; +template +PythonLayer* InitSplit( const std::string& className, const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) +{ + static_assert( std::is_base_of::value, "PySplitLayer.cpp, InitSplit" ); + static_assert( std::is_constructible::value, "PySplitLayer.cpp, InitSplit" ); + + py::gil_scoped_release release; + CDnn& dnn = layer1.Dnn(); + CPtr split = new Layer( dnn.GetMathEngine() ); + split->SetName( FindFreeLayerName( dnn, className, name ).c_str() ); + dnn.AddLayer( *split ); + split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); + std::unique_ptr layer( new PythonLayer( *split, layer1.MathEngineOwner() ) ); + layer->SetOutputCounts( sizes ); + return layer.release(); +} + void InitializeSplitLayer( py::module& m ) { - py::class_(m, "SplitChannels") + py::class_(m, "BaseSplit") + .def( "get_output_counts", &CPyBaseSplitLayer::GetOutputCounts, py::return_value_policy::move ) + .def( "set_output_counts", &CPyBaseSplitLayer::SetOutputCounts, py::return_value_policy::reference ) + ; + + py::class_(m, "SplitChannels") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitChannelsLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitChannelsLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitChannels", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitChannelsLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitChannels", name, layer1, outputNumber1, sizes ); }) ) ; - py::class_(m, "SplitDepth") + py::class_(m, "SplitDepth") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitDepthLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitDepthLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitDepth", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitDepthLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitDepth", name, layer1, outputNumber1, sizes ); }) ) ; - py::class_(m, "SplitWidth") + py::class_(m, "SplitWidth") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitWidthLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitWidthLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitWidth", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitWidthLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitWidth", name, layer1, outputNumber1, sizes ); }) ) ; - py::class_(m, "SplitHeight") + py::class_(m, "SplitHeight") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitHeightLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitHeightLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitHeight", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitHeightLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitHeight", name, layer1, outputNumber1, sizes ); }) ) ; - py::class_(m, "SplitListSize") + py::class_(m, "SplitListSize") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitListSizeLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitListSizeLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitListSize", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitListSizeLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitListSize", name, layer1, outputNumber1, sizes ); }) ) ; - py::class_(m, "SplitBatchWidth") + py::class_(m, "SplitBatchWidth") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitBatchWidthLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitBatchWidthLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitBatchWidth", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitBatchWidthLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitBatchWidth", name, layer1, outputNumber1, sizes ); }) ) ; - py::class_(m, "SplitBatchLength") + py::class_(m, "SplitBatchLength") .def( py::init([]( const CPyLayer& layer ) { return new CPySplitBatchLengthLayer( *layer.Layer(), layer.MathEngineOwner() ); })) .def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) { - py::gil_scoped_release release; - CDnn& dnn = layer1.Dnn(); - IMathEngine& mathEngine = dnn.GetMathEngine(); - CPtr split = new CSplitBatchLengthLayer( mathEngine ); - CArray outputs; - outputs.SetSize(static_cast(sizes.size())); - for( int i = 0; i < outputs.Size(); i++ ) { - outputs[i] = reinterpret_cast(sizes.data())[i]; - } - split->SetOutputCounts(outputs); - split->SetName( FindFreeLayerName( dnn, "SplitBatchLength", name ).c_str() ); - dnn.AddLayer( *split ); - split->Connect( 0, layer1.BaseLayer(), outputNumber1 ); - return new CPySplitBatchLengthLayer( *split, layer1.MathEngineOwner() ); + return InitSplit( "SplitBatchLength", name, layer1, outputNumber1, sizes ); }) ) ; - } diff --git a/NeoML/Python/src/PySplitLayer.h b/NeoML/Python/src/PySplitLayer.h index 3253d02a9..3195d497e 100644 --- a/NeoML/Python/src/PySplitLayer.h +++ b/NeoML/Python/src/PySplitLayer.h @@ -17,4 +17,13 @@ limitations under the License. #include "PyLayer.h" -void InitializeSplitLayer( py::module& m ); \ No newline at end of file +class CPyBaseSplitLayer : public CPyLayer { +public: + CPyBaseSplitLayer( CBaseSplitLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {} + + py::array GetOutputCounts() const; + + void SetOutputCounts( py::array counts ); +}; + +void InitializeSplitLayer( py::module& m );