-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CK_TILE] Multiple-D GEMM example #2008
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't yet finished all files, but I've spotted few things which looks very suspicious to me. Please verify them.
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) | ||
{ | ||
CK_TILE_ERROR( | ||
"Can't support N that is not a multiple of NPerBlock without padding!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note this is wrt tensor D.
@@ -399,6 +467,29 @@ struct GemmKernel | |||
} | |||
}(); | |||
|
|||
// TODO: enable vector write for D in ColMajor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is misleading. Please remove.
make_tuple(kargs.M, kargs.N), | ||
make_tuple(kargs.stride_Ds[i], 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix this. This is currently RowMajor layout, not Column Major.
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); | ||
return make_tuple(a_tensor_view, | ||
b_tensor_view, | ||
generate_tuple(d_tensor_view, number<NumDTensor>{}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep coherent style? Please create a variable and use it here.
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) | ||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, | ||
const OAccTile& o_acc_tile, | ||
const DDramWindow& ds_dram_window, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please static assert its size.
@@ -154,6 +181,14 @@ struct CShuffleEpilogue | |||
tile_distribution_pattern::thread_raked>; | |||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); | |||
|
|||
auto d_dram_small_window = generate_tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto d_dram_small_window = generate_tuple( | |
auto d_dram_windows = generate_tuple( |
@@ -154,6 +181,14 @@ struct CShuffleEpilogue | |||
tile_distribution_pattern::thread_raked>; | |||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); | |||
|
|||
auto d_dram_small_window = generate_tuple( | |||
[&](auto idx) { return make_tile_window(ds_dram_window[idx], dram_tile_distribution); }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to set tile window lengths here as for lds windows. Otherwise you have in here window of size : MPerBlock x Nperblock.
|
||
using elemenet_wise_output_t = | ||
decltype(load_tile(make_tile_window(out_lds_window, dram_tile_distribution))); | ||
elemenet_wise_output_t elemenet_wise_output; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is actually needed? Why not just overwrite c_out_tensor
- this would use much less registers
constexpr int M = 3840; | ||
constexpr int N = 4096; | ||
constexpr int K = 4096; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use such large inputs in unit-tests.
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered