-
Notifications
You must be signed in to change notification settings - Fork 93
[PTX] Enable migration of mma (m16n8k16) #2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: SYCLomatic
Are you sure you want to change the base?
Conversation
2b4c1bd
to
6719399
Compare
/// \param [in] item The sycl::nd_item index space class | ||
template <typename MulType, typename ABType, typename CDType, typename ItemT> | ||
__attribute__((optnone)) void | ||
mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According the PTX spec: https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-multiply-accumulate-operation-using-mma-instruction, there are 11 kinds of shapes for mma:
m8n8k4
m8n8k16
m8n8k32
m8n8k128
m16n8k4
m16n8k8
m16n8k16
m16n8k32
m16n8k64
m16n8k128
m16n8k256
Pls make sure your helper function has the capability to support them besides the shape of "m16n8k16".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is the initial PR, I've added support for the shape used in top apps
Will add support for remaining shapes soon
This PR adds support for the migration of mma PTX ASM API