diff --git a/test/test_ops.py b/test/test_ops.py index 3f0d8312c01..c7c415e0ab3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -929,6 +929,7 @@ def test_batched_nms_implementations(self, seed): class TestDeformConv: dtype = torch.float64 + mps_dtype = torch.float32 def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) @@ -1050,12 +1051,11 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) - @pytest.mark.opcheck_only_one() def test_forward(self, device, contiguous, batch_sz, dtype=None): - dtype = dtype or self.dtype + dtype = self.mps_dtype if device == "mps" else dtype or self.dtype x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) in_channels = 6 out_channels = 2 @@ -1201,13 +1201,50 @@ def test_forward_scriptability(self): torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3)) -optests.generate_opcheck_tests( - testcase=TestDeformConv, - namespaces=["torchvision"], - failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"), - additional_decorators=[], - test_utils=OPTESTS, -) +@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64)) +@pytest.mark.parametrize("device", cpu_and_cuda()) +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_deform_conv2d_opcheck(dtype, device, requires_grad): + batch_size, channels_in, height, width = 1, 6, 10, 10 + kernel_size = (3, 3) + stride = (1, 1) + padding = (1, 1) + dilation = (1, 1) + groups = 2 + out_channels = 4 + out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 + out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 + x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad) + offset = torch.randn(batch_size, 2 * kernel_size[0] * kernel_size[1], out_h, out_w, + dtype=dtype, device=device, requires_grad=requires_grad) + weight = torch.randn(out_channels, channels_in // groups, kernel_size[0], kernel_size[1], + dtype=dtype, device=device, requires_grad=requires_grad) + bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad) + use_mask = True + mask = torch.sigmoid(torch.randn( + batch_size, + kernel_size[0] * kernel_size[1], + out_h, + out_w, + dtype=dtype, device=device, requires_grad=requires_grad + )) + kwargs = { + "offset": offset, + "weight": weight, + "bias": bias, + "stride_h": stride[0], + "stride_w": stride[1], + "pad_h": padding[0], + "pad_w": padding[1], + "dilation_h": dilation[0], + "dilation_w": dilation[1], + "groups": groups, + "offset_groups": 1, + "use_mask": use_mask, + "mask": mask, # no modulation in this test + } + optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs) + class TestFrozenBNT: diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm new file mode 100644 index 00000000000..1d390a37f43 --- /dev/null +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -0,0 +1,134 @@ +#include +#include +#include +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + using namespace at::native::mps; + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D"); + TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D"); + TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D"); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true"); + TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); + + at::DeviceGuard guard(input_c.device()); + + int batch = input_c.size(0); + int in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + int out_channels = weight_c.size(0); + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels, + "Input channels (", in_channels, + ") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")"); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0, + "Weight tensor's out channels (", weight_c.size(0), + ") must be divisible by n_weight_grps (", n_weight_grps, ")"); + TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w, + "Offset tensor shape[1] is invalid: got ", offset_c.size(1), + ", expected ", n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w, + "Mask tensor shape[1] is invalid: got ", mask_c.size(1), + ", expected ", n_offset_grps * weight_h * weight_w); + TORCH_CHECK(in_channels % n_offset_grps == 0, + "Input tensor channels (", in_channels, + ") must be divisible by n_offset_grps (", n_offset_grps, ")"); + TORCH_CHECK(offset_c.size(0) == batch, + "Offset tensor batch size (", offset_c.size(0), + ") must match input tensor batch size (", batch, ")"); + TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w, + "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3), + ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); + TORCH_CHECK(!use_mask || mask_c.size(0) == batch, + "Mask tensor batch size (", mask_c.size(0), + ") must match input tensor batch size (", batch, ")"); + TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w), + "Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3), + ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); + TORCH_CHECK(out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); + + auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options()); + + id inputBuffer = getMTLBufferStorage(input_c); + id offsetBuffer = getMTLBufferStorage(offset_c); + id maskBuffer = use_mask ? getMTLBufferStorage(mask_c) : nil; + id outputBuffer = getMTLBufferStorage(columns); + + id device = MPSDevice::getInstance()->device(); + std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); + id pipelineState = mps::visionPipelineState(device, kernelName); + + int num_kernels = in_channels * out_h * out_w * batch; + NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup; + NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup; + MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^{ + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + [computeEncoder setComputePipelineState:pipelineState]; + at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer, + in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, batch, in_channels, n_offset_grps, out_h, out_w, + use_mask, outputBuffer); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + } + }); + int in_channels_per_grp = in_channels / n_weight_grps; + int out_channels_per_grp = out_channels / n_weight_grps; + auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w}); + auto columns_grouped = columns.view({n_weight_grps, + (in_channels * weight_h * weight_w) / n_weight_grps, + batch * out_h * out_w}); + auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1}); + auto out_grouped = at::bmm(weight_reshaped, columns_grouped); + auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w}) + .transpose(0, 1); + return out + bias_c.view({1, out_channels, 1, 1}); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); +} + +} // namespace ops +} // namespace vision \ No newline at end of file diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index f85546a6c41..2f24c86c6bf 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -91,6 +91,52 @@ inline T bilinear_interpolate( return val; } +template +inline T bilinear_interpolate_deformable_conv2d( + constant T* input, + integer_t height, + integer_t width, + T y, + T x, + uint index /* index for debug only*/) { + if (y <= -1.0 || y >= height || x <= -1.0 || x >= width) { + return 0; + } + integer_t y_low = static_cast(floor(y)); + integer_t x_low = static_cast(floor(x)); + integer_t y_high = y_low + 1; + integer_t x_high = x_low + 1; + + T ly = y - static_cast(y_low); + T lx = x - static_cast(x_low); + T hh = 1.0 - ly; + T hw = 1.0 - lx; + + T v1 = 0; + if (y_low >= 0 && x_low >= 0) + v1 = input[y_low * width + x_low]; + + T v2 = 0; + if (y_low >= 0 && x_high <= width - 1) + v2 = input[y_low * width + x_high]; + + T v3 = 0; + if (y_high <= height - 1 && x_low >= 0) + v3 = input[y_high * width + x_low]; + + T v4 = 0; + if (y_high <= height - 1 && x_high <= width - 1) + v4 = input[y_high * width + x_high]; + + T w1 = hh * hw; + T w2 = hh * lx; + T w3 = ly * hw; + T w4 = ly * lx; + + T val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + return val; +} + template inline void bilinear_interpolate_gradient( integer_t height, @@ -225,6 +271,117 @@ kernel void nms( \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void deformable_im2col_kernel( + constant T* input_ptr [[ buffer(0) ]], + constant T* offset_ptr [[ buffer(1) ]], + constant T* mask_ptr [[ buffer(2) ]], + constant int& height [[ buffer(3) ]], + constant int& width [[ buffer(4) ]], + constant int& weight_h [[ buffer(5) ]], + constant int& weight_w [[ buffer(6) ]], + constant int& pad_h [[ buffer(7) ]], + constant int& pad_w [[ buffer(8) ]], + constant int& stride_h [[ buffer(9) ]], + constant int& stride_w [[ buffer(10)]], + constant int& dilation_h [[ buffer(11)]], + constant int& dilation_w [[ buffer(12)]], + constant int& batch_size [[ buffer(13)]], + constant int& n_in_channels [[ buffer(14)]], + constant int& n_offset_grps [[ buffer(15)]], + constant int& out_h [[ buffer(16)]], + constant int& out_w [[ buffer(17)]], + constant bool& use_mask [[ buffer(18)]], + device T* columns_ptr [[ buffer(19)]], + uint tid [[ thread_position_in_grid ]], + uint tpg [[ threads_per_grid ]]) +{ + int total = out_w * out_h * batch_size * n_in_channels; + int gridSize = tpg; + if (tid >= total) { + return; + } + + int out_x = tid % out_w; + int out_y = (tid / out_w) % out_h; + int out_b = (tid / (out_w * out_h)) % batch_size; + int in_c = tid / (out_w * out_h * batch_size); + int out_c = in_c * weight_h * weight_w; + + int c_per_offset_grp = n_in_channels / n_offset_grps; + int grp_idx = in_c / c_per_offset_grp; + + int col_offset = out_c * (batch_size * out_h * out_w) + + out_b * (out_h * out_w) + + out_y * out_w + out_x; + device T* local_columns_ptr = columns_ptr + col_offset; + + int input_offset = out_b * (n_in_channels * height * width) + + in_c * (height * width); + constant T* local_input_ptr = input_ptr + input_offset; + + int offset_offset = (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + constant T* local_offset_ptr = offset_ptr + offset_offset; + + constant T* local_mask_ptr = nullptr; + if (use_mask) { + int mask_offset = (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; + local_mask_ptr = mask_ptr + mask_offset; + } + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + int mask_index = i * weight_w + j; + int offset_index = 2 * mask_index; + + T mask_value = 1; + if (use_mask) { + mask_value = local_mask_ptr[mask_index * (out_h * out_w) + out_y * out_w + out_x]; + } + + T offset_h_val = local_offset_ptr[offset_index * (out_h * out_w) + out_y * out_w + out_x]; + T offset_w_val = local_offset_ptr[(offset_index + 1) * (out_h * out_w) + out_y * out_w + out_x]; + + T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h_val; + T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w_val; + + T interp = bilinear_interpolate_deformable_conv2d(local_input_ptr, height, width, y, x, tid); + + *local_columns_ptr = mask_value * interp; + + local_columns_ptr += batch_size * out_h * out_w; + } + } +} + +#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \ +template \ +[[host_name("deformable_im2col_" #DTYPE)]] \ +kernel void deformable_im2col_kernel( \ + constant DTYPE* input_ptr [[ buffer(0) ]], \ + constant DTYPE* offset_ptr [[ buffer(1) ]], \ + constant DTYPE* mask_ptr [[ buffer(2) ]], \ + constant int& height [[ buffer(3) ]], \ + constant int& width [[ buffer(4) ]], \ + constant int& weight_h [[ buffer(5) ]], \ + constant int& weight_w [[ buffer(6) ]], \ + constant int& pad_h [[ buffer(7) ]], \ + constant int& pad_w [[ buffer(8) ]], \ + constant int& stride_h [[ buffer(9) ]], \ + constant int& stride_w [[ buffer(10)]], \ + constant int& dilation_h [[ buffer(11)]], \ + constant int& dilation_w [[ buffer(12)]], \ + constant int& batch_sz [[ buffer(13)]], \ + constant int& n_in_channels[[ buffer(14)]], \ + constant int& n_offset_grps[[ buffer(15)]], \ + constant int& out_h [[ buffer(16)]], \ + constant int& out_w [[ buffer(17)]], \ + constant bool& use_mask [[ buffer(18)]], \ + device DTYPE* columns_ptr [[ buffer(19)]], \ + uint tid [[ thread_position_in_grid ]], \ + uint tpg [[ threads_per_grid ]]); + template kernel void roi_align( constant T * input [[buffer(0)]], @@ -1013,6 +1170,8 @@ kernel void ps_roi_pool_backward( \ REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); +REGISTER_DEFORMABLE_IM2COL_OP(float); +REGISTER_DEFORMABLE_IM2COL_OP(half); REGISTER_ROI_ALIGN_OP(float, int64_t); REGISTER_ROI_ALIGN_OP(half, int64_t); REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);