Skip to content

Commit cae59f8

Browse files
Use weight cache for quantized tensor scale data (#14455)
Summary: When enabling the XNNPACK weight cache and running a model with qb4 or qc8-quantized linear weights, it triggers an assertion that is intended to make sure all data is in the weight cache. This can be reproduced by running the XNNPACK backend linear op tests with weight cache enabled. The root cause appears to be that tensor scale data is bypassing the weight cache - likely an oversight. This isn't a correctness issue, but does cause the aforementioned assert to fail and uses marginally more memory than it otherwise needs to. This PR updates the XNNPACK compileModel call to use the weight cache for scale data (instead of putting it in the unpacked_buffers list). With this change, the linear op tests pass with weight cache enabled. Differential Revision: D82862629 Co-authored-by: Gregory Comer <[email protected]>
1 parent 2fecf2c commit cae59f8

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174
data associated with the tensor value, then returns nullptr.
175175
*/
176176
const uint8_t* getConstantDataPtr(
177-
const fb_xnnpack::XNNTensorValue* tensor_value,
177+
uint32_t buffer_idx,
178178
GraphPtr flatbuffer_graph,
179179
const uint8_t* constant_data_ptr,
180180
const NamedDataMap* named_data_map,
181181
std::vector<FreeableBuffer>& freeable_buffers,
182182
XNNWeightsCache* weights_cache) {
183-
auto buffer_idx = tensor_value->constant_buffer_idx();
184183
if (buffer_idx) {
185184
if (!constant_data_ptr) {
186185
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
@@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr(
230229
return nullptr;
231230
}
232231

232+
const uint8_t* getConstantDataPtr(
233+
const fb_xnnpack::XNNTensorValue* tensor_value,
234+
GraphPtr flatbuffer_graph,
235+
const uint8_t* constant_data_ptr,
236+
const NamedDataMap* named_data_map,
237+
std::vector<FreeableBuffer>& freeable_buffers,
238+
XNNWeightsCache* weights_cache) {
239+
return getConstantDataPtr(
240+
tensor_value->constant_buffer_idx(),
241+
flatbuffer_graph,
242+
constant_data_ptr,
243+
named_data_map,
244+
freeable_buffers,
245+
weights_cache);
246+
}
247+
233248
/**
234249
Define serialized tensor value into
235250
the subgraph. While also keeping track of the remapped ids from
@@ -434,22 +449,15 @@ Error defineTensor(
434449
const float* scale = qparams->scale()->data();
435450

436451
if (qparams->scale_buffer_idx() != 0) {
437-
// if scales are stored in named data, then retrieve it
438-
ConstantDataOffsetPtr scale_buffer_offset =
439-
flatbuffer_graph->constant_data()->Get(
440-
qparams->scale_buffer_idx());
441-
const std::string& data_name =
442-
scale_buffer_offset->named_key()->str();
443-
Result<FreeableBuffer> scale_buffer =
444-
named_data_map->get_data(data_name.c_str());
452+
scale = reinterpret_cast<const float*>(getConstantDataPtr(
453+
qparams->scale_buffer_idx(),
454+
flatbuffer_graph,
455+
constant_data_ptr,
456+
named_data_map,
457+
freeable_buffers,
458+
weights_cache));
445459
ET_CHECK_OR_RETURN_ERROR(
446-
scale_buffer.ok(),
447-
Internal,
448-
"Failed to get constant data for key %s from named_data_map. Error code: %u",
449-
data_name.c_str(),
450-
static_cast<uint32_t>(scale_buffer.error()));
451-
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
452-
freeable_buffers.push_back(std::move(scale_buffer.get()));
460+
scale != nullptr, Internal, "Failed to load scale data.");
453461
}
454462
status = xnn_define_channelwise_quantized_tensor_value_v2(
455463
/*subgraph=*/subgraph_ptr,
@@ -483,22 +491,15 @@ Error defineTensor(
483491
// Block scales are preferably serialized as bf16 but can also be
484492
// serialized as fp32 for backwards compatability.
485493
if (qparams->scale_buffer_idx() != 0) {
486-
ConstantDataOffsetPtr scale_buffer_offset =
487-
flatbuffer_graph->constant_data()->Get(
488-
qparams->scale_buffer_idx());
489-
const std::string& data_name =
490-
scale_buffer_offset->named_key()->str();
491-
Result<FreeableBuffer> scale_buffer =
492-
named_data_map->get_data(data_name.c_str());
494+
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
495+
qparams->scale_buffer_idx(),
496+
flatbuffer_graph,
497+
constant_data_ptr,
498+
named_data_map,
499+
freeable_buffers,
500+
weights_cache));
493501
ET_CHECK_OR_RETURN_ERROR(
494-
scale_buffer.ok(),
495-
Internal,
496-
"Failed to get constant data for key %s from named_data_map. Error code: %u",
497-
data_name.c_str(),
498-
static_cast<uint32_t>(scale_buffer.error()));
499-
scale_data =
500-
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
501-
freeable_buffers.push_back(std::move(scale_buffer.get()));
502+
scale_data != nullptr, Internal, "Failed to load scale data.");
502503
scale_numel = qparams->num_scales();
503504
} else {
504505
// Read fp32 scales, convert to bf16.

0 commit comments

Comments
 (0)