Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tflite_micro/python_ops_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddConv2D();
AddCos();
AddCumSum();
AddDecode();
AddDelay();
AddDepthToSpace();
AddDepthwiseConv2D();
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ tflm_kernel_cc_library(
"conv.cc",
"conv_common.cc",
"cumsum.cc",
"decode.cc",
"decode_state.cc",
"decode_state_huffman.cc",
"decode_state_lut.cc",
"decode_state_prune.cc",
"depth_to_space.cc",
"depthwise_conv.cc",
"depthwise_conv_common.cc",
Expand Down Expand Up @@ -327,6 +332,10 @@ tflm_kernel_cc_library(
"batch_matmul.h",
"circular_buffer.h",
"conv.h",
"decode_state.h",
"decode_state_huffman.h",
"decode_state_lut.h",
"decode_state_prune.h",
"depthwise_conv.h",
"dequantize.h",
"ethosu.h",
Expand Down Expand Up @@ -643,6 +652,21 @@ tflm_cc_test(
],
)

tflm_cc_test(
name = "decode_test",
srcs = [
"decode_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

tflm_cc_test(
name = "decompress_test",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/Makefile.inc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
Expand Down
187 changes: 187 additions & 0 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/decode_state.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {
namespace {

TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
size_t tensor_output_index,
TfLiteTensor* output) {
// If alternate decompression memory is available, set the tensor data
// pointer now to preclude allocation by the memory planner.
void* alternate_decompress_mem =
GetMicroContext(context)->AllocateDecompressionMemory(
output->bytes, MicroArenaBufferAlignment());
if (alternate_decompress_mem != nullptr) {
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, tensor_output_index);
TF_LITE_ENSURE(context, output_eval != nullptr);
output_eval->data.data = alternate_decompress_mem;
output->data.data = alternate_decompress_mem;
}

return kTfLiteOk;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
TF_LITE_ENSURE(context, num_outputs > 0);
TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs * 2);

MicroContext* const micro_context = GetMicroContext(context);

node->user_data = micro_context->AllocatePersistentBuffer(
num_outputs * sizeof(DecodeState*));
TF_LITE_ENSURE(context, node->user_data != nullptr);
DecodeState** const dsp_arr =
reinterpret_cast<DecodeState**>(node->user_data);

TfLiteTensor* input = nullptr;
TfLiteTensor* ancillary = nullptr;
TfLiteTensor* output = nullptr;
TfLiteStatus status = kTfLiteOk;

micro_context->ResetDecompressionMemoryAllocations();

for (size_t i = 0; i < num_inputs; i += 2) {
input = micro_context->AllocateTempInputTensor(node, i);
if (input == nullptr) {
MicroPrintf("failed to allocate input tensor %u", i);
status = kTfLiteError;
break;
}
ancillary = micro_context->AllocateTempInputTensor(node, i + 1);
if (ancillary == nullptr) {
MicroPrintf("failed to allocate ancillary tensor %u", i + 1);
status = kTfLiteError;
break;
}
output = micro_context->AllocateTempOutputTensor(node, i / 2);
if (output == nullptr) {
MicroPrintf("failed to allocate output tensor %u", i / 2);
status = kTfLiteError;
break;
}

TF_LITE_ENSURE(context, IsConstantTensor(input));
TF_LITE_ENSURE(context, IsConstantTensor(ancillary));

if (DecodeState::Version(*ancillary) != 1) {
MicroPrintf("version %u != 1", DecodeState::Version(*ancillary));
status = kTfLiteError;
break;
}

DecodeState* dsp = nullptr;
switch (DecodeState::Type(*ancillary)) {
case DecodeState::kDcmTypeLUT:
dsp = DecodeState::CreateDecodeStateLUT(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypePrune:
dsp = DecodeState::CreateDecodeStatePrune(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeHuffman:
dsp = DecodeState::CreateDecodeStateHuffman(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
break;
}

status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}

if (dsp != nullptr) {
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
}
dsp_arr[i / 2] = dsp;
} else {
MicroPrintf("failed to allocate DecodeState[%u]", i / 2);
status = kTfLiteError;
break;
}

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(ancillary);
micro_context->DeallocateTempTfLiteTensor(output);
input = nullptr;
ancillary = nullptr;
output = nullptr;
}

if (input != nullptr) {
micro_context->DeallocateTempTfLiteTensor(input);
}
if (ancillary != nullptr) {
micro_context->DeallocateTempTfLiteTensor(ancillary);
}
if (output != nullptr) {
micro_context->DeallocateTempTfLiteTensor(output);
}

return status;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
DecodeState** const dsp_arr =
reinterpret_cast<DecodeState**>(node->user_data);

for (size_t i = 0; i < num_inputs; i += 2) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, i);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteEvalTensor* ancillary =
tflite::micro::GetEvalInput(context, node, i + 1);
TF_LITE_ENSURE(context, ancillary != nullptr);
const TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, i / 2);
TF_LITE_ENSURE(context, output != nullptr);

TfLiteStatus status = dsp_arr[i / 2]->Decode(*input, *ancillary, *output);
TF_LITE_ENSURE(context, status == kTfLiteOk);
}

return kTfLiteOk;
}

} // namespace

TFLMRegistration Register_DECODE() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}

} // namespace tflite
64 changes: 64 additions & 0 deletions tensorflow/lite/micro/kernels/decode_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/micro/kernels/decode_state.h"

#include "tensorflow/lite/micro/kernels/decode_state_huffman.h"
#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
#include "tensorflow/lite/micro/micro_context.h"

namespace tflite {

DecodeState* DecodeState::CreateDecodeStateLUT(
const TfLiteContext* context, MicroProfilerInterface* profiler) {
MicroContext* const micro_context = GetMicroContext(context);
void* buffer =
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateLUT));
if (buffer == nullptr) {
return nullptr;
}
DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler);

return dsp;
}

DecodeState* DecodeState::CreateDecodeStatePrune(
const TfLiteContext* context, MicroProfilerInterface* profiler) {
MicroContext* const micro_context = GetMicroContext(context);
void* buffer =
micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune));
if (buffer == nullptr) {
return nullptr;
}
DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler);

return dsp;
}

DecodeState* DecodeState::CreateDecodeStateHuffman(
const TfLiteContext* context, MicroProfilerInterface* profiler) {
MicroContext* const micro_context = GetMicroContext(context);
void* buffer =
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateHuffman));
if (buffer == nullptr) {
return nullptr;
}
DecodeState* dsp = new (buffer) DecodeStateHuffman(context, profiler);

return dsp;
}

} // namespace tflite
Loading
Loading