Skip to content

Commit fad628a

Browse files
committed
Disable quant for remote infer
Tracked-On: Signed-off-by: Ratnesh Kumar Rai <ratnesh.kumar.rai@intel.com>
1 parent 4139a4c commit fad628a

3 files changed

Lines changed: 128 additions & 141 deletions

File tree

BasePreparedModel.cpp

Lines changed: 126 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ void BasePreparedModel::deinitialize() {
4444
if ((ret_xml != 0) || (ret_bin != 0)) {
4545
ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin);
4646
}
47-
auto reply = mDetectionClient->release(is_success);
48-
ALOGI("GRPC release response is %d : %s", is_success, reply.c_str());
49-
setRemoteEnabled(false);
50-
47+
if (mRemoteCheck && mDetectionClient) {
48+
auto reply = mDetectionClient->release(is_success);
49+
ALOGI("GRPC release response is %d : %s", is_success, reply.c_str());
50+
setRemoteEnabled(false);
51+
}
5152
ALOGV("Exiting %s", __func__);
5253
}
5354

@@ -64,12 +65,10 @@ bool BasePreparedModel::initialize() {
6465
ALOGE("Failed to initialize Model runtime parameters!!");
6566
return false;
6667
}
67-
68-
setRemoteEnabled(checkRemoteConnection());
6968
mNgraphNetCreator = std::make_shared<NgraphNetworkCreator>(mModelInfo, mTargetDevice);
7069

7170
if (!mNgraphNetCreator->validateOperations()) return false;
72-
ALOGI("Generating IR Graph");
71+
ALOGI("Generating IR Graph for Model %u", mFileId);
7372
auto ov_model = mNgraphNetCreator->generateGraph();
7473
if (ov_model == nullptr) {
7574
ALOGE("%s Openvino model generation failed", __func__);
@@ -78,17 +77,29 @@ bool BasePreparedModel::initialize() {
7877
try {
7978
mPlugin = std::make_unique<IENetwork>(mTargetDevice, ov_model);
8079
mPlugin->loadNetwork(mXmlFile, mBinFile);
81-
if(mRemoteCheck) {
82-
auto resp = loadRemoteModel(mXmlFile, mBinFile);
83-
ALOGD("%s Load Remote Model returns %d", __func__, resp);
84-
} else {
85-
ALOGD("%s Remote connection unavailable", __func__);
86-
}
8780
} catch (const std::exception& ex) {
8881
ALOGE("%s Exception !!! %s", __func__, ex.what());
8982
return false;
9083
}
91-
84+
bool disableOffload = false;
85+
for (auto i : mModelInfo->getModelInputIndexes()) {
86+
auto& nnapiOperandType = mModelInfo->getOperand(i).type;
87+
switch (nnapiOperandType) {
88+
case OperandType::FLOAT32:
89+
case OperandType::TENSOR_FLOAT32:
90+
case OperandType::TENSOR_FLOAT16:
91+
case OperandType::TENSOR_INT32:
92+
break;
93+
default :
94+
ALOGD("GRPC Remote Infer not enabled for %d", nnapiOperandType);
95+
disableOffload = true;
96+
break;
97+
}
98+
if (disableOffload) break;
99+
}
100+
if (!disableOffload) {
101+
loadRemoteModel(mXmlFile, mBinFile);
102+
}
92103
ALOGV("Exiting %s", __func__);
93104
return true;
94105
}
@@ -103,9 +114,11 @@ bool BasePreparedModel::checkRemoteConnection() {
103114
args.SetMaxSendMessageSize(INT_MAX);
104115
mDetectionClient = std::make_shared<DetectionClient>(
105116
grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
106-
if(mDetectionClient) {
117+
if (mDetectionClient) {
107118
auto reply = mDetectionClient->prepare(is_success);
108119
ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str());
120+
} else {
121+
ALOGE("%s mDetectionClient is null", __func__);
109122
}
110123
}
111124
if (!is_success && getGrpcSocketPath(grpc_prop)) {
@@ -115,30 +128,27 @@ bool BasePreparedModel::checkRemoteConnection() {
115128
args.SetMaxSendMessageSize(INT_MAX);
116129
mDetectionClient = std::make_shared<DetectionClient>(
117130
grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
118-
if(mDetectionClient) {
131+
if (mDetectionClient) {
119132
auto reply = mDetectionClient->prepare(is_success);
120133
ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str());
134+
} else {
135+
ALOGE("%s mDetectionClient is null", __func__);
121136
}
122137
}
123-
setRemoteEnabled(is_success);
124138
return is_success;
125139
}
126140

127-
bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) {
128-
ALOGI("Entering %s", __func__);
141+
void BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) {
142+
ALOGI("Entering %s for Model %u", __func__, mFileId);
129143
bool is_success = false;
130-
if(mDetectionClient) {
144+
if(checkRemoteConnection() && mDetectionClient) {
131145
auto reply = mDetectionClient->sendIRs(is_success, ir_xml, ir_bin);
132146
ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str());
133147
if (reply == "status False") {
134148
ALOGE("%s Model Load Failed",__func__);
135149
}
150+
setRemoteEnabled(is_success);
136151
}
137-
else {
138-
ALOGE("%s mDetectionClient is null",__func__);
139-
}
140-
setRemoteEnabled(is_success);
141-
return is_success;
142152
}
143153

144154
void BasePreparedModel::setRemoteEnabled(bool flag) {
@@ -287,20 +297,14 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
287297
ALOGD("%s Run", __func__);
288298

289299
if (measure == MeasureTiming::YES) deviceStart = now();
290-
if(preparedModel->mRemoteCheck) {
291-
ALOGI("%s GRPC Remote Infer", __func__);
292-
auto reply = preparedModel->mDetectionClient->remote_infer();
293-
ALOGI("***********GRPC server response************* %s", reply.c_str());
294-
}
295-
if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){
296-
try {
297-
plugin->infer();
298-
} catch (const std::exception& ex) {
299-
ALOGE("%s Exception !!! %s", __func__, ex.what());
300-
notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
301-
return;
302-
}
300+
try {
301+
plugin->infer();
302+
} catch (const std::exception& ex) {
303+
ALOGE("%s Exception !!! %s", __func__, ex.what());
304+
notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
305+
return;
303306
}
307+
304308
if (measure == MeasureTiming::YES) deviceEnd = now();
305309

306310
tensorIndex = 0;
@@ -351,50 +355,45 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
351355
return;
352356
}
353357

354-
if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
355-
preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
356-
ngraphNw->getOutputShape(outIndex), expectedLength);
357-
} else {
358-
switch (operandType) {
359-
case OperandType::TENSOR_INT32:
360-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
361-
srcTensor.get_byte_size());
362-
break;
363-
case OperandType::TENSOR_FLOAT32:
364-
std::memcpy((uint8_t*)destPtr, srcTensor.data<float>(),
365-
srcTensor.get_byte_size());
366-
break;
367-
case OperandType::TENSOR_BOOL8:
368-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
369-
srcTensor.get_byte_size());
370-
break;
371-
case OperandType::TENSOR_QUANT8_ASYMM:
372-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
373-
srcTensor.get_byte_size());
374-
break;
375-
case OperandType::TENSOR_QUANT8_SYMM:
376-
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
377-
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
378-
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
379-
srcTensor.get_byte_size());
380-
break;
381-
case OperandType::TENSOR_FLOAT16:
382-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
383-
srcTensor.get_byte_size());
384-
break;
385-
case OperandType::TENSOR_QUANT16_SYMM:
386-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
387-
srcTensor.get_byte_size());
388-
break;
389-
case OperandType::TENSOR_QUANT16_ASYMM:
390-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
391-
srcTensor.get_byte_size());
392-
break;
393-
default:
394-
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
395-
srcTensor.get_byte_size());
396-
break;
397-
}
358+
switch (operandType) {
359+
case OperandType::TENSOR_INT32:
360+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
361+
srcTensor.get_byte_size());
362+
break;
363+
case OperandType::TENSOR_FLOAT32:
364+
std::memcpy((uint8_t*)destPtr, srcTensor.data<float>(),
365+
srcTensor.get_byte_size());
366+
break;
367+
case OperandType::TENSOR_BOOL8:
368+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
369+
srcTensor.get_byte_size());
370+
break;
371+
case OperandType::TENSOR_QUANT8_ASYMM:
372+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
373+
srcTensor.get_byte_size());
374+
break;
375+
case OperandType::TENSOR_QUANT8_SYMM:
376+
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
377+
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
378+
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
379+
srcTensor.get_byte_size());
380+
break;
381+
case OperandType::TENSOR_FLOAT16:
382+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
383+
srcTensor.get_byte_size());
384+
break;
385+
case OperandType::TENSOR_QUANT16_SYMM:
386+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
387+
srcTensor.get_byte_size());
388+
break;
389+
case OperandType::TENSOR_QUANT16_ASYMM:
390+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
391+
srcTensor.get_byte_size());
392+
break;
393+
default:
394+
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
395+
srcTensor.get_byte_size());
396+
break;
398397
}
399398
}
400399

@@ -843,19 +842,12 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,
843842

844843
time_point deviceStart, deviceEnd;
845844
if (measure == MeasureTiming::YES) deviceStart = now();
846-
if(mRemoteCheck) {
847-
ALOGI("%s GRPC Remote Infer", __func__);
848-
auto reply = mDetectionClient->remote_infer();
849-
ALOGI("***********GRPC server response************* %s", reply.c_str());
850-
}
851-
if (!mRemoteCheck || !mDetectionClient->get_status()){
852-
try {
853-
mPlugin->infer();
854-
} catch (const std::exception& ex) {
855-
ALOGE("%s Exception !!! %s", __func__, ex.what());
856-
cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
857-
return Void();
858-
}
845+
try {
846+
mPlugin->infer();
847+
} catch (const std::exception& ex) {
848+
ALOGE("%s Exception !!! %s", __func__, ex.what());
849+
cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
850+
return Void();
859851
}
860852
if (measure == MeasureTiming::YES) deviceEnd = now();
861853

@@ -893,50 +885,45 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,
893885
mModelInfo->updateOutputshapes(i, outDims);
894886
}
895887

896-
if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
897-
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
898-
mNgraphNetCreator->getOutputShape(outIndex), expectedLength);
899-
} else {
900-
switch (operandType) {
901-
case OperandType::TENSOR_INT32:
902-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
903-
srcTensor.get_byte_size());
904-
break;
905-
case OperandType::TENSOR_FLOAT32:
906-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<float>(),
907-
srcTensor.get_byte_size());
908-
break;
909-
case OperandType::TENSOR_BOOL8:
910-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
911-
srcTensor.get_byte_size());
912-
break;
913-
case OperandType::TENSOR_QUANT8_ASYMM:
914-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
915-
srcTensor.get_byte_size());
916-
break;
917-
case OperandType::TENSOR_QUANT8_SYMM:
918-
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
919-
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
920-
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
921-
srcTensor.get_byte_size());
922-
break;
923-
case OperandType::TENSOR_FLOAT16:
924-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
925-
srcTensor.get_byte_size());
926-
break;
927-
case OperandType::TENSOR_QUANT16_SYMM:
928-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
929-
srcTensor.get_byte_size());
930-
break;
931-
case OperandType::TENSOR_QUANT16_ASYMM:
932-
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
933-
srcTensor.get_byte_size());
934-
break;
935-
default:
936-
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
937-
srcTensor.get_byte_size());
938-
break;
939-
}
888+
switch (operandType) {
889+
case OperandType::TENSOR_INT32:
890+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
891+
srcTensor.get_byte_size());
892+
break;
893+
case OperandType::TENSOR_FLOAT32:
894+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<float>(),
895+
srcTensor.get_byte_size());
896+
break;
897+
case OperandType::TENSOR_BOOL8:
898+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
899+
srcTensor.get_byte_size());
900+
break;
901+
case OperandType::TENSOR_QUANT8_ASYMM:
902+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
903+
srcTensor.get_byte_size());
904+
break;
905+
case OperandType::TENSOR_QUANT8_SYMM:
906+
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
907+
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
908+
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
909+
srcTensor.get_byte_size());
910+
break;
911+
case OperandType::TENSOR_FLOAT16:
912+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
913+
srcTensor.get_byte_size());
914+
break;
915+
case OperandType::TENSOR_QUANT16_SYMM:
916+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
917+
srcTensor.get_byte_size());
918+
break;
919+
case OperandType::TENSOR_QUANT16_ASYMM:
920+
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
921+
srcTensor.get_byte_size());
922+
break;
923+
default:
924+
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
925+
srcTensor.get_byte_size());
926+
break;
940927
}
941928
}
942929

BasePreparedModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class BasePreparedModel : public V1_3::IPreparedModel {
8888

8989
virtual bool initialize();
9090
virtual bool checkRemoteConnection();
91-
virtual bool loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
91+
virtual void loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
9292
virtual void setRemoteEnabled(bool flag);
9393

9494
std::shared_ptr<NnapiModelInfo> getModelInfo() { return mModelInfo; }

DetectionClient.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ std::string DetectionClient::prepare(bool& flag) {
88
request.mutable_token()->set_data(mToken);
99
ReplyStatus reply;
1010
ClientContext context;
11-
time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(100);
11+
time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(10000);
1212
context.set_deadline(deadline);
1313

1414
Status status = stub_->prepare(&context, request, &reply);

0 commit comments

Comments
 (0)