-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add conv2d conv1d forward function #166
base: master
Are you sure you want to change the base?
Conversation
There are new conflicts needed to be resolve. Please see to that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the code to the newest to make all the checks work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review ongoing
src/flag_gems/ops/conv1d.py
Outdated
+ (tl.arange(0, BLOCK_C_IN) * padded_width_input)[:, None] | ||
+ i * stride_width | ||
+ tl.arange(0, BLOCK_W) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't support width_kernel that is not power of 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As Continuous data loading problem. Currently, block only supports powers of 2, so the h w of the kernel only supports this specification
src/flag_gems/ops/conv1d.py
Outdated
weight_value = tl.load(weight + weight_offset, mask=mask_weight, other=0) | ||
input_value = tl.reshape(input_value, (BLOCK_N, BLOCK_OUT_WEIGHT)) | ||
weight_value = tl.reshape(weight_value, (BLOCK_OUT_WEIGHT, BLOCK_O)) | ||
accumulator = tl.dot(input_value, weight_value, allow_tf32=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accumulator redefined
tests/test_reduction_ops.py
Outdated
@pytest.mark.parametrize("kernel", [(17, 2, 2)]) | ||
@pytest.mark.parametrize("stride", [2]) | ||
@pytest.mark.parametrize("padding", [1]) | ||
@pytest.mark.parametrize("dtype", [torch.float32]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more test cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please implement the performance tests.
src/flag_gems/ops/conv2d.py
Outdated
) | ||
input_ci_offset = offset_ci[:, None, None] * height_input * width_input | ||
input_group_offset = ( | ||
pid_group[None, None, None, None] * c_input * width_input * height_input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initializing input_group_offset as a scalar is okay
src/flag_gems/ops/conv2d.py
Outdated
input_col = mm(out_grad, weight_reshape.T) | ||
|
||
# return dx,None,None,None,None,None,None | ||
conv2d_col2img[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could fuse mm and col2img together as bwd kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but not necessarily
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
121285e
to
5adc6e6
Compare
@@ -16,6 +16,9 @@ | |||
from .bmm import bmm | |||
from .cat import cat | |||
from .clamp import clamp, clamp_tensor | |||
from .conv1d import conv1d | |||
from .conv2d import conv2d | |||
from .conv_depthwise2d import _conv_depthwise2d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is conv_depthwise2d implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
padding_width = padding[0] | ||
else: | ||
padding_width = padding | ||
return conv2d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest implementing a kernel function for conv1d specifically. Calling conv2d will cost additional runtime.
src/flag_gems/ops/conv2d.py
Outdated
accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32) | ||
|
||
for h in range(kernel_height): | ||
for w in range(kernel_width): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since kernel_height and kernel_width are processed as loop iterator range, why not support those non-power-of-2?
|
||
class Conv2d(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, weight, bias, stride, padding, dilation, groups): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dilation is not used at all, so it does not support dilation other than 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dilation is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should raise an error if dilation > 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added dilation function
src/flag_gems/ops/conv2d.py
Outdated
Returns: | ||
Output size of 2D convolution. | ||
""" | ||
return (in_size + 2 * padding - kernel_size) // stride + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it support asymmetric padding? I suppose not. Alright, torch's convolution does not support asymmetric padding.
|
||
return ( | ||
input, | ||
None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gradient of weight and bias should also be computed.
# default conv shape for input and weight stride padding groups | ||
# default Ni Ci Hi WI Co Hk Wk stride padding groups | ||
ConvBenchmark: | ||
shapes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 shapes are enough for core mode.
We have completed the forward fuction of the conv1d, conv2d conv2d_depthwise operators.