Skip to content

Commit

Permalink
working cpu prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
EiffL committed Jun 20, 2021
1 parent 016087c commit 849ceab
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 37 deletions.
67 changes: 34 additions & 33 deletions tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/work_sharder.h"

namespace tensorflow {
namespace functor {

SamplingKernelType SamplingKernelTypeFromString(const StringPiece str) {
const string lower_case = absl::AsciiStrToLower(str);
if (lower_case == "lanczos1") return Lanczos1Kernel;
if (lower_case == "lanczos3") return Lanczos3Kernel;
if (lower_case == "lanczos5") return Lanczos5Kernel;
if (lower_case == "gaussian") return GaussianKernel;
if (lower_case == "box") return BoxKernel;
if (lower_case == "triangle") return TriangleKernel;
if (lower_case == "keyscubic") return KeysCubicKernel;
if (lower_case == "mitchellcubic") return MitchellCubicKernel;
return SamplingKernelTypeEnd;
}

} // namespace functor
} // namespace tensorflow



namespace tensorflow {

Expand All @@ -52,7 +72,8 @@ struct Resampler2DFunctor<CPUDevice, T> {
const T zero = static_cast<T>(0.0);
const T one = static_cast<T>(1.0);

auto kernel = tensorflow::functor::CreateTriangleKernel();
// Creating the interpolation kernel
auto kernel = tensorflow::functor::CreateKeysCubicKernel();

auto resample_batches = [&](const int start, const int limit) {
for (int batch_id = start; batch_id < limit; ++batch_id) {
Expand Down Expand Up @@ -93,40 +114,20 @@ struct Resampler2DFunctor<CPUDevice, T> {
const int fx = std::floor(static_cast<float>(x));
const int fy = std::floor(static_cast<float>(y));

// Custom Linear interpolation
if(kernel_type == tensorflow::functor::TriangleKernel){
const int span_size = static_cast<int>(std::ceil(kernel.Radius()));

for (int chan = 0; chan < data_channels; ++chan) {
T res = zero;

for(int inx=-span_size; inx <= span_size; inx++){
for(int iny=-span_size; iny <= span_size; iny++){
const int cx = fx + inx;
const int cy = fy + iny;
const float dx = static_cast<float>(cx) - static_cast<float>(x);
const float dy = static_cast<float>(cy) - static_cast<float>(y);
res += get_data_point(cx, cy, chan) * static_cast<T>(kernel(dx) * kernel(dy));
}
const int span_size = static_cast<int>(std::ceil(kernel.Radius()));
for (int chan = 0; chan < data_channels; ++chan) {
T res = zero;

for(int inx=-span_size; inx <= span_size; inx++){
for(int iny=-span_size; iny <= span_size; iny++){
const int cx = fx + inx;
const int cy = fy + iny;
const float dx = static_cast<float>(cx) - static_cast<float>(x);
const float dy = static_cast<float>(cy) - static_cast<float>(y);
res += get_data_point(cx, cy, chan) * static_cast<T>(kernel(dx) * kernel(dy));
}
set_output(sample_id, chan, res);
}

}else{
const int cx = fx + 1;
const int cy = fy + 1;
const T dx = static_cast<T>(cx) - x;
const T dy = static_cast<T>(cy) - y;

for (int chan = 0; chan < data_channels; ++chan) {
const T img_fxfy = dx * dy * get_data_point(fx, fy, chan);
const T img_cxcy =
(one - dx) * (one - dy) * get_data_point(cx, cy, chan);
const T img_fxcy = dx * (one - dy) * get_data_point(fx, cy, chan);
const T img_cxfy = (one - dx) * dy * get_data_point(cx, fy, chan);
set_output(sample_id, chan,
img_fxfy + img_cxcy + img_fxcy + img_cxfy);
}
set_output(sample_id, chan, res);
}

} else {
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
#define __restrict__ __restrict
#endif

#include <string>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/image/sampling_kernels.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"


namespace tensorflow {
namespace addons {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ REGISTER_OP("Addons>Resampler")
.Input("warp: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("kernel_type: string = 'bilinear'")
.Attr("kernel_type: string = 'triangle'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle data;
ShapeHandle warp;
Expand Down
13 changes: 10 additions & 3 deletions tensorflow_addons/image/resampler_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO

from typing import Optional
from typing import Optional, Type

_resampler_so = LazySO("custom_ops/image/_resampler_ops.so")


@tf.function
def resampler(
data: types.TensorLike, warp: types.TensorLike, name: Optional[str] = None
data: types.TensorLike,
warp: types.TensorLike,
method: Type[tf.image.ResizeMethod] = tf.image.ResizeMethod.BILINEAR,
name: Optional[str] = None
) -> tf.Tensor:
"""Resamples input data at user defined coordinates.
Expand Down Expand Up @@ -54,7 +57,11 @@ def resampler(
with tf.name_scope(name or "resampler"):
data_tensor = tf.convert_to_tensor(data, name="data")
warp_tensor = tf.convert_to_tensor(warp, name="warp")
return _resampler_so.ops.addons_resampler(data_tensor, warp_tensor, 'triangle')
if method == tf.image.ResizeMethod.BILINEAR:
kernel_type = 'triangle'
else:
kernel_type = 'keyscubic'
return _resampler_so.ops.addons_resampler(data_tensor, warp_tensor, kernel_type)


@tf.RegisterGradient("Addons>Resampler")
Expand Down

0 comments on commit 849ceab

Please sign in to comment.