Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions Android.bp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ cc_library_shared {
],

include_dirs: [
"packages/modules/NeuralNetworks/common/include",
"packages/modules/NeuralNetworks/common/types/include",
"packages/modules/NeuralNetworks/runtime/include",
"frameworks/ml/nn/runtime/include/",
"frameworks/native/libs/nativewindow/include",
"external/mesa3d/include/android_stub",
"external/grpc-grpc",
Expand Down Expand Up @@ -168,9 +166,8 @@ cc_binary {
srcs: ["service.cpp"],

include_dirs: [
"packages/modules/NeuralNetworks/common/include",
"packages/modules/NeuralNetworks/common/types/include",
"packages/modules/NeuralNetworks/runtime/include",
"frameworks/ml/nn/common/include",
"frameworks/ml/nn/runtime/include/",
"frameworks/native/libs/nativewindow/include",
"external/mesa3d/include/android_stub",
],
Expand All @@ -186,7 +183,6 @@ cc_binary {

shared_libs: [
"libhidlbase",
"libhidltransport",
"libhidlmemory",
"libutils",
"liblog",
Expand Down
96 changes: 67 additions & 29 deletions BasePreparedModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,20 @@ namespace android::hardware::neuralnetworks::nnhal {
using namespace android::nn;

static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
bool mRemoteCheck = false;
std::shared_ptr<DetectionClient> mDetectionClient;
uint32_t BasePreparedModel::mFileId = 0;

void BasePreparedModel::deinitialize() {
ALOGV("Entering %s", __func__);
bool is_success = false;
mModelInfo->unmapRuntimeMemPools();
auto ret_xml = std::remove(mXmlFile.c_str());
auto ret_bin = std::remove(mBinFile.c_str());
if ((ret_xml != 0) || (ret_bin != 0)) {
ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin);
}
auto reply = mDetectionClient->release(is_success);
ALOGI("GRPC release response is %d : %s", is_success, reply.c_str());
setRemoteEnabled(false);

ALOGV("Exiting %s", __func__);
}
Expand All @@ -62,7 +64,23 @@ bool BasePreparedModel::initialize() {
ALOGE("Failed to initialize Model runtime parameters!!");
return false;
}
checkRemoteConnection();

setRemoteEnabled(checkRemoteConnection());
if (mRemoteCheck) {
for (auto i : mModelInfo->getModelInputIndexes()) {
auto& nnapiOperandType = mModelInfo->getOperand(i).type;
switch (nnapiOperandType) {
case OperandType::FLOAT32:
case OperandType::TENSOR_FLOAT32:
break;
default:
ALOGD("GRPC Remote Infer not enabled for %d", nnapiOperandType);
setRemoteEnabled(false);
break;
}
if (!mRemoteCheck) break;
}
}
mNgraphNetCreator = std::make_shared<NgraphNetworkCreator>(mModelInfo, mTargetDevice);

if (!mNgraphNetCreator->validateOperations()) return false;
Expand Down Expand Up @@ -95,23 +113,29 @@ bool BasePreparedModel::checkRemoteConnection() {
bool is_success = false;
if(getGrpcIpPort(grpc_prop)) {
ALOGV("Attempting GRPC via TCP : %s", grpc_prop);
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(INT_MAX);
args.SetMaxSendMessageSize(INT_MAX);
mDetectionClient = std::make_shared<DetectionClient>(
grpc::CreateChannel(grpc_prop, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
if(mDetectionClient) {
auto reply = mDetectionClient->prepare(is_success);
ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str());
}
}
if (!is_success && getGrpcSocketPath(grpc_prop)) {
ALOGV("Attempting GRPC via unix : %s", grpc_prop);
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(INT_MAX);
args.SetMaxSendMessageSize(INT_MAX);
mDetectionClient = std::make_shared<DetectionClient>(
grpc::CreateChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
if(mDetectionClient) {
auto reply = mDetectionClient->prepare(is_success);
ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str());
}
}
mRemoteCheck = is_success;
setRemoteEnabled(is_success);
return is_success;
}

Expand All @@ -121,14 +145,24 @@ bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::st
if(mDetectionClient) {
auto reply = mDetectionClient->sendIRs(is_success, ir_xml, ir_bin);
ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str());
if (reply == "status False") {
ALOGE("%s Model Load Failed",__func__);
}
}
else {
ALOGE("%s mDetectionClient is null",__func__);
}
mRemoteCheck = is_success;
setRemoteEnabled(is_success);
return is_success;
}

void BasePreparedModel::setRemoteEnabled(bool flag) {
if(mRemoteCheck != flag) {
ALOGD("GRPC %s Remote Connection", flag ? "ACQUIRED" : "RELEASED");
mRemoteCheck = flag;
}
}

static Return<void> notify(const sp<V1_0::IExecutionCallback>& callback, const ErrorStatus& status,
const hidl_vec<OutputShape>&, Timing) {
return callback->notify(status);
Expand Down Expand Up @@ -268,12 +302,12 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
ALOGD("%s Run", __func__);

if (measure == MeasureTiming::YES) deviceStart = now();
if(mRemoteCheck) {
if(preparedModel->mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = mDetectionClient->remote_infer();
auto reply = preparedModel->mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!mRemoteCheck || !mDetectionClient->get_status()){
if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){
try {
plugin->infer();
} catch (const std::exception& ex) {
Expand Down Expand Up @@ -332,9 +366,9 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
return;
}

if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex));
if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex), expectedLength);
} else {
switch (operandType) {
case OperandType::TENSOR_INT32:
Expand Down Expand Up @@ -399,7 +433,7 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
}

static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynchronouslyBase(
const Request& request, MeasureTiming measure, BasePreparedModel* preparedModel,
const V1_3::Request& request, MeasureTiming measure, BasePreparedModel* preparedModel,
time_point driverStart) {
ALOGV("Entering %s", __func__);
auto modelInfo = preparedModel->getModelInfo();
Expand All @@ -408,7 +442,7 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
time_point driverEnd, deviceStart, deviceEnd;
std::vector<RunTimePoolInfo> requestPoolInfos;
auto errorStatus = modelInfo->setRunTimePoolInfosFromHidlMemories(request.pools);
if (errorStatus != ErrorStatus::NONE) {
if (errorStatus != V1_3::ErrorStatus::NONE) {
ALOGE("Failed to set runtime pool info from HIDL memories");
return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
}
Expand All @@ -427,8 +461,9 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
ALOGV("Input index: %d layername : %s", inIndex, inputNodeName.c_str());
//check if remote infer is available
//TODO: Need to add FLOAT16 support for remote inferencing
if(mRemoteCheck && mDetectionClient) {
mDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len);
if(preparedModel->mRemoteCheck && preparedModel->mDetectionClient) {
auto inOperandType = modelInfo->getOperandType(inIndex);
preparedModel->mDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len, inOperandType);
} else {
ov::Tensor destTensor;
try {
Expand Down Expand Up @@ -493,12 +528,15 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
ALOGV("%s Run", __func__);

if (measure == MeasureTiming::YES) deviceStart = now();
if(mRemoteCheck) {
if(preparedModel->mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = mDetectionClient->remote_infer();
auto reply = preparedModel->mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!mRemoteCheck || !mDetectionClient->get_status()){
if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){
if(preparedModel->mRemoteCheck) {
preparedModel->setRemoteEnabled(false);
}
try {
ALOGV("%s Client Infer", __func__);
plugin->infer();
Expand Down Expand Up @@ -555,9 +593,9 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
}
//copy output from remote infer
//TODO: Add support for other OperandType
if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex));
if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex), expectedLength);
} else {
switch (operandType) {
case OperandType::TENSOR_INT32:
Expand Down Expand Up @@ -606,8 +644,8 @@ static std::tuple<ErrorStatus, hidl_vec<V1_2::OutputShape>, Timing> executeSynch
ALOGE("Failed to update the request pool infos");
return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
}
if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->clear_data();
if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
preparedModel->mDetectionClient->clear_data();
}

if (measure == MeasureTiming::YES) {
Expand All @@ -631,7 +669,7 @@ Return<void> BasePreparedModel::executeSynchronously(const Request& request, Mea
return Void();
}
auto [status, outputShapes, timing] =
executeSynchronouslyBase(request, measure, this, driverStart);
executeSynchronouslyBase(convertToV1_3(request), measure, this, driverStart);
cb(status, std::move(outputShapes), timing);
ALOGV("Exiting %s", __func__);
return Void();
Expand All @@ -646,12 +684,12 @@ Return<void> BasePreparedModel::executeSynchronously_1_3(const V1_3::Request& re
time_point driverStart;
if (measure == MeasureTiming::YES) driverStart = now();

if (!validateRequest(convertToV1_0(request), convertToV1_2(mModelInfo->getModel()))) {
if (!validateRequest(request, mModelInfo->getModel())) {
cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
return Void();
}
auto [status, outputShapes, timing] =
executeSynchronouslyBase(convertToV1_0(request), measure, this, driverStart);
executeSynchronouslyBase(request, measure, this, driverStart);
cb(convertToV1_3(status), std::move(outputShapes), timing);
ALOGV("Exiting %s", __func__);
return Void();
Expand Down Expand Up @@ -872,7 +910,7 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,

if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
mNgraphNetCreator->getOutputShape(outIndex));
mNgraphNetCreator->getOutputShape(outIndex), expectedLength);
} else {
switch (operandType) {
case OperandType::TENSOR_INT32:
Expand Down
9 changes: 5 additions & 4 deletions BasePreparedModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@ namespace android::hardware::neuralnetworks::nnhal {
template <class T>
using vec = std::vector<T>;
typedef uint8_t* memory;
extern bool mRemoteCheck;
extern std::shared_ptr<DetectionClient> mDetectionClient;
class BasePreparedModel : public V1_3::IPreparedModel {
public:
bool mRemoteCheck = false;
BasePreparedModel(const IntelDeviceType device, const Model& model) : mTargetDevice(device) {
mModelInfo = std::make_shared<NnapiModelInfo>(model);
mXmlFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".xml");
mBinFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".bin");
mXmlFile = MODEL_DIR + std::to_string(mFileId) + std::string(".xml");
mBinFile = MODEL_DIR + std::to_string(mFileId) + std::string(".bin");
mFileId++;
}

Expand Down Expand Up @@ -90,6 +89,7 @@ class BasePreparedModel : public V1_3::IPreparedModel {
virtual bool initialize();
virtual bool checkRemoteConnection();
virtual bool loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
virtual void setRemoteEnabled(bool flag);

std::shared_ptr<NnapiModelInfo> getModelInfo() { return mModelInfo; }

Expand All @@ -98,6 +98,7 @@ class BasePreparedModel : public V1_3::IPreparedModel {
std::shared_ptr<IIENetwork> getPlugin() { return mPlugin; }

std::shared_ptr<ov::Model> modelPtr;
std::shared_ptr<DetectionClient> mDetectionClient;

protected:
virtual void deinitialize();
Expand Down
Loading