Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tensorflow/compiler/mlir/lite/schema/schema_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"

#include <algorithm>
#include <complex>
#include <cstddef>
#include <cstdint>

#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"

namespace tflite {

Expand Down Expand Up @@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) {
op_code->deprecated_builtin_code));
}

size_t TensorTypeGetSize(::tflite::TensorType data_type) {
switch (data_type) {
case ::tflite::TensorType_FLOAT32:
static_assert(sizeof(float) == 4, "");
return 4;
case ::tflite::TensorType_FLOAT16:
static_assert(sizeof(int16_t) == 2, "");
return 2;
case ::tflite::TensorType_INT32:
static_assert(sizeof(int32_t) == 4, "");
return 4;
case ::tflite::TensorType_UINT8:
static_assert(sizeof(uint8_t) == 1, "");
return 1;
case ::tflite::TensorType_INT64:
static_assert(sizeof(int64_t) == 8, "");
return 8;
case ::tflite::TensorType_BOOL:
return sizeof(bool);
case ::tflite::TensorType_INT16:
static_assert(sizeof(int16_t) == 2, "");
return 2;
case ::tflite::TensorType_COMPLEX64:
static_assert(sizeof(std::complex<float>) == 8, "");
return 8;
case ::tflite::TensorType_INT8:
static_assert(sizeof(int8_t) == 1, "");
return 1;
case ::tflite::TensorType_FLOAT64:
static_assert(sizeof(double) == 8, "");
return 8;
case ::tflite::TensorType_COMPLEX128:
static_assert(sizeof(std::complex<double>) == 16, "");
return 16;
case ::tflite::TensorType_UINT64:
static_assert(sizeof(uint64_t) == 8, "");
return 8;
case ::tflite::TensorType_UINT32:
static_assert(sizeof(uint32_t) == 4, "");
return 4;
case ::tflite::TensorType_UINT16:
static_assert(sizeof(uint16_t) == 2, "");
return 2;
default:
return 0;
}
}
} // namespace tflite
7 changes: 7 additions & 0 deletions tensorflow/compiler/mlir/lite/schema/schema_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_

#include <cstddef>

#include "flatbuffers/flatbuffers.h"
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"

Expand All @@ -28,6 +30,11 @@ BuiltinOperator GetBuiltinCode(const OperatorCode *op_code);

BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code);

// Returns the size of the given TensorType in bytes, or 0 if the TensorType is
// not supported, this function should be aligned with TfLiteTypeGetSize in
// lite/kernels/kernel_util.h.
size_t TensorTypeGetSize(::tflite::TensorType data_type);

} // namespace tflite

#endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_
5 changes: 4 additions & 1 deletion tensorflow/lite/kernels/internal/reference/broadcast_to.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_

#include <cstddef>

#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/kernel_util.h"

Expand Down Expand Up @@ -83,7 +85,8 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
// If non-broadcasting, just copy data from input to output tensor.
if (last_broadcast_dim == -1) {
memcpy(output_data, input_data,
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
static_cast<size_t>(unextended_input_shape.FlatSize()) *
static_cast<size_t>(TfLiteTypeGetSize(data_type)));
return;
}

Expand Down
28 changes: 28 additions & 0 deletions tensorflow/lite/kernels/internal/reference/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_

#include <cstdint>
#include <vector>

#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"

namespace tflite {
Expand Down Expand Up @@ -74,6 +81,27 @@ inline void Slice(const tflite::SliceParams& op_params,
return Slice(op_params, input_shape, output_shape, &writer);
}

inline void SliceInt4(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape,
const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
const int num_input_elements = input_shape.FlatSize();
std::vector<int8_t> unpacked_input(num_input_elements);
tensor_utils::UnpackPackedIntToInt8(GetTensorData<int8_t>(input),
num_input_elements, 4,
unpacked_input.data());

const int num_output_elements = output_shape.FlatSize();
std::vector<int8_t> unpacked_output(num_output_elements);

reference_ops::Slice<int8_t>(op_params, input_shape, unpacked_input.data(),
output_shape, unpacked_output.data());

tensor_utils::PackInt8IntoDenseInt(unpacked_output.data(),
num_output_elements, 4,
GetTensorData<int8_t>(output));
}

} // namespace reference_ops
} // namespace tflite

Expand Down
Loading