Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ We have officially verified support for the following models:
#### Text-only LLMs: `run_llm.py`

```bash
python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
python run_llm.py --model Qwen/Qwen2.5-0.5B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
```

#### Vision Language Models: `run_vlm.py`
Expand Down
3 changes: 2 additions & 1 deletion tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def compile_torchtrt(model, input_ids, args):
use_fp32_acc = False
else:
enabled_precisions = {torch.float32}
use_explicit_typing = True

with torch_tensorrt.logging.debug() if args.debug else nullcontext():
with torch_tensorrt.dynamo.Debugger() if args.debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[input_ids, position_ids],
Expand Down
46 changes: 46 additions & 0 deletions tools/llm/test_trt_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch_tensorrt
from torchtrt_ext import register_sdpa


class ModelNoCache(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v):
return torch._C._nn.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, is_causal=True, scale=1.0
)


model = ModelNoCache().cuda().eval().to(torch.float16)
q = torch.randn(1, 32, 6, 64).cuda().to(torch.float16)
k = torch.randn(1, 32, 6, 64).cuda().to(torch.float16)
v = torch.randn(1, 32, 6, 64).cuda().to(torch.float16)
pyt_outputs = model(q, k, v)

register_sdpa.enable_sdpa_converter("default", None)
seq_len_query = torch.export.Dim("seq_len_query", min=2, max=128)
seq_len_key = torch.export.Dim("seq_len_key", min=2, max=128)
dynamic_shapes = {"q": {2: seq_len_key}, "k": {2: seq_len_key}, "v": {2: seq_len_key}}
ep = torch.export.export(model, (q, k, v), dynamic_shapes=dynamic_shapes, strict=False)

with torch_tensorrt.dynamo.Debugger():
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=(q, k, v),
enabled_precisions={torch.float32},
min_block_size=1,
disable_tf32=True,
use_explicit_typing=True,
)

trt_outputs = trt_gm(q, k, v)
print("Diff between pyt and trt: ", torch.mean(torch.abs(pyt_outputs - trt_outputs)))
# breakpoint()
# q = torch.randn(1, 32, 1, 64).cuda().to(torch.float16)
# k = torch.randn(1, 32, 10, 64).cuda().to(torch.float16)
# v = torch.randn(1, 32, 10, 64).cuda().to(torch.float16)
# pyt_outputs = model(q, k, v)
# trt_outputs = trt_gm(q, k, v)
# print("Diff between pyt and trt: ", torch.mean(torch.abs(pyt_outputs - trt_outputs)))
3 changes: 1 addition & 2 deletions tools/llm/torchtrt_ext/register_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from transformers import AutoConfig, Gemma3TextConfig

from .sdpa_converter import *
from .trt_sdpa_converter import *

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -138,7 +138,6 @@ def _process_sdpa_node(
dropout_p,
is_causal,
)

# Create a new node with torch.nn.functional.scaled_dot_product_attention
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
Expand Down
193 changes: 193 additions & 0 deletions tools/llm/torchtrt_ext/trt_sdpa_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import logging
import math
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt
from torch.fx.node import Target
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
prepend_ones,
)
from torch_tensorrt.dynamo.types import TRTTensor

logger = logging.getLogger(__name__)


def tril(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
row: TRTTensor,
col: TRTTensor,
sliding_window_size: Optional[int] = None,
) -> TRTTensor:
"""
Create a lower triangular mask tensor for attention mechanisms.

This function generates a lower triangular mask that can be used in attention
operations to enforce causal attention (each position can only attend to itself
and previous positions). It optionally supports sliding window attention by
limiting the attention span to a specified window size.

The function creates the mask by:
1. Generating row and column index tensors
2. Computing the difference between row and column indices
3. Creating a mask where row >= col (lower triangular)
4. Optionally applying sliding window constraints

Args:
ctx: TensorRT conversion context for managing the conversion process
target: Target operation identifier (usually the operation being converted)
source_ir: Source IR type (e.g., ATEN, TRT) - can be None
name: Base name for generated TensorRT operations (will be extended with suffixes)
row: Tensor representing the number of rows (sequence length dimension)
col: Tensor representing the number of columns (sequence length dimension)
sliding_window_size: Optional sliding window size for attention span limitation.
If None, creates a full lower triangular mask.
If specified, creates a sliding window mask where each position
can only attend to positions within the window.

Returns:
TRTTensor: A boolean mask tensor with shape [batch, heads, seq_len, seq_len]
where True values indicate allowed attention positions.

Example:
# Create a full lower triangular mask for causal attention
mask = tril(ctx, target, source_ir, "causal_mask", seq_len, seq_len)

# Create a sliding window mask with window size 3
mask = tril(ctx, target, source_ir, "sliding_mask", seq_len, seq_len, 3)

Mask Examples:
Without sliding window (sliding_window_size=None):
For seq_len=5, returns:
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]

With sliding window (sliding_window_size=3):
For seq_len=5, returns:
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[False, True, True, True, False],
[False, False, True, True, True]]

Note:
This function is specifically designed for attention mechanisms in transformer
models and is used internally by the scaled_dot_product_attention converter.
The sliding window functionality is particularly useful for models like Gemma3
that use sliding window attention to reduce computational complexity.
"""
row_arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
)
col_arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
)
row_arange_tensor = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1
)
col_arange_tensor = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0
)
# sub will return the following mask tensor:
# [[0, -1, -2, -3],
# [1, 0, -1, -2],
# [2, 1, 0, -1],
# [3, 2, 1, 0]]
mask = impl.elementwise.sub(
ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor
)
ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0)
if sliding_window_size is None:
# return the following lower triangular mask includes the main diagonal:
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
# 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False],
# 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]])
return ge_0_mask

lt_window_mask = impl.elementwise.lt(
ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size
)
mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask
)
# return the following mask if sliding_window_size is 3:
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
# 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False],
# 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]])
return mask


@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
enabled=True,
supports_dynamic_shapes=True,
)
def scaled_dot_product_attention(
ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: str,
) -> TRTTensor:
source_ir = SourceIR.ATEN

# always create our own attn_mask
query, key, value, mask, dropout_p, is_causal = args

# The exported graph of LLM models have -1 in the attention heads dimension for the query tensor. This value is static for key and value tensors though.
# TODO: We assume that the attention heads dimension is the same for key and value and query tensors. We can implement a lowering pass
# that reads number of attention heads from model config similar to gemma3. For now, we directly use the key.shape[1] as the attention heads dimension.
query = impl.shuffle.reshape(
ctx,
target,
source_ir,
name + "_query_reshape",
input=query,
shape=[query.shape[0], key.shape[1], query.shape[2], query.shape[3]],
)
# L, S = query.shape[-2], key.shape[-2]
query_len = impl.shape.shape(ctx, target, source_ir, name + "_query_len", query, -2)
key_len = impl.shape.shape(ctx, target, source_ir, name + "_key_len", key, -2)
mask_tensor = tril(
ctx,
target,
source_ir,
name + "_tril",
query_len,
key_len,
)

diff = len(query.shape) - len(mask_tensor.shape)

mask_tensor = prepend_ones(ctx, mask_tensor, name + "_prepend_ones", diff)
attention_layer = ctx.net.add_attention(
query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False
)
attention_layer.decomposable = True

assert attention_layer is not None, "attention layer is None"

if is_causal:
attention_layer.mask = mask_tensor

attention_output = attention_layer.get_output(0)

return attention_output