diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 13f160ca53c..a8873dbc955 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -47,6 +47,7 @@ PythonOpsResolver::PythonOpsResolver() { AddDequantize(); AddDetectionPostprocess(); AddDiv(); + AddDynamicUpdateSlice(); AddElu(); AddEmbeddingLookup(); AddEnergy(); diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 2009d99ca85..79b99cee1a9 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -262,6 +262,7 @@ tflm_kernel_cc_library( "dequantize_common.cc", "detection_postprocess.cc", "div.cc", + "dynamic_update_slice.cc", "elementwise.cc", "elu.cc", "embedding_lookup.cc", @@ -352,6 +353,7 @@ tflm_kernel_cc_library( "decode_state_prune.h", "depthwise_conv.h", "dequantize.h", + "dynamic_update_slice.h", "ethosu.h", "fully_connected.h", "hard_swish.h", @@ -824,6 +826,21 @@ tflm_cc_test( ], ) +tflm_cc_test( + name = "dynamic_update_slice_test", + srcs = [ + "dynamic_update_slice_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_test( name = "elementwise_test", srcs = ["elementwise_test.cc"], diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 6b6a3a72974..daa8c2b95d6 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -131,6 +131,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup_test.cc \ diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc new file mode 100644 index 00000000000..42ccdd14842 --- /dev/null +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc @@ -0,0 +1,231 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/micro/kernels/dynamic_update_slice.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { + +constexpr int kMaxDimensions = RuntimeShape::kMaxSmallSize; + +namespace { + +void CalculateClampedStartIndices(int num_dims, const int64_t* raw_indices_data, + const int32_t* input_dims_data, + const int32_t* update_dims_data, + int32_t* clamped_start_indices_output) { + for (int i = 0; i < num_dims; ++i) { + clamped_start_indices_output[i] = static_cast( + std::min(std::max(0, raw_indices_data[i]), + input_dims_data[i] - update_dims_data[i])); + } + return; +} + +// Recursive helper for N-dimensional slice update. +template +void UpdateSliceRecursive(int current_dim, int max_dims, + const int32_t* output_strides, + const int32_t* update_strides, + const int32_t* update_dims_data, + const T* update_tensor_data, + const int32_t* clamped_start_indices, + T* output_tensor_data) { + if (current_dim == max_dims) return; + output_tensor_data += + clamped_start_indices[current_dim] * output_strides[current_dim]; + if (current_dim == max_dims - 1) { + std::memcpy(output_tensor_data, update_tensor_data, + update_dims_data[max_dims - 1] * sizeof(T)); + } else { + for (int i = 0; i < update_dims_data[current_dim]; ++i) { + UpdateSliceRecursive(current_dim + 1, max_dims, output_strides, + update_strides, update_dims_data, + update_tensor_data, clamped_start_indices, + output_tensor_data); + output_tensor_data += output_strides[current_dim]; + update_tensor_data += update_strides[current_dim]; + } + } +} + +// Main dispatch function for Eval, templated on data type. +template +void EvalImpl(const TfLiteEvalTensor* operand_eval, + const TfLiteEvalTensor* update_eval, const int64_t* indices_eval, + TfLiteEvalTensor* output_eval) { + const RuntimeShape operand_shape = + tflite::micro::GetTensorShape(operand_eval); + const RuntimeShape update_shape = tflite::micro::GetTensorShape(update_eval); + const T* update_tensor_data = tflite::micro::GetTensorData(update_eval); + T* output_tensor_data = tflite::micro::GetTensorData(output_eval); + + const int num_dims = operand_shape.DimensionsCount(); + if (operand_shape.FlatSize() == update_shape.FlatSize()) { + std::memcpy(output_tensor_data, update_tensor_data, + ElementCount(*operand_eval->dims) * sizeof(T)); + return; + } + + // If the operation is not done in-place, copy the input data to the output. + if (operand_eval->data.data != output_eval->data.data) { + std::memcpy(output_eval->data.data, operand_eval->data.data, + ElementCount(*operand_eval->dims) * sizeof(T)); + } + + // If update tensor is empty, no actual update is needed after operand copy. + if (ElementCount(*update_eval->dims) == 0) { + return; + } + + // Calculate clamped start indices (stack-allocated) + int32_t clamped_start_indices[kMaxDimensions]; + CalculateClampedStartIndices(num_dims, indices_eval, operand_shape.DimsData(), + update_shape.DimsData(), clamped_start_indices); + + // Calculate strides (stack-allocated) + int32_t output_stride[kMaxDimensions]; + int32_t update_stride[kMaxDimensions]; + output_stride[num_dims - 1] = 1; + update_stride[num_dims - 1] = 1; + for (int i = num_dims - 2; i >= 0; --i) { + output_stride[i] = output_stride[i + 1] * operand_shape.Dims(i + 1); + update_stride[i] = update_stride[i + 1] * update_shape.Dims(i + 1); + } + + // Perform the N-dimensional update + // The recursive function needs base pointers and initial offsets. + UpdateSliceRecursive( + /*current_dim=*/0, num_dims, output_stride, update_stride, + update_shape.DimsData(), update_tensor_data, clamped_start_indices, + output_tensor_data); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Use MicroContext to allocate temporary tensors for inspection + // This is a robust pattern shown in EMBEDDING_LOOKUP. + TfLiteTensor* operand = micro_context->AllocateTempInputTensor( + node, kDynamicUpdateSliceOperandTensor); + TF_LITE_ENSURE(context, operand != nullptr); + + TfLiteTensor* update = micro_context->AllocateTempInputTensor( + node, kDynamicUpdateSliceUpdateTensor); + TF_LITE_ENSURE(context, update != nullptr); + + TfLiteTensor* start_indices = micro_context->AllocateTempInputTensor( + node, kDynamicUpdateSliceStartIndicesTensor); + TF_LITE_ENSURE(context, start_indices != nullptr); + + TfLiteTensor* output = micro_context->AllocateTempOutputTensor( + node, kDynamicUpdateSliceOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + + // Type checks + TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type); + TF_LITE_ENSURE(context, start_indices->type == kTfLiteInt32 || + start_indices->type == kTfLiteInt64); + + TF_LITE_ENSURE_EQ(context, NumDimensions(start_indices), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(start_indices, 0), + NumDimensions(operand)); + + TF_LITE_ENSURE_EQ(context, NumDimensions(update), NumDimensions(operand)); + // Check that update dimensions are not larger than operand dimensions + for (int i = 0; i < NumDimensions(operand); ++i) { + TF_LITE_ENSURE(context, + SizeOfDimension(update, i) <= SizeOfDimension(operand, i)); + } + + // Deallocate temporary tensors + micro_context->DeallocateTempTfLiteTensor(operand); + micro_context->DeallocateTempTfLiteTensor(update); + micro_context->DeallocateTempTfLiteTensor(start_indices); + micro_context->DeallocateTempTfLiteTensor(output); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* operand_eval = tflite::micro::GetEvalInput( + context, node, kDynamicUpdateSliceOperandTensor); + const TfLiteEvalTensor* update_eval = tflite::micro::GetEvalInput( + context, node, kDynamicUpdateSliceUpdateTensor); + const TfLiteEvalTensor* indices_eval = tflite::micro::GetEvalInput( + context, node, kDynamicUpdateSliceStartIndicesTensor); + TfLiteEvalTensor* output_eval = tflite::micro::GetEvalOutput( + context, node, kDynamicUpdateSliceOutputTensor); + + const auto& input_shape = tflite::micro::GetTensorShape(operand_eval); + const int input_dims = input_shape.DimensionsCount(); + int64_t indices_data_i64[kMaxDimensions]; + if (indices_eval->type == kTfLiteInt32) { + for (int i = 0; i < input_dims; i++) + indices_data_i64[i] = static_cast(indices_eval->data.i32[i]); + } else if (indices_eval->type == kTfLiteInt64) { + for (int i = 0; i < input_dims; i++) + indices_data_i64[i] = indices_eval->data.i64[i]; + } else { + TF_LITE_KERNEL_LOG(context, + "DynamicUpdateSlice only currently supports " + "int32 or int64 indices type, got %d.", + indices_eval->type); + return kTfLiteError; + } + // Dispatch based on tensor type + switch (operand_eval->type) { + case kTfLiteFloat32: + EvalImpl(operand_eval, update_eval, indices_data_i64, output_eval); + break; + case kTfLiteInt8: + EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); + break; + case kTfLiteInt16: + EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); + break; + case kTfLiteInt32: + EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); + break; + default: + MicroPrintf("DYNAMIC_UPDATE_SLICE: Operand type %s not supported.", + TfLiteTypeGetName(operand_eval->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_DYNAMIC_UPDATE_SLICE() { + return tflite::micro::RegisterOp(/*init=*/nullptr, /*prepare=*/Prepare, + /*invoke=*/Eval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.h b/tensorflow/lite/micro/kernels/dynamic_update_slice.h new file mode 100644 index 00000000000..89546110b72 --- /dev/null +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.h @@ -0,0 +1,36 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/micro/micro_common.h" + +namespace tflite { + +constexpr int kDynamicUpdateSliceOperandTensor = 0; +constexpr int kDynamicUpdateSliceUpdateTensor = 1; +constexpr int kDynamicUpdateSliceStartIndicesTensor = 2; +constexpr int kDynamicUpdateSliceOutputTensor = 0; + +TfLiteStatus PrepareDynamicUpdateSlice(TfLiteContext* context, + TfLiteNode* node); + +TFLMRegistration Register_DYNAMIC_UPDATE_SLICE(); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc new file mode 100644 index 00000000000..0bfd6c89740 --- /dev/null +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +constexpr int kNumInputs = 3; +constexpr int kNumOutputs = 1; +constexpr int kInputTensorIndex_0 = 0; +constexpr int kInputTensorIndex_1 = 1; +constexpr int kInputTensorIndex_2 = 2; +constexpr int kOutputTensorIndex = 3; + +void ExecuteDynamicUpdateSliceTest(TfLiteTensor* tensors, int tensors_count) { + int kInputArrayData[] = {kNumInputs, kInputTensorIndex_0, kInputTensorIndex_1, + kInputTensorIndex_2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); + int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex}; + TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); + + const TFLMRegistration registration = tflite::Register_DYNAMIC_UPDATE_SLICE(); + micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array, + outputs_array, nullptr); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); +} + +template +void TestDynamicUpdateSlice(int* input_dims_data[kNumInputs], + const T* input_data_0, const T* input_data_1, + const U* input_data_2, const T* golden_data, + int* expected_dims, T* output_data) { + TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]); + TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]); + TfLiteIntArray* input_dims_2 = IntArrayFromInts(input_dims_data[2]); + TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims); + const int output_count = ElementCount(*output_dims); + + TfLiteTensor tensors[] = { + CreateTensor(input_data_0, input_dims_0), + CreateTensor(input_data_1, input_dims_1), + CreateTensor(input_data_2, input_dims_2), + CreateTensor(output_data, output_dims), + }; + constexpr int tensors_count = std::extent::value; + ExecuteDynamicUpdateSliceTest(tensors, tensors_count); + + // check output data against expected + for (int i = 0; i < output_count; i++) { + TF_LITE_MICRO_EXPECT_EQ(golden_data[i], output_data[i]); + } + + // check output dimensions (relocated) against original dimensions + TF_LITE_MICRO_EXPECT_EQ(output_dims->size, + tensors[kOutputTensorIndex].dims->size); + for (int i = 0; i < output_dims->size; i++) { + TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i], + tensors[kOutputTensorIndex].dims->data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr float kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr float kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8) { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int8_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int8_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int8_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int16_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int16_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int16_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int16_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int32_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int32_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int32_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int32_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int8_t kInput_1[] = {-1, -2}; + constexpr int64_t kInput_2[] = {1, 1}; + constexpr int8_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int8_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 2}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr float kInput_1[] = {-1, -2, -3, -4}; + constexpr int32_t kInput_2[] = {2, 2}; + constexpr float kExpect[] = {1, 2, 3, 4, -1, -2, 7, -3, -4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 264af300a02..4f8c3c068f1 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -58,6 +58,7 @@ TFLMRegistration Register_DEPTH_TO_SPACE(); TFLMRegistration Register_DEPTHWISE_CONV_2D(); TFLMRegistration Register_DEQUANTIZE(); TFLMRegistration Register_DIV(); +TFLMRegistration Register_DYNAMIC_UPDATE_SLICE(); TFLMRegistration Register_ELU(); TFLMRegistration Register_EMBEDDING_LOOKUP(); TFLMRegistration Register_EQUAL(); diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index cf28f8ccf2c..c5540ea669a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -251,6 +251,11 @@ class MicroMutableOpResolver : public MicroOpResolver { return AddBuiltin(BuiltinOperator_DIV, registration, ParseDiv); } + TfLiteStatus AddDynamicUpdateSlice() { + return AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE, + Register_DYNAMIC_UPDATE_SLICE(), ParseDynamicUpdateSlice); + } + TfLiteStatus AddEmbeddingLookup( const TFLMRegistration& registration = Register_EMBEDDING_LOOKUP()) { return AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, registration, diff --git a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h index 42063dcca7e..71cea217285 100644 --- a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h +++ b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h @@ -23,7 +23,7 @@ limitations under the License. namespace tflite { -using TflmOpResolver = MicroMutableOpResolver<115>; +using TflmOpResolver = MicroMutableOpResolver<116>; inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddAbs()); @@ -52,6 +52,7 @@ inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddDequantize()); TF_LITE_ENSURE_STATUS(op_resolver.AddDetectionPostprocess()); TF_LITE_ENSURE_STATUS(op_resolver.AddDiv()); + TF_LITE_ENSURE_STATUS(op_resolver.AddDynamicUpdateSlice()); TF_LITE_ENSURE_STATUS(op_resolver.AddElu()); TF_LITE_ENSURE_STATUS(op_resolver.AddEmbeddingLookup()); TF_LITE_ENSURE_STATUS(op_resolver.AddEnergy()); diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index b77bf010dbb..21f21a1ce05 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -400,6 +400,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/detection_postprocess.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dynamic_update_slice.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup.cc \