Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Scan outputs for onnx.Loop #3988

Open
javidcf opened this issue Jan 28, 2025 · 6 comments
Open

[ONNX] Scan outputs for onnx.Loop #3988

javidcf opened this issue Jan 28, 2025 · 6 comments
Assignees

Comments

@javidcf
Copy link

javidcf commented Jan 28, 2025

The current implementation for onnx.Loop operator (PR #3408) does not support scan outputs. These are additional outputs that are concatenated across all iterations (see operator spec for more information).

@javidcf
Copy link
Author

javidcf commented Jan 29, 2025

Example of loop with a scan output:
onnx_loop_with_scan_outputs_example.py

@zjgarvey
Copy link
Collaborator

I'm going to assign this to someone. In the meantime, some reproduction instructions:

  1. download the above python script.
  2. run it.
  3. run python -m torch_mlir.tools.import_onnx loop_with_scan_outputs_example.onnx -o repro.mlir
  4. run torch-mlir-opt repro.mlir --convert-torch-onnx-to-torch

We don't seem to have support for scan_outputs as indicated by the OP. These should be the result of successively concatenating the outputs of each iteration. Since we likely cannot initialize a tensor before the loop (since it might not be knowable at compile time what the shape will be for these scan outputs), it might be a good idea to try and return a PrimListConstruct containing all of the values from each iteration, then concatenate the list at the end.

@rkayaith rkayaith self-assigned this Feb 13, 2025
@javidcf
Copy link
Author

javidcf commented Feb 13, 2025

Thank you @zjgarvey and @rkayaith. I actually intended to submit a PR for this if I could crack it myself, but I didn't commit to it because I was not familiar with the code base.

Looking into it, I did come to the same conclusion and tried to implemented this using PrimListConstruct ("for-like" loops could probably be done better with a single allocation in advance and modifying one "row" at a time, and if the number of iterations could be statically known maybe even the result type dimension could be fixed, but I thought those improvements could be left for later). Here is what I put together. I think that is sound, but then I found the process would still fail later down the line. ConvertTorchPrimLoopWhileLikeOp, and I suppose ConvertTorchPrimLoopForLikeOp too, expect only tensor types as iteration variables. I tried to "bypass" that (accept either tensors or "containers" of tensors) but it would still fail later. In addition to this, I found other issues while trying to figure this out:

  • While loops (i.e. not "for-like") do not seem to work correctly, even without scan outputs. I think the iteration index parameter is converted from a Torch tensor type to a native tensor type earlier than expected.
  • Detection of "for-like" loops seems limited at the moment, it expects exactly the connection "input condition -> identity -> output condition", I'm not sure if this is how ONNX always encodes for-like loops.
  • I think that the current code would not detect if a loop operation has an initial false condition (that is, should never enter the loop). I'm not sure if maybe the loop would need to be wrapped into a conditional.
  • Lists are defined as a type, but they are not well supported and, as far as I can tell, they are mostly expected to have size and values known at compile time (for things like concatenating a given list of tensors or reducing across multiple axes where the list is typically a literal).

I do not intend to be critical, I appreciate the complexity of this project. I just wanted to point these out in case you want to take it into consideration for the purpose of task prioritisation.

@zjgarvey
Copy link
Collaborator

@javidcf Thanks for the details, this seems quite complicated as a first issue for @rkayaith .

How high-priority is this issue for you? Do you have a model you are trying to support which relies on such an op?

@zjgarvey
Copy link
Collaborator

zjgarvey commented Feb 13, 2025

Taking another look, it appears that the optional attribute max_trip_count could be used to initialize a tensor for the torch loop, then slice_scatter can be used to insert iterations into the result. We can then slice to an appropriate size after the loop (if termination completed early). [[edit: I think there might be issues lower in the IREE compilation pipeline (resolution of dynamic dims leaving the flow dialect), but we can try to fix that directly.]]

I think we can support this if max_trip_count is present, and we should make sure to have appropriate match failure messaging for unsupportable cases.

@javidcf
Copy link
Author

javidcf commented Feb 13, 2025

Thank you @zjgarvey, yes, I had some ONNX model I was trying to compile with IREE and was failing on this operation. However, I still found other issues even after I removed the operation (I think the next error I found was in "Expand" operation where the number of output dimensions was not known at compilation time).

I eventually decided to start rewriting the model in PyTorch (from TensorFlow), possibly reducing the amount of pre- and post-processing work I was embedding in the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants