Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/online_serving/afd/deepseek-v2-lite/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# P2P Connector
P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances.

1. Attn

```
vllm serve "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"1","afd_extra_config":{"afd_size":"2A2F"}}'

```

2. FFN

```
vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"1", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}'
```

29 changes: 29 additions & 0 deletions examples/online_serving/afd/step3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Dummy Connector
Dummy connector is used for testing basic functions. Attn and FFN server would not be connected as dummy connector would intermediately return input tensors.

1. Attn

```
vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}'
```

2. FFN

```
vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}'
```

# StepMesh Connector
StepMesh connector is used for production deployment. Make sure [stepmesh](https://github.com/stepfun-ai/StepMesh) is installed and `afd_host` and `afd_port` are correctly set.

1. Attn

```
vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}'
```

2. FFN

```
vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}'
```
7 changes: 6 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,14 +617,19 @@ def unified_attention_with_output(
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
afd_stage_idx = forward_context.afd_metadata.afd_stage_idx
if isinstance(attn_metadata, dict) and afd_stage_idx > 1:
attn_metadata = attn_metadata[layer_name]
if forward_context.afd_metadata:
afd_stage_idx = forward_context.afd_metadata.afd_stage_idx
if afd_stage_idx < len(attn_metadata):
attn_metadata = attn_metadata[afd_stage_idx]
else:
attn_metadata = None # padding
else:
attn_metadata = attn_metadata[
layer_name] if attn_metadata is not None else None

self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand Down
5 changes: 5 additions & 0 deletions vllm/distributed/afd_transfer/afd_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,8 @@ def get_connector_class(cls,
AFDConnectorFactory.register_connector(
"dummy", "vllm.distributed.afd_transfer.afd_connector.dummy_connector",
"DummyAFDConnector")

AFDConnectorFactory.register_connector(
"p2pconnector",
"vllm.distributed.afd_transfer.afd_connector.p2p_connector",
"P2PAFDConnector")
39 changes: 37 additions & 2 deletions vllm/distributed/afd_transfer/afd_connector/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@
import time
from dataclasses import dataclass
from typing import Optional

import typing
import torch

class FFNNeedForwardData:

def __init__(self,
moe_comm_method: typing.Any,
num_input_tokens: int,
with_prefill: bool,
total_num_scheduled_tokens: Optional[int],
is_dummy_run:bool = False):
self.moe_comm_method = moe_comm_method
self.num_input_tokens = num_input_tokens
self.with_prefill = with_prefill
self.total_num_scheduled_tokens = total_num_scheduled_tokens
self.is_dummy_run = is_dummy_run


@dataclass
class AFDConnectorMetadata:
Expand All @@ -21,7 +35,26 @@ class AFDConnectorMetadata:
# multiple sequences
dtype: torch.dtype
device: torch.device
topk_idx: Optional[torch.Tensor]
# indices token which expert to be sended
topk_weights: Optional[torch.Tensor]
# the expert weights
moe_expert_num: Optional[int]
# number of moe experts
shared_expert_num: Optional[int]
# number of share experts
scale: Optional[torch.Tensor]
# quant scale
expertTokenNumsOut: Optional[torch.Tensor]
# The number of tokens received by each expert used as input for GMM
handle: Optional[torch.Tensor]
# the communication handle given by the recv_attn_output function

# Optional fields for debugging and extensibility
request_id: Optional[str] = None
timestamp: Optional[float] = None
"""ffn need forward data"""
ffn_need_forward_data: Optional[FFNNeedForwardData] = None
# Optional fields for debugging and extensibility
request_id: Optional[str] = None
timestamp: Optional[float] = None
Expand Down Expand Up @@ -61,14 +94,16 @@ def create_attention_metadata(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
request_id: Optional[str] = None) -> "AFDConnectorMetadata":
request_id: Optional[str] = None,
ffn_need_forward_data:Optional[FFNNeedForwardData] = None) -> "AFDConnectorMetadata":
"""Create metadata for attention side (single sequence)."""
return cls(layer_idx=layer_idx,
stage_idx=stage_idx,
seq_lens=[seq_len],
dtype=dtype,
device=device,
request_id=request_id,
ffn_need_forward_data = ffn_need_forward_data,
timestamp=time.time())

@classmethod
Expand Down
205 changes: 205 additions & 0 deletions vllm/distributed/afd_transfer/afd_connector/p2p_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re
from datetime import timedelta

import torch
from torch.distributed.distributed_c10d import (
_update_default_pg,
_get_default_group,
)
from typing import Any, Optional

from .base import AFDConnectorBase
from .metadata import AFDConnectorMetadata
from vllm.distributed.parallel_state import (
init_afd_process_group,
init_model_parallel_group,
)
from vllm.sequence import IntermediateTensors
from vllm.logger import init_logger
from vllm.config import VllmConfig
from vllm.platforms import current_platform

logger = init_logger(__name__)


class DefaultProcessGroupSwitcher:
def __init__(self, default_group, new_default_group):
self.default_group = default_group
self.new_default_group = new_default_group

def __enter__(self):
_update_default_pg(self.new_default_group)

def __exit__(self, exc_type, exc_value, traceback):
_update_default_pg(self.default_group)


class P2PAFDConnector(AFDConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
) -> None:
self.rank = rank
self.local_rank = local_rank
self._initialized = False
self.config = config

def close(self) -> None:
"""Close the connector and release resources."""
# destroy process group
pass

def init_afd_connector(self) -> None:
"""Initialize the AFD connector."""
afd_size = self.config.afd_config.afd_extra_config.get("afd_size")
role = self.config.afd_config.afd_role
host = self.config.afd_config.afd_host
port = self.config.afd_config.afd_port
attn_size, ffn_size = map(
int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()
)
world_rank = self.rank if role == "attention" else self.rank + attn_size

logger.info(
"world_size = %d, world_rank = %d", ffn_size + attn_size, world_rank
)
backend = current_platform.dist_backend
afd_pg = init_afd_process_group(
backend=backend,
init_method=f"tcp://{host}:{port}",
world_size=ffn_size + attn_size,
rank=world_rank,
group_name="afd",
timeout=timedelta(minutes=2),
)
ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be range(attn_size, ffn_size + attn_size)

attn_ranks = [i for i in range(attn_size)]

default_pg_switcher = DefaultProcessGroupSwitcher(
_get_default_group(), afd_pg
)
with default_pg_switcher:
sub_group_ranks = []
for i in range(len(ffn_ranks)):
ranks = list([attn_ranks[i], ffn_ranks[i]])
sub_group_ranks.append(ranks)
self.process_group = init_model_parallel_group(
sub_group_ranks, self.rank, backend=backend, group_name="ae"
)

logger.info("p2p connector initialized")

self._initialized = True

def is_initialized(self) -> bool:
"""Check if the connector is initialized and ready to use.

Returns:
bool: True if the connector is initialized, False otherwise.
"""
return self._initialized

def send_attn_output(
self,
hidden_states: torch.Tensor,
metadata: "AFDConnectorMetadata",
) -> Any:
"""
This method will be called by the ATTN side.


* To send the intermediate tensors generated by ATTN instances to FFN.
"""

intermediate_tensors = IntermediateTensors(
{
"hidden_states": hidden_states,
}
)
try:
self.process_group.send_tensor_dict(
intermediate_tensors.tensors,
all_gather_group=None,
)
dst = (
self.process_group.rank_in_group + 1
) % self.process_group.world_size
self.process_group.send_object(metadata, dst)
except Exception as e:
raise RuntimeError(f"Communication error: {e}")

def recv_attn_output(
self,
timeout_ms: Optional[int] = None,
) -> tuple[torch.Tensor, "AFDConnectorMetadata"]:
"""
This method will be called by the FFN side.


* To receive the intermediate tensors from ATTN.
* And (Maybe) dispatch them from the receiver to other GPUs.
"""
intermediate_tensors = self.process_group.recv_tensor_dict(
all_gather_group=None,
)
src = (
self.process_group.rank_in_group - 1
) % self.process_group.world_size
metadata = self.process_group.recv_object(src)
return intermediate_tensors["hidden_states"], metadata

# -------------------------------------------------------------------------
# attn <- ffn
# -------------------------------------------------------------------------
def send_ffn_output(
self,
ffn_output: torch.Tensor,
metadata: "AFDConnectorMetadata",
) -> None:
"""
This method will be called by the FFN side.


* To send the intermediate tensors generated by FFN instances back to
the sender (this should be the same GPU as it comes from)
"""
intermediate_tensors = IntermediateTensors(
{
"hidden_states": ffn_output,
}
)
self.process_group.send_tensor_dict(
intermediate_tensors.tensors,
)
dst = (
self.process_group.rank_in_group + 1
) % self.process_group.world_size

self.process_group.send_object(metadata, dst)

def recv_ffn_output(
self,
handle: Any,
) -> torch.Tensor:
"""
This method will be called by the ATTN side.


* To receive the MOE output intermediate tensors.
* And (Maybe) dispatch them from the receiver to other GPUs.
(this should be the same GPU as it comes from)
"""
intermediate_tensors = self.process_group.recv_tensor_dict(
all_gather_group=None,
)
src = (
self.process_group.rank_in_group - 1
) % self.process_group.world_size

self.process_group.recv_object(src)
return intermediate_tensors["hidden_states"]
Loading