Skip to content

Commit cf1c4bc

Browse files
authored
Use weight cache for quantized tensor scale data
Differential Revision: D82862629 Pull Request resolved: #14448
1 parent 46d7591 commit cf1c4bc

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174
data associated with the tensor value, then returns nullptr.
175175
*/
176176
const uint8_t* getConstantDataPtr(
177-
const fb_xnnpack::XNNTensorValue* tensor_value,
177+
uint32_t buffer_idx,
178178
GraphPtr flatbuffer_graph,
179179
const uint8_t* constant_data_ptr,
180180
const NamedDataMap* named_data_map,
181181
std::vector<FreeableBuffer>& freeable_buffers,
182182
XNNWeightsCache* weights_cache) {
183-
auto buffer_idx = tensor_value->constant_buffer_idx();
184183
if (buffer_idx) {
185184
if (!constant_data_ptr) {
186185
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
@@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr(
230229
return nullptr;
231230
}
232231

232+
const uint8_t* getConstantDataPtr(
233+
const fb_xnnpack::XNNTensorValue* tensor_value,
234+
GraphPtr flatbuffer_graph,
235+
const uint8_t* constant_data_ptr,
236+
const NamedDataMap* named_data_map,
237+
std::vector<FreeableBuffer>& freeable_buffers,
238+
XNNWeightsCache* weights_cache) {
239+
return getConstantDataPtr(
240+
tensor_value->constant_buffer_idx(),
241+
flatbuffer_graph,
242+
constant_data_ptr,
243+
named_data_map,
244+
freeable_buffers,
245+
weights_cache);
246+
}
247+
233248
/**
234249
Define serialized tensor value into
235250
the subgraph. While also keeping track of the remapped ids from
@@ -434,22 +449,15 @@ Error defineTensor(
434449
const float* scale = qparams->scale()->data();
435450

436451
if (qparams->scale_buffer_idx() != 0) {
437-
// if scales are stored in named data, then retrieve it
438-
ConstantDataOffsetPtr scale_buffer_offset =
439-
flatbuffer_graph->constant_data()->Get(
440-
qparams->scale_buffer_idx());
441-
const std::string& data_name =
442-
scale_buffer_offset->named_key()->str();
443-
Result<FreeableBuffer> scale_buffer =
444-
named_data_map->get_data(data_name.c_str());
452+
scale = reinterpret_cast<const float*>(getConstantDataPtr(
453+
qparams->scale_buffer_idx(),
454+
flatbuffer_graph,
455+
constant_data_ptr,
456+
named_data_map,
457+
freeable_buffers,
458+
weights_cache));
445459
ET_CHECK_OR_RETURN_ERROR(
446-
scale_buffer.ok(),
447-
Internal,
448-
"Failed to get constant data for key %s from named_data_map. Error code: %u",
449-
data_name.c_str(),
450-
static_cast<uint32_t>(scale_buffer.error()));
451-
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
452-
freeable_buffers.push_back(std::move(scale_buffer.get()));
460+
scale != nullptr, Internal, "Failed to load scale data.");
453461
}
454462
status = xnn_define_channelwise_quantized_tensor_value_v2(
455463
/*subgraph=*/subgraph_ptr,
@@ -483,22 +491,15 @@ Error defineTensor(
483491
// Block scales are preferably serialized as bf16 but can also be
484492
// serialized as fp32 for backwards compatability.
485493
if (qparams->scale_buffer_idx() != 0) {
486-
ConstantDataOffsetPtr scale_buffer_offset =
487-
flatbuffer_graph->constant_data()->Get(
488-
qparams->scale_buffer_idx());
489-
const std::string& data_name =
490-
scale_buffer_offset->named_key()->str();
491-
Result<FreeableBuffer> scale_buffer =
492-
named_data_map->get_data(data_name.c_str());
494+
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
495+
qparams->scale_buffer_idx(),
496+
flatbuffer_graph,
497+
constant_data_ptr,
498+
named_data_map,
499+
freeable_buffers,
500+
weights_cache));
493501
ET_CHECK_OR_RETURN_ERROR(
494-
scale_buffer.ok(),
495-
Internal,
496-
"Failed to get constant data for key %s from named_data_map. Error code: %u",
497-
data_name.c_str(),
498-
static_cast<uint32_t>(scale_buffer.error()));
499-
scale_data =
500-
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
501-
freeable_buffers.push_back(std::move(scale_buffer.get()));
502+
scale_data != nullptr, Internal, "Failed to load scale data.");
502503
scale_numel = qparams->num_scales();
503504
} else {
504505
// Read fp32 scales, convert to bf16.

0 commit comments

Comments
 (0)