|
23 | 23 | from ..model_config import ModelConfig |
24 | 24 | from ..peft.lora.layer import LoraLayer, LoraModuleType |
25 | 25 | from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, |
26 | | - is_torch_compiling) |
| 26 | + is_piecewise_running, is_torch_compiling) |
27 | 27 | from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig |
28 | 28 | from .multi_stream_utils import maybe_execute_in_parallel |
29 | 29 | from .rms_norm import RMSNorm |
@@ -76,13 +76,24 @@ def extract_extra_attrs(layer_idx: str, attn_type: str): |
76 | 76 | return metadata, attn_layer |
77 | 77 |
|
78 | 78 |
|
79 | | -@torch.compile |
80 | | -def compiled_copy_(dst, src): |
| 79 | +def maybe_compile(func): |
| 80 | + |
| 81 | + def wrapper(*args, **kwargs): |
| 82 | + if is_piecewise_running(): |
| 83 | + # When piecewise running, we don't need to compile the function to avoid host overhead in attention op. |
| 84 | + return func(*args, **kwargs) |
| 85 | + return torch.compile(func)(*args, **kwargs) |
| 86 | + |
| 87 | + return wrapper |
| 88 | + |
| 89 | + |
| 90 | +@maybe_compile |
| 91 | +def maybe_compiled_copy_(dst, src): |
81 | 92 | dst.copy_(src) |
82 | 93 |
|
83 | 94 |
|
84 | | -@torch.compile |
85 | | -def compiled_cat(tensors, dim): |
| 95 | +@maybe_compile |
| 96 | +def maybe_compiled_cat(tensors, dim): |
86 | 97 | return torch.cat(tensors, dim) |
87 | 98 |
|
88 | 99 |
|
@@ -1222,8 +1233,9 @@ def forward_context_default( |
1222 | 1233 | ) |
1223 | 1234 |
|
1224 | 1235 | k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim) |
1225 | | - compiled_copy_(k[..., :self.qk_nope_head_dim], |
1226 | | - k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)) |
| 1236 | + maybe_compiled_copy_( |
| 1237 | + k[..., :self.qk_nope_head_dim], |
| 1238 | + k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)) |
1227 | 1239 | if self.apply_rotary_emb: |
1228 | 1240 | k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1, |
1229 | 1241 | self.qk_rope_head_dim) |
@@ -1317,7 +1329,7 @@ def forward_context_with_cached_kv( |
1317 | 1329 | full_k_nope = full_k_nope.view(-1, self.num_heads, |
1318 | 1330 | self.qk_nope_head_dim) |
1319 | 1331 | full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim) |
1320 | | - full_k = compiled_cat( |
| 1332 | + full_k = maybe_compiled_cat( |
1321 | 1333 | (full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1) |
1322 | 1334 | full_k = full_k.view(-1, self.num_heads * self.qk_head_dim) |
1323 | 1335 |
|
@@ -1412,7 +1424,7 @@ def forward_context_with_chunked_prefill( |
1412 | 1424 | chunked_k_nope = chunked_k_nope.view(-1, self.num_heads, |
1413 | 1425 | self.qk_nope_head_dim) |
1414 | 1426 | chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim) |
1415 | | - chunked_k = compiled_cat( |
| 1427 | + chunked_k = maybe_compiled_cat( |
1416 | 1428 | (chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)), |
1417 | 1429 | dim=-1) |
1418 | 1430 | chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim) |
@@ -1470,7 +1482,8 @@ def forward_context_with_chunked_prefill( |
1470 | 1482 |
|
1471 | 1483 | k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) |
1472 | 1484 | k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) |
1473 | | - k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1) |
| 1485 | + k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), |
| 1486 | + dim=-1) |
1474 | 1487 | k = k.view(-1, self.num_heads * self.qk_head_dim) |
1475 | 1488 |
|
1476 | 1489 | # copy q_lens to replace kv_lens_runtime |
|
0 commit comments