@@ -24,49 +24,75 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
2424
2525 namespace raw {
2626
27- template <class T >
27+ template <class T , std:: size_t CHANNELS_PER_ITER >
2828 __global__ void roi_pooling (
2929 Span<T> output, size_type pooled_height, size_type pooled_width,
3030 View<T> input, size_type in_height, size_type in_width,
31- View<T> rois, size_type num_channels, T spatial_scale)
31+ View<T> rois, size_type num_channels, float spatial_scale)
3232 {
3333 // input: [1, num_channels, in_height, in_width]
34+ const auto in_image_size = in_height * in_width;
35+
3436 // rois: [num_rois, 5]
37+ auto num_rois = rois.size () / 5 ;
3538
3639 // output: [num_rois, num_channels, pooled_height, pooled_width]
3740 const auto out_spatial_size = pooled_height * pooled_width;
3841 const auto out_roi_size = num_channels * out_spatial_size;
3942
40- /* every element in the output is mapped to a window in the input and each thread processes several windows */
41- for (auto idx : grid_stride_range (output.size ()))
43+ /* we have to compute the output value for every combination of (roi, c, y, x) in the output
44+ *
45+ * the computation involving (y, x) are identical for all non-spatial dimensions
46+ * the computation and memory requests involving the roi are identical for remaining three axes
47+ *
48+ * we process multiple channels every iteration to reuse the identical computation
49+ * and memory requests involved with the roi and spatial dimensions
50+ */
51+ /*
52+ * if we are processing `CHANNELS_PER_ITER` channels per iteration, we will need
53+ * (num_channels / CHANNELS_PER_ITER) iterations per (roi, x, y)
54+ */
55+ auto num_channel_iters_per_roi_xy = num_channels / CHANNELS_PER_ITER;
56+
57+ /* we need `num_channel_iters_per_roi_xy` iterations per (roi, x, y) and there are
58+ * `num_rois` rois and `out_spatial_size` combinations of (x, y)
59+ */
60+ auto iters_per_roi = num_channel_iters_per_roi_xy * out_spatial_size;
61+ auto iters_required = num_rois * iters_per_roi;
62+
63+ for (auto iter : grid_stride_range (iters_required))
4264 {
43- const auto n = idx / out_roi_size;
44- const auto c = (idx % out_roi_size) / out_spatial_size;
45- const auto y = (idx % out_spatial_size) / pooled_width;
46- const auto x = idx % pooled_width;
65+ const index_type roi_no = iter / iters_per_roi;
66+ const index_type c_start = ((iter % iters_per_roi) / out_spatial_size) * CHANNELS_PER_ITER;
67+
68+ /* note here that consecutive `iter` values will often have consecutive `x` values
69+ * => stores into output will be coalesced across threads
70+ */
71+ const index_type y = (iter % out_spatial_size) / pooled_width;
72+ const index_type x = iter % pooled_width;
4773
48- const index_type roi_offset = n * 5 ;
74+ const index_type roi_offset = roi_no * 5 ;
4975
5076 using device::round;
5177 const index_type batch_id = rois[roi_offset + 0 ];
52- const index_type x_start_roi = round (rois[roi_offset + 1 ] * spatial_scale);
53- const index_type y_start_roi = round (rois[roi_offset + 2 ] * spatial_scale);
54- const index_type x_end_roi = round (rois[roi_offset + 3 ] * spatial_scale);
55- const index_type y_end_roi = round (rois[roi_offset + 4 ] * spatial_scale);
78+ const index_type x_start_roi = round (static_cast < float >( rois[roi_offset + 1 ]) * spatial_scale);
79+ const index_type y_start_roi = round (static_cast < float >( rois[roi_offset + 2 ]) * spatial_scale);
80+ const index_type x_end_roi = round (static_cast < float >( rois[roi_offset + 3 ]) * spatial_scale);
81+ const index_type y_end_roi = round (static_cast < float >( rois[roi_offset + 4 ]) * spatial_scale);
5682
5783 using device::max;
5884 const auto roi_width = max<index_type>(x_end_roi - x_start_roi + 1 , 1 );
5985 const auto roi_height = max<index_type>(y_end_roi - y_start_roi + 1 , 1 );
6086
61- const auto roi_width_ratio = static_cast <T >(roi_width) / static_cast <T>( pooled_width) ;
62- const auto roi_height_ratio = static_cast <T >(roi_height) / static_cast <T>( pooled_height) ;
87+ const auto roi_width_ratio = static_cast <float >(roi_width) / pooled_width;
88+ const auto roi_height_ratio = static_cast <float >(roi_height) / pooled_height;
6389
64- auto x_start = x_start_roi + static_cast <index_type>(static_cast <T>(x) * roi_width_ratio);
65- auto y_start = y_start_roi + static_cast <index_type>(static_cast <T>(y) * roi_height_ratio);
90+ auto x_start = x_start_roi + static_cast <index_type>(x * roi_width_ratio);
91+ auto y_start = y_start_roi + static_cast <index_type>(y * roi_height_ratio);
6692
6793 using device::ceil;
68- auto x_end = x_start_roi + static_cast <index_type>(ceil (static_cast <T> (x + 1 ) * roi_width_ratio));
69- auto y_end = y_start_roi + static_cast <index_type>(ceil (static_cast <T> (y + 1 ) * roi_height_ratio));
94+ auto x_end = x_start_roi + static_cast <index_type>(ceil ((x + 1 ) * roi_width_ratio));
95+ auto y_end = y_start_roi + static_cast <index_type>(ceil ((y + 1 ) * roi_height_ratio));
7096
7197 using device::max;
7298 x_start = max<index_type>(x_start, 0 );
@@ -76,29 +102,48 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
76102 x_end = min<index_type>(x_end, in_width);
77103 y_end = min<index_type>(y_end, in_height);
78104
79- /* We have to set the output to zero if (x_start >= x_end) or (y_start >= y_end). If either
80- * condition is true, the loops below won't execute even a single iteration. Hence, by setting
81- * `max_val` to zero in this case, we can combine it with the `else` code.
82- */
83- T max_val = (x_start >= x_end || y_start >= y_end) ? T (0 ) : device::numeric_limits<T>::lowest ();
105+ index_type in_offset = (batch_id * num_channels + c_start) * in_height * in_width;
106+ index_type out_idx = roi_no * out_roi_size + c_start * out_spatial_size + y * pooled_width + x;
84107
85- const index_type in_offset = (batch_id * num_channels + c) * in_height * in_width;
86- for (auto iy = y_start; iy < y_end; iy++)
108+ for (int i = 0 ; i < CHANNELS_PER_ITER; i++)
87109 {
88- for (auto ix = x_start; ix < x_end; ix++)
110+ /* We have to set the output to zero if (x_start >= x_end) or (y_start >= y_end). If either
111+ * condition is true, the loops below won't execute even a single iteration. Hence, by setting
112+ * `max_val` to zero in this case, we can combine it with the `else` code.
113+ */
114+ T max_val = (x_start >= x_end || y_start >= y_end) ? T (0 ) : device::numeric_limits<T>::lowest ();
115+
116+ for (auto iy = y_start; iy < y_end; iy++)
89117 {
90- const auto in_idx = in_offset + iy * in_width + ix;
91- max_val = max (max_val, input[in_idx]);
118+ const auto in_idx = in_offset + iy * in_width;
119+ for (auto ix = x_start; ix < x_end; ix++)
120+ {
121+ max_val = max (max_val, input[in_idx + ix]);
122+ }
92123 }
93- }
94124
95- output[idx] = max_val;
125+ output[out_idx] = max_val;
126+
127+ in_offset += in_image_size;
128+ out_idx += out_spatial_size;
129+ }
96130 }
97131 }
98132 }
99133
134+ template <class T , std::size_t CHANNELS_PER_ITER> static
135+ void launch_multichannel_roi_pooling (const Stream& stream,
136+ Span<T> output, size_type pooled_height, size_type pooled_width,
137+ View<T> input, size_type in_height, size_type in_width,
138+ View<T> rois, size_type num_channels, float spatial_scale)
139+ {
140+ auto kernel = raw::roi_pooling<T, CHANNELS_PER_ITER>;
141+ auto policy = make_policy (kernel, output.size () / CHANNELS_PER_ITER, 0 , stream);
142+ launch_kernel (kernel, policy, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
143+ }
144+
100145 template <class T >
101- void roi_pooling (const Stream& stream, TensorSpan<T> output, TensorView<T> input, View<T> rois, T spatial_scale)
146+ void roi_pooling (const Stream& stream, TensorSpan<T> output, TensorView<T> input, View<T> rois, float spatial_scale)
102147 {
103148 CV_Assert (input.get_axis_size (1 ) == output.get_axis_size (1 ));
104149
@@ -110,13 +155,25 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
110155 size_type in_height = input.get_axis_size (2 );
111156 size_type in_width = input.get_axis_size (3 );
112157
113- auto kernel = raw::roi_pooling<T>;
114- auto policy = make_policy (kernel, output.size (), 0 , stream);
115- launch_kernel (kernel, policy, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
158+ if (num_channels % 64 == 0 ) {
159+ launch_multichannel_roi_pooling<T, 64 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
160+ } else if (num_channels % 32 == 0 ) {
161+ launch_multichannel_roi_pooling<T, 32 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
162+ } else if (num_channels % 16 == 0 ) {
163+ launch_multichannel_roi_pooling<T, 16 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
164+ } else if (num_channels % 8 == 0 ) {
165+ launch_multichannel_roi_pooling<T, 8 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
166+ } else if (num_channels % 4 == 0 ) {
167+ launch_multichannel_roi_pooling<T, 4 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
168+ } else if (num_channels % 2 == 0 ) {
169+ launch_multichannel_roi_pooling<T, 2 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
170+ } else {
171+ launch_multichannel_roi_pooling<T, 1 >(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
172+ }
116173 }
117174
118175#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
119- template void roi_pooling (const Stream& stream, TensorSpan<__half> output, TensorView<__half> input, View<__half> rois, __half spatial_scale);
176+ template void roi_pooling (const Stream& stream, TensorSpan<__half> output, TensorView<__half> input, View<__half> rois, float spatial_scale);
120177#endif
121178 template void roi_pooling (const Stream& stream, TensorSpan<float > output, TensorView<float > input, View<float > rois, float spatial_scale);
122179
0 commit comments