Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Encountered bug when using Torch-TensorRT #3406

Open
cehongwang opened this issue Feb 20, 2025 · 1 comment
Open

🐛 [Bug] Encountered bug when using Torch-TensorRT #3406

cehongwang opened this issue Feb 20, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@cehongwang
Copy link
Collaborator

cehongwang commented Feb 20, 2025

Bug Description

When using CudaGraph, if there is a graph break resulting a pytorch subgraph, and if the input has a nested dictionary, cuda graph breaks.
In the following example, the number of input is 2 in the forward function, but 3 in the graph module, which gives an error.

To Reproduce

import torch
import torch_tensorrt

class TestModel(torch.nn.Module):
    def forward(self, x, additional_param: dict):
        x = x + additional_param['y']
        return x  * additional_param['z'] + 5

device = "cuda:0"
inputs=(torch.rand(1).to(device),)
kwarg_inputs={
    "additional_param": {
        'y':torch.rand(1).to(device), 
        'z':torch.rand(1).to(device), 
    }
}
model = TestModel().to(device)
compiled_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    arg_inputs=inputs,
    kwarg_inputs=kwarg_inputs,
    min_block_size=1,
    torch_executed_ops={"torch.ops.aten.mul.Tensor"}
)

with torch_tensorrt.runtime.enable_cudagraphs(
    compiled_model
) as cudagraphs_module:
    cudagraphs_module(*inputs, **kwarg_inputs)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): nightly
  • PyTorch Version (e.g. 1.0): nightly
  • CPU Architecture: x86
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • GPU models and configuration: A40
@cehongwang cehongwang added the bug Something isn't working label Feb 20, 2025
@keehyuna
Copy link
Collaborator

@cehongwang Thanks for reporting this issue. It seems torch.utils._pytree.tree_unflatten is required to get original args/kwargs. I'm working on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants