Skip to content

Commit 08441c7

Browse files
Lijiachen1018lijiachen19
authored andcommitted
[patch]seprate sparse patch (ModelEngine-Group#417)
seprate spase patch Co-authored-by: lijiachen19 <[email protected]>
1 parent f5302fc commit 08441c7

File tree

9 files changed

+45
-1357
lines changed

9 files changed

+45
-1357
lines changed

ucm/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11
from ucm.integration.vllm.uc_connector import UnifiedCacheConnectorV1
22
from ucm.integration.vllm.ucm_connector import UCMConnector
33

4-
__all__ = ["UnifiedCacheConnectorV1", "UCMConnector"]
4+
try:
5+
from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied
6+
7+
ensure_patches_applied()
8+
except Exception as e:
9+
# Don't fail if patches can't be applied - might be running in environment without vLLM
10+
import warnings
11+
12+
warnings.warn(
13+
f"Failed to apply vLLM patches: {e}. "
14+
f"If you're using vLLM, ensure it's installed and patches are compatible."
15+
)
16+
17+
__all__ = ["UCMConnector"]

ucm/integration/vllm/patch/apply_patch.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,19 @@ def apply_all_patches() -> None:
8888
supported_versions = get_supported_versions()
8989
if version not in supported_versions:
9090
logger.warning(
91-
f"vLLM version {version} is not explicitly supported. "
91+
f"vLLM version {version} is not explicitly supported to apply UCM patches. "
9292
f"Supported versions: {', '.join(supported_versions)}. "
93-
f"Attempting to apply 0.9.2 patches..."
9493
)
95-
raise ValueError(f"vLLM version {version} is not explicitly supported")
9694

9795
# Apply version-specific patches
98-
if version == "0.9.1":
99-
_apply_patches_v091()
100-
elif version == "0.9.2":
101-
_apply_patches_v092()
102-
else:
103-
raise ValueError(f"Unsupported vLLM version: {version}")
96+
match version:
97+
case "0.9.2":
98+
_apply_patches_v092()
99+
case _:
100+
logger.warning(
101+
f"Unsupported vLLM version: {version} to apply UCM patches. "
102+
f"Supported versions: {', '.join(supported_versions)}."
103+
)
104104

105105
_patches_applied = True
106106
logger.info(f"All vLLM patches applied successfully for version {version}")
@@ -109,25 +109,13 @@ def apply_all_patches() -> None:
109109
raise
110110

111111

112-
def _apply_patches_v091() -> None:
113-
"""Apply patches for vLLM 0.9.1."""
114-
from .patch_funcs.v091.vllm_adapt import _apply_adapt_patch
115-
116-
_apply_adapt_patch() # apply vllm-adapt-pc.patch
117-
if _patch_ascend():
118-
from .patch_funcs.v091.vllm_ascend_adapt import _apply_ascend_patch
119-
120-
_apply_ascend_patch() # apply vllm-ascend-adapt.patch
121-
122-
123112
def _apply_patches_v092() -> None:
124113
"""Apply patches for vLLM 0.9.2."""
125-
from .patch_funcs.v092.vllm_adapt import _apply_adapt_patches
126-
127-
_apply_adapt_patches()
114+
from .patch_funcs.v092.vllm_patch import _apply_sparse_adapt
128115

116+
_apply_sparse_adapt() # apply vllm-sparse-adapt.patch
129117
if _patch_ascend():
130-
from .patch_funcs.v092.vllm_ascend_adapt import _apply_ascend_patch
118+
from .patch_funcs.v092.vllm_ascend_patch import _apply_ascend_patch
131119

132120
_apply_ascend_patch() # apply vllm-ascend-adapt.patch
133121

ucm/integration/vllm/patch/patch_funcs/v091/__init__.py

Whitespace-only changes.

ucm/integration/vllm/patch/patch_funcs/v091/vllm_adapt.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

ucm/integration/vllm/patch/patch_funcs/v091/vllm_ascend_adapt.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py renamed to ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -44,44 +44,11 @@ def _patch_attention_v1() -> None:
4444
from typing import List
4545

4646
import torch
47-
from vllm.distributed.kv_transfer import (
48-
get_kv_transfer_group,
49-
has_kv_transfer_group,
50-
is_v1_kv_transfer_group,
51-
)
5247
from vllm.forward_context import ForwardContext, get_forward_context
5348
from vllm_ascend.attention import attention_v1
5449

5550
from ucm.sparse.state import get_ucm_sparse, has_ucm_sparse
5651

57-
def wait_for_kv_layer_from_connector(layer_name: str):
58-
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
59-
return
60-
61-
connector = get_kv_transfer_group()
62-
forward_context: ForwardContext = get_forward_context()
63-
attn_metadata = forward_context.attn_metadata
64-
if attn_metadata is None:
65-
return
66-
connector.wait_for_layer_load(layer_name)
67-
68-
attention_v1.wait_for_kv_layer_from_connector = wait_for_kv_layer_from_connector
69-
70-
def maybe_save_kv_layer_to_connector(
71-
layer_name: str,
72-
kv_cache_layer: List[torch.Tensor],
73-
):
74-
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
75-
return
76-
connector = get_kv_transfer_group()
77-
forward_context: ForwardContext = get_forward_context()
78-
attn_metadata = forward_context.attn_metadata
79-
if attn_metadata is None:
80-
return
81-
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
82-
83-
attention_v1.maybe_save_kv_layer_to_connector = maybe_save_kv_layer_to_connector
84-
8552
def maybe_execute_sparse_attention_begin(
8653
query: torch.Tensor,
8754
key: torch.Tensor,
@@ -149,7 +116,6 @@ def unified_ascend_attention_with_output_impl(
149116
output: torch.Tensor,
150117
layer_name: str,
151118
) -> None:
152-
wait_for_kv_layer_from_connector(layer_name)
153119

154120
forward_context: ForwardContext = get_forward_context()
155121
attn_metadata = forward_context.attn_metadata
@@ -173,7 +139,6 @@ def unified_ascend_attention_with_output_impl(
173139
maybe_execute_sparse_attention_finished(
174140
query, key, value, output, layer_name, forward_context
175141
)
176-
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
177142
return
178143

179144
vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload(
@@ -205,8 +170,6 @@ def _patch_mla_v1() -> None:
205170
from vllm.forward_context import ForwardContext, get_forward_context
206171
from vllm_ascend.attention.attention_v1 import (
207172
AscendAttentionState,
208-
maybe_save_kv_layer_to_connector,
209-
wait_for_kv_layer_from_connector,
210173
)
211174
from vllm_ascend.attention.mla_v1 import AscendMLAImpl
212175
from vllm_ascend.multistream.context import get_multistream_comm_context
@@ -406,7 +369,6 @@ def forward(
406369
# FIX: aicore move should be also placed on the comm stream in dbo,
407370
# otherwise it may affect the accuracy
408371
# TODO: use an elegant way to overlap
409-
wait_for_kv_layer_from_connector(layer.layer_name)
410372
maybe_execute_sparse_attention_begin(
411373
prefill_q,
412374
prefill_k_c_normed,
@@ -434,9 +396,7 @@ def forward(
434396
forward_context,
435397
"prefill",
436398
)
437-
maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)
438399
if has_decode:
439-
wait_for_kv_layer_from_connector(layer.layer_name)
440400
maybe_execute_sparse_attention_begin(
441401
torch.cat([decode_ql_nope, decode_q_pe], dim=-1),
442402
decode_ql_nope,
@@ -480,7 +440,6 @@ def forward(
480440
forward_context,
481441
"decode",
482442
)
483-
maybe_save_kv_layer_to_connector(layer.layer_name, kv_cache)
484443

485444
return output_padded
486445

@@ -530,7 +489,6 @@ def _patch_model_runner_v1() -> None:
530489
from vllm.v1.core.sched.output import SchedulerOutput
531490
from vllm.distributed.kv_transfer import (
532491
get_kv_transfer_group,
533-
has_kv_transfer_group,
534492
)
535493
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
536494
from vllm.forward_context import get_forward_context, set_forward_context
@@ -1041,7 +999,6 @@ def _process_reqs(
1041999
positions = self.positions[:padded_batch_size]
10421000

10431001
# Run forward pass
1044-
finished_dumping = None
10451002
with set_forward_context(
10461003
attn_metadata, self.vllm_config, num_tokens=num_input_tokens
10471004
):
@@ -1070,7 +1027,6 @@ def _process_reqs(
10701027
maybe_converting_weight_acl_format(
10711028
self.model, ACL_FORMAT_FRACTAL_ND
10721029
)
1073-
self.maybe_setup_kv_connector(scheduler_output)
10741030
self.maybe_execute_ucm_sparse_begin(
10751031
scheduler_output, attn_metadata
10761032
)
@@ -1082,7 +1038,6 @@ def _process_reqs(
10821038
inputs_embeds=inputs_embeds,
10831039
**model_kwargs,
10841040
)
1085-
finished_dumping = self.maybe_wait_for_kv_save()
10861041
self.maybe_execute_ucm_sparse_finished()
10871042

10881043
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -1123,7 +1078,6 @@ def _process_reqs(
11231078
logits_indices,
11241079
aux_hidden_states,
11251080
num_scheduled_tokens,
1126-
finished_dumping,
11271081
)
11281082

11291083
NPUModelRunner._process_reqs = _process_reqs
@@ -1148,7 +1102,6 @@ def execute_model(
11481102
logits_indices,
11491103
aux_hidden_states,
11501104
num_scheduled_tokens_np,
1151-
finished_dumping,
11521105
) = self._process_reqs(scheduler_output, intermediate_tensors)
11531106

11541107
with ProfileExecuteDuration().capture_async("post process"):
@@ -1320,7 +1273,6 @@ def execute_model(
13201273
logprobs=logprobs_lists,
13211274
prompt_logprobs_dict=prompt_logprobs_dict,
13221275
pooler_output=[],
1323-
finished_dumping=finished_dumping,
13241276
)
13251277

13261278
durations = ProfileExecuteDuration().pop_captured_sync()
@@ -1341,27 +1293,6 @@ def execute_model(
13411293

13421294
NPUModelRunner.execute_model = execute_model
13431295

1344-
@staticmethod
1345-
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
1346-
# Update KVConnector with the KVConnector metadata forward().
1347-
if has_kv_transfer_group():
1348-
kv_connector = get_kv_transfer_group()
1349-
assert isinstance(kv_connector, KVConnectorBase_V1)
1350-
assert scheduler_output.kv_connector_metadata is not None
1351-
kv_connector.bind_connector_metadata(
1352-
scheduler_output.kv_connector_metadata
1353-
)
1354-
# Background KV cache transfers happen here.
1355-
# These transfers are designed to be async and the requests
1356-
# involved may be disjoint from the running requests.
1357-
# Do this here to save a collective_rpc.
1358-
kv_connector.start_load_kv(get_forward_context())
1359-
1360-
@staticmethod
1361-
def maybe_wait_for_kv_save():
1362-
if has_kv_transfer_group():
1363-
return get_kv_transfer_group().wait_for_save()
1364-
13651296
def maybe_execute_ucm_sparse_begin(
13661297
self,
13671298
scheduler_output: "SchedulerOutput",
@@ -1387,8 +1318,6 @@ def ucm_sparse_request_finished_in_worker(self, request_id: str | int):
13871318
ucm_sparse = get_ucm_sparse()
13881319
ucm_sparse.request_finished_in_worker(request_id)
13891320

1390-
NPUModelRunner.maybe_setup_kv_connector = maybe_setup_kv_connector
1391-
NPUModelRunner.maybe_wait_for_kv_save = maybe_wait_for_kv_save
13921321
NPUModelRunner.maybe_execute_ucm_sparse_begin = maybe_execute_ucm_sparse_begin
13931322
NPUModelRunner.maybe_execute_ucm_sparse_finished = (
13941323
maybe_execute_ucm_sparse_finished
@@ -1408,9 +1337,6 @@ def _patch_worker_v1() -> None:
14081337
import copy
14091338
from typing import Optional
14101339

1411-
from vllm.distributed.kv_transfer import (
1412-
has_kv_transfer_group,
1413-
)
14141340
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
14151341
from vllm.logger import logger
14161342
from vllm.sequence import IntermediateTensors
@@ -1442,8 +1368,6 @@ def execute_model(
14421368
get_pp_group().send_tensor_dict(
14431369
output.tensors, all_gather_group=get_tp_group()
14441370
)
1445-
if not has_kv_transfer_group():
1446-
return None
14471371

14481372
kv_connector_output = output.kv_connector_output
14491373
finished_sending = kv_connector_output.finished_sending

0 commit comments

Comments
 (0)