|
| 1 | +/* Copyright © 2017-2021 ABBYY Production LLC |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +--------------------------------------------------------------------------------------------------------------*/ |
| 15 | + |
| 16 | +#include <common.h> |
| 17 | +#pragma hdrstop |
| 18 | + |
| 19 | +#include "PyDataLayer.h" |
| 20 | +#include "PyDnnBlob.h" |
| 21 | +#include "PyDnn.h" |
| 22 | +#include "PyMathEngine.h" |
| 23 | + |
| 24 | +class CPyDataLayer : public CPyLayer { |
| 25 | +public: |
| 26 | + CPyDataLayer( CDataLayer* layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( *layer, mathEngineOwner ) {} |
| 27 | + |
| 28 | + void SetBlob( const CPyBlob& blob ) |
| 29 | + { |
| 30 | + Layer<CDataLayer>()->SetBlob( blob.Blob() ); |
| 31 | + } |
| 32 | + |
| 33 | + CPyBlob GetBlob() const |
| 34 | + { |
| 35 | + return CPyBlob( MathEngineOwner(), Layer<CDataLayer>()->GetBlob() ); |
| 36 | + } |
| 37 | + |
| 38 | + py::object CreatePythonObject() const |
| 39 | + { |
| 40 | + py::object pyModule = py::module::import( "neoml.Dnn" ); |
| 41 | + py::object pyConstructor = pyModule.attr( "Data" ); |
| 42 | + return pyConstructor( py::cast(this) ); |
| 43 | + } |
| 44 | +}; |
| 45 | + |
| 46 | +void InitializeDataLayer( py::module& m ) |
| 47 | +{ |
| 48 | + py::class_<CPyDataLayer, CPyLayer>(m, "Data") |
| 49 | + .def( py::init([]( const CPyLayer& layer ) |
| 50 | + { |
| 51 | + return new CPyDataLayer( layer.Layer<CDataLayer>(), layer.MathEngineOwner() ); |
| 52 | + })) |
| 53 | + .def( py::init([]( const CPyDnn& dnn, const std::string& name ) |
| 54 | + { |
| 55 | + py::gil_scoped_release release; |
| 56 | + CPtr<CDataLayer> dataLayer( new CDataLayer( dnn.MathEngine() ) ); |
| 57 | + dataLayer->SetName( FindFreeLayerName( dnn.Dnn(), "Data", name ).c_str() ); |
| 58 | + dnn.Dnn().AddLayer( *dataLayer ); |
| 59 | + |
| 60 | + return new CPyDataLayer( dataLayer, dnn.MathEngineOwner() ); |
| 61 | + })) |
| 62 | + .def( "set_blob", &CPyDataLayer::SetBlob, py::return_value_policy::reference ) |
| 63 | + .def( "get_blob", &CPyDataLayer::GetBlob, py::return_value_policy::reference ) |
| 64 | + ; |
| 65 | +} |
0 commit comments