@@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174
174
data associated with the tensor value, then returns nullptr.
175
175
*/
176
176
const uint8_t * getConstantDataPtr (
177
- const fb_xnnpack::XNNTensorValue* tensor_value ,
177
+ uint32_t buffer_idx ,
178
178
GraphPtr flatbuffer_graph,
179
179
const uint8_t * constant_data_ptr,
180
180
const NamedDataMap* named_data_map,
181
181
std::vector<FreeableBuffer>& freeable_buffers,
182
182
XNNWeightsCache* weights_cache) {
183
- auto buffer_idx = tensor_value->constant_buffer_idx ();
184
183
if (buffer_idx) {
185
184
if (!constant_data_ptr) {
186
185
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
@@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr(
230
229
return nullptr ;
231
230
}
232
231
232
+ const uint8_t * getConstantDataPtr (
233
+ const fb_xnnpack::XNNTensorValue* tensor_value,
234
+ GraphPtr flatbuffer_graph,
235
+ const uint8_t * constant_data_ptr,
236
+ const NamedDataMap* named_data_map,
237
+ std::vector<FreeableBuffer>& freeable_buffers,
238
+ XNNWeightsCache* weights_cache) {
239
+ return getConstantDataPtr (
240
+ tensor_value->constant_buffer_idx (),
241
+ flatbuffer_graph,
242
+ constant_data_ptr,
243
+ named_data_map,
244
+ freeable_buffers,
245
+ weights_cache);
246
+ }
247
+
233
248
/* *
234
249
Define serialized tensor value into
235
250
the subgraph. While also keeping track of the remapped ids from
@@ -434,22 +449,15 @@ Error defineTensor(
434
449
const float * scale = qparams->scale ()->data ();
435
450
436
451
if (qparams->scale_buffer_idx () != 0 ) {
437
- // if scales are stored in named data, then retrieve it
438
- ConstantDataOffsetPtr scale_buffer_offset =
439
- flatbuffer_graph->constant_data ()->Get (
440
- qparams->scale_buffer_idx ());
441
- const std::string& data_name =
442
- scale_buffer_offset->named_key ()->str ();
443
- Result<FreeableBuffer> scale_buffer =
444
- named_data_map->get_data (data_name.c_str ());
452
+ scale = reinterpret_cast <const float *>(getConstantDataPtr (
453
+ qparams->scale_buffer_idx (),
454
+ flatbuffer_graph,
455
+ constant_data_ptr,
456
+ named_data_map,
457
+ freeable_buffers,
458
+ weights_cache));
445
459
ET_CHECK_OR_RETURN_ERROR (
446
- scale_buffer.ok (),
447
- Internal,
448
- " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
449
- data_name.c_str (),
450
- static_cast <uint32_t >(scale_buffer.error ()));
451
- scale = reinterpret_cast <const float *>(scale_buffer.get ().data ());
452
- freeable_buffers.push_back (std::move (scale_buffer.get ()));
460
+ scale != nullptr , Internal, " Failed to load scale data." );
453
461
}
454
462
status = xnn_define_channelwise_quantized_tensor_value_v2 (
455
463
/* subgraph=*/ subgraph_ptr,
@@ -483,22 +491,15 @@ Error defineTensor(
483
491
// Block scales are preferably serialized as bf16 but can also be
484
492
// serialized as fp32 for backwards compatability.
485
493
if (qparams->scale_buffer_idx () != 0 ) {
486
- ConstantDataOffsetPtr scale_buffer_offset =
487
- flatbuffer_graph-> constant_data ()-> Get (
488
- qparams-> scale_buffer_idx ());
489
- const std::string& data_name =
490
- scale_buffer_offset-> named_key ()-> str ();
491
- Result<FreeableBuffer> scale_buffer =
492
- named_data_map-> get_data (data_name. c_str ( ));
494
+ scale_data = reinterpret_cast < const uint16_t *>( getConstantDataPtr (
495
+ qparams-> scale_buffer_idx (),
496
+ flatbuffer_graph,
497
+ constant_data_ptr,
498
+ named_data_map,
499
+ freeable_buffers,
500
+ weights_cache ));
493
501
ET_CHECK_OR_RETURN_ERROR (
494
- scale_buffer.ok (),
495
- Internal,
496
- " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
497
- data_name.c_str (),
498
- static_cast <uint32_t >(scale_buffer.error ()));
499
- scale_data =
500
- reinterpret_cast <const uint16_t *>(scale_buffer.get ().data ());
501
- freeable_buffers.push_back (std::move (scale_buffer.get ()));
502
+ scale_data != nullptr , Internal, " Failed to load scale data." );
502
503
scale_numel = qparams->num_scales ();
503
504
} else {
504
505
// Read fp32 scales, convert to bf16.
0 commit comments