diff --git a/examples/online_serving/afd/deepseek-v2-lite/readme.md b/examples/online_serving/afd/deepseek-v2-lite/readme.md new file mode 100644 index 000000000000..3216670b01ee --- /dev/null +++ b/examples/online_serving/afd/deepseek-v2-lite/readme.md @@ -0,0 +1,16 @@ +# P2P Connector +P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances. + +1. Attn + +``` +vllm serve "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"1","afd_extra_config":{"afd_size":"2A2F"}}' + +``` + +2. FFN + +``` +vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"1", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}' +``` + diff --git a/examples/online_serving/afd/step3/README.md b/examples/online_serving/afd/step3/README.md new file mode 100644 index 000000000000..881171ec30cd --- /dev/null +++ b/examples/online_serving/afd/step3/README.md @@ -0,0 +1,29 @@ +# Dummy Connector +Dummy connector is used for testing basic functions. Attn and FFN server would not be connected as dummy connector would intermediately return input tensors. + +1. Attn + +``` +vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` + +2. FFN + +``` +vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` + +# StepMesh Connector +StepMesh connector is used for production deployment. Make sure [stepmesh](https://github.com/stepfun-ai/StepMesh) is installed and `afd_host` and `afd_port` are correctly set. + +1. Attn + +``` +vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` + +2. FFN + +``` +vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` \ No newline at end of file diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3d352257a931..008c7ef3c00c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -617,7 +617,8 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): + afd_stage_idx = forward_context.afd_metadata.afd_stage_idx + if isinstance(attn_metadata, dict) and afd_stage_idx > 1: attn_metadata = attn_metadata[layer_name] if forward_context.afd_metadata: afd_stage_idx = forward_context.afd_metadata.afd_stage_idx @@ -625,6 +626,10 @@ def unified_attention_with_output( attn_metadata = attn_metadata[afd_stage_idx] else: attn_metadata = None # padding + else: + attn_metadata = attn_metadata[ + layer_name] if attn_metadata is not None else None + self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/distributed/afd_transfer/afd_connector/factory.py b/vllm/distributed/afd_transfer/afd_connector/factory.py index cd397457c860..862aba71d1b1 100644 --- a/vllm/distributed/afd_transfer/afd_connector/factory.py +++ b/vllm/distributed/afd_transfer/afd_connector/factory.py @@ -89,3 +89,8 @@ def get_connector_class(cls, AFDConnectorFactory.register_connector( "dummy", "vllm.distributed.afd_transfer.afd_connector.dummy_connector", "DummyAFDConnector") + +AFDConnectorFactory.register_connector( + "p2pconnector", + "vllm.distributed.afd_transfer.afd_connector.p2p_connector", + "P2PAFDConnector") diff --git a/vllm/distributed/afd_transfer/afd_connector/metadata.py b/vllm/distributed/afd_transfer/afd_connector/metadata.py index d05c1aa8f913..bdcc57fb877e 100644 --- a/vllm/distributed/afd_transfer/afd_connector/metadata.py +++ b/vllm/distributed/afd_transfer/afd_connector/metadata.py @@ -6,9 +6,23 @@ import time from dataclasses import dataclass from typing import Optional - +import typing import torch +class FFNNeedForwardData: + + def __init__(self, + moe_comm_method: typing.Any, + num_input_tokens: int, + with_prefill: bool, + total_num_scheduled_tokens: Optional[int], + is_dummy_run:bool = False): + self.moe_comm_method = moe_comm_method + self.num_input_tokens = num_input_tokens + self.with_prefill = with_prefill + self.total_num_scheduled_tokens = total_num_scheduled_tokens + self.is_dummy_run = is_dummy_run + @dataclass class AFDConnectorMetadata: @@ -21,7 +35,26 @@ class AFDConnectorMetadata: # multiple sequences dtype: torch.dtype device: torch.device + topk_idx: Optional[torch.Tensor] + # indices token which expert to be sended + topk_weights: Optional[torch.Tensor] + # the expert weights + moe_expert_num: Optional[int] + # number of moe experts + shared_expert_num: Optional[int] + # number of share experts + scale: Optional[torch.Tensor] + # quant scale + expertTokenNumsOut: Optional[torch.Tensor] + # The number of tokens received by each expert used as input for GMM + handle: Optional[torch.Tensor] + # the communication handle given by the recv_attn_output function + # Optional fields for debugging and extensibility + request_id: Optional[str] = None + timestamp: Optional[float] = None + """ffn need forward data""" + ffn_need_forward_data: Optional[FFNNeedForwardData] = None # Optional fields for debugging and extensibility request_id: Optional[str] = None timestamp: Optional[float] = None @@ -61,7 +94,8 @@ def create_attention_metadata( seq_len: int, dtype: torch.dtype, device: torch.device, - request_id: Optional[str] = None) -> "AFDConnectorMetadata": + request_id: Optional[str] = None, + ffn_need_forward_data:Optional[FFNNeedForwardData] = None) -> "AFDConnectorMetadata": """Create metadata for attention side (single sequence).""" return cls(layer_idx=layer_idx, stage_idx=stage_idx, @@ -69,6 +103,7 @@ def create_attention_metadata( dtype=dtype, device=device, request_id=request_id, + ffn_need_forward_data = ffn_need_forward_data, timestamp=time.time()) @classmethod diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py new file mode 100644 index 000000000000..0fa0adeb24aa --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import ( + _update_default_pg, + _get_default_group, +) +from typing import Any, Optional + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata +from vllm.distributed.parallel_state import ( + init_afd_process_group, + init_model_parallel_group, +) +from vllm.sequence import IntermediateTensors +from vllm.logger import init_logger +from vllm.config import VllmConfig +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self._initialized = False + self.config = config + + def close(self) -> None: + """Close the connector and release resources.""" + # destroy process group + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + host = self.config.afd_config.afd_host + port = self.config.afd_config.afd_port + attn_size, ffn_size = map( + int, re.match(r"(\d+)\D+(\d+)", afd_size).groups() + ) + world_rank = self.rank if role == "attention" else self.rank + attn_size + + logger.info( + "world_size = %d, world_rank = %d", ffn_size + attn_size, world_rank + ) + backend = current_platform.dist_backend + afd_pg = init_afd_process_group( + backend=backend, + init_method=f"tcp://{host}:{port}", + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + + default_pg_switcher = DefaultProcessGroupSwitcher( + _get_default_group(), afd_pg + ) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = list([attn_ranks[i], ffn_ranks[i]]) + sub_group_ranks.append(ranks) + self.process_group = init_model_parallel_group( + sub_group_ranks, self.rank, backend=backend, group_name="ae" + ) + + logger.info("p2p connector initialized") + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> Any: + """ + This method will be called by the ATTN side. + + + * To send the intermediate tensors generated by ATTN instances to FFN. + """ + + intermediate_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) + try: + self.process_group.send_tensor_dict( + intermediate_tensors.tensors, + all_gather_group=None, + ) + dst = ( + self.process_group.rank_in_group + 1 + ) % self.process_group.world_size + self.process_group.send_object(metadata, dst) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_attn_output( + self, + timeout_ms: Optional[int] = None, + ) -> tuple[torch.Tensor, "AFDConnectorMetadata"]: + """ + This method will be called by the FFN side. + + + * To receive the intermediate tensors from ATTN. + * And (Maybe) dispatch them from the receiver to other GPUs. + """ + intermediate_tensors = self.process_group.recv_tensor_dict( + all_gather_group=None, + ) + src = ( + self.process_group.rank_in_group - 1 + ) % self.process_group.world_size + metadata = self.process_group.recv_object(src) + return intermediate_tensors["hidden_states"], metadata + + # ------------------------------------------------------------------------- + # attn <- ffn + # ------------------------------------------------------------------------- + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> None: + """ + This method will be called by the FFN side. + + + * To send the intermediate tensors generated by FFN instances back to + the sender (this should be the same GPU as it comes from) + """ + intermediate_tensors = IntermediateTensors( + { + "hidden_states": ffn_output, + } + ) + self.process_group.send_tensor_dict( + intermediate_tensors.tensors, + ) + dst = ( + self.process_group.rank_in_group + 1 + ) % self.process_group.world_size + + self.process_group.send_object(metadata, dst) + + def recv_ffn_output( + self, + handle: Any, + ) -> torch.Tensor: + """ + This method will be called by the ATTN side. + + + * To receive the MOE output intermediate tensors. + * And (Maybe) dispatch them from the receiver to other GPUs. + (this should be the same GPU as it comes from) + """ + intermediate_tensors = self.process_group.recv_tensor_dict( + all_gather_group=None, + ) + src = ( + self.process_group.rank_in_group - 1 + ) % self.process_group.world_size + + self.process_group.recv_object(src) + return intermediate_tensors["hidden_states"] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 12571afaa4c1..2e7331f0a70b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,6 +37,10 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from torch.distributed.distributed_c10d import (PrefixStore, Store, + _new_process_group_helper, + _world, default_pg_timeout, + rendezvous) from typing_extensions import deprecated import vllm.envs as envs @@ -893,6 +897,55 @@ def combine(self, hidden_states) -> torch.Tensor: return hidden_states +def init_afd_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = None, + pg_options: Optional[Any] = None, +): + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store.") + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, + rank, + world_size, + timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = PrefixStore(group_name, store) + + pg_options_param_name = ("backend_options" if str(torch.__version__) + >= "2.6" else "pg_options") + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg + + _WORLD: Optional[GroupCoordinator] = None _NODE_COUNT: Optional[int] = None diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e4a21febc5bd..9868952ba338 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -35,11 +35,14 @@ import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -66,6 +69,11 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata,FFNNeedForwardData) + +logger = init_logger(__name__) + class DeepseekV2MLP(nn.Module): @@ -626,6 +634,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config + afd_config = vllm_config.afd_config + self.role = afd_config.afd_role self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -639,41 +649,44 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE( + if self.role is None or self.role == "attention": + self.self_attn = attn_cls( config=config, - parallel_config=parallel_config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr( + config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=f"{prefix}.self_attn", ) + + if self.role is None or self.role == "ffn": + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -687,6 +700,10 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention + forward_ctx = get_forward_context() + afd_metadata = (forward_ctx.afd_metadata + if forward_ctx is not None else None) + afd_connector = afd_metadata.afd_connector if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -711,6 +728,28 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + # ---------ascend ffn need data + if forward_ctx.moe_comm_method_name is not None: + moe_comm_method = forward_ctx.moe_comm_method_name + num_tokens = hidden_states.shape[0] + with_prefill = forward_ctx.with_prefill + + ffn_need_forward_data = FFNNeedForwardData(moe_comm_method,num_tokens,with_prefill) + num_stages = 0 + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=self.layer_idx, + stage_idx=num_stages, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + ffn_need_forward_data=ffn_need_forward_data + ) + else: + metadata = None + if self.role == "attention": + afd_connector.send_attn_output(hidden_states, metadata) + hidden_states = afd_connector.recv_ffn_output(None) + return hidden_states, residual hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, @@ -724,6 +763,52 @@ def forward( return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + return hidden_states, residual + + def compute_ffn_output(self, hidden_states): + assert self.role == "ffn" + hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + return hidden_states + @support_torch_compile class DeepseekV2Model(nn.Module): @@ -782,8 +867,29 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + forward_ctx = get_forward_context() + afd_metadata = (forward_ctx.afd_metadata + if forward_ctx is not None else None) + for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + if (afd_metadata is not None + and isinstance(afd_metadata.afd_tokens_start_loc, list) + and len(afd_metadata.afd_tokens_start_loc) - 1 > 1): + num_stages = len(afd_metadata.afd_tokens_start_loc) - 1 + stage_hidden_states: list[torch.Tensor] = [] + stage_residual: list[Optional[torch.Tensor]] = [] + stage_positions: list[torch.Tensor] = [] + for stage_idx in range(num_stages): + start = afd_metadata.afd_tokens_start_loc[stage_idx] + end = start + afd_metadata.afd_tokens_lens[stage_idx] + stage_hidden_states.append( + hidden_states[start:end].clone()) + stage_residual.append(residual[start:end].clone( + ) if residual is not None else None) + stage_positions.append(positions[start:end]) + else: + hidden_states, residual = layer(positions, hidden_states, + residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -794,6 +900,13 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, hidden_states, + layer_idx) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.layers[layer_idx].compute_ffn_output( + hidden_states) + return hidden_states + class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): @@ -807,7 +920,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - + self.afd_config = vllm_config.afd_config # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the @@ -845,11 +958,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): continue assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): + if (self.afd_config.afd_role is None or self.afd_config.afd_role + == "ffn") and isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) + if self.afd_config.afd_role == "attention": + return if example_moe is None: raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") @@ -908,6 +1024,13 @@ def forward( inputs_embeds) return hidden_states + def compute_ffn_output( + self, current_layer_idx, + hidden_states) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model.compute_ffn_output(hidden_states, + current_layer_idx) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, @@ -929,19 +1052,27 @@ def load_weights(self, weights: Iterable[tuple[str, # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + if self.afd_config.afd_role == "attention": + vllm_config = get_current_vllm_config() + num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts) + else: + num_redundant_experts = self.num_redundant_experts expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=num_redundant_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - + if self.afd_config.afd_role == "attention" and self.is_moe_weight( + name): + continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model @@ -989,7 +1120,9 @@ def load_weights(self, weights: Iterable[tuple[str, # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True - + if (self.afd_config.afd_role is not None + and self.afd_config.afd_role == "attention"): + continue # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) @@ -1013,6 +1146,12 @@ def load_weights(self, weights: Iterable[tuple[str, name = name_mapped break else: + if ( + self.afd_config.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank @@ -1039,6 +1178,18 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params + def is_moe_weight(self, name): + return bool("shared_experts" in name or "experts" in name + or "gate" in name or "up" in name or "down" in name) + + def is_common_weight(self, name): + if ("lm_head" in name or "model.norm.weight" in name + or "embed_tokens" in name or "input_layernorm" in name + or "post_attention_layernorm" in name): + # or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache + return True + return False + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1ac3d4a02f47..432c5feeca13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2165,6 +2165,8 @@ def execute_model( cudagraph_runtime_mode, batch_descriptor = \ self.cudagraph_dispatcher.dispatch(batch_descriptor) + if afd_metadata is None: + afd_metadata = AFDMetadata(0, 0, 0, self.afd_connector, 0) # Run the model. # Use persistent buffers for CUDA graphs. with (set_forward_context(