@@ -44,10 +44,11 @@ void BasePreparedModel::deinitialize() {
4444 if ((ret_xml != 0 ) || (ret_bin != 0 )) {
4545 ALOGW (" %s Deletion status of xml:%d, bin:%d" , __func__, ret_xml, ret_bin);
4646 }
47- auto reply = mDetectionClient ->release (is_success);
48- ALOGI (" GRPC release response is %d : %s" , is_success, reply.c_str ());
49- setRemoteEnabled (false );
50-
47+ if (mRemoteCheck && mDetectionClient ) {
48+ auto reply = mDetectionClient ->release (is_success);
49+ ALOGI (" GRPC release response is %d : %s" , is_success, reply.c_str ());
50+ setRemoteEnabled (false );
51+ }
5152 ALOGV (" Exiting %s" , __func__);
5253}
5354
@@ -64,12 +65,10 @@ bool BasePreparedModel::initialize() {
6465 ALOGE (" Failed to initialize Model runtime parameters!!" );
6566 return false ;
6667 }
67-
68- setRemoteEnabled (checkRemoteConnection ());
6968 mNgraphNetCreator = std::make_shared<NgraphNetworkCreator>(mModelInfo , mTargetDevice );
7069
7170 if (!mNgraphNetCreator ->validateOperations ()) return false ;
72- ALOGI (" Generating IR Graph" );
71+ ALOGI (" Generating IR Graph for Model %u " , mFileId );
7372 auto ov_model = mNgraphNetCreator ->generateGraph ();
7473 if (ov_model == nullptr ) {
7574 ALOGE (" %s Openvino model generation failed" , __func__);
@@ -78,17 +77,29 @@ bool BasePreparedModel::initialize() {
7877 try {
7978 mPlugin = std::make_unique<IENetwork>(mTargetDevice , ov_model);
8079 mPlugin ->loadNetwork (mXmlFile , mBinFile );
81- if (mRemoteCheck ) {
82- auto resp = loadRemoteModel (mXmlFile , mBinFile );
83- ALOGD (" %s Load Remote Model returns %d" , __func__, resp);
84- } else {
85- ALOGD (" %s Remote connection unavailable" , __func__);
86- }
8780 } catch (const std::exception& ex) {
8881 ALOGE (" %s Exception !!! %s" , __func__, ex.what ());
8982 return false ;
9083 }
91-
84+ bool disableOffload = false ;
85+ for (auto i : mModelInfo ->getModelInputIndexes ()) {
86+ auto & nnapiOperandType = mModelInfo ->getOperand (i).type ;
87+ switch (nnapiOperandType) {
88+ case OperandType::FLOAT32:
89+ case OperandType::TENSOR_FLOAT32:
90+ case OperandType::TENSOR_FLOAT16:
91+ case OperandType::TENSOR_INT32:
92+ break ;
93+ default :
94+ ALOGD (" GRPC Remote Infer not enabled for %d" , nnapiOperandType);
95+ disableOffload = true ;
96+ break ;
97+ }
98+ if (disableOffload) break ;
99+ }
100+ if (!disableOffload) {
101+ loadRemoteModel (mXmlFile , mBinFile );
102+ }
92103 ALOGV (" Exiting %s" , __func__);
93104 return true ;
94105}
@@ -103,9 +114,11 @@ bool BasePreparedModel::checkRemoteConnection() {
103114 args.SetMaxSendMessageSize (INT_MAX);
104115 mDetectionClient = std::make_shared<DetectionClient>(
105116 grpc::CreateCustomChannel (grpc_prop, grpc::InsecureChannelCredentials (), args), mFileId );
106- if (mDetectionClient ) {
117+ if (mDetectionClient ) {
107118 auto reply = mDetectionClient ->prepare (is_success);
108119 ALOGI (" GRPC(TCP) prepare response is %d : %s" , is_success, reply.c_str ());
120+ } else {
121+ ALOGE (" %s mDetectionClient is null" , __func__);
109122 }
110123 }
111124 if (!is_success && getGrpcSocketPath (grpc_prop)) {
@@ -115,30 +128,27 @@ bool BasePreparedModel::checkRemoteConnection() {
115128 args.SetMaxSendMessageSize (INT_MAX);
116129 mDetectionClient = std::make_shared<DetectionClient>(
117130 grpc::CreateCustomChannel (std::string (" unix:" ) + grpc_prop, grpc::InsecureChannelCredentials (), args), mFileId );
118- if (mDetectionClient ) {
131+ if (mDetectionClient ) {
119132 auto reply = mDetectionClient ->prepare (is_success);
120133 ALOGI (" GRPC(unix) prepare response is %d : %s" , is_success, reply.c_str ());
134+ } else {
135+ ALOGE (" %s mDetectionClient is null" , __func__);
121136 }
122137 }
123- setRemoteEnabled (is_success);
124138 return is_success;
125139}
126140
127- bool BasePreparedModel::loadRemoteModel (const std::string& ir_xml, const std::string& ir_bin) {
128- ALOGI (" Entering %s" , __func__);
141+ void BasePreparedModel::loadRemoteModel (const std::string& ir_xml, const std::string& ir_bin) {
142+ ALOGI (" Entering %s for Model %u " , __func__, mFileId );
129143 bool is_success = false ;
130- if (mDetectionClient ) {
144+ if (checkRemoteConnection () && mDetectionClient ) {
131145 auto reply = mDetectionClient ->sendIRs (is_success, ir_xml, ir_bin);
132146 ALOGI (" sendIRs response GRPC %d %s" , is_success, reply.c_str ());
133147 if (reply == " status False" ) {
134148 ALOGE (" %s Model Load Failed" ,__func__);
135149 }
150+ setRemoteEnabled (is_success);
136151 }
137- else {
138- ALOGE (" %s mDetectionClient is null" ,__func__);
139- }
140- setRemoteEnabled (is_success);
141- return is_success;
142152}
143153
144154void BasePreparedModel::setRemoteEnabled (bool flag) {
@@ -287,20 +297,14 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
287297 ALOGD (" %s Run" , __func__);
288298
289299 if (measure == MeasureTiming::YES) deviceStart = now ();
290- if (preparedModel->mRemoteCheck ) {
291- ALOGI (" %s GRPC Remote Infer" , __func__);
292- auto reply = preparedModel->mDetectionClient ->remote_infer ();
293- ALOGI (" ***********GRPC server response************* %s" , reply.c_str ());
294- }
295- if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient ->get_status ()){
296- try {
297- plugin->infer ();
298- } catch (const std::exception& ex) {
299- ALOGE (" %s Exception !!! %s" , __func__, ex.what ());
300- notify (callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming );
301- return ;
302- }
300+ try {
301+ plugin->infer ();
302+ } catch (const std::exception& ex) {
303+ ALOGE (" %s Exception !!! %s" , __func__, ex.what ());
304+ notify (callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming );
305+ return ;
303306 }
307+
304308 if (measure == MeasureTiming::YES) deviceEnd = now ();
305309
306310 tensorIndex = 0 ;
@@ -351,50 +355,45 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
351355 return ;
352356 }
353357
354- if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient ->get_status ()) {
355- preparedModel->mDetectionClient ->get_output_data (std::to_string (i), (uint8_t *)destPtr,
356- ngraphNw->getOutputShape (outIndex), expectedLength);
357- } else {
358- switch (operandType) {
359- case OperandType::TENSOR_INT32:
360- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int32_t >(),
361- srcTensor.get_byte_size ());
362- break ;
363- case OperandType::TENSOR_FLOAT32:
364- std::memcpy ((uint8_t *)destPtr, srcTensor.data <float >(),
365- srcTensor.get_byte_size ());
366- break ;
367- case OperandType::TENSOR_BOOL8:
368- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <bool >(),
369- srcTensor.get_byte_size ());
370- break ;
371- case OperandType::TENSOR_QUANT8_ASYMM:
372- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint8_t >(),
373- srcTensor.get_byte_size ());
374- break ;
375- case OperandType::TENSOR_QUANT8_SYMM:
376- case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
377- case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
378- std::memcpy ((int8_t *)destPtr, (int8_t *)srcTensor.data <int8_t >(),
379- srcTensor.get_byte_size ());
380- break ;
381- case OperandType::TENSOR_FLOAT16:
382- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <ov::float16>(),
383- srcTensor.get_byte_size ());
384- break ;
385- case OperandType::TENSOR_QUANT16_SYMM:
386- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int16_t >(),
387- srcTensor.get_byte_size ());
388- break ;
389- case OperandType::TENSOR_QUANT16_ASYMM:
390- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint16_t >(),
391- srcTensor.get_byte_size ());
392- break ;
393- default :
394- std::memcpy ((uint8_t *)destPtr, srcTensor.data <uint8_t >(),
395- srcTensor.get_byte_size ());
396- break ;
397- }
358+ switch (operandType) {
359+ case OperandType::TENSOR_INT32:
360+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int32_t >(),
361+ srcTensor.get_byte_size ());
362+ break ;
363+ case OperandType::TENSOR_FLOAT32:
364+ std::memcpy ((uint8_t *)destPtr, srcTensor.data <float >(),
365+ srcTensor.get_byte_size ());
366+ break ;
367+ case OperandType::TENSOR_BOOL8:
368+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <bool >(),
369+ srcTensor.get_byte_size ());
370+ break ;
371+ case OperandType::TENSOR_QUANT8_ASYMM:
372+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint8_t >(),
373+ srcTensor.get_byte_size ());
374+ break ;
375+ case OperandType::TENSOR_QUANT8_SYMM:
376+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
377+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
378+ std::memcpy ((int8_t *)destPtr, (int8_t *)srcTensor.data <int8_t >(),
379+ srcTensor.get_byte_size ());
380+ break ;
381+ case OperandType::TENSOR_FLOAT16:
382+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <ov::float16>(),
383+ srcTensor.get_byte_size ());
384+ break ;
385+ case OperandType::TENSOR_QUANT16_SYMM:
386+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int16_t >(),
387+ srcTensor.get_byte_size ());
388+ break ;
389+ case OperandType::TENSOR_QUANT16_ASYMM:
390+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint16_t >(),
391+ srcTensor.get_byte_size ());
392+ break ;
393+ default :
394+ std::memcpy ((uint8_t *)destPtr, srcTensor.data <uint8_t >(),
395+ srcTensor.get_byte_size ());
396+ break ;
398397 }
399398 }
400399
@@ -843,19 +842,12 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,
843842
844843 time_point deviceStart, deviceEnd;
845844 if (measure == MeasureTiming::YES) deviceStart = now ();
846- if (mRemoteCheck ) {
847- ALOGI (" %s GRPC Remote Infer" , __func__);
848- auto reply = mDetectionClient ->remote_infer ();
849- ALOGI (" ***********GRPC server response************* %s" , reply.c_str ());
850- }
851- if (!mRemoteCheck || !mDetectionClient ->get_status ()){
852- try {
853- mPlugin ->infer ();
854- } catch (const std::exception& ex) {
855- ALOGE (" %s Exception !!! %s" , __func__, ex.what ());
856- cb (V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle (nullptr ), nullptr );
857- return Void ();
858- }
845+ try {
846+ mPlugin ->infer ();
847+ } catch (const std::exception& ex) {
848+ ALOGE (" %s Exception !!! %s" , __func__, ex.what ());
849+ cb (V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle (nullptr ), nullptr );
850+ return Void ();
859851 }
860852 if (measure == MeasureTiming::YES) deviceEnd = now ();
861853
@@ -893,50 +885,45 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,
893885 mModelInfo ->updateOutputshapes (i, outDims);
894886 }
895887
896- if (mRemoteCheck && mDetectionClient && mDetectionClient ->get_status ()) {
897- mDetectionClient ->get_output_data (std::to_string (i), (uint8_t *)destPtr,
898- mNgraphNetCreator ->getOutputShape (outIndex), expectedLength);
899- } else {
900- switch (operandType) {
901- case OperandType::TENSOR_INT32:
902- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int32_t >(),
903- srcTensor.get_byte_size ());
904- break ;
905- case OperandType::TENSOR_FLOAT32:
906- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <float >(),
907- srcTensor.get_byte_size ());
908- break ;
909- case OperandType::TENSOR_BOOL8:
910- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <bool >(),
911- srcTensor.get_byte_size ());
912- break ;
913- case OperandType::TENSOR_QUANT8_ASYMM:
914- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint8_t >(),
915- srcTensor.get_byte_size ());
916- break ;
917- case OperandType::TENSOR_QUANT8_SYMM:
918- case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
919- case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
920- std::memcpy ((int8_t *)destPtr, (int8_t *)srcTensor.data <int8_t >(),
921- srcTensor.get_byte_size ());
922- break ;
923- case OperandType::TENSOR_FLOAT16:
924- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <ov::float16>(),
925- srcTensor.get_byte_size ());
926- break ;
927- case OperandType::TENSOR_QUANT16_SYMM:
928- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int16_t >(),
929- srcTensor.get_byte_size ());
930- break ;
931- case OperandType::TENSOR_QUANT16_ASYMM:
932- std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint16_t >(),
933- srcTensor.get_byte_size ());
934- break ;
935- default :
936- std::memcpy ((uint8_t *)destPtr, srcTensor.data <uint8_t >(),
937- srcTensor.get_byte_size ());
938- break ;
939- }
888+ switch (operandType) {
889+ case OperandType::TENSOR_INT32:
890+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int32_t >(),
891+ srcTensor.get_byte_size ());
892+ break ;
893+ case OperandType::TENSOR_FLOAT32:
894+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <float >(),
895+ srcTensor.get_byte_size ());
896+ break ;
897+ case OperandType::TENSOR_BOOL8:
898+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <bool >(),
899+ srcTensor.get_byte_size ());
900+ break ;
901+ case OperandType::TENSOR_QUANT8_ASYMM:
902+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint8_t >(),
903+ srcTensor.get_byte_size ());
904+ break ;
905+ case OperandType::TENSOR_QUANT8_SYMM:
906+ case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
907+ case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
908+ std::memcpy ((int8_t *)destPtr, (int8_t *)srcTensor.data <int8_t >(),
909+ srcTensor.get_byte_size ());
910+ break ;
911+ case OperandType::TENSOR_FLOAT16:
912+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <ov::float16>(),
913+ srcTensor.get_byte_size ());
914+ break ;
915+ case OperandType::TENSOR_QUANT16_SYMM:
916+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <int16_t >(),
917+ srcTensor.get_byte_size ());
918+ break ;
919+ case OperandType::TENSOR_QUANT16_ASYMM:
920+ std::memcpy ((uint8_t *)destPtr, (uint8_t *)srcTensor.data <uint16_t >(),
921+ srcTensor.get_byte_size ());
922+ break ;
923+ default :
924+ std::memcpy ((uint8_t *)destPtr, srcTensor.data <uint8_t >(),
925+ srcTensor.get_byte_size ());
926+ break ;
940927 }
941928 }
942929
0 commit comments