-
Notifications
You must be signed in to change notification settings - Fork 533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch] Add Canonicalize Pattern for embedding op #3277
Conversation
37782f7
to
22465ee
Compare
I can see benefit to this optimization however this is more avoiding the compilation issue we have been encountering rather than preventing the crash. |
Note this could also be a pessimization: if you have your embeddings as f32, gather them, and convert to f16 you really want the conversion to fold into the embeddings so you aren't shipping (and doing the memory transactions) on f32 if you don't need those bits. This may get taken care of later in the pipeline but it's important to note that there are some massive implications of things like this (it's always better to hoist narrowing operations and sink widening operations, almost never the opposite). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid hurting performance we should only perform the swap during the widening case, otherwise we are potentially loading more data just to truncate back down whereas there is benefit to truncating overall.
There's a tradeoff between memory and compute; doing this might take more memory but is less compute-intensive, whereas the one suggested might be compute-intensive since we are not able to fuse both kernels at the backend. I will add the check to perform a swap only during the widening case. |
I've made the necessary changes. Please review. |
Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op.
Hi @pashu123, it seems the PR has been there for quite some time. Can you please update it in order to get merged? |
It’s not needed. I’ll close the PR. |
Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op.
Issue: iree-org/iree#17226 (comment)