Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions tensorrt_llm/_torch/compilation/piecewise_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from ..utils import (get_model_extra_attrs,
get_per_request_piecewise_cuda_graph_flag,
get_piecewise_cuda_graph_flag, make_weak_ref)
get_piecewise_cuda_graph_flag, make_weak_ref,
set_piecewise_running)
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function

Expand All @@ -27,6 +28,7 @@ def __init__(
compile_time_num_tokens: Union[int | torch.SymInt],
capture_num_tokens: list[int],
exclude_modules_id: list[int],
piecewise_runner_num: int,
graph_pool_handle: tuple[int, int],
garbage_collect_values: bool = True,
graph=None,
Expand All @@ -38,6 +40,8 @@ def __init__(

self.compile_time_num_tokens = compile_time_num_tokens
self.capture_num_tokens = capture_num_tokens
self.piecewise_runner_num = piecewise_runner_num
self.piecewise_runner_idx = 0
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
self.graph_pool_handle = graph_pool_handle
self.enable_inductor = enable_inductor
Expand Down Expand Up @@ -90,8 +94,10 @@ def call_module(self, target, args, kwargs):
self.graph_pool_handle,
compile_fx(submod, args) if self.enable_inductor else submod,
self.enable_inductor,
self.piecewise_runner_idx == 0,
self.piecewise_runner_idx == self.piecewise_runner_num - 1,
)

self.piecewise_runner_idx += 1
return output


Expand Down Expand Up @@ -124,6 +130,8 @@ def __init__(
graph_pool_handle,
default_callable: Callable,
enable_inductor: bool,
is_first_runner: bool,
is_last_runner: bool,
):
if runtime_num_tokens_idx != None:
assert isinstance(compile_time_num_tokens, torch.SymInt)
Expand All @@ -138,6 +146,8 @@ def __init__(
self.enable_inductor = enable_inductor

self.entries: dict[int, Entry] = {}
self.is_first_runner = is_first_runner
self.is_last_runner = is_last_runner

for num_tokens in capture_num_tokens:
self.entries[num_tokens] = Entry(
Expand All @@ -161,6 +171,12 @@ def __call__(self, *args):
or not get_per_request_piecewise_cuda_graph_flag()):
return self.default_callable(*args)

if self.is_first_runner or self.is_last_runner:
if self.is_first_runner == self.is_last_runner:
set_piecewise_running(False)
else:
set_piecewise_running(self.is_first_runner)

entry = self.entries[runtime_num_of_token]

if entry.enable_inductor and not entry.compiled:
Expand Down Expand Up @@ -267,6 +283,7 @@ def piecewise_optimizer(
input_num_tokens,
capture_num_tokens,
exclude_modules_id,
len(set(node_to_graph_id.values())) - len(exclude_modules_id),
graph_pool_handle,
max_num_streams=max_num_streams,
)
Expand Down
33 changes: 23 additions & 10 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_torch_compiling)
is_piecewise_running, is_torch_compiling)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
Expand Down Expand Up @@ -76,13 +76,24 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
return metadata, attn_layer


@torch.compile
def compiled_copy_(dst, src):
def maybe_compile(func):

def wrapper(*args, **kwargs):
if is_piecewise_running():
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
return func(*args, **kwargs)
return torch.compile(func)(*args, **kwargs)

return wrapper


@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)


@torch.compile
def compiled_cat(tensors, dim):
@maybe_compile
def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)


Expand Down Expand Up @@ -1222,8 +1233,9 @@ def forward_context_default(
)

k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
compiled_copy_(k[..., :self.qk_nope_head_dim],
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
maybe_compiled_copy_(
k[..., :self.qk_nope_head_dim],
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
if self.apply_rotary_emb:
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
self.qk_rope_head_dim)
Expand Down Expand Up @@ -1317,7 +1329,7 @@ def forward_context_with_cached_kv(
full_k_nope = full_k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
full_k = compiled_cat(
full_k = maybe_compiled_cat(
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)

Expand Down Expand Up @@ -1412,7 +1424,7 @@ def forward_context_with_chunked_prefill(
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
chunked_k = compiled_cat(
chunked_k = maybe_compiled_cat(
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
dim=-1)
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
Expand Down Expand Up @@ -1470,7 +1482,8 @@ def forward_context_with_chunked_prefill(

k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)),
dim=-1)
k = k.view(-1, self.num_heads * self.qk_head_dim)

# copy q_lens to replace kv_lens_runtime
Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tensorrt_llm.quantization.utils import fp4_utils

is_torch_compiling_flag = False
is_piecewise_running_flag = False

aux_stream_name_list = [
'Attention',
Expand Down Expand Up @@ -40,6 +41,16 @@ def is_torch_compiling() -> bool:
return is_torch_compiling_flag


def set_piecewise_running(enable: bool):
global is_piecewise_running_flag
is_piecewise_running_flag = enable


def is_piecewise_running() -> bool:
global is_piecewise_running_flag
return is_piecewise_running_flag


_global_attrs = threading.local()


Expand Down
Loading