-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
class SampleNetwork(nn.Module):
def __init__(
self,
num_attention_heads: int,
) -> None:
super().__init__()
self.heads = num_attention_heads
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.unflatten(2, (self.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (self.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (self.heads, -1)).transpose(1, 2)
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
# Use expand operations instead of repeat_interleave for torch_tensorrt compatibility
# key = key.expand(-1, -1, -1, query_idx // key_idx * key.size(3))
# value = value.expand(-1, -1, -1, query_idx // value_idx * value.size(3))
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
# hidden_states = query + key + value
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
return hidden_states
def export_attention(model, query, key, value):
with torch.no_grad():
# Only mark sequence length as dynamic, like run_llm.py does
# Don't mark batch dimension as dynamic to avoid constraint violations
seq_len = torch.export.Dim("seq_len", min=1, max=56320)
print("Trying to export the model using torch.export.export()..")
# strict=False only enables autograd tracing and excludes dynamo.
# Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
ep = torch.export.export(
model,
args=(query, key, value),
kwargs={},
dynamic_shapes=({1: seq_len}, {1: seq_len}, {1: seq_len}),
strict=False,
)
return ep
def compile_torchtrt(model, query, key, value, min_block_size, debug):
ep = export_attention(model, query, key, value)
# Set precision specific flags
use_fp32_acc = False
use_explicit_typing = False
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
with torch_tensorrt.logging.debug() if debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[query, key, value],
enabled_precisions=enabled_precisions,
# truncate_double=True,
use_explicit_typing=use_explicit_typing,
use_fp32_acc=use_fp32_acc,
disable_tf32=True,
use_python_runtime=True,
debug=debug,
offload_module_to_cpu=False,
min_block_size=min_block_size,
)
return trt_model
if __name__ == "__main__":
precision = "BF16"
min_block_size = 1
batch_size = 1
seq_len = 28160
num_attention_heads = 32
attention_head_dim = 128
enable_pytorch_run = True
debug = False
device = "cuda"
with torch.inference_mode():
hidden_size = num_attention_heads * attention_head_dim
model = SampleNetwork(
num_attention_heads=num_attention_heads,
).to(device)
# Convert model to the appropriate precision
model = model.to(torch.bfloat16)
input_dtype = torch.bfloat16
# Prepare input for benchmarking or evaluation
query = torch.randn(
batch_size, seq_len, hidden_size, dtype=input_dtype
).to(device)
key = torch.randn(
batch_size, seq_len, hidden_size, dtype=input_dtype
).to(device)
value = torch.randn(
batch_size, seq_len, hidden_size, dtype=input_dtype
).to(device)
# Pyt
pyt_output = model(query, key, value)
print("PyTorch output shape:", pyt_output.shape)
print("Pytorch output first 10 elements:", pyt_output.flatten()[:10])
# Compile the model with Torch-TensorRT
trt_model = compile_torchtrt(model, query, key, value, min_block_size, debug)
# trt_model = torch.compile(
# model,
# backend="torch_tensorrt",
# options={
# "enabled_precisions": {input_dtype},
# "use_python_runtime": True,
# "min_block_size": min_block_size,
# },
# dynamic=None,
# )
trt_model = trt_model.to(device)
trt_output = trt_model(query, key, value)
print("TensorRT output shape:", trt_output.shape)
print("TensorRT output first 10 elements:", trt_output.flatten()[:10])
Here is the error message.
Traceback (most recent call last):
File "/workspace/tools/llm/minimal_reproducer.py", line 130, in <module>
trt_model = compile_torchtrt(model, query, key, value, min_block_size, debug)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/tools/llm/minimal_reproducer.py", line 66, in compile_torchtrt
ep = export_attention(model, query, key, value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/tools/llm/minimal_reproducer.py", line 54, in export_attention
ep = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 304, in export
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 271, in export
return _export(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1152, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1118, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2165, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1152, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1118, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2028, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1970, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1764, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1893, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1678, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2351, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2283, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2254, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1005, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1283, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1865, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 850, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1341, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1582, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1935, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1877, in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1935, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/tools/llm/minimal_reproducer.py", line 38, in forward
hidden_states = F.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1389, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1436, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_export/non_strict_utils.py", line 1067, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 961, in handler
return torch._library.utils.handle_dispatch_mode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_library/utils.py", line 286, in handle_dispatch_mode
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1491, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 840, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1361, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2077, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1496, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2625, in _dispatch_impl
r = func.decompose(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 885, in decompose
return self._op_dk(dk, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
r = self.evaluate()
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7239, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7339, in evaluate_expr
return self._inner_evaluate_expr(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7362, in _inner_evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7586, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(128, u0) (unhinted: Eq(128, u0)). (Size-like symbols: u0)
Caused by: (_ops.py:885 in decompose)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/workspace/tools/llm/minimal_reproducer.py", line 38, in forward
hidden_states = F.scaled_dot_product_attention(
To fix the error, insert one of the following checks before this call:
1. torch._check(128 == key.shape[3])
2. torch._check(128 != key.shape[3])
(These suggested fixes were derived by replacing `u0` with key.shape[3] in Eq(128, u0) and its negation.)
To Reproduce
Steps to reproduce the behavior:
- Run Python script above
Expected behavior
Passed and the Torch-TRT output matches the Torch output.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
- PyTorch Version (e.g. 1.0): 2.9.0
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
- Are you using local sources or building from archives: No
- Python version: 3.12
- CUDA version: 13.0
- GPU models and configuration: Nvidia B200
- Any other relevant information: None
Additional context
The script can pass if I switch from repeat_interleave to expand. It also passes if I disable SDPA and replace it with hidden_states = query + key + value. The script can also pass with torch.compile() but failed with torch_tensorrt.dynamo.compile.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working