Skip to content

Added grid_sample backward batch rule #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions codegen/codegen_outofplacebatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def parse_return(return_t):
return tuple([x.strip() for x in m.group(1).split(',')])

def parse_args(args_t):
args = args_t.split(',')
# There is an assumption made that args are separated with comma-space
# and types like std::array<bool,2> do not contain spaces after the comma
args = args_t.split(', ')
result = []
for arg in args:
split_idx = arg.rfind(' ')
Expand All @@ -172,8 +174,6 @@ def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', includ
for line in lines:
if 'void' in line:
continue
if 'std::array' in line:
continue
m = re.match(r'(.*) \w+\((.*)\); // {"schema": "aten::(\w+\.?\w*)\(.*', line)
if m is None:
continue
Expand Down
127 changes: 127 additions & 0 deletions functorch/csrc/BatchRulesModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,86 @@ grid_sample_batch_rule(const Tensor& input, optional<int64_t> input_bdim, const
return result;
}

std::tuple<Tensor, Tensor, Tensor, int64_t>
grid_sample_backward_helper_in(
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& grid, optional<int64_t> grid_bdim) {

auto batch_size = get_bdim_size3(
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);

auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
grad_output_ = ensure_has_bdim(grad_output_, grad_output_bdim.has_value(), batch_size);
grad_output_ = reshape_dim_into(0, 0, grad_output_);

auto input_ = moveBatchDimToFront(input, input_bdim);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
input_ = reshape_dim_into(0, 0, input_);

auto grid_ = moveBatchDimToFront(grid, grid_bdim);
grid_ = ensure_has_bdim(grid_, grid_bdim.has_value(), batch_size);
grid_ = reshape_dim_into(0, 0, grid_);

return std::make_tuple(grad_output_, input_, grid_, batch_size);
}

std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
grid_sample_backward_helper_out(
const std::tuple<Tensor, Tensor> & bw_out,
optional<int64_t> grad_input_out_bdim,
optional<int64_t> grad_grid_out_bdim,
int64_t bdim_size) {
auto grad_input = std::get<0>(bw_out);
auto grad_grid = std::get<1>(bw_out);
grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
return result;
}


template<typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
grid_sample_backward_batch_rule(
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& grid, optional<int64_t> grid_bdim,
ExtraArgs... extra_args) {

auto new_bw_input = grid_sample_backward_helper_in(
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);

auto new_grad_output = std::get<0>(new_bw_input);
auto new_input = std::get<1>(new_bw_input);
auto new_grid = std::get<2>(new_bw_input);
int64_t batch_size = std::get<3>(new_bw_input);

auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);

return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size);
}

template<typename F, F Func>
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
cudnn_grid_sample_backward_batch_rule(
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& grid, optional<int64_t> grid_bdim,
const Tensor& grad_output, optional<int64_t> grad_output_bdim) {

auto new_bw_input = grid_sample_backward_helper_in(
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);

auto new_grad_output = std::get<0>(new_bw_input);
auto new_input = std::get<1>(new_bw_input);
auto new_grid = std::get<2>(new_bw_input);
int64_t bdim_size = std::get<3>(new_bw_input);

auto bw_out = Func(new_input, new_grid, new_grad_output);

return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size);
}

std::tuple<Tensor, optional<int64_t>> cross_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim,
Expand Down Expand Up @@ -370,12 +450,53 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
}
};

template <typename A, A a, typename C>
struct GridSampleBackwardBatchRuleHelper;

template <typename F, F Func, typename T1, typename T2, typename T3, typename... T>
struct GridSampleBackwardBatchRuleHelper<F, Func, typelist<T1, T2, T3, T...>> {
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim,
const Tensor& input, optional<int64_t> input_batch_dim,
const Tensor& grid, optional<int64_t> grid_batch_dim,
T... extra_args) {
return grid_sample_backward_batch_rule<F, Func, T...>(
grad_output, grad_output_batch_dim,
input, input_batch_dim,
grid, grid_batch_dim,
std::forward<T>(extra_args)...);
}
};

template <typename F, F Func>
struct CudnnGridSampleBackwardBatchRuleHelper {
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
const Tensor& input, optional<int64_t> input_batch_dim,
const Tensor& grid, optional<int64_t> grid_batch_dim,
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim) {
return cudnn_grid_sample_backward_batch_rule<F, Func>(
input, input_batch_dim,
grid, grid_batch_dim,
grad_output, grad_output_batch_dim
);
}
};

#define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\
GridSampleBatchRuleHelper<\
decltype(&ATEN_FN(fn)),\
&ATEN_FN(fn),\
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)

#define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\
GridSampleBackwardBatchRuleHelper<\
decltype(&ATEN_FN(fn)),\
&ATEN_FN(fn),\
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)

#define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\
CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN(fn)), &ATEN_FN(fn)>::apply

#define UPSAMPLE_BACKWARD(op, overload) VMAP_SUPPORT(#op"."#overload, SINGLE_ARG(\
UpsampleBackwardBatchRuleHelper<\
decltype(&ATEN_FN2(op, overload)),\
Expand All @@ -386,6 +507,7 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
EXISTING_BDIM2(op, vec); \
EXISTING_BDIM(op);


TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("convolution", convolution_batch_rule);
// m.impl("conv_transpose2d", convNd_transpose_decomp);
Expand All @@ -400,7 +522,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
EXISTING_BDIM(im2col_backward);

VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
VMAP_SUPPORT("grid_sampler_2d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));

VMAP_SUPPORT("grid_sampler_3d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
VMAP_SUPPORT("grid_sampler_3d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward));
VMAP_SUPPORT("cudnn_grid_sampler_backward", CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward));
Comment on lines +528 to +529
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do both these get exercised in the tests?

Copy link
Contributor Author

@vfdev-5 vfdev-5 Nov 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, cudnn_grid_sampler_backward is tested by test/test_ops.py::TestOperatorsCUDA::test_vmapvjp_has_batch_rule_nn_functional_grid_sample_cuda_float32
and grid_sampler_3d_backward is tested by one of test/test_ops.py::TestOperatorsCPU::test_vmapvjp_has_batch_rule_nn_functional_grid_sample_cpu_float32


VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
VMAP_SUPPORT("cross", cross_batch_rule);

Expand Down
36 changes: 4 additions & 32 deletions functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,6 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
return self;
}

int64_t bdim_size(
const Tensor& a, optional<int64_t> a_bdim,
const Tensor& b, optional<int64_t> b_bdim,
const Tensor& c, optional<int64_t> c_bdim) {
if (a_bdim) {
return a.size(*a_bdim);
}
if (b_bdim) {
return b.size(*b_bdim);
}
if (c_bdim) {
return c.size(*c_bdim);
}
TORCH_INTERNAL_ASSERT(false);
}

int64_t bdim_size(
const Tensor& a, optional<int64_t> a_bdim,
const Tensor& b, optional<int64_t> b_bdim) {
if (a_bdim) {
return a.size(*a_bdim);
}
if (b_bdim) {
return b.size(*b_bdim);
}
TORCH_INTERNAL_ASSERT(false);
}

namespace {

template<typename Func, typename ...Args>
Expand All @@ -197,7 +169,7 @@ std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
const Scalar& value, Args... args) {
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto batch_size = bdim_size(self, self_bdim, index, index_bdim);
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);

auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
Expand Down Expand Up @@ -230,7 +202,7 @@ inline std::tuple<Tensor,optional<int64_t>> scatter_batch_rule(
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
auto batch_size = bdim_size(self, self_bdim, index, index_bdim, src, src_bdim);
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim);

auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
Expand Down Expand Up @@ -314,7 +286,7 @@ std::tuple<Tensor,optional<int64_t>> gather_batch_rule(
bool sparse_grad) {
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto batch_size = bdim_size(self, self_bdim, index, index_bdim);
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);

auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
Expand Down Expand Up @@ -343,7 +315,7 @@ std::tuple<Tensor,optional<int64_t>> gather_backward_batch_rule(
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
bool sparse_grad) {
auto batch_size = bdim_size(grad, grad_bdim, self, self_bdim, index, index_bdim);
auto batch_size = get_bdim_size3(grad, grad_bdim, self, self_bdim, index, index_bdim);
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
Expand Down
Loading