forked from intel/nn-hal
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathAdd.cpp
More file actions
58 lines (47 loc) · 1.84 KB
/
Add.cpp
File metadata and controls
58 lines (47 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <Add.hpp>
#undef LOG_TAG
#define LOG_TAG "Add"
namespace android {
namespace hardware {
namespace neuralnetworks {
namespace nnhal {
Add::Add(int operationIndex) : OperationsBase(operationIndex) {
mDefaultOutputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
}
bool Add::validate() {
const auto& activationIndex = sModelInfo->getOperationInput(mNnapiOperationIndex, 2);
if (!sModelInfo->isOperandLifeTimeConst(activationIndex)) {
ALOGE("%s Only Constant supported for specifying Activation", __func__);
return false;
}
ALOGV("%s PASSED", __func__);
return true;
}
std::shared_ptr<ngraph::Node> Add::createNode() {
// Creating input nodes
std::shared_ptr<ngraph::Node> input1, input2;
input1 = getInputNode(0);
input2 = getInputNode(1);
auto activationFn = sModelInfo->ParseOperationInput<uint32_t>(mNnapiOperationIndex, 2);
auto addNode =
std::make_shared<ngraph::opset3::Add>(input1, input2, ngraph::op::AutoBroadcastType::NUMPY);
auto outputNode = applyActivation(addNode, activationFn);
return outputNode;
}
std::shared_ptr<ngraph::Node> Add::createNodeForPlugin() {
if (sPluginType == IntelDeviceType::VPU) {
auto input = mNgraphNodes->getOperationOutput(
sModelInfo->getOperationInput(mNnapiOperationIndex, 0));
std::shared_ptr<ngraph::Node> constantOp =
std::make_shared<ngraph::opset3::Constant>(ngraph::element::f32, input.get_shape());
auto transposedOp = transpose(NHWC_NCHW, constantOp);
return std::make_shared<ngraph::opset3::Add>(input, transposedOp,
ngraph::op::AutoBroadcastType::NUMPY);
} else {
return createNode();
}
}
} // namespace nnhal
} // namespace neuralnetworks
} // namespace hardware
} // namespace android