Skip to content

Commit 4ab7d6e

Browse files
committed
Disable torch compile in piecewise to Avoid host overhead
Signed-off-by: yizhang-nv <[email protected]>
1 parent 0e36484 commit 4ab7d6e

File tree

3 files changed

+46
-12
lines changed

3 files changed

+46
-12
lines changed

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
from ..utils import (get_model_extra_attrs,
1515
get_per_request_piecewise_cuda_graph_flag,
16-
get_piecewise_cuda_graph_flag, make_weak_ref)
16+
get_piecewise_cuda_graph_flag, make_weak_ref,
17+
set_piecewise_running)
1718
from .multi_stream.auto_multi_stream import multi_stream_schedule
1819
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
1920

@@ -27,6 +28,7 @@ def __init__(
2728
compile_time_num_tokens: Union[int | torch.SymInt],
2829
capture_num_tokens: list[int],
2930
exclude_modules_id: list[int],
31+
piecewise_runner_num: int,
3032
graph_pool_handle: tuple[int, int],
3133
garbage_collect_values: bool = True,
3234
graph=None,
@@ -38,6 +40,8 @@ def __init__(
3840

3941
self.compile_time_num_tokens = compile_time_num_tokens
4042
self.capture_num_tokens = capture_num_tokens
43+
self.piecewise_runner_num = piecewise_runner_num
44+
self.piecewise_runner_idx = 0
4145
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
4246
self.graph_pool_handle = graph_pool_handle
4347
self.enable_inductor = enable_inductor
@@ -90,8 +94,10 @@ def call_module(self, target, args, kwargs):
9094
self.graph_pool_handle,
9195
compile_fx(submod, args) if self.enable_inductor else submod,
9296
self.enable_inductor,
97+
self.piecewise_runner_idx == 0,
98+
self.piecewise_runner_idx == self.piecewise_runner_num - 1,
9399
)
94-
100+
self.piecewise_runner_idx += 1
95101
return output
96102

97103

@@ -124,6 +130,8 @@ def __init__(
124130
graph_pool_handle,
125131
default_callable: Callable,
126132
enable_inductor: bool,
133+
is_first_runner: bool,
134+
is_last_runner: bool,
127135
):
128136
if runtime_num_tokens_idx != None:
129137
assert isinstance(compile_time_num_tokens, torch.SymInt)
@@ -138,6 +146,8 @@ def __init__(
138146
self.enable_inductor = enable_inductor
139147

140148
self.entries: dict[int, Entry] = {}
149+
self.is_first_runner = is_first_runner
150+
self.is_last_runner = is_last_runner
141151

142152
for num_tokens in capture_num_tokens:
143153
self.entries[num_tokens] = Entry(
@@ -161,6 +171,9 @@ def __call__(self, *args):
161171
or not get_per_request_piecewise_cuda_graph_flag()):
162172
return self.default_callable(*args)
163173

174+
if self.is_first_runner or self.is_last_runner:
175+
set_piecewise_running(self.is_first_runner)
176+
164177
entry = self.entries[runtime_num_of_token]
165178

166179
if entry.enable_inductor and not entry.compiled:
@@ -267,6 +280,7 @@ def piecewise_optimizer(
267280
input_num_tokens,
268281
capture_num_tokens,
269282
exclude_modules_id,
283+
len(set(node_to_graph_id.values())) - len(exclude_modules_id),
270284
graph_pool_handle,
271285
max_num_streams=max_num_streams,
272286
)

tensorrt_llm/_torch/modules/attention.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..model_config import ModelConfig
2424
from ..peft.lora.layer import LoraLayer, LoraModuleType
2525
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
26-
is_torch_compiling)
26+
is_piecewise_running, is_torch_compiling)
2727
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
2828
from .multi_stream_utils import maybe_execute_in_parallel
2929
from .rms_norm import RMSNorm
@@ -76,13 +76,20 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
7676
return metadata, attn_layer
7777

7878

79-
@torch.compile
80-
def compiled_copy_(dst, src):
79+
def maybe_compile(func):
80+
if is_piecewise_running():
81+
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
82+
return func
83+
return torch.compile(func)
84+
85+
86+
@maybe_compile
87+
def maybe_compiled_copy_(dst, src):
8188
dst.copy_(src)
8289

8390

84-
@torch.compile
85-
def compiled_cat(tensors, dim):
91+
@maybe_compile
92+
def maybe_compiled_cat(tensors, dim):
8693
return torch.cat(tensors, dim)
8794

8895

@@ -1222,8 +1229,9 @@ def forward_context_default(
12221229
)
12231230

12241231
k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
1225-
compiled_copy_(k[..., :self.qk_nope_head_dim],
1226-
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
1232+
maybe_compiled_copy_(
1233+
k[..., :self.qk_nope_head_dim],
1234+
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
12271235
if self.apply_rotary_emb:
12281236
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
12291237
self.qk_rope_head_dim)
@@ -1317,7 +1325,7 @@ def forward_context_with_cached_kv(
13171325
full_k_nope = full_k_nope.view(-1, self.num_heads,
13181326
self.qk_nope_head_dim)
13191327
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
1320-
full_k = compiled_cat(
1328+
full_k = maybe_compiled_cat(
13211329
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
13221330
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
13231331

@@ -1412,7 +1420,7 @@ def forward_context_with_chunked_prefill(
14121420
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
14131421
self.qk_nope_head_dim)
14141422
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
1415-
chunked_k = compiled_cat(
1423+
chunked_k = maybe_compiled_cat(
14161424
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
14171425
dim=-1)
14181426
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
@@ -1470,7 +1478,8 @@ def forward_context_with_chunked_prefill(
14701478

14711479
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
14721480
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
1473-
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
1481+
k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)),
1482+
dim=-1)
14741483
k = k.view(-1, self.num_heads * self.qk_head_dim)
14751484

14761485
# copy q_lens to replace kv_lens_runtime

tensorrt_llm/_torch/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorrt_llm.quantization.utils import fp4_utils
1313

1414
is_torch_compiling_flag = False
15+
is_piecewise_running_flag = False
1516

1617
aux_stream_name_list = [
1718
'Attention',
@@ -40,6 +41,16 @@ def is_torch_compiling() -> bool:
4041
return is_torch_compiling_flag
4142

4243

44+
def set_piecewise_running(enable: bool):
45+
global is_piecewise_running_flag
46+
is_piecewise_running_flag = enable
47+
48+
49+
def is_piecewise_running() -> bool:
50+
global is_piecewise_running_flag
51+
return is_piecewise_running_flag
52+
53+
4354
_global_attrs = threading.local()
4455

4556

0 commit comments

Comments
 (0)