Skip to content

Conversation

@keshavvinayak01
Copy link

@keshavvinayak01 keshavvinayak01 commented Oct 10, 2025

This PR adds support for emitting graphs for Pytorch HOPs, beginning with torch._higher_order_ops.while_loop.

The proposed change is to modify the import_program to call function _import_all_child_modules, which recursively imports the stateless graph for all the children modules.

Since HOP operator graphs are stateless graphs with no mutation, it is correct to import them as stateless graphs, although the method import_stateless_graph is marked as "deprecated".

@keshavvinayak01 keshavvinayak01 changed the title [WIP] Modified fx_imoprter to support hop_while_loop [WIP] Modified fx_importer to support hop_while_loop Oct 10, 2025
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review October 13, 2025 09:33
@keshavvinayak01 keshavvinayak01 marked this pull request as draft October 13, 2025 19:36
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review October 15, 2025 15:14
@keshavvinayak01 keshavvinayak01 changed the title [WIP] Modified fx_importer to support hop_while_loop [TORCH] Modified fx_importer to support hop_while_loop Oct 15, 2025
@zjgarvey
Copy link
Collaborator

I'm working on resolving ci issues right now. Once that lands, I'll ping here to have you sync your branch with main and review.

@keshavvinayak01
Copy link
Author

I'm working on resolving ci issues right now. Once that lands, I'll ping here to have you sync your branch with main and review.

Sounds good, thanks.

@zjgarvey
Copy link
Collaborator

@keshavvinayak01 can you sync with main?

@keshavvinayak01 keshavvinayak01 force-pushed the keshavvinayak01/torch-hop-while branch from 3eb338f to c8c711c Compare October 22, 2025 16:42
@keshavvinayak01
Copy link
Author

Should be synced. Had to force push.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some name inconsistencies, and even after resolving these, the generated IR does not lower out of torch dialect (some issue with the conversion to scf.while). I'd double-check the correctness of the generated prim loop, and possibly consider adding an e2e test

for child_name, child_module in prog.graph.owning_module.named_children():
if isinstance(child_module, GraphModule) and hasattr(child_module, 'graph'):
# Generate function name: parent_childname
child_func_name = f"{parent_name}_{child_name}_{id(child_module)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with the _hop_while_loop callee names (hence the ci failure). Since you can't easily pass the parent name to this func, it would make sense to simply put the child name + an additional uniqueifier (I'm not a fan of id, since it will not be reproducible between runs).

You might be better off defining a mapping between graph modules and mlir func names as an attribute of the FxImporter, and handling name collisions as necessary there.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some more comments for now, thanks for addressing the earlier comments.

Please add an e2e test for this op and debug, since I wasn't able to lower the test output IR to linalg/tensor/scf (even with consistent naming).

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
@keshavvinayak01
Copy link
Author

I've addressed your comments, but specifically about the generated while_loop IR, you're right; prim will not support function calls in body. I'm going to modify the implementation to inline the function within the body.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body
Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition
Without these fixes, the conversion would fail when while loops carry tensor values

Also modified basic_test.py FILECHECK statements.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
@keshavvinayak01
Copy link
Author

Had to modify TorchToScf conversion to support these changes. In general, the current state of TorchToScf would error out for any tensor values. @zjgarvey Please take a look and let me know. I also saw that PR https://github.com/llvm/torch-mlir/pull/3040/files solved this issue, but they were also removed?

@zjgarvey
Copy link
Collaborator

Had to modify TorchToScf conversion to support these changes. In general, the current state of TorchToScf would error out for any tensor values. @zjgarvey Please take a look and let me know. I also saw that PR https://github.com/llvm/torch-mlir/pull/3040/files solved this issue, but they were also removed?

That PR added tensor arg support for "for-like" loop conversion, but not "while-like" loop conversion.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but please add at least one e2e test in projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py. Let me know if you want some pointers on adding one of these.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants