-
Notifications
You must be signed in to change notification settings - Fork 376
Description
Bug Description
When compiling a model using torch.compile(..., backend="tensorrt") with dynamic batch dimension marked via torch._dynamo.mark_dynamic, the remove_sym_nodes pass in Torch-TensorRT does not fully remove torch.SymInt placeholder nodes.
there is still a torch.SymInt placeholder node (%s0) left in the FX graph. This causes problems for downstream lowering / codegen, since these SymInt placeholders are expected to be eliminated and replaced with tensor-based size queries (e.g., via torch.ops.aten.sym_size).
I believe the pass should instead replace those SymInt placeholders with aten.sym_size calls on the corresponding Tensor placeholders (e.g., the first dimension of the input tensor), and then remove the SymInt placeholders entirely.
To Reproduce
Reproduction Code:
import torch
import torch.nn as nn
import torch_tensorrt
import logging
logging.basicConfig(level=logging.DEBUG)
class ExpandReshapeModel(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.embed_dim = embed_dim
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = self.qkv_proj(x)
reshaped_qkv = x.reshape(
batch_size,
x.size(1),
3,
12,
-1
)
return reshaped_qkv
model = ExpandReshapeModel(embed_dim=768).cuda().eval()
model = torch.compile(model, backend="tensorrt")
x = torch.randn(4, 196, 768).cuda()
torch._dynamo.mark_dynamic(x, index=0, min=2, max=32)
out = model(x)
print(out.shape)
FX Graph Before / After remove_sym_nodes
From the debug log of torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:
Current behavior (after the pass, but still with SymInt placeholder):
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
%s0 : torch.SymInt [num_users=2] = placeholder[target=s0]
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
%l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
%l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
%clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
%clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %s0, -1, -1), kwargs = {})
%x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
%x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
%reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %s0, 197, 3, 12, -1), kwargs = {})
return (reshaped_qkv,)
As shown, %s0 : torch.SymInt = placeholder[target=s0] is still present and used in expand and reshape.
Expected / Desired Transformation
The problematic torch.SymInt placeholder should be removed and replaced by a dynamic size extracted from the relevant tensor placeholder (here, the batch dimension of %l_x_), via aten.sym_size.
For example, a correct/desired transformed FX graph would look like:
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
%l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
%sym_size : [num_users=2] = call_function[target=torch.ops.aten.sym_size](args = (%l_x_, 0), kwargs = {})
%l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
%l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
%l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
%clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
%clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %sym_size, -1, -1), kwargs = {})
%x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
%x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
%reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %sym_size, 197, 3, 12, -1), kwargs = {})
return (reshaped_qkv,)
Here, %sym_size = torch.ops.aten.sym_size(%l_x_, 0) replaces the SymInt placeholder %s0, and %s0 is completely removed.
I am happy to provide additional logs, the full FX graph before/after, or try a patch if you can point to the relevant parts of the code.
Expected behavior
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
%l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
%sym_size : [num_users=2] = call_function[target=torch.ops.aten.sym_size](args = (%l_x_, 0), kwargs = {})
%l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
%l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
%l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
%clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
%clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
%cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %sym_size, -1, -1), kwargs = {})
%x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
%x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
%reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %sym_size, 197, 3, 12, -1), kwargs = {})
return (reshaped_qkv,)
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 2.6):
- PyTorch Version (e.g. 2.6):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip,libtorch, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: