You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using CudaGraph, if there is a graph break resulting a pytorch subgraph, and if the input has a nested dictionary, cuda graph breaks.
In the following example, the number of input is 2 in the forward function, but 3 in the graph module, which gives an error.
To Reproduce
import torch
import torch_tensorrt
class TestModel(torch.nn.Module):
def forward(self, x, additional_param: dict):
x = x + additional_param['y']
return x * additional_param['z'] + 5
device = "cuda:0"
inputs=(torch.rand(1).to(device),)
kwarg_inputs={
"additional_param": {
'y':torch.rand(1).to(device),
'z':torch.rand(1).to(device),
}
}
model = TestModel().to(device)
compiled_model = torch_tensorrt.compile(
model,
ir="dynamo",
arg_inputs=inputs,
kwarg_inputs=kwarg_inputs,
min_block_size=1,
torch_executed_ops={"torch.ops.aten.mul.Tensor"}
)
with torch_tensorrt.runtime.enable_cudagraphs(
compiled_model
) as cudagraphs_module:
cudagraphs_module(*inputs, **kwarg_inputs)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0): nightly
PyTorch Version (e.g. 1.0): nightly
CPU Architecture: x86
OS (e.g., Linux): Linux
How you installed PyTorch (conda, pip, libtorch, source): pip
Build command you used (if compiling from source):
GPU models and configuration: A40
The text was updated successfully, but these errors were encountered:
Bug Description
When using CudaGraph, if there is a graph break resulting a pytorch subgraph, and if the input has a nested dictionary, cuda graph breaks.
In the following example, the number of input is 2 in the forward function, but 3 in the graph module, which gives an error.
To Reproduce
Expected behavior
Environment
conda
,pip
,libtorch
, source): pipThe text was updated successfully, but these errors were encountered: