Skip to content

Commit f51572f

Browse files
yizhang-nvdominicshanshan
authored andcommitted
[https://nvbugs/5550409][fix] Disable torch compile in piecewise attention part to Avoid host overhead (NVIDIA#8708)
Signed-off-by: yizhang-nv <[email protected]>
1 parent 138b26d commit f51572f

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
from ..utils import (get_model_extra_attrs,
1515
get_per_request_piecewise_cuda_graph_flag,
16-
get_piecewise_cuda_graph_flag, make_weak_ref)
16+
get_piecewise_cuda_graph_flag, make_weak_ref,
17+
set_piecewise_running)
1718
from .multi_stream.auto_multi_stream import multi_stream_schedule
1819
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
1920

@@ -27,6 +28,7 @@ def __init__(
2728
compile_time_num_tokens: Union[int | torch.SymInt],
2829
capture_num_tokens: list[int],
2930
exclude_modules_id: list[int],
31+
piecewise_runner_num: int,
3032
graph_pool_handle: tuple[int, int],
3133
garbage_collect_values: bool = True,
3234
graph=None,
@@ -38,6 +40,8 @@ def __init__(
3840

3941
self.compile_time_num_tokens = compile_time_num_tokens
4042
self.capture_num_tokens = capture_num_tokens
43+
self.piecewise_runner_num = piecewise_runner_num
44+
self.piecewise_runner_idx = 0
4145
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
4246
self.graph_pool_handle = graph_pool_handle
4347
self.enable_inductor = enable_inductor
@@ -90,8 +94,10 @@ def call_module(self, target, args, kwargs):
9094
self.graph_pool_handle,
9195
compile_fx(submod, args) if self.enable_inductor else submod,
9296
self.enable_inductor,
97+
self.piecewise_runner_idx == 0,
98+
self.piecewise_runner_idx == self.piecewise_runner_num - 1,
9399
)
94-
100+
self.piecewise_runner_idx += 1
95101
return output
96102

97103

@@ -124,6 +130,8 @@ def __init__(
124130
graph_pool_handle,
125131
default_callable: Callable,
126132
enable_inductor: bool,
133+
is_first_runner: bool,
134+
is_last_runner: bool,
127135
):
128136
if runtime_num_tokens_idx != None:
129137
assert isinstance(compile_time_num_tokens, torch.SymInt)
@@ -138,6 +146,8 @@ def __init__(
138146
self.enable_inductor = enable_inductor
139147

140148
self.entries: dict[int, Entry] = {}
149+
self.is_first_runner = is_first_runner
150+
self.is_last_runner = is_last_runner
141151

142152
for num_tokens in capture_num_tokens:
143153
self.entries[num_tokens] = Entry(
@@ -161,6 +171,12 @@ def __call__(self, *args):
161171
or not get_per_request_piecewise_cuda_graph_flag()):
162172
return self.default_callable(*args)
163173

174+
if self.is_first_runner or self.is_last_runner:
175+
if self.is_first_runner == self.is_last_runner:
176+
set_piecewise_running(False)
177+
else:
178+
set_piecewise_running(self.is_first_runner)
179+
164180
entry = self.entries[runtime_num_of_token]
165181

166182
if entry.enable_inductor and not entry.compiled:
@@ -267,6 +283,7 @@ def piecewise_optimizer(
267283
input_num_tokens,
268284
capture_num_tokens,
269285
exclude_modules_id,
286+
len(set(node_to_graph_id.values())) - len(exclude_modules_id),
270287
graph_pool_handle,
271288
max_num_streams=max_num_streams,
272289
)

tensorrt_llm/_torch/modules/attention.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..model_config import ModelConfig
2424
from ..peft.lora.layer import LoraLayer, LoraModuleType
2525
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
26-
is_torch_compiling)
26+
is_piecewise_running, is_torch_compiling)
2727
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
2828
from .multi_stream_utils import maybe_execute_in_parallel
2929
from .rms_norm import RMSNorm
@@ -76,13 +76,24 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
7676
return metadata, attn_layer
7777

7878

79-
@torch.compile
80-
def compiled_copy_(dst, src):
79+
def maybe_compile(func):
80+
81+
def wrapper(*args, **kwargs):
82+
if is_piecewise_running():
83+
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
84+
return func(*args, **kwargs)
85+
return torch.compile(func)(*args, **kwargs)
86+
87+
return wrapper
88+
89+
90+
@maybe_compile
91+
def maybe_compiled_copy_(dst, src):
8192
dst.copy_(src)
8293

8394

84-
@torch.compile
85-
def compiled_cat(tensors, dim):
95+
@maybe_compile
96+
def maybe_compiled_cat(tensors, dim):
8697
return torch.cat(tensors, dim)
8798

8899

@@ -1222,8 +1233,9 @@ def forward_context_default(
12221233
)
12231234

12241235
k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
1225-
compiled_copy_(k[..., :self.qk_nope_head_dim],
1226-
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
1236+
maybe_compiled_copy_(
1237+
k[..., :self.qk_nope_head_dim],
1238+
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
12271239
if self.apply_rotary_emb:
12281240
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
12291241
self.qk_rope_head_dim)
@@ -1317,7 +1329,7 @@ def forward_context_with_cached_kv(
13171329
full_k_nope = full_k_nope.view(-1, self.num_heads,
13181330
self.qk_nope_head_dim)
13191331
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
1320-
full_k = compiled_cat(
1332+
full_k = maybe_compiled_cat(
13211333
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
13221334
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
13231335

@@ -1412,7 +1424,7 @@ def forward_context_with_chunked_prefill(
14121424
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
14131425
self.qk_nope_head_dim)
14141426
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
1415-
chunked_k = compiled_cat(
1427+
chunked_k = maybe_compiled_cat(
14161428
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
14171429
dim=-1)
14181430
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
@@ -1470,7 +1482,8 @@ def forward_context_with_chunked_prefill(
14701482

14711483
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
14721484
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
1473-
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
1485+
k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)),
1486+
dim=-1)
14741487
k = k.view(-1, self.num_heads * self.qk_head_dim)
14751488

14761489
# copy q_lens to replace kv_lens_runtime

tensorrt_llm/_torch/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorrt_llm.quantization.utils import fp4_utils
1313

1414
is_torch_compiling_flag = False
15+
is_piecewise_running_flag = False
1516

1617
aux_stream_name_list = [
1718
'Attention',
@@ -40,6 +41,16 @@ def is_torch_compiling() -> bool:
4041
return is_torch_compiling_flag
4142

4243

44+
def set_piecewise_running(enable: bool):
45+
global is_piecewise_running_flag
46+
is_piecewise_running_flag = enable
47+
48+
49+
def is_piecewise_running() -> bool:
50+
global is_piecewise_running_flag
51+
return is_piecewise_running_flag
52+
53+
4354
_global_attrs = threading.local()
4455

4556

0 commit comments

Comments
 (0)