-
Notifications
You must be signed in to change notification settings - Fork 6k
[torch.compile] Make HiDream torch.compile ready #11477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) | ||
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts) | ||
tokens_per_expert = count_freq.cumsum(dim=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just reimplemented it to eliminate the numpy()
dependency.
@require_torch_2 | ||
@is_torch_compile | ||
@slow | ||
def test_torch_compile_recompilation_and_graph_break(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relevant test for this PR.
The graph break seems to be induced by
@anijain2305 is this known? |
Even if we remove the decorator, it still fails with the same error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Edit - Checked the messages, missed that there is still a graph break. I can take a look today.
Thanks! Appreciate it.
…On Thu, 8 May 2025 at 7:06 PM, Animesh Jain ***@***.***> wrote:
***@***.**** approved this pull request.
LGTM
—
Reply to this email directly, view it on GitHub
<#11477 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFPE2TCL5CXY6KDTOMUTXM325NMWRAVCNFSM6AAAAAB4H4U5SOVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDQMRVGE3DSNZQHE>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
What does this PR do?
Trying to make the HiDream model fully compatible with
torch.compile()
but it fails with:https://pastebin.com/EbCFqBvw
To reproduce run the following from a GPU machine:
RUN_COMPILE=1 RUN_SLOW=1 pytest tests/models/transformers/test_models_transformer_hidream.py -k "test_torch_compile_recompilation_and_graph_break"
I am on the following env:
@anijain2305 @StrongerXi would you have any pointers?