@@ -255,6 +255,202 @@ grid_sample_batch_rule(const Tensor& input, optional<int64_t> input_bdim, const
255
255
return result;
256
256
}
257
257
258
+ Tensor expand_reshape_dim_into (int64_t batch_size, int64_t dst, const Tensor& x) {
259
+ auto x_ = x.unsqueeze (0 );
260
+ VmapDimVector new_shape (x_.sizes ().begin (), x_.sizes ().end ());
261
+ new_shape[0 ] = batch_size;
262
+ x_ = x_.expand (new_shape);
263
+ return reshape_dim_into (0 , dst, x_);
264
+ }
265
+
266
+
267
+ std::tuple<Tensor, Tensor, optional<int64_t >, Tensor, optional<int64_t >, int64_t >
268
+ grid_sample_backward_helper_in (
269
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim,
270
+ const Tensor& input, optional<int64_t > input_bdim,
271
+ const Tensor& grid, optional<int64_t > grid_bdim) {
272
+ auto new_grad_output = grad_output;
273
+ auto new_input = input;
274
+ auto new_grid = grid;
275
+
276
+ optional<int64_t > grad_input_out_bdim = nullopt;
277
+ optional<int64_t > grad_grid_out_bdim = nullopt;
278
+ int64_t bdim_size = 0 ;
279
+
280
+ if (grad_output_bdim) {
281
+
282
+ bdim_size = grad_output.sizes ()[*grad_output_bdim];
283
+
284
+ if (input_bdim && grid_bdim) {
285
+ // case 1: (grad_output is batched, input is batched, grid is batched)
286
+ // grad_output: (BN)CH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
287
+ // grad_input: (BN)CH_{in}W_{in}
288
+
289
+ new_grad_output = reshape_dim_into (*grad_output_bdim, 0 , grad_output);
290
+ new_input = reshape_dim_into (*input_bdim, 0 , input);
291
+ new_grid = reshape_dim_into (*grid_bdim, 0 , grid);
292
+ grad_input_out_bdim = 0 ;
293
+ grad_grid_out_bdim = 0 ;
294
+ } else if (input_bdim && !grid_bdim) {
295
+ // case 2: (grad_output is batched, input is batched, grid is not batched)
296
+ // IF PUT BATCH DIM TO CHANNEL -> backward produces wrong grad_grid
297
+ //
298
+ // grad_output: (BN)CH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: NH_{out}W_{out}2
299
+ // -> grid: (BN)H_{out}W_{out}2
300
+ // grad_input: (BN)CH_{in}W_{in}
301
+
302
+ new_grad_output = reshape_dim_into (*grad_output_bdim, 0 , grad_output);
303
+ new_input = reshape_dim_into (*input_bdim, 0 , input);
304
+ grad_input_out_bdim = 0 ;
305
+ new_grid = expand_reshape_dim_into (bdim_size, 0 , grid);
306
+ grad_grid_out_bdim = 0 ;
307
+ } else if (!input_bdim && grid_bdim) {
308
+ // case 3: (grad_output is batched, input is not batched, grid is batched)
309
+ // IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
310
+ //
311
+ // grad_output: (BN)CH_{out}W_{out}, input: NCH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
312
+ // -> input: (BN)CH_{in}W_{in}
313
+ // grad_input: (BN)CH_{in}W_{in}
314
+
315
+ new_grad_output = reshape_dim_into (*grad_output_bdim, 0 , grad_output);
316
+ new_grid = reshape_dim_into (*grid_bdim, 0 , grid);
317
+ grad_grid_out_bdim = 0 ;
318
+ // expand input to (BN)H_{out}W_{out}2
319
+ new_input = expand_reshape_dim_into (bdim_size, 0 , new_input);
320
+ grad_input_out_bdim = 0 ;
321
+ } else {
322
+ // case 4: (grad_output is batched, input is not batched, grid is not batched)
323
+ // IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
324
+ //
325
+ // grad_output: (BN)CH_{out}W_{out}, input: NCH_{in}W_{in}, grid: NH_{out}W_{out}2
326
+ // -> grid: (BN)H_{out}W_{out}2
327
+ // -> input: (BN)CH_{in}W_{in}
328
+ // grad_input: NCH_{in}W_{in}
329
+
330
+ new_grad_output = reshape_dim_into (*grad_output_bdim, 0 , grad_output);
331
+ // expand grid to (BN)H_{out}W_{out}2
332
+ new_grid = expand_reshape_dim_into (bdim_size, 0 , grid);
333
+ grad_grid_out_bdim = 0 ;
334
+ // expand input to (BN)CH_{in}W_{in}
335
+ new_input = expand_reshape_dim_into (bdim_size, 0 , input);
336
+ grad_input_out_bdim = 0 ;
337
+ }
338
+ } else {
339
+ if (input_bdim && grid_bdim) {
340
+ // case 5: (grad_output is not batched, input is batched, grid is batched)
341
+ // grad_output: NCH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
342
+ // -> grad_output: (BN)CH_{out}W_{out}
343
+ // grad_input: (BN)CH_{in}W_{in}
344
+
345
+ bdim_size = input.sizes ()[*input_bdim];
346
+ // expand new_grad_output to (BN)CH_{out}W_{out}
347
+ new_grad_output = expand_reshape_dim_into (bdim_size, 0 , new_grad_output);
348
+ new_input = reshape_dim_into (*input_bdim, 0 , input);
349
+ grad_input_out_bdim = 0 ;
350
+ new_grid = reshape_dim_into (*grid_bdim, 0 , grid);
351
+ grad_grid_out_bdim = 0 ;
352
+ } else if (input_bdim && !grid_bdim) {
353
+ // case 6: (grad_output is not batched, input is batched, grid is not batched)
354
+ // grad_output: NCH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: NH_{out}W_{out}2
355
+ // -> grad_output: (BN)CH_{out}W_{out}
356
+ // -> grid: (BN)H_{out}W_{out}2
357
+ // grad_input: (BN)CH_{in}W_{in}
358
+
359
+ bdim_size = input.sizes ()[*input_bdim];
360
+ // expand new_grad_output to (BN)CH_{out}W_{out}
361
+ new_grad_output = expand_reshape_dim_into (bdim_size, 0 , new_grad_output);
362
+ new_input = reshape_dim_into (*input_bdim, 0 , input);
363
+ grad_input_out_bdim = 0 ;
364
+ // expand new_grid to (BN)H_{out}W_{out}2
365
+ new_grid = expand_reshape_dim_into (bdim_size, 0 , grid);
366
+ grad_grid_out_bdim = 0 ;
367
+ } else if (!input_bdim && grid_bdim) {
368
+ // case 7: (grad_output is not batched, input is not batched, grid is batched)
369
+ // IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
370
+ //
371
+ // grad_output: NCH_{out}W_{out}, input: NCH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
372
+ // -> grad_output: (BN)CH_{out}W_{out}
373
+ // -> input: (BN)CH_{out}W_{out}
374
+ // grad_input: NCH_{in}W_{in}
375
+
376
+ bdim_size = grid.sizes ()[*grid_bdim];
377
+ // expand new_grad_output to NC(BH_{out})W_{out}
378
+ new_grad_output = expand_reshape_dim_into (bdim_size, 0 , new_grad_output);
379
+ // expand new_input to (BN)CH_{in}W_{in}
380
+ new_input = expand_reshape_dim_into (bdim_size, 0 , new_input);
381
+ grad_input_out_bdim = 0 ;
382
+ new_grid = reshape_dim_into (*grid_bdim, 0 , grid);
383
+ grad_grid_out_bdim = 0 ;
384
+ } // case 8 can be ignored
385
+ }
386
+ return std::make_tuple (
387
+ new_grad_output, new_input, grad_input_out_bdim, new_grid, grad_grid_out_bdim, bdim_size);
388
+ }
389
+
390
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >>
391
+ grid_sample_backward_helper_out (
392
+ const std::tuple<Tensor, Tensor> & bw_out,
393
+ optional<int64_t > grad_input_out_bdim,
394
+ optional<int64_t > grad_grid_out_bdim,
395
+ int64_t bdim_size) {
396
+ auto grad_input = std::get<0 >(bw_out);
397
+ auto grad_grid = std::get<1 >(bw_out);
398
+ if (grad_input_out_bdim) {
399
+ grad_input = reshape_dim_outof (*grad_input_out_bdim, bdim_size, grad_input);
400
+ }
401
+ if (grad_grid_out_bdim) {
402
+ grad_grid = reshape_dim_outof (*grad_grid_out_bdim, bdim_size, grad_grid);
403
+ }
404
+ auto result = std::make_tuple (grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
405
+ return result;
406
+ }
407
+
408
+
409
+ template <typename F, F Func, typename ... ExtraArgs>
410
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >>
411
+ grid_sample_backward_batch_rule (
412
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim,
413
+ const Tensor& input, optional<int64_t > input_bdim,
414
+ const Tensor& grid, optional<int64_t > grid_bdim,
415
+ ExtraArgs... extra_args) {
416
+
417
+ auto new_bw_input = grid_sample_backward_helper_in (
418
+ grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
419
+
420
+ auto new_grad_output = std::get<0 >(new_bw_input);
421
+ auto new_input = std::get<1 >(new_bw_input);
422
+ auto grad_input_out_bdim = std::get<2 >(new_bw_input);
423
+ auto new_grid = std::get<3 >(new_bw_input);
424
+ auto grad_grid_out_bdim = std::get<4 >(new_bw_input);
425
+ int64_t bdim_size = std::get<5 >(new_bw_input);
426
+
427
+ auto bw_out = Func (new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
428
+
429
+ return grid_sample_backward_helper_out (bw_out, grad_input_out_bdim, grad_grid_out_bdim, bdim_size);
430
+ }
431
+
432
+ template <typename F, F Func>
433
+ std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >>
434
+ cudnn_grid_sample_backward_batch_rule (
435
+ const Tensor& input, optional<int64_t > input_bdim,
436
+ const Tensor& grid, optional<int64_t > grid_bdim,
437
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim) {
438
+
439
+ auto new_bw_input = grid_sample_backward_helper_in (
440
+ grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
441
+
442
+ auto new_grad_output = std::get<0 >(new_bw_input);
443
+ auto new_input = std::get<1 >(new_bw_input);
444
+ auto grad_input_out_bdim = std::get<2 >(new_bw_input);
445
+ auto new_grid = std::get<3 >(new_bw_input);
446
+ auto grad_grid_out_bdim = std::get<4 >(new_bw_input);
447
+ int64_t bdim_size = std::get<5 >(new_bw_input);
448
+
449
+ auto bw_out = Func (new_input, new_grid, new_grad_output);
450
+
451
+ return grid_sample_backward_helper_out (bw_out, grad_input_out_bdim, grad_grid_out_bdim, bdim_size);
452
+ }
453
+
258
454
std::tuple<Tensor, optional<int64_t >> cross_batch_rule (
259
455
const Tensor& self, optional<int64_t > self_bdim,
260
456
const Tensor& other, optional<int64_t > other_bdim,
@@ -370,12 +566,53 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
370
566
}
371
567
};
372
568
569
+ template <typename A, A a, typename C>
570
+ struct GridSampleBackwardBatchRuleHelper ;
571
+
572
+ template <typename F, F Func, typename T1, typename T2, typename T3, typename ... T>
573
+ struct GridSampleBackwardBatchRuleHelper <F, Func, typelist<T1, T2, T3, T...>> {
574
+ static std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >> apply (
575
+ const Tensor& grad_output, optional<int64_t > grad_output_batch_dim,
576
+ const Tensor& input, optional<int64_t > input_batch_dim,
577
+ const Tensor& grid, optional<int64_t > grid_batch_dim,
578
+ T... extra_args) {
579
+ return grid_sample_backward_batch_rule<F, Func, T...>(
580
+ grad_output, grad_output_batch_dim,
581
+ input, input_batch_dim,
582
+ grid, grid_batch_dim,
583
+ std::forward<T>(extra_args)...);
584
+ }
585
+ };
586
+
587
+ template <typename F, F Func>
588
+ struct CudnnGridSampleBackwardBatchRuleHelper {
589
+ static std::tuple<Tensor, optional<int64_t >, Tensor, optional<int64_t >> apply (
590
+ const Tensor& input, optional<int64_t > input_batch_dim,
591
+ const Tensor& grid, optional<int64_t > grid_batch_dim,
592
+ const Tensor& grad_output, optional<int64_t > grad_output_batch_dim) {
593
+ return cudnn_grid_sample_backward_batch_rule<F, Func>(
594
+ input, input_batch_dim,
595
+ grid, grid_batch_dim,
596
+ grad_output, grad_output_batch_dim
597
+ );
598
+ }
599
+ };
600
+
373
601
#define GRID_SAMPLE_BATCH_RULE (fn ) SINGLE_ARG(\
374
602
GridSampleBatchRuleHelper<\
375
603
decltype (&ATEN_FN (fn)),\
376
604
&ATEN_FN(fn),\
377
605
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
378
606
607
+ #define GRID_SAMPLE_BW_BATCH_RULE (fn ) SINGLE_ARG(\
608
+ GridSampleBackwardBatchRuleHelper<\
609
+ decltype (&ATEN_FN (fn)),\
610
+ &ATEN_FN(fn),\
611
+ c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
612
+
613
+ #define CUDNN_GRID_SAMPLE_BW_BATCH_RULE (fn )\
614
+ CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN (fn)), &ATEN_FN(fn)>::apply
615
+
379
616
#define UPSAMPLE_BACKWARD (op, overload ) VMAP_SUPPORT(#op" ." #overload, SINGLE_ARG(\
380
617
UpsampleBackwardBatchRuleHelper<\
381
618
decltype (&ATEN_FN2 (op, overload)),\
@@ -386,6 +623,12 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
386
623
EXISTING_BDIM2 (op, vec); \
387
624
EXISTING_BDIM (op);
388
625
626
+ Tensor this_grid_sampler_3d_backward_cpu (const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
627
+ return input;
628
+ }
629
+
630
+
631
+
389
632
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
390
633
VMAP_SUPPORT (" convolution" , convolution_batch_rule);
391
634
// m.impl("conv_transpose2d", convNd_transpose_decomp);
@@ -400,7 +643,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
400
643
EXISTING_BDIM (im2col_backward);
401
644
402
645
VMAP_SUPPORT (" grid_sampler_2d" , GRID_SAMPLE_BATCH_RULE (grid_sampler));
646
+ VMAP_SUPPORT (" grid_sampler_2d_backward" , GRID_SAMPLE_BW_BATCH_RULE (grid_sampler_2d_backward));
647
+
403
648
VMAP_SUPPORT (" grid_sampler_3d" , GRID_SAMPLE_BATCH_RULE (grid_sampler));
649
+ VMAP_SUPPORT (" grid_sampler_3d_backward" , GRID_SAMPLE_BW_BATCH_RULE (grid_sampler_3d_backward));
650
+ VMAP_SUPPORT (" cudnn_grid_sampler_backward" , CUDNN_GRID_SAMPLE_BW_BATCH_RULE (cudnn_grid_sampler_backward));
651
+
404
652
VMAP_SUPPORT (" cudnn_grid_sampler" , GRID_SAMPLE_BATCH_RULE (cudnn_grid_sampler));
405
653
VMAP_SUPPORT (" cross" , cross_batch_rule);
406
654
0 commit comments