Skip to content

🐛 [Bug] SymInt placeholder nodes not fully removed in remove_sym_nodes pass, causing issues with TensorRT lowering #3981

@lwlsaysnuaa

Description

@lwlsaysnuaa

Bug Description

When compiling a model using torch.compile(..., backend="tensorrt") with dynamic batch dimension marked via torch._dynamo.mark_dynamic, the remove_sym_nodes pass in Torch-TensorRT does not fully remove torch.SymInt placeholder nodes.

there is still a torch.SymInt placeholder node (%s0) left in the FX graph. This causes problems for downstream lowering / codegen, since these SymInt placeholders are expected to be eliminated and replaced with tensor-based size queries (e.g., via torch.ops.aten.sym_size).

I believe the pass should instead replace those SymInt placeholders with aten.sym_size calls on the corresponding Tensor placeholders (e.g., the first dimension of the input tensor), and then remove the SymInt placeholders entirely.

To Reproduce

Reproduction Code:

import torch
import torch.nn as nn
import torch_tensorrt
import logging
logging.basicConfig(level=logging.DEBUG)

class ExpandReshapeModel(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

    def forward(self, x: torch.Tensor):
        batch_size = x.shape[0]
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.qkv_proj(x)
        reshaped_qkv = x.reshape(
            batch_size,
            x.size(1),
            3,
            12,
            -1
        )
        return reshaped_qkv

model = ExpandReshapeModel(embed_dim=768).cuda().eval()
model = torch.compile(model, backend="tensorrt")

x = torch.randn(4, 196, 768).cuda()
torch._dynamo.mark_dynamic(x, index=0, min=2, max=32)

out = model(x)
print(out.shape)

FX Graph Before / After remove_sym_nodes
From the debug log of torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:

Current behavior (after the pass, but still with SymInt placeholder):

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %s0 : torch.SymInt [num_users=2] = placeholder[target=s0]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
    %l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
    %l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
    %clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %s0, -1, -1), kwargs = {})
    %x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
    %x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
    %reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %s0, 197, 3, 12, -1), kwargs = {})
    return (reshaped_qkv,)

As shown, %s0 : torch.SymInt = placeholder[target=s0] is still present and used in expand and reshape.
Expected / Desired Transformation
The problematic torch.SymInt placeholder should be removed and replaced by a dynamic size extracted from the relevant tensor placeholder (here, the batch dimension of %l_x_), via aten.sym_size.

For example, a correct/desired transformed FX graph would look like:

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %sym_size : [num_users=2] = call_function[target=torch.ops.aten.sym_size](args = (%l_x_, 0), kwargs = {})
    %l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
    %l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
    %l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
    %clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %sym_size, -1, -1), kwargs = {})
    %x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
    %x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
    %reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %sym_size, 197, 3, 12, -1), kwargs = {})
    return (reshaped_qkv,)

Here, %sym_size = torch.ops.aten.sym_size(%l_x_, 0) replaces the SymInt placeholder %s0, and %s0 is completely removed.

I am happy to provide additional logs, the full FX graph before/after, or try a patch if you can point to the relevant parts of the code.

Expected behavior

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %sym_size : [num_users=2] = call_function[target=torch.ops.aten.sym_size](args = (%l_x_, 0), kwargs = {})
    %l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
    %l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
    %l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
    %clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %sym_size, -1, -1), kwargs = {})
    %x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
    %x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
    %reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %sym_size, 197, 3, 12, -1), kwargs = {})
    return (reshaped_qkv,)

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 2.6):
  • PyTorch Version (e.g. 2.6):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions