Skip to content

Add unittest for quantized_relu_out #12499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backends/cadence/hifi/operators/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ namespace impl {
namespace HiFi {
namespace native {

void dequantize_per_tensor_out(
::executorch::runtime::KernelRuntimeContext& ctx,
const ::executorch::aten::Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
::executorch::aten::ScalarType dtype,
::executorch::aten::Tensor& out);

// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
// used in any computation.
void quantize_per_tensor_out(
Expand All @@ -42,6 +52,15 @@ ::executorch::aten::Tensor& div_out_mode(
std::optional<std::string_view> mode,
::executorch::aten::Tensor& out);

void quantized_relu_out(
::executorch::runtime::KernelRuntimeContext& ctx,
const ::executorch::aten::Tensor& input,
const ::executorch::aten::Tensor& in_zero_point,
const int64_t out_zero_point,
const ::executorch::aten::Tensor& out_multiplier,
const ::executorch::aten::Tensor& out_shift,
::executorch::aten::Tensor& output);

void quantized_linear_out(
__ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& in,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>
#include <sys/times.h>

#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/platform/runtime.h>

#include <executorch/backends/cadence/hifi/operators/operators.h>

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
namespace {

using ::executorch::aten::Scalar;
using ::executorch::aten::ScalarType;
using ::executorch::aten::Tensor;
using ::executorch::aten::TensorImpl;
using ::executorch::runtime::Error;
using ::executorch::runtime::KernelRuntimeContext;
using ::executorch::runtime::runtime_init;
using ::executorch::runtime::testing::TensorFactory;
using std::optional;
using std::string_view;

class HiFiDequantizePerTensorTest : public OperatorTest {
public:
protected:
void dequantize_per_tensor_out(
const Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
return ::cadence::impl::HiFi::native::dequantize_per_tensor_out(
context_, input, scale, zero_point, quant_min, quant_max, dtype, out);
}
};

TEST_F(HiFiDequantizePerTensorTest, MultiDimensionalTest) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Char> tf_chars;
const std::vector<int32_t> sizes{2, 3, 5, 6};
Tensor quantized_tensor = tf_chars.full(sizes, -128);
Tensor output_float = tf_float.zeros(sizes);
double dequant_scale = 0.000244140625;
int64_t dequant_zero_point = -128;
int64_t quant_min = -128;
int64_t quant_max = 127;

dequantize_per_tensor_out(
quantized_tensor,
dequant_scale,
dequant_zero_point,
quant_min,
quant_max,
ScalarType::Float,
output_float);

EXPECT_TENSOR_EQ(output_float, tf_float.zeros(sizes));
}

TEST_F(HiFiDequantizePerTensorTest, OneDimensionalTest) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Char> tf_chars;
const std::vector<int32_t> sizes{56};
Tensor quantized_tensor = tf_chars.full(sizes, -128);
Tensor output_float = tf_float.zeros(sizes);
double dequant_scale = 0.000244140625;
int64_t dequant_zero_point = -128;
int64_t quant_min = -128;
int64_t quant_max = 127;

dequantize_per_tensor_out(
quantized_tensor,
dequant_scale,
dequant_zero_point,
quant_min,
quant_max,
ScalarType::Float,
output_float);

EXPECT_TENSOR_EQ(output_float, tf_float.zeros(sizes));
}

} // namespace
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
109 changes: 109 additions & 0 deletions backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>
#include <sys/times.h>

#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/platform/runtime.h>

#include <executorch/backends/cadence/hifi/operators/operators.h>

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
namespace {

using ::executorch::aten::Scalar;
using ::executorch::aten::ScalarType;
using ::executorch::aten::Tensor;
using ::executorch::aten::TensorImpl;
using ::executorch::runtime::Error;
using ::executorch::runtime::KernelRuntimeContext;
using ::executorch::runtime::runtime_init;
using ::executorch::runtime::testing::TensorFactory;
using std::optional;
using std::string_view;

class HiFiQuantizedReluTest : public OperatorTest {
public:
protected:
void quantized_relu_out(
const Tensor& input,
const Tensor& in_zero_point,
const int64_t out_zero_point,
const Tensor& out_multiplier,
const Tensor& out_shift,
Tensor& output) {
return ::cadence::impl::HiFi::native::quantized_relu_out(
context_,
input,
in_zero_point,
out_zero_point,
out_multiplier,
out_shift,
output);
}
};

TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {
TensorFactory<ScalarType::Char> tf_chars;
const std::vector<int32_t> sizes{2, 3, 5, 6};
Tensor quantized_input = tf_chars.full(sizes, -128);
Tensor quantized_output = tf_chars.full(sizes, 100);
Tensor in_zero_point = tf_chars.full({1}, 127);
int64_t out_zero_point = -128;
Tensor out_multiplier =
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);

quantized_relu_out(
quantized_input,
in_zero_point,
out_zero_point,
out_multiplier,
out_shift,
quantized_output);

Tensor expected_output = tf_chars.full(sizes, -128);
EXPECT_TENSOR_EQ(quantized_output, expected_output);
}

TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) {
TensorFactory<ScalarType::Char> tf_chars;
const std::vector<int32_t> sizes{56};
Tensor quantized_input = tf_chars.full(sizes, -128);
Tensor quantized_output = tf_chars.full(sizes, 100);
Tensor in_zero_point = tf_chars.full({1}, 127);
int64_t out_zero_point = -128;
Tensor out_multiplier =
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);

quantized_relu_out(
quantized_input,
in_zero_point,
out_zero_point,
out_multiplier,
out_shift,
quantized_output);

Tensor expected_output = tf_chars.full(sizes, -128);
EXPECT_TENSOR_EQ(quantized_output, expected_output);
}

} // namespace
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
Loading