Skip to content

🐛 [Bug] Symbolic shape error when using repeat_interleave for K and V followed by SDPA #3964

@zhaoyuanh

Description

@zhaoyuanh

Bug Description

from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt


class SampleNetwork(nn.Module):
    def __init__(
        self,
        num_attention_heads: int,
    ) -> None:
        super().__init__()

        self.heads = num_attention_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        query = query.unflatten(2, (self.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (self.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (self.heads, -1)).transpose(1, 2)

        query_idx = torch.tensor(query.size(3), device=query.device)
        key_idx = torch.tensor(key.size(3), device=key.device)
        value_idx = torch.tensor(value.size(3), device=value.device)
        # Use expand operations instead of repeat_interleave for torch_tensorrt compatibility
        # key = key.expand(-1, -1, -1, query_idx // key_idx * key.size(3))
        # value = value.expand(-1, -1, -1, query_idx // value_idx * value.size(3))
        key = key.repeat_interleave(query_idx // key_idx, dim=3)
        value = value.repeat_interleave(query_idx // value_idx, dim=3)

        # hidden_states = query + key + value
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)

        return hidden_states


def export_attention(model, query, key, value):
    with torch.no_grad():
        # Only mark sequence length as dynamic, like run_llm.py does
        # Don't mark batch dimension as dynamic to avoid constraint violations
        seq_len = torch.export.Dim("seq_len", min=1, max=56320)
        print("Trying to export the model using torch.export.export()..")
        # strict=False only enables autograd tracing and excludes dynamo.
        # Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
        ep = torch.export.export(
            model,
            args=(query, key, value),
            kwargs={},
            dynamic_shapes=({1: seq_len}, {1: seq_len}, {1: seq_len}), 
            strict=False,
        )

    return ep


def compile_torchtrt(model, query, key, value, min_block_size, debug):
    ep = export_attention(model, query, key, value)
    # Set precision specific flags
    use_fp32_acc = False
    use_explicit_typing = False
    enabled_precisions = {torch.bfloat16}
    use_fp32_acc = False

    with torch_tensorrt.logging.debug() if debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[query, key, value],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
            use_explicit_typing=use_explicit_typing,
            use_fp32_acc=use_fp32_acc,
            disable_tf32=True,
            use_python_runtime=True,
            debug=debug,
            offload_module_to_cpu=False,
            min_block_size=min_block_size,
        )

    return trt_model


if __name__ == "__main__":
    precision = "BF16"
    min_block_size = 1
    batch_size = 1
    seq_len = 28160
    num_attention_heads = 32
    attention_head_dim = 128
    enable_pytorch_run = True
    debug = False
    device = "cuda"

    with torch.inference_mode():
        hidden_size = num_attention_heads * attention_head_dim

        model = SampleNetwork(
            num_attention_heads=num_attention_heads,
        ).to(device)

        # Convert model to the appropriate precision
        model = model.to(torch.bfloat16)
        input_dtype = torch.bfloat16

        # Prepare input for benchmarking or evaluation
        query = torch.randn(
            batch_size, seq_len, hidden_size, dtype=input_dtype
        ).to(device)
        key = torch.randn(
            batch_size, seq_len, hidden_size, dtype=input_dtype
        ).to(device)
        value = torch.randn(
            batch_size, seq_len, hidden_size, dtype=input_dtype
        ).to(device)

        # Pyt
        pyt_output = model(query, key, value)
        print("PyTorch output shape:", pyt_output.shape)
        print("Pytorch output first 10 elements:", pyt_output.flatten()[:10])

        # Compile the model with Torch-TensorRT
        trt_model = compile_torchtrt(model, query, key, value, min_block_size, debug)
        # trt_model = torch.compile(
        #     model,
        #     backend="torch_tensorrt",
        #     options={
        #         "enabled_precisions": {input_dtype},
        #         "use_python_runtime": True,
        #         "min_block_size": min_block_size,
        #     },
        #     dynamic=None,
        # )
        trt_model = trt_model.to(device)

        trt_output = trt_model(query, key, value)
        print("TensorRT output shape:", trt_output.shape)
        print("TensorRT output first 10 elements:", trt_output.flatten()[:10])

Here is the error message.

Traceback (most recent call last):
  File "/workspace/tools/llm/minimal_reproducer.py", line 130, in <module>
    trt_model = compile_torchtrt(model, query, key, value, min_block_size, debug)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tools/llm/minimal_reproducer.py", line 66, in compile_torchtrt
    ep = export_attention(model, query, key, value)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tools/llm/minimal_reproducer.py", line 54, in export_attention
    ep = torch.export.export(
         ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 304, in export
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 271, in export
    return _export(
           ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1152, in wrapper
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1118, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2165, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1152, in wrapper
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1118, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2028, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1970, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1764, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1893, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1678, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2351, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2283, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2254, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1005, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1283, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1865, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1341, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1582, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1935, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1877, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1935, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tools/llm/minimal_reproducer.py", line 38, in forward
    hidden_states = F.scaled_dot_product_attention(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1389, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1436, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_export/non_strict_utils.py", line 1067, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 961, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_library/utils.py", line 286, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1491, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 840, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1361, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2077, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1496, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2625, in _dispatch_impl
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 885, in decompose
    return self._op_dk(dk, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
    r = self.evaluate()
        ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7239, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7339, in evaluate_expr
    return self._inner_evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7362, in _inner_evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7586, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(128, u0) (unhinted: Eq(128, u0)).  (Size-like symbols: u0)

Caused by: (_ops.py:885 in decompose)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/workspace/tools/llm/minimal_reproducer.py", line 38, in forward
    hidden_states = F.scaled_dot_product_attention(

To fix the error, insert one of the following checks before this call:
  1. torch._check(128 == key.shape[3])
  2. torch._check(128 != key.shape[3])

(These suggested fixes were derived by replacing `u0` with key.shape[3] in Eq(128, u0) and its negation.)

To Reproduce

Steps to reproduce the behavior:

  1. Run Python script above

Expected behavior

Passed and the Torch-TRT output matches the Torch output.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
  • PyTorch Version (e.g. 1.0): 2.9.0
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
  • Are you using local sources or building from archives: No
  • Python version: 3.12
  • CUDA version: 13.0
  • GPU models and configuration: Nvidia B200
  • Any other relevant information: None

Additional context

The script can pass if I switch from repeat_interleave to expand. It also passes if I disable SDPA and replace it with hidden_states = query + key + value. The script can also pass with torch.compile() but failed with torch_tensorrt.dynamo.compile.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions