|
23 | 23 | from ..model_config import ModelConfig |
24 | 24 | from ..peft.lora.layer import LoraLayer, LoraModuleType |
25 | 25 | from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, |
26 | | - is_torch_compiling) |
| 26 | + is_piecewise_running, is_torch_compiling) |
27 | 27 | from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig |
28 | 28 | from .multi_stream_utils import maybe_execute_in_parallel |
29 | 29 | from .rms_norm import RMSNorm |
@@ -76,13 +76,20 @@ def extract_extra_attrs(layer_idx: str, attn_type: str): |
76 | 76 | return metadata, attn_layer |
77 | 77 |
|
78 | 78 |
|
79 | | -@torch.compile |
80 | | -def compiled_copy_(dst, src): |
| 79 | +def maybe_compile(func): |
| 80 | + if is_piecewise_running(): |
| 81 | + # When piecewise running, we don't need to compile the function to avoid host overhead in attention op. |
| 82 | + return func |
| 83 | + return torch.compile(func) |
| 84 | + |
| 85 | + |
| 86 | +@maybe_compile |
| 87 | +def maybe_compiled_copy_(dst, src): |
81 | 88 | dst.copy_(src) |
82 | 89 |
|
83 | 90 |
|
84 | | -@torch.compile |
85 | | -def compiled_cat(tensors, dim): |
| 91 | +@maybe_compile |
| 92 | +def maybe_compiled_cat(tensors, dim): |
86 | 93 | return torch.cat(tensors, dim) |
87 | 94 |
|
88 | 95 |
|
@@ -1222,8 +1229,9 @@ def forward_context_default( |
1222 | 1229 | ) |
1223 | 1230 |
|
1224 | 1231 | k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim) |
1225 | | - compiled_copy_(k[..., :self.qk_nope_head_dim], |
1226 | | - k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)) |
| 1232 | + maybe_compiled_copy_( |
| 1233 | + k[..., :self.qk_nope_head_dim], |
| 1234 | + k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)) |
1227 | 1235 | if self.apply_rotary_emb: |
1228 | 1236 | k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1, |
1229 | 1237 | self.qk_rope_head_dim) |
@@ -1317,7 +1325,7 @@ def forward_context_with_cached_kv( |
1317 | 1325 | full_k_nope = full_k_nope.view(-1, self.num_heads, |
1318 | 1326 | self.qk_nope_head_dim) |
1319 | 1327 | full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim) |
1320 | | - full_k = compiled_cat( |
| 1328 | + full_k = maybe_compiled_cat( |
1321 | 1329 | (full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1) |
1322 | 1330 | full_k = full_k.view(-1, self.num_heads * self.qk_head_dim) |
1323 | 1331 |
|
@@ -1412,7 +1420,7 @@ def forward_context_with_chunked_prefill( |
1412 | 1420 | chunked_k_nope = chunked_k_nope.view(-1, self.num_heads, |
1413 | 1421 | self.qk_nope_head_dim) |
1414 | 1422 | chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim) |
1415 | | - chunked_k = compiled_cat( |
| 1423 | + chunked_k = maybe_compiled_cat( |
1416 | 1424 | (chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)), |
1417 | 1425 | dim=-1) |
1418 | 1426 | chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim) |
@@ -1470,7 +1478,8 @@ def forward_context_with_chunked_prefill( |
1470 | 1478 |
|
1471 | 1479 | k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) |
1472 | 1480 | k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) |
1473 | | - k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1) |
| 1481 | + k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), |
| 1482 | + dim=-1) |
1474 | 1483 | k = k.view(-1, self.num_heads * self.qk_head_dim) |
1475 | 1484 |
|
1476 | 1485 | # copy q_lens to replace kv_lens_runtime |
|
0 commit comments