From c659f9e882f3be7c15f9b22be5eeeb22cd635c8a Mon Sep 17 00:00:00 2001 From: Ratnesh Kumar Rai Date: Tue, 28 Feb 2023 08:40:00 +0530 Subject: [PATCH 1/9] Compatibility changes to build nn-hal for Android R Tracked-On: OAM-105850 Signed-off-by: Anoob Anto K Signed-off-by: Ratnesh Kumar Rai Signed-off-by: Jaikrishna, Nemallapudi --- Android.bp | 10 +++------- ngraph_creator/Android.bp | 3 --- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/Android.bp b/Android.bp index c4d807098..35e1f6a4f 100644 --- a/Android.bp +++ b/Android.bp @@ -21,9 +21,7 @@ cc_library_shared { ], include_dirs: [ - "packages/modules/NeuralNetworks/common/include", - "packages/modules/NeuralNetworks/common/types/include", - "packages/modules/NeuralNetworks/runtime/include", + "frameworks/ml/nn/runtime/include/", "frameworks/native/libs/nativewindow/include", "external/mesa3d/include/android_stub", "external/grpc-grpc", @@ -168,9 +166,8 @@ cc_binary { srcs: ["service.cpp"], include_dirs: [ - "packages/modules/NeuralNetworks/common/include", - "packages/modules/NeuralNetworks/common/types/include", - "packages/modules/NeuralNetworks/runtime/include", + "frameworks/ml/nn/common/include", + "frameworks/ml/nn/runtime/include/", "frameworks/native/libs/nativewindow/include", "external/mesa3d/include/android_stub", ], @@ -186,7 +183,6 @@ cc_binary { shared_libs: [ "libhidlbase", - "libhidltransport", "libhidlmemory", "libutils", "liblog", diff --git a/ngraph_creator/Android.bp b/ngraph_creator/Android.bp index 5a4175aa6..7f4f08397 100755 --- a/ngraph_creator/Android.bp +++ b/ngraph_creator/Android.bp @@ -99,9 +99,6 @@ cc_library_static { ], include_dirs: [ - "packages/modules/NeuralNetworks/common/include", - "packages/modules/NeuralNetworks/common/types/include", - "packages/modules/NeuralNetworks/runtime/include", "external/mesa3d/include/android_stub", ], From a13d58fc302e5edbc50ced8d2753d53656107d34 Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Fri, 12 May 2023 08:18:20 +0000 Subject: [PATCH 2/9] Disable parallel attempts for remote inference Signed-off-by: Anoob Anto K --- BasePreparedModel.cpp | 71 ++++++++++++++++++++++++------------------- BasePreparedModel.h | 3 +- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index 31c04b49c..425bfcc01 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -33,8 +33,8 @@ namespace android::hardware::neuralnetworks::nnhal { using namespace android::nn; static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX}; -bool mRemoteCheck = false; -std::shared_ptr mDetectionClient; +bool gRemoteCheck = false; +std::shared_ptr gDetectionClient; uint32_t BasePreparedModel::mFileId = 0; void BasePreparedModel::deinitialize() { @@ -45,6 +45,10 @@ void BasePreparedModel::deinitialize() { if ((ret_xml != 0) || (ret_bin != 0)) { ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin); } + if(mRemoteCheck) { + ALOGD("GRPC RELEASED Remote Connection"); + gRemoteCheck = false; + } ALOGV("Exiting %s", __func__); } @@ -62,7 +66,7 @@ bool BasePreparedModel::initialize() { ALOGE("Failed to initialize Model runtime parameters!!"); return false; } - checkRemoteConnection(); + mRemoteCheck = checkRemoteConnection(); mNgraphNetCreator = std::make_shared(mModelInfo, mTargetDevice); if (!mNgraphNetCreator->validateOperations()) return false; @@ -91,39 +95,44 @@ bool BasePreparedModel::initialize() { } bool BasePreparedModel::checkRemoteConnection() { + if(gRemoteCheck) { + ALOGD("GRPC Remote Connection already under use"); + return false; + } char grpc_prop[PROPERTY_VALUE_MAX] = ""; bool is_success = false; if(getGrpcIpPort(grpc_prop)) { ALOGV("Attempting GRPC via TCP : %s", grpc_prop); - mDetectionClient = std::make_shared( + gDetectionClient = std::make_shared( grpc::CreateChannel(grpc_prop, grpc::InsecureChannelCredentials())); - if(mDetectionClient) { - auto reply = mDetectionClient->prepare(is_success); + if(gDetectionClient) { + auto reply = gDetectionClient->prepare(is_success); ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str()); } } if (!is_success && getGrpcSocketPath(grpc_prop)) { ALOGV("Attempting GRPC via unix : %s", grpc_prop); - mDetectionClient = std::make_shared( + gDetectionClient = std::make_shared( grpc::CreateChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials())); - if(mDetectionClient) { - auto reply = mDetectionClient->prepare(is_success); + if(gDetectionClient) { + auto reply = gDetectionClient->prepare(is_success); ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str()); } } - mRemoteCheck = is_success; + gRemoteCheck = is_success; + ALOGD("GRPC ACQUIRED Remote Connection"); return is_success; } bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) { ALOGI("Entering %s", __func__); bool is_success = false; - if(mDetectionClient) { - auto reply = mDetectionClient->sendIRs(is_success, ir_xml, ir_bin); + if(gDetectionClient) { + auto reply = gDetectionClient->sendIRs(is_success, ir_xml, ir_bin); ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str()); } else { - ALOGE("%s mDetectionClient is null",__func__); + ALOGE("%s gDetectionClient is null",__func__); } mRemoteCheck = is_success; return is_success; @@ -268,12 +277,12 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod ALOGD("%s Run", __func__); if (measure == MeasureTiming::YES) deviceStart = now(); - if(mRemoteCheck) { + if(preparedModel->mRemoteCheck) { ALOGI("%s GRPC Remote Infer", __func__); - auto reply = mDetectionClient->remote_infer(); + auto reply = gDetectionClient->remote_infer(); ALOGI("***********GRPC server response************* %s", reply.c_str()); } - if (!mRemoteCheck || !mDetectionClient->get_status()){ + if (!preparedModel->mRemoteCheck || !gDetectionClient->get_status()){ try { plugin->infer(); } catch (const std::exception& ex) { @@ -332,8 +341,8 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod return; } - if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) { - mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, + if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { + gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, ngraphNw->getOutputShape(outIndex)); } else { switch (operandType) { @@ -427,8 +436,8 @@ static std::tuple, Timing> executeSynch ALOGV("Input index: %d layername : %s", inIndex, inputNodeName.c_str()); //check if remote infer is available //TODO: Need to add FLOAT16 support for remote inferencing - if(mRemoteCheck && mDetectionClient) { - mDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len); + if(preparedModel->mRemoteCheck && gDetectionClient) { + gDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len); } else { ov::Tensor destTensor; try { @@ -493,12 +502,12 @@ static std::tuple, Timing> executeSynch ALOGV("%s Run", __func__); if (measure == MeasureTiming::YES) deviceStart = now(); - if(mRemoteCheck) { + if(preparedModel->mRemoteCheck) { ALOGI("%s GRPC Remote Infer", __func__); - auto reply = mDetectionClient->remote_infer(); + auto reply = gDetectionClient->remote_infer(); ALOGI("***********GRPC server response************* %s", reply.c_str()); } - if (!mRemoteCheck || !mDetectionClient->get_status()){ + if (!preparedModel->mRemoteCheck || !gDetectionClient->get_status()){ try { ALOGV("%s Client Infer", __func__); plugin->infer(); @@ -555,8 +564,8 @@ static std::tuple, Timing> executeSynch } //copy output from remote infer //TODO: Add support for other OperandType - if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) { - mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, + if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { + gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, ngraphNw->getOutputShape(outIndex)); } else { switch (operandType) { @@ -606,8 +615,8 @@ static std::tuple, Timing> executeSynch ALOGE("Failed to update the request pool infos"); return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; } - if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) { - mDetectionClient->clear_data(); + if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { + gDetectionClient->clear_data(); } if (measure == MeasureTiming::YES) { @@ -822,10 +831,10 @@ Return BasePreparedModel::executeFenced(const V1_3::Request& request1_3, if (measure == MeasureTiming::YES) deviceStart = now(); if(mRemoteCheck) { ALOGI("%s GRPC Remote Infer", __func__); - auto reply = mDetectionClient->remote_infer(); + auto reply = gDetectionClient->remote_infer(); ALOGI("***********GRPC server response************* %s", reply.c_str()); } - if (!mRemoteCheck || !mDetectionClient->get_status()){ + if (!mRemoteCheck || !gDetectionClient->get_status()){ try { mPlugin->infer(); } catch (const std::exception& ex) { @@ -870,8 +879,8 @@ Return BasePreparedModel::executeFenced(const V1_3::Request& request1_3, mModelInfo->updateOutputshapes(i, outDims); } - if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) { - mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, + if (mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { + gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, mNgraphNetCreator->getOutputShape(outIndex)); } else { switch (operandType) { diff --git a/BasePreparedModel.h b/BasePreparedModel.h index 9fbdd1abd..cc8a884ed 100755 --- a/BasePreparedModel.h +++ b/BasePreparedModel.h @@ -49,10 +49,9 @@ namespace android::hardware::neuralnetworks::nnhal { template using vec = std::vector; typedef uint8_t* memory; -extern bool mRemoteCheck; -extern std::shared_ptr mDetectionClient; class BasePreparedModel : public V1_3::IPreparedModel { public: + bool mRemoteCheck = false; BasePreparedModel(const IntelDeviceType device, const Model& model) : mTargetDevice(device) { mModelInfo = std::make_shared(model); mXmlFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".xml"); From bb639c90dc17f8e24de5dba2f0990087b0981763 Mon Sep 17 00:00:00 2001 From: Ratnesh Kumar Rai Date: Fri, 12 May 2023 17:28:02 +0530 Subject: [PATCH 3/9] Added load model and data_type for remote infer loadmodel rpc call added after sending IR files Included data_type parameter for input data Tracked-On: OAM-109729 Signed-off-by: Ratnesh Kumar Rai --- BasePreparedModel.cpp | 12 +++-- DetectionClient.cpp | 77 ++++++++++++++++++++++++++++-- DetectionClient.h | 6 ++- proto/nnhal_object_detection.proto | 20 ++++++++ 4 files changed, 105 insertions(+), 10 deletions(-) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index 425bfcc01..3cf11023b 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -130,6 +130,9 @@ bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::st if(gDetectionClient) { auto reply = gDetectionClient->sendIRs(is_success, ir_xml, ir_bin); ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str()); + if (reply == "status False") { + ALOGE("%s Model Load Failed",__func__); + } } else { ALOGE("%s gDetectionClient is null",__func__); @@ -343,7 +346,7 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, - ngraphNw->getOutputShape(outIndex)); + ngraphNw->getOutputShape(outIndex), expectedLength); } else { switch (operandType) { case OperandType::TENSOR_INT32: @@ -437,7 +440,8 @@ static std::tuple, Timing> executeSynch //check if remote infer is available //TODO: Need to add FLOAT16 support for remote inferencing if(preparedModel->mRemoteCheck && gDetectionClient) { - gDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len); + auto inOperandType = modelInfo->getOperandType(inIndex); + gDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len, inOperandType); } else { ov::Tensor destTensor; try { @@ -566,7 +570,7 @@ static std::tuple, Timing> executeSynch //TODO: Add support for other OperandType if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, - ngraphNw->getOutputShape(outIndex)); + ngraphNw->getOutputShape(outIndex), expectedLength); } else { switch (operandType) { case OperandType::TENSOR_INT32: @@ -881,7 +885,7 @@ Return BasePreparedModel::executeFenced(const V1_3::Request& request1_3, if (mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, - mNgraphNetCreator->getOutputShape(outIndex)); + mNgraphNetCreator->getOutputShape(outIndex), expectedLength); } else { switch (operandType) { case OperandType::TENSOR_INT32: diff --git a/DetectionClient.cpp b/DetectionClient.cpp index 4d0716180..bcf0fc6d2 100644 --- a/DetectionClient.cpp +++ b/DetectionClient.cpp @@ -47,6 +47,22 @@ Status DetectionClient::sendFile(std::string fileName, return writer->Finish(); } +bool DetectionClient::isModelLoaded(std::string fileName) { + ReplyStatus reply; + ClientContext context; + RequestString request; + request.set_value(fileName); + time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(20000); + context.set_deadline(deadline); + status = stub_->loadModel(&context, request, &reply); + if(status.ok()) { + return reply.status(); + } else { + ALOGE("Model Load failure: %s", status.error_message().c_str()); + } + return false; +} + std::string DetectionClient::sendIRs(bool& flag, const std::string& ir_xml, const std::string& ir_bin) { ReplyStatus reply; ClientContext context; @@ -62,25 +78,75 @@ std::string DetectionClient::sendIRs(bool& flag, const std::string& ir_xml, cons status = sendFile(ir_bin, writerBin); if (status.ok()) { flag = reply.status(); - return (flag ? "status True" : "status False"); + //if model is sent succesfully trigger model loading + if (flag && isModelLoaded(ir_xml) ) { + flag = true; + return ("status True"); + } else { + flag = false; + ALOGE("Model Loading Failed!!!"); + return ("status False"); + } + } else { + return ("status False"); } } return std::string(status.error_message()); } -void DetectionClient::add_input_data(std::string label, const uint8_t* buffer, std::vector shape, uint32_t size) { +void DetectionClient::add_input_data(std::string label, const uint8_t* buffer, std::vector shape, uint32_t size, android::hardware::neuralnetworks::nnhal::OperandType operandType) { const float* src; size_t index; DataTensor* input = request.add_data_tensors(); input->set_node_name(label); + switch(operandType) { + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_INT32: { + input->set_data_type(DataTensor::i32); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_FLOAT16: { + input->set_data_type(DataTensor::f16); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_FLOAT32: { + input->set_data_type(DataTensor::f32); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_BOOL8: { + input->set_data_type(DataTensor::boolean); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_QUANT8_ASYMM: { + input->set_data_type(DataTensor::u8); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_QUANT8_SYMM: + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_QUANT8_ASYMM_SIGNED: { + input->set_data_type(DataTensor::i8); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_QUANT16_SYMM: { + input->set_data_type(DataTensor::i16); + break; + } + case android::hardware::neuralnetworks::nnhal::OperandType::TENSOR_QUANT16_ASYMM: { + input->set_data_type(DataTensor::u16); + break; + } + default: { + input->set_data_type(DataTensor::u8); + break; + } + } for (index = 0; index < shape.size(); index++) { input->add_tensor_shape(shape[index]); } input->set_data(buffer, size); } -void DetectionClient::get_output_data(std::string label, uint8_t* buffer, std::vector shape) { +void DetectionClient::get_output_data(std::string label, uint8_t* buffer, std::vector shape, uint32_t expectedLength) { std::string src; size_t index; size_t size = 1; @@ -91,6 +157,9 @@ void DetectionClient::get_output_data(std::string label, uint8_t* buffer, std::v for (index = 0; index < reply.data_tensors_size(); index++) { if (label.compare(reply.data_tensors(index).node_name()) == 0) { src = reply.data_tensors(index).data(); + if(expectedLength != src.length()) { + ALOGE("Length Mismatch error: expected length %d , actual length %d", expectedLength, src.length()); + } memcpy(buffer, src.data(), src.length()); break; } @@ -104,7 +173,7 @@ void DetectionClient::clear_data() { std::string DetectionClient::remote_infer() { ClientContext context; - time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(20000); + time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(5000); context.set_deadline(deadline); status = stub_->getInferResult(&context, request, &reply); diff --git a/DetectionClient.h b/DetectionClient.h index dece36ae4..f26306b6e 100644 --- a/DetectionClient.h +++ b/DetectionClient.h @@ -8,6 +8,7 @@ #include #include #include "nnhal_object_detection.grpc.pb.h" +#include "Driver.h" using grpc::Channel; using grpc::ClientContext; @@ -32,9 +33,10 @@ class DetectionClient { std::unique_ptr >& writer); std::string sendIRs(bool& flag, const std::string& ir_xml, const std::string& ir_bin); + bool isModelLoaded(std::string fileName); - void add_input_data(std::string label, const uint8_t* buffer, std::vector shape, uint32_t size); - void get_output_data(std::string label, uint8_t* buffer, std::vector shape); + void add_input_data(std::string label, const uint8_t* buffer, std::vector shape, uint32_t size, android::hardware::neuralnetworks::nnhal::OperandType operandType); + void get_output_data(std::string label, uint8_t* buffer, std::vector shape, uint32_t expectedLength); void clear_data(); std::string remote_infer(); bool get_status(); diff --git a/proto/nnhal_object_detection.proto b/proto/nnhal_object_detection.proto index e0f14722a..293d6c56c 100644 --- a/proto/nnhal_object_detection.proto +++ b/proto/nnhal_object_detection.proto @@ -27,6 +27,7 @@ service Detection { rpc getInferResult (RequestDataTensors) returns (ReplyDataTensors) {} rpc sendXml (stream RequestDataChunks) returns (ReplyStatus) {} rpc sendBin (stream RequestDataChunks) returns (ReplyStatus) {} + rpc loadModel(RequestString) returns (ReplyStatus) {} rpc prepare (RequestString) returns (ReplyStatus) {} //Placeholder for any future support : RequestString } @@ -47,6 +48,25 @@ message DataTensor { bytes data = 1; string node_name = 2; repeated int32 tensor_shape = 3; + enum DATA_TYPE { + boolean = 0; + bf16 = 1; + f16 = 2; + f32 = 3; + f64 = 4; + i4 = 5; + i8 = 6; + i16 = 7; + i32 = 8; + i64 = 9; + u1 = 10; + u4 = 11; + u8 = 12; + u16 = 13; + u32 = 14; + u64 = 15; + } + DATA_TYPE data_type = 4; } // Reply message containing the Output Data Tensors(blobs) From 591503771434ad1d7b3263a13bdd2c8885329a3a Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Mon, 15 May 2023 05:24:42 +0000 Subject: [PATCH 4/9] Only F32 for remote inference Signed-off-by: Anoob Anto K --- BasePreparedModel.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index 3cf11023b..958c2e6e5 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -66,7 +66,22 @@ bool BasePreparedModel::initialize() { ALOGE("Failed to initialize Model runtime parameters!!"); return false; } - mRemoteCheck = checkRemoteConnection(); + + mRemoteCheck = true; + for (auto i : mModelInfo->getModelInputIndexes()) { + auto& nnapiOperandType = mModelInfo->getOperand(i).type; + switch (nnapiOperandType) { + case OperandType::FLOAT32: + case OperandType::TENSOR_FLOAT32: + break; + default: + ALOGD("GRPC Remote Infer not enabled for %d", nnapiOperandType); + mRemoteCheck = false; + break; + } + if (!mRemoteCheck) break; + } + if (mRemoteCheck) mRemoteCheck = checkRemoteConnection(); mNgraphNetCreator = std::make_shared(mModelInfo, mTargetDevice); if (!mNgraphNetCreator->validateOperations()) return false; From f58d01ca4545bfa971711e290bbcee9940443efd Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Mon, 15 May 2023 05:25:22 +0000 Subject: [PATCH 5/9] Enable debug logs Signed-off-by: Anoob Anto K --- Driver.cpp | 1 + ModelManager.h | 1 - utils.h | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Driver.cpp b/Driver.cpp index 36fcec5a1..df1bc3729 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -403,6 +403,7 @@ Return Driver::prepareModel_1_3( // TODO: make asynchronous later sp driverPreparedModel = ModelFactory(mDeviceType, model); + for (auto& opn : model.main.operations) dumpOperation(opn); if (!driverPreparedModel->initialize()) { ALOGI("Failed to initialize prepared model"); cb->notify_1_3(convertToV1_3(ErrorStatus::INVALID_ARGUMENT), nullptr); diff --git a/ModelManager.h b/ModelManager.h index 67e1a4b3d..49168aa92 100755 --- a/ModelManager.h +++ b/ModelManager.h @@ -133,7 +133,6 @@ class NnapiModelInfo { const auto value = GetConstOperand(inputIndex); ALOGV("Operation input index: %d, operand index: %d", index, inputIndex); ALOGV("Operation: %s", toString(mModel.main.operations[operationIndex]).c_str()); - printHelper::print(value, toString(operand).c_str()); return value; } diff --git a/utils.h b/utils.h index aea1cafc5..a67d051e9 100644 --- a/utils.h +++ b/utils.h @@ -96,12 +96,12 @@ enum PaddingScheme { #define dumpOperand(index, model) \ do { \ const auto op = model.operands[index]; \ - ALOGV("Operand (%zu) %s", index, toString(op).c_str()); \ + ALOGD("Operand (%zu) %s", index, toString(op).c_str()); \ } while (0) #define dumpOperation(operation) \ do { \ - ALOGV("Operation: %s", toString(operation).c_str()); \ + ALOGD("Operation: %s", toString(operation).c_str()); \ } while (0) #define WRONG_DIM (-1) From cbdd6fb047a75d5f9f11c2e94e1b62c5de0843d9 Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Tue, 16 May 2023 09:23:08 +0000 Subject: [PATCH 6/9] Increase GRPC message size and improved remote checks Signed-off-by: Anoob Anto K --- BasePreparedModel.cpp | 61 +++++++++++++++++++++++++++---------------- BasePreparedModel.h | 1 + 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index 958c2e6e5..2bfb57739 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -45,10 +45,7 @@ void BasePreparedModel::deinitialize() { if ((ret_xml != 0) || (ret_bin != 0)) { ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin); } - if(mRemoteCheck) { - ALOGD("GRPC RELEASED Remote Connection"); - gRemoteCheck = false; - } + setRemoteEnabled(false); ALOGV("Exiting %s", __func__); } @@ -67,21 +64,22 @@ bool BasePreparedModel::initialize() { return false; } - mRemoteCheck = true; - for (auto i : mModelInfo->getModelInputIndexes()) { - auto& nnapiOperandType = mModelInfo->getOperand(i).type; - switch (nnapiOperandType) { - case OperandType::FLOAT32: - case OperandType::TENSOR_FLOAT32: - break; - default: - ALOGD("GRPC Remote Infer not enabled for %d", nnapiOperandType); - mRemoteCheck = false; - break; + setRemoteEnabled(checkRemoteConnection()); + if (mRemoteCheck) { + for (auto i : mModelInfo->getModelInputIndexes()) { + auto& nnapiOperandType = mModelInfo->getOperand(i).type; + switch (nnapiOperandType) { + case OperandType::FLOAT32: + case OperandType::TENSOR_FLOAT32: + break; + default: + ALOGD("GRPC Remote Infer not enabled for %d", nnapiOperandType); + setRemoteEnabled(false); + break; + } + if (!mRemoteCheck) break; } - if (!mRemoteCheck) break; } - if (mRemoteCheck) mRemoteCheck = checkRemoteConnection(); mNgraphNetCreator = std::make_shared(mModelInfo, mTargetDevice); if (!mNgraphNetCreator->validateOperations()) return false; @@ -111,15 +109,18 @@ bool BasePreparedModel::initialize() { bool BasePreparedModel::checkRemoteConnection() { if(gRemoteCheck) { - ALOGD("GRPC Remote Connection already under use"); + ALOGD("%s GRPC Remote Connection Busy", __func__); return false; } char grpc_prop[PROPERTY_VALUE_MAX] = ""; bool is_success = false; if(getGrpcIpPort(grpc_prop)) { ALOGV("Attempting GRPC via TCP : %s", grpc_prop); + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(INT_MAX); + args.SetMaxSendMessageSize(INT_MAX); gDetectionClient = std::make_shared( - grpc::CreateChannel(grpc_prop, grpc::InsecureChannelCredentials())); + grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args)); if(gDetectionClient) { auto reply = gDetectionClient->prepare(is_success); ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str()); @@ -127,15 +128,17 @@ bool BasePreparedModel::checkRemoteConnection() { } if (!is_success && getGrpcSocketPath(grpc_prop)) { ALOGV("Attempting GRPC via unix : %s", grpc_prop); + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(INT_MAX); + args.SetMaxSendMessageSize(INT_MAX); gDetectionClient = std::make_shared( - grpc::CreateChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials())); + grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args)); if(gDetectionClient) { auto reply = gDetectionClient->prepare(is_success); ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str()); } } - gRemoteCheck = is_success; - ALOGD("GRPC ACQUIRED Remote Connection"); + setRemoteEnabled(is_success); return is_success; } @@ -152,10 +155,22 @@ bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::st else { ALOGE("%s gDetectionClient is null",__func__); } - mRemoteCheck = is_success; + setRemoteEnabled(is_success); return is_success; } +void BasePreparedModel::setRemoteEnabled(bool flag) { + if (gRemoteCheck && flag) { + ALOGD("%s GRPC Remote Connection Busy", __func__); + return; + } + if(mRemoteCheck != flag) { + ALOGD("GRPC %s Remote Connection", flag ? "ACQUIRED" : "RELEASED"); + gRemoteCheck = flag; + mRemoteCheck = flag; + } +} + static Return notify(const sp& callback, const ErrorStatus& status, const hidl_vec&, Timing) { return callback->notify(status); diff --git a/BasePreparedModel.h b/BasePreparedModel.h index cc8a884ed..44f6be27b 100755 --- a/BasePreparedModel.h +++ b/BasePreparedModel.h @@ -89,6 +89,7 @@ class BasePreparedModel : public V1_3::IPreparedModel { virtual bool initialize(); virtual bool checkRemoteConnection(); virtual bool loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin); + virtual void setRemoteEnabled(bool flag); std::shared_ptr getModelInfo() { return mModelInfo; } From d2484a0c5bb2be36fafab582b4f3f541c0ddaa0e Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Wed, 17 May 2023 14:01:26 +0000 Subject: [PATCH 7/9] Disable remote inference if a request fails Signed-off-by: Anoob Anto K --- BasePreparedModel.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index 2bfb57739..dd3160da3 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -542,6 +542,9 @@ static std::tuple, Timing> executeSynch ALOGI("***********GRPC server response************* %s", reply.c_str()); } if (!preparedModel->mRemoteCheck || !gDetectionClient->get_status()){ + if(preparedModel->mRemoteCheck) { + preparedModel->setRemoteEnabled(false); + } try { ALOGV("%s Client Infer", __func__); plugin->infer(); From ddb22bb730d2d503053e507abeb5eb660d2228c2 Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Thu, 18 May 2023 05:55:33 +0000 Subject: [PATCH 8/9] Parallel Remote Inference Signed-off-by: Anoob Anto K --- BasePreparedModel.cpp | 68 +++++++++++++----------------- BasePreparedModel.h | 5 ++- DetectionClient.cpp | 22 +++++++++- DetectionClient.h | 6 ++- proto/nnhal_object_detection.proto | 39 ++++++++++------- 5 files changed, 82 insertions(+), 58 deletions(-) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index dd3160da3..79c2c13eb 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -33,18 +33,19 @@ namespace android::hardware::neuralnetworks::nnhal { using namespace android::nn; static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX}; -bool gRemoteCheck = false; -std::shared_ptr gDetectionClient; uint32_t BasePreparedModel::mFileId = 0; void BasePreparedModel::deinitialize() { ALOGV("Entering %s", __func__); + bool is_success = false; mModelInfo->unmapRuntimeMemPools(); auto ret_xml = std::remove(mXmlFile.c_str()); auto ret_bin = std::remove(mBinFile.c_str()); if ((ret_xml != 0) || (ret_bin != 0)) { ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin); } + auto reply = mDetectionClient->release(is_success); + ALOGI("GRPC release response is %d : %s", is_success, reply.c_str()); setRemoteEnabled(false); ALOGV("Exiting %s", __func__); @@ -108,10 +109,6 @@ bool BasePreparedModel::initialize() { } bool BasePreparedModel::checkRemoteConnection() { - if(gRemoteCheck) { - ALOGD("%s GRPC Remote Connection Busy", __func__); - return false; - } char grpc_prop[PROPERTY_VALUE_MAX] = ""; bool is_success = false; if(getGrpcIpPort(grpc_prop)) { @@ -119,10 +116,10 @@ bool BasePreparedModel::checkRemoteConnection() { grpc::ChannelArguments args; args.SetMaxReceiveMessageSize(INT_MAX); args.SetMaxSendMessageSize(INT_MAX); - gDetectionClient = std::make_shared( - grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args)); - if(gDetectionClient) { - auto reply = gDetectionClient->prepare(is_success); + mDetectionClient = std::make_shared( + grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId); + if(mDetectionClient) { + auto reply = mDetectionClient->prepare(is_success); ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str()); } } @@ -131,10 +128,10 @@ bool BasePreparedModel::checkRemoteConnection() { grpc::ChannelArguments args; args.SetMaxReceiveMessageSize(INT_MAX); args.SetMaxSendMessageSize(INT_MAX); - gDetectionClient = std::make_shared( - grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args)); - if(gDetectionClient) { - auto reply = gDetectionClient->prepare(is_success); + mDetectionClient = std::make_shared( + grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId); + if(mDetectionClient) { + auto reply = mDetectionClient->prepare(is_success); ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str()); } } @@ -145,28 +142,23 @@ bool BasePreparedModel::checkRemoteConnection() { bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) { ALOGI("Entering %s", __func__); bool is_success = false; - if(gDetectionClient) { - auto reply = gDetectionClient->sendIRs(is_success, ir_xml, ir_bin); + if(mDetectionClient) { + auto reply = mDetectionClient->sendIRs(is_success, ir_xml, ir_bin); ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str()); if (reply == "status False") { ALOGE("%s Model Load Failed",__func__); } } else { - ALOGE("%s gDetectionClient is null",__func__); + ALOGE("%s mDetectionClient is null",__func__); } setRemoteEnabled(is_success); return is_success; } void BasePreparedModel::setRemoteEnabled(bool flag) { - if (gRemoteCheck && flag) { - ALOGD("%s GRPC Remote Connection Busy", __func__); - return; - } if(mRemoteCheck != flag) { ALOGD("GRPC %s Remote Connection", flag ? "ACQUIRED" : "RELEASED"); - gRemoteCheck = flag; mRemoteCheck = flag; } } @@ -312,10 +304,10 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod if (measure == MeasureTiming::YES) deviceStart = now(); if(preparedModel->mRemoteCheck) { ALOGI("%s GRPC Remote Infer", __func__); - auto reply = gDetectionClient->remote_infer(); + auto reply = preparedModel->mDetectionClient->remote_infer(); ALOGI("***********GRPC server response************* %s", reply.c_str()); } - if (!preparedModel->mRemoteCheck || !gDetectionClient->get_status()){ + if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){ try { plugin->infer(); } catch (const std::exception& ex) { @@ -374,8 +366,8 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod return; } - if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { - gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, + if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) { + preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, ngraphNw->getOutputShape(outIndex), expectedLength); } else { switch (operandType) { @@ -469,9 +461,9 @@ static std::tuple, Timing> executeSynch ALOGV("Input index: %d layername : %s", inIndex, inputNodeName.c_str()); //check if remote infer is available //TODO: Need to add FLOAT16 support for remote inferencing - if(preparedModel->mRemoteCheck && gDetectionClient) { + if(preparedModel->mRemoteCheck && preparedModel->mDetectionClient) { auto inOperandType = modelInfo->getOperandType(inIndex); - gDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len, inOperandType); + preparedModel->mDetectionClient->add_input_data(std::to_string(i), (uint8_t*)srcPtr, ngraphNw->getOutputShape(inIndex), len, inOperandType); } else { ov::Tensor destTensor; try { @@ -538,10 +530,10 @@ static std::tuple, Timing> executeSynch if (measure == MeasureTiming::YES) deviceStart = now(); if(preparedModel->mRemoteCheck) { ALOGI("%s GRPC Remote Infer", __func__); - auto reply = gDetectionClient->remote_infer(); + auto reply = preparedModel->mDetectionClient->remote_infer(); ALOGI("***********GRPC server response************* %s", reply.c_str()); } - if (!preparedModel->mRemoteCheck || !gDetectionClient->get_status()){ + if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){ if(preparedModel->mRemoteCheck) { preparedModel->setRemoteEnabled(false); } @@ -601,8 +593,8 @@ static std::tuple, Timing> executeSynch } //copy output from remote infer //TODO: Add support for other OperandType - if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { - gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, + if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) { + preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, ngraphNw->getOutputShape(outIndex), expectedLength); } else { switch (operandType) { @@ -652,8 +644,8 @@ static std::tuple, Timing> executeSynch ALOGE("Failed to update the request pool infos"); return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; } - if (preparedModel->mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { - gDetectionClient->clear_data(); + if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) { + preparedModel->mDetectionClient->clear_data(); } if (measure == MeasureTiming::YES) { @@ -868,10 +860,10 @@ Return BasePreparedModel::executeFenced(const V1_3::Request& request1_3, if (measure == MeasureTiming::YES) deviceStart = now(); if(mRemoteCheck) { ALOGI("%s GRPC Remote Infer", __func__); - auto reply = gDetectionClient->remote_infer(); + auto reply = mDetectionClient->remote_infer(); ALOGI("***********GRPC server response************* %s", reply.c_str()); } - if (!mRemoteCheck || !gDetectionClient->get_status()){ + if (!mRemoteCheck || !mDetectionClient->get_status()){ try { mPlugin->infer(); } catch (const std::exception& ex) { @@ -916,8 +908,8 @@ Return BasePreparedModel::executeFenced(const V1_3::Request& request1_3, mModelInfo->updateOutputshapes(i, outDims); } - if (mRemoteCheck && gDetectionClient && gDetectionClient->get_status()) { - gDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, + if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) { + mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr, mNgraphNetCreator->getOutputShape(outIndex), expectedLength); } else { switch (operandType) { diff --git a/BasePreparedModel.h b/BasePreparedModel.h index 44f6be27b..ec8f5dcc6 100755 --- a/BasePreparedModel.h +++ b/BasePreparedModel.h @@ -54,8 +54,8 @@ class BasePreparedModel : public V1_3::IPreparedModel { bool mRemoteCheck = false; BasePreparedModel(const IntelDeviceType device, const Model& model) : mTargetDevice(device) { mModelInfo = std::make_shared(model); - mXmlFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".xml"); - mBinFile = std::string("/data/vendor/neuralnetworks/") + std::to_string(mFileId) + std::string(".bin"); + mXmlFile = MODEL_DIR + std::to_string(mFileId) + std::string(".xml"); + mBinFile = MODEL_DIR + std::to_string(mFileId) + std::string(".bin"); mFileId++; } @@ -98,6 +98,7 @@ class BasePreparedModel : public V1_3::IPreparedModel { std::shared_ptr getPlugin() { return mPlugin; } std::shared_ptr modelPtr; + std::shared_ptr mDetectionClient; protected: virtual void deinitialize(); diff --git a/DetectionClient.cpp b/DetectionClient.cpp index bcf0fc6d2..de08783db 100644 --- a/DetectionClient.cpp +++ b/DetectionClient.cpp @@ -5,7 +5,7 @@ std::string DetectionClient::prepare(bool& flag) { RequestString request; - request.set_value(""); + request.mutable_token()->set_data(mToken); ReplyStatus reply; ClientContext context; time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(100); @@ -21,9 +21,26 @@ std::string DetectionClient::prepare(bool& flag) { } } +std::string DetectionClient::release(bool& flag) { + RequestString request; + request.mutable_token()->set_data(mToken); + ReplyStatus reply; + ClientContext context; + + Status status = stub_->release(&context, request, &reply); + + if (status.ok()) { + flag = reply.status(); + return (flag ? "status True" : "status False"); + } else { + return std::string(status.error_message()); + } +} + Status DetectionClient::sendFile(std::string fileName, std::unique_ptr >& writer) { RequestDataChunks request; + request.mutable_token()->set_data(mToken); uint32_t CHUNK_SIZE = 1024 * 1024; std::ifstream fin(fileName, std::ifstream::binary); std::vector buffer(CHUNK_SIZE, 0); @@ -51,7 +68,7 @@ bool DetectionClient::isModelLoaded(std::string fileName) { ReplyStatus reply; ClientContext context; RequestString request; - request.set_value(fileName); + request.mutable_token()->set_data(mToken); time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(20000); context.set_deadline(deadline); status = stub_->loadModel(&context, request, &reply); @@ -176,6 +193,7 @@ std::string DetectionClient::remote_infer() { time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(5000); context.set_deadline(deadline); + request.mutable_token()->set_data(mToken); status = stub_->getInferResult(&context, request, &reply); if (status.ok()) { if (reply.data_tensors_size() == 0) ALOGE("GRPC reply empty, ovms failure ?"); diff --git a/DetectionClient.h b/DetectionClient.h index f26306b6e..dcfe2ec70 100644 --- a/DetectionClient.h +++ b/DetectionClient.h @@ -23,11 +23,14 @@ using objectDetection::RequestDataTensors; using objectDetection::RequestString; using time_point = std::chrono::system_clock::time_point; +#define MODEL_DIR std::string("/data/vendor/neuralnetworks/") + class DetectionClient { public: - DetectionClient(std::shared_ptr channel) : stub_(Detection::NewStub(channel)){} + DetectionClient(std::shared_ptr channel, uint32_t token) : stub_(Detection::NewStub(channel)), mToken(token) {} std::string prepare(bool& flag); + std::string release(bool& flag); Status sendFile(std::string fileName, std::unique_ptr >& writer); @@ -46,6 +49,7 @@ class DetectionClient { RequestDataTensors request; ReplyDataTensors reply; Status status; + uint32_t mToken; }; #endif \ No newline at end of file diff --git a/proto/nnhal_object_detection.proto b/proto/nnhal_object_detection.proto index 293d6c56c..d129ac575 100644 --- a/proto/nnhal_object_detection.proto +++ b/proto/nnhal_object_detection.proto @@ -1,16 +1,18 @@ -// Copyright 2015 gRPC authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* +* Copyright (c) 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ syntax = "proto3"; @@ -29,15 +31,21 @@ service Detection { rpc sendBin (stream RequestDataChunks) returns (ReplyStatus) {} rpc loadModel(RequestString) returns (ReplyStatus) {} rpc prepare (RequestString) returns (ReplyStatus) {} //Placeholder for any future support : RequestString + rpc release (RequestString) returns (ReplyStatus) {} } +message Token { + uint32 data = 1; +} + message RequestDataChunks { bytes data = 1; + Token token = 2; } message RequestString { - string value = 1; + Token token = 1; } message ReplyStatus { bool status = 1; @@ -77,4 +85,5 @@ message ReplyDataTensors { // Request message containing the Input Data Tensors(blobs) message RequestDataTensors { repeated DataTensor data_tensors = 1; -} \ No newline at end of file + Token token = 2; +} From ee4097739b8e5dd64807f17d14977c384d9b01ab Mon Sep 17 00:00:00 2001 From: Anoob Anto K Date: Fri, 19 May 2023 09:53:36 +0000 Subject: [PATCH 9/9] Prevent non-compliant upcast of model, operand Fix the following errors : Upcasting non-compliant model Upcasting non-compliant operand type TENSOR_QUANT8_ASYMM_SIGNED from V1_3::OperandType to V1_2::OperandType Signed-off-by: Anoob Anto K --- BasePreparedModel.cpp | 10 ++++----- ModelManager.cpp | 52 +++++++++++++++++++++++++++++++++++++++++++ ModelManager.h | 2 ++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/BasePreparedModel.cpp b/BasePreparedModel.cpp index 79c2c13eb..59e9d9737 100644 --- a/BasePreparedModel.cpp +++ b/BasePreparedModel.cpp @@ -433,7 +433,7 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod } static std::tuple, Timing> executeSynchronouslyBase( - const Request& request, MeasureTiming measure, BasePreparedModel* preparedModel, + const V1_3::Request& request, MeasureTiming measure, BasePreparedModel* preparedModel, time_point driverStart) { ALOGV("Entering %s", __func__); auto modelInfo = preparedModel->getModelInfo(); @@ -442,7 +442,7 @@ static std::tuple, Timing> executeSynch time_point driverEnd, deviceStart, deviceEnd; std::vector requestPoolInfos; auto errorStatus = modelInfo->setRunTimePoolInfosFromHidlMemories(request.pools); - if (errorStatus != ErrorStatus::NONE) { + if (errorStatus != V1_3::ErrorStatus::NONE) { ALOGE("Failed to set runtime pool info from HIDL memories"); return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; } @@ -669,7 +669,7 @@ Return BasePreparedModel::executeSynchronously(const Request& request, Mea return Void(); } auto [status, outputShapes, timing] = - executeSynchronouslyBase(request, measure, this, driverStart); + executeSynchronouslyBase(convertToV1_3(request), measure, this, driverStart); cb(status, std::move(outputShapes), timing); ALOGV("Exiting %s", __func__); return Void(); @@ -684,12 +684,12 @@ Return BasePreparedModel::executeSynchronously_1_3(const V1_3::Request& re time_point driverStart; if (measure == MeasureTiming::YES) driverStart = now(); - if (!validateRequest(convertToV1_0(request), convertToV1_2(mModelInfo->getModel()))) { + if (!validateRequest(request, mModelInfo->getModel())) { cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming); return Void(); } auto [status, outputShapes, timing] = - executeSynchronouslyBase(convertToV1_0(request), measure, this, driverStart); + executeSynchronouslyBase(request, measure, this, driverStart); cb(convertToV1_3(status), std::move(outputShapes), timing); ALOGV("Exiting %s", __func__); return Void(); diff --git a/ModelManager.cpp b/ModelManager.cpp index e162ec55a..e1d0e6ebe 100644 --- a/ModelManager.cpp +++ b/ModelManager.cpp @@ -233,6 +233,31 @@ void* NnapiModelInfo::getBlobFromMemoryPoolIn(const Request& request, uint32_t i return (r.buffer + arg.location.offset); } +void* NnapiModelInfo::getBlobFromMemoryPoolIn(const V1_3::Request& request, uint32_t index, + uint32_t& rBufferLength) { + RunTimeOperandInfo& operand = mOperands[mModel.main.inputIndexes[index]]; + const V1_0::RequestArgument& arg = request.inputs[index]; + auto poolIndex = arg.location.poolIndex; + nnAssert(poolIndex < mRequestPoolInfos.size()); + auto& r = mRequestPoolInfos[poolIndex]; + + if (arg.dimensions.size() > 0) { + // It's the responsibility of the caller to validate that + // from.dimensions only modifies the dimensions that were + // unspecified in the model. That's the case in SampleDriver.cpp + // with the call to validateRequest(). + operand.dimensions = arg.dimensions; + } + + operand.buffer = r.buffer + arg.location.offset; + operand.length = arg.location.length; + ALOGV("%s Operand length:%d pointer:%p offset:%d pool index: %d", __func__, operand.length, + (r.buffer + arg.location.offset), arg.location.offset, poolIndex); + rBufferLength = operand.length; + + return (r.buffer + arg.location.offset); +} + void* NnapiModelInfo::getBlobFromMemoryPoolOut(const Request& request, uint32_t index, uint32_t& rBufferLength) { RunTimeOperandInfo& operand = mOperands[mModel.main.outputIndexes[index]]; @@ -260,6 +285,33 @@ void* NnapiModelInfo::getBlobFromMemoryPoolOut(const Request& request, uint32_t return (r.buffer + arg.location.offset); } +void* NnapiModelInfo::getBlobFromMemoryPoolOut(const V1_3::Request& request, uint32_t index, + uint32_t& rBufferLength) { + RunTimeOperandInfo& operand = mOperands[mModel.main.outputIndexes[index]]; + const V1_0::RequestArgument& arg = request.outputs[index]; + auto poolIndex = arg.location.poolIndex; + nnAssert(poolIndex < mRequestPoolInfos.size()); + auto& r = mRequestPoolInfos[poolIndex]; + + ALOGV("%s lifetime:%d location offset:%d length:%d pool index:%d", __func__, operand.lifetime, + arg.location.offset, arg.location.length, poolIndex); + + if (arg.dimensions.size() > 0) { + // It's the responsibility of the caller to validate that + // from.dimensions only modifies the dimensions that were + // unspecified in the model. That's the case in SampleDriver.cpp + // with the call to validateRequest(). + operand.dimensions = arg.dimensions; + } + + operand.buffer = r.buffer + arg.location.offset; + operand.length = arg.location.length; + rBufferLength = operand.length; + ALOGV("%s Operand length:%d pointer:%p", __func__, operand.length, + (r.buffer + arg.location.offset)); + return (r.buffer + arg.location.offset); +} + bool NnapiModelInfo::isOmittedInput(int operationIndex, uint32_t index) { uint32_t inputIndex = mModel.main.operations[operationIndex].inputs[index]; const auto op = mModel.main.operands[inputIndex]; diff --git a/ModelManager.h b/ModelManager.h index 49168aa92..cf853ab3e 100755 --- a/ModelManager.h +++ b/ModelManager.h @@ -166,7 +166,9 @@ class NnapiModelInfo { T GetConstFromBuffer(const uint8_t* buf, uint32_t len); void* getBlobFromMemoryPoolIn(const Request& request, uint32_t index, uint32_t& rBufferLength); + void* getBlobFromMemoryPoolIn(const V1_3::Request& request, uint32_t index, uint32_t& rBufferLength); void* getBlobFromMemoryPoolOut(const Request& request, uint32_t index, uint32_t& rBufferLength); + void* getBlobFromMemoryPoolOut(const V1_3::Request& request, uint32_t index, uint32_t& rBufferLength); Model getModel() { return mModel; }