-
Notifications
You must be signed in to change notification settings - Fork 534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FxImporter]: support parsing nodes which return list #3031
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. If you wouldn't mind, could you add a test? I know the testing in tree is thin for this (we have a bunch of tests out of tree that have not been ported yet), but would like to not keep adding features without tests for structural things.
Existing tests are here: https://github.com/llvm/torch-mlir/tree/main/test/python/fx_importer
Could you create a new op_forms.py and add a case. It should follow how basic_test.py does it. Checks don't need to be super detailed.
Thanks for the kind advice. A simple test case for list_return_test.py which contains only one case of torch.unbind.int op is added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution. Since this stuff gets a bit twisty, do you mind if I take a bit of time and poke at this so that I understand? If I make changes, I'll extend your patch and land as co-author. These things are hard to review in the abstract without actually trying things.
Thank you for the test -- that helps a lot.
@@ -888,6 +888,20 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT | |||
tensor_meta = node.meta.get("tensor_meta") | |||
val = node.meta.get("val") | |||
sparsity = node.meta.get("sparsity", None) | |||
# Some nodes returns a list, like torch.ops.aten.unbind.int | |||
if isinstance(tensor_meta, List) or isinstance(val, List): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty sure that the second clause should be isinstance(val, list)
: List
is a type descriptor that is never instantiated. list
is the runtime type.
OK. |
This PR allows FxImporter to parse operators that return list of tensors, such as aten,unbind.int and aten.split. NOTE: This is a temporary patch, we should remove it once [this torch-mlir PR](llvm/torch-mlir#3031) is merged.
@Vremold Thank you for your work! I've encountered a similar issue myself, and this pull request neatly addresses my problem. I noticed that this PR has not yet been merged. Would it be possible to resolve the merge conflicts and proceed with the merge? Your prompt attention would be greatly appreciated. Many thanks! |
@rog93 No problem. Done. |
eb5adae
to
5d5302d
Compare
Appreciate it~ Seems like there are code conflicts that need resolving before the merge can proceed. |
May I ask if this commit is ready to be merged now, or do we still need approvals from other reviewers? |
Hi @Vremold, can you please rebase this PR? |
As described in title. The context is that I'm using torch.export + torch_mlir.fx_importer to trace a Mixtral MoE layer. It has aten.unbind.int op, which returns a list of tensor as the output.