Skip to content

Commit d500cb9

Browse files
committed
Added grid_sample backward batch rule
Description: - Added grid_sample backward batch rule: CPU and CUDA - Updated tests Notes: I had to expand on dim 0 in most of the cases and could not use tricks like in forward pass when batch dim is merged either with channel or H_out due to wrong grid grads in these cases
1 parent f16519d commit d500cb9

File tree

4 files changed

+2150
-1093
lines changed

4 files changed

+2150
-1093
lines changed

codegen/codegen_outofplacebatching.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def parse_return(return_t):
155155
return tuple([x.strip() for x in m.group(1).split(',')])
156156

157157
def parse_args(args_t):
158-
args = args_t.split(',')
158+
args = args_t.split(', ')
159159
result = []
160160
for arg in args:
161161
split_idx = arg.rfind(' ')
@@ -170,8 +170,6 @@ def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', includ
170170
for line in lines:
171171
if 'void' in line:
172172
continue
173-
if 'std::array' in line:
174-
continue
175173
m = re.match(r'(.*) \w+\((.*)\); // {"schema": "aten::(\w+\.?\w*)\(.*', line)
176174
if m is None:
177175
continue

functorch/csrc/BatchRulesModules.cpp

+248
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,202 @@ grid_sample_batch_rule(const Tensor& input, optional<int64_t> input_bdim, const
255255
return result;
256256
}
257257

258+
Tensor expand_reshape_dim_into(int64_t batch_size, int64_t dst, const Tensor& x) {
259+
auto x_ = x.unsqueeze(0);
260+
VmapDimVector new_shape(x_.sizes().begin(), x_.sizes().end());
261+
new_shape[0] = batch_size;
262+
x_ = x_.expand(new_shape);
263+
return reshape_dim_into(0, dst, x_);
264+
}
265+
266+
267+
std::tuple<Tensor, Tensor, optional<int64_t>, Tensor, optional<int64_t>, int64_t>
268+
grid_sample_backward_helper_in(
269+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
270+
const Tensor& input, optional<int64_t> input_bdim,
271+
const Tensor& grid, optional<int64_t> grid_bdim) {
272+
auto new_grad_output = grad_output;
273+
auto new_input = input;
274+
auto new_grid = grid;
275+
276+
optional<int64_t> grad_input_out_bdim = nullopt;
277+
optional<int64_t> grad_grid_out_bdim = nullopt;
278+
int64_t bdim_size = 0;
279+
280+
if (grad_output_bdim) {
281+
282+
bdim_size = grad_output.sizes()[*grad_output_bdim];
283+
284+
if (input_bdim && grid_bdim) {
285+
// case 1: (grad_output is batched, input is batched, grid is batched)
286+
// grad_output: (BN)CH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
287+
// grad_input: (BN)CH_{in}W_{in}
288+
289+
new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
290+
new_input = reshape_dim_into(*input_bdim, 0, input);
291+
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
292+
grad_input_out_bdim = 0;
293+
grad_grid_out_bdim = 0;
294+
} else if (input_bdim && !grid_bdim) {
295+
// case 2: (grad_output is batched, input is batched, grid is not batched)
296+
// IF PUT BATCH DIM TO CHANNEL -> backward produces wrong grad_grid
297+
//
298+
// grad_output: (BN)CH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: NH_{out}W_{out}2
299+
// -> grid: (BN)H_{out}W_{out}2
300+
// grad_input: (BN)CH_{in}W_{in}
301+
302+
new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
303+
new_input = reshape_dim_into(*input_bdim, 0, input);
304+
grad_input_out_bdim = 0;
305+
new_grid = expand_reshape_dim_into(bdim_size, 0, grid);
306+
grad_grid_out_bdim = 0;
307+
} else if (!input_bdim && grid_bdim) {
308+
// case 3: (grad_output is batched, input is not batched, grid is batched)
309+
// IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
310+
//
311+
// grad_output: (BN)CH_{out}W_{out}, input: NCH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
312+
// -> input: (BN)CH_{in}W_{in}
313+
// grad_input: (BN)CH_{in}W_{in}
314+
315+
new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
316+
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
317+
grad_grid_out_bdim = 0;
318+
// expand input to (BN)H_{out}W_{out}2
319+
new_input = expand_reshape_dim_into(bdim_size, 0, new_input);
320+
grad_input_out_bdim = 0;
321+
} else {
322+
// case 4: (grad_output is batched, input is not batched, grid is not batched)
323+
// IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
324+
//
325+
// grad_output: (BN)CH_{out}W_{out}, input: NCH_{in}W_{in}, grid: NH_{out}W_{out}2
326+
// -> grid: (BN)H_{out}W_{out}2
327+
// -> input: (BN)CH_{in}W_{in}
328+
// grad_input: NCH_{in}W_{in}
329+
330+
new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
331+
// expand grid to (BN)H_{out}W_{out}2
332+
new_grid = expand_reshape_dim_into(bdim_size, 0, grid);
333+
grad_grid_out_bdim = 0;
334+
// expand input to (BN)CH_{in}W_{in}
335+
new_input = expand_reshape_dim_into(bdim_size, 0, input);
336+
grad_input_out_bdim = 0;
337+
}
338+
} else {
339+
if (input_bdim && grid_bdim) {
340+
// case 5: (grad_output is not batched, input is batched, grid is batched)
341+
// grad_output: NCH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
342+
// -> grad_output: (BN)CH_{out}W_{out}
343+
// grad_input: (BN)CH_{in}W_{in}
344+
345+
bdim_size = input.sizes()[*input_bdim];
346+
// expand new_grad_output to (BN)CH_{out}W_{out}
347+
new_grad_output = expand_reshape_dim_into(bdim_size, 0, new_grad_output);
348+
new_input = reshape_dim_into(*input_bdim, 0, input);
349+
grad_input_out_bdim = 0;
350+
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
351+
grad_grid_out_bdim = 0;
352+
} else if (input_bdim && !grid_bdim) {
353+
// case 6: (grad_output is not batched, input is batched, grid is not batched)
354+
// grad_output: NCH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: NH_{out}W_{out}2
355+
// -> grad_output: (BN)CH_{out}W_{out}
356+
// -> grid: (BN)H_{out}W_{out}2
357+
// grad_input: (BN)CH_{in}W_{in}
358+
359+
bdim_size = input.sizes()[*input_bdim];
360+
// expand new_grad_output to (BN)CH_{out}W_{out}
361+
new_grad_output = expand_reshape_dim_into(bdim_size, 0, new_grad_output);
362+
new_input = reshape_dim_into(*input_bdim, 0, input);
363+
grad_input_out_bdim = 0;
364+
// expand new_grid to (BN)H_{out}W_{out}2
365+
new_grid = expand_reshape_dim_into(bdim_size, 0, grid);
366+
grad_grid_out_bdim = 0;
367+
} else if (!input_bdim && grid_bdim) {
368+
// case 7: (grad_output is not batched, input is not batched, grid is batched)
369+
// IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
370+
//
371+
// grad_output: NCH_{out}W_{out}, input: NCH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
372+
// -> grad_output: (BN)CH_{out}W_{out}
373+
// -> input: (BN)CH_{out}W_{out}
374+
// grad_input: NCH_{in}W_{in}
375+
376+
bdim_size = grid.sizes()[*grid_bdim];
377+
// expand new_grad_output to NC(BH_{out})W_{out}
378+
new_grad_output = expand_reshape_dim_into(bdim_size, 0, new_grad_output);
379+
// expand new_input to (BN)CH_{in}W_{in}
380+
new_input = expand_reshape_dim_into(bdim_size, 0, new_input);
381+
grad_input_out_bdim = 0;
382+
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
383+
grad_grid_out_bdim = 0;
384+
} // case 8 can be ignored
385+
}
386+
return std::make_tuple(
387+
new_grad_output, new_input, grad_input_out_bdim, new_grid, grad_grid_out_bdim, bdim_size);
388+
}
389+
390+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
391+
grid_sample_backward_helper_out(
392+
const std::tuple<Tensor, Tensor> & bw_out,
393+
optional<int64_t> grad_input_out_bdim,
394+
optional<int64_t> grad_grid_out_bdim,
395+
int64_t bdim_size) {
396+
auto grad_input = std::get<0>(bw_out);
397+
auto grad_grid = std::get<1>(bw_out);
398+
if (grad_input_out_bdim) {
399+
grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
400+
}
401+
if (grad_grid_out_bdim) {
402+
grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
403+
}
404+
auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
405+
return result;
406+
}
407+
408+
409+
template<typename F, F Func, typename... ExtraArgs>
410+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
411+
grid_sample_backward_batch_rule(
412+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
413+
const Tensor& input, optional<int64_t> input_bdim,
414+
const Tensor& grid, optional<int64_t> grid_bdim,
415+
ExtraArgs... extra_args) {
416+
417+
auto new_bw_input = grid_sample_backward_helper_in(
418+
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
419+
420+
auto new_grad_output = std::get<0>(new_bw_input);
421+
auto new_input = std::get<1>(new_bw_input);
422+
auto grad_input_out_bdim = std::get<2>(new_bw_input);
423+
auto new_grid = std::get<3>(new_bw_input);
424+
auto grad_grid_out_bdim = std::get<4>(new_bw_input);
425+
int64_t bdim_size = std::get<5>(new_bw_input);
426+
427+
auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
428+
429+
return grid_sample_backward_helper_out(bw_out, grad_input_out_bdim, grad_grid_out_bdim, bdim_size);
430+
}
431+
432+
template<typename F, F Func>
433+
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
434+
cudnn_grid_sample_backward_batch_rule(
435+
const Tensor& input, optional<int64_t> input_bdim,
436+
const Tensor& grid, optional<int64_t> grid_bdim,
437+
const Tensor& grad_output, optional<int64_t> grad_output_bdim) {
438+
439+
auto new_bw_input = grid_sample_backward_helper_in(
440+
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
441+
442+
auto new_grad_output = std::get<0>(new_bw_input);
443+
auto new_input = std::get<1>(new_bw_input);
444+
auto grad_input_out_bdim = std::get<2>(new_bw_input);
445+
auto new_grid = std::get<3>(new_bw_input);
446+
auto grad_grid_out_bdim = std::get<4>(new_bw_input);
447+
int64_t bdim_size = std::get<5>(new_bw_input);
448+
449+
auto bw_out = Func(new_input, new_grid, new_grad_output);
450+
451+
return grid_sample_backward_helper_out(bw_out, grad_input_out_bdim, grad_grid_out_bdim, bdim_size);
452+
}
453+
258454
std::tuple<Tensor, optional<int64_t>> cross_batch_rule(
259455
const Tensor& self, optional<int64_t> self_bdim,
260456
const Tensor& other, optional<int64_t> other_bdim,
@@ -370,12 +566,53 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
370566
}
371567
};
372568

569+
template <typename A, A a, typename C>
570+
struct GridSampleBackwardBatchRuleHelper;
571+
572+
template <typename F, F Func, typename T1, typename T2, typename T3, typename... T>
573+
struct GridSampleBackwardBatchRuleHelper<F, Func, typelist<T1, T2, T3, T...>> {
574+
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
575+
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim,
576+
const Tensor& input, optional<int64_t> input_batch_dim,
577+
const Tensor& grid, optional<int64_t> grid_batch_dim,
578+
T... extra_args) {
579+
return grid_sample_backward_batch_rule<F, Func, T...>(
580+
grad_output, grad_output_batch_dim,
581+
input, input_batch_dim,
582+
grid, grid_batch_dim,
583+
std::forward<T>(extra_args)...);
584+
}
585+
};
586+
587+
template <typename F, F Func>
588+
struct CudnnGridSampleBackwardBatchRuleHelper {
589+
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
590+
const Tensor& input, optional<int64_t> input_batch_dim,
591+
const Tensor& grid, optional<int64_t> grid_batch_dim,
592+
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim) {
593+
return cudnn_grid_sample_backward_batch_rule<F, Func>(
594+
input, input_batch_dim,
595+
grid, grid_batch_dim,
596+
grad_output, grad_output_batch_dim
597+
);
598+
}
599+
};
600+
373601
#define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\
374602
GridSampleBatchRuleHelper<\
375603
decltype(&ATEN_FN(fn)),\
376604
&ATEN_FN(fn),\
377605
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
378606

607+
#define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\
608+
GridSampleBackwardBatchRuleHelper<\
609+
decltype(&ATEN_FN(fn)),\
610+
&ATEN_FN(fn),\
611+
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
612+
613+
#define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\
614+
CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN(fn)), &ATEN_FN(fn)>::apply
615+
379616
#define UPSAMPLE_BACKWARD(op, overload) VMAP_SUPPORT(#op"."#overload, SINGLE_ARG(\
380617
UpsampleBackwardBatchRuleHelper<\
381618
decltype(&ATEN_FN2(op, overload)),\
@@ -386,6 +623,12 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
386623
EXISTING_BDIM2(op, vec); \
387624
EXISTING_BDIM(op);
388625

626+
Tensor this_grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
627+
return input;
628+
}
629+
630+
631+
389632
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
390633
VMAP_SUPPORT("convolution", convolution_batch_rule);
391634
// m.impl("conv_transpose2d", convNd_transpose_decomp);
@@ -400,7 +643,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
400643
EXISTING_BDIM(im2col_backward);
401644

402645
VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
646+
VMAP_SUPPORT("grid_sampler_2d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));
647+
403648
VMAP_SUPPORT("grid_sampler_3d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
649+
VMAP_SUPPORT("grid_sampler_3d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward));
650+
VMAP_SUPPORT("cudnn_grid_sampler_backward", CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward));
651+
404652
VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
405653
VMAP_SUPPORT("cross", cross_batch_rule);
406654

0 commit comments

Comments
 (0)