Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conv2d conv1d forward function #166

Open
wants to merge 17 commits into
base: master
Choose a base branch
from

Conversation

FatJhon
Copy link
Collaborator

@FatJhon FatJhon commented Aug 16, 2024

We have completed the forward fuction of the conv1d, conv2d conv2d_depthwise operators.

@tongxin
Copy link
Contributor

tongxin commented Aug 19, 2024

There are new conflicts needed to be resolve. Please see to that.

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the code to the newest to make all the checks work

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review ongoing

src/flag_gems/ops/conv1d.py Outdated Show resolved Hide resolved
+ (tl.arange(0, BLOCK_C_IN) * padded_width_input)[:, None]
+ i * stride_width
+ tl.arange(0, BLOCK_W)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't support width_kernel that is not power of 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Continuous data loading problem. Currently, block only supports powers of 2, so the h w of the kernel only supports this specification

weight_value = tl.load(weight + weight_offset, mask=mask_weight, other=0)
input_value = tl.reshape(input_value, (BLOCK_N, BLOCK_OUT_WEIGHT))
weight_value = tl.reshape(weight_value, (BLOCK_OUT_WEIGHT, BLOCK_O))
accumulator = tl.dot(input_value, weight_value, allow_tf32=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulator redefined

@pytest.mark.parametrize("kernel", [(17, 2, 2)])
@pytest.mark.parametrize("stride", [2])
@pytest.mark.parametrize("padding", [1])
@pytest.mark.parametrize("dtype", [torch.float32])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more test cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please implement the performance tests.

)
input_ci_offset = offset_ci[:, None, None] * height_input * width_input
input_group_offset = (
pid_group[None, None, None, None] * c_input * width_input * height_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initializing input_group_offset as a scalar is okay

input_col = mm(out_grad, weight_reshape.T)

# return dx,None,None,None,None,None,None
conv2d_col2img[grid](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could fuse mm and col2img together as bwd kernel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but not necessarily

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@FatJhon FatJhon force-pushed the dev_xcoresigma_jiangbin_conv branch from 121285e to 5adc6e6 Compare October 6, 2024 02:50
@@ -16,6 +16,9 @@
from .bmm import bmm
from .cat import cat
from .clamp import clamp, clamp_tensor
from .conv1d import conv1d
from .conv2d import conv2d
from .conv_depthwise2d import _conv_depthwise2d
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is conv_depthwise2d implemented?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

padding_width = padding[0]
else:
padding_width = padding
return conv2d(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest implementing a kernel function for conv1d specifically. Calling conv2d will cost additional runtime.

accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32)

for h in range(kernel_height):
for w in range(kernel_width):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since kernel_height and kernel_width are processed as loop iterator range, why not support those non-power-of-2?


class Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dilation is not used at all, so it does not support dilation other than 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dilation is not used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should raise an error if dilation > 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added dilation function

Returns:
Output size of 2D convolution.
"""
return (in_size + 2 * padding - kernel_size) // stride + 1
Copy link
Collaborator

@iclementine iclementine Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it support asymmetric padding? I suppose not. Alright, torch's convolution does not support asymmetric padding.


return (
input,
None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient of weight and bias should also be computed.

# default conv shape for input and weight stride padding groups
# default Ni Ci Hi WI Co Hk Wk stride padding groups
ConvBenchmark:
shapes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 shapes are enough for core mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants