forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
AFD: update DeepSeek support #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hsliuustc0106
wants to merge
2
commits into
Oliver-ss:afd-step3
Choose a base branch
from
JiusiServe:afd-dev
base: afd-step3
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # P2P Connector | ||
| P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances. | ||
|
|
||
| 1. Attn | ||
|
|
||
| ``` | ||
| vllm serve "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"1","afd_extra_config":{"afd_size":"2A2F"}}' | ||
|
|
||
| ``` | ||
|
|
||
| 2. FFN | ||
|
|
||
| ``` | ||
| vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"1", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}' | ||
| ``` | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # Dummy Connector | ||
| Dummy connector is used for testing basic functions. Attn and FFN server would not be connected as dummy connector would intermediately return input tensors. | ||
|
|
||
| 1. Attn | ||
|
|
||
| ``` | ||
| vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' | ||
| ``` | ||
|
|
||
| 2. FFN | ||
|
|
||
| ``` | ||
| vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' | ||
| ``` | ||
|
|
||
| # StepMesh Connector | ||
| StepMesh connector is used for production deployment. Make sure [stepmesh](https://github.com/stepfun-ai/StepMesh) is installed and `afd_host` and `afd_port` are correctly set. | ||
|
|
||
| 1. Attn | ||
|
|
||
| ``` | ||
| vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' | ||
| ``` | ||
|
|
||
| 2. FFN | ||
|
|
||
| ``` | ||
| vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
205 changes: 205 additions & 0 deletions
205
vllm/distributed/afd_transfer/afd_connector/p2p_connector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import re | ||
| from datetime import timedelta | ||
|
|
||
| import torch | ||
| from torch.distributed.distributed_c10d import ( | ||
| _update_default_pg, | ||
| _get_default_group, | ||
| ) | ||
| from typing import Any, Optional | ||
|
|
||
| from .base import AFDConnectorBase | ||
| from .metadata import AFDConnectorMetadata | ||
| from vllm.distributed.parallel_state import ( | ||
| init_afd_process_group, | ||
| init_model_parallel_group, | ||
| ) | ||
| from vllm.sequence import IntermediateTensors | ||
| from vllm.logger import init_logger | ||
| from vllm.config import VllmConfig | ||
| from vllm.platforms import current_platform | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class DefaultProcessGroupSwitcher: | ||
| def __init__(self, default_group, new_default_group): | ||
| self.default_group = default_group | ||
| self.new_default_group = new_default_group | ||
|
|
||
| def __enter__(self): | ||
| _update_default_pg(self.new_default_group) | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| _update_default_pg(self.default_group) | ||
|
|
||
|
|
||
| class P2PAFDConnector(AFDConnectorBase): | ||
| def __init__( | ||
| self, | ||
| rank: int, | ||
| local_rank: int, | ||
| config: "VllmConfig", | ||
| ) -> None: | ||
| self.rank = rank | ||
| self.local_rank = local_rank | ||
| self._initialized = False | ||
| self.config = config | ||
|
|
||
| def close(self) -> None: | ||
| """Close the connector and release resources.""" | ||
| # destroy process group | ||
| pass | ||
|
|
||
| def init_afd_connector(self) -> None: | ||
| """Initialize the AFD connector.""" | ||
| afd_size = self.config.afd_config.afd_extra_config.get("afd_size") | ||
| role = self.config.afd_config.afd_role | ||
| host = self.config.afd_config.afd_host | ||
| port = self.config.afd_config.afd_port | ||
| attn_size, ffn_size = map( | ||
| int, re.match(r"(\d+)\D+(\d+)", afd_size).groups() | ||
| ) | ||
| world_rank = self.rank if role == "attention" else self.rank + attn_size | ||
|
|
||
| logger.info( | ||
| "world_size = %d, world_rank = %d", ffn_size + attn_size, world_rank | ||
| ) | ||
| backend = current_platform.dist_backend | ||
| afd_pg = init_afd_process_group( | ||
| backend=backend, | ||
| init_method=f"tcp://{host}:{port}", | ||
| world_size=ffn_size + attn_size, | ||
| rank=world_rank, | ||
| group_name="afd", | ||
| timeout=timedelta(minutes=2), | ||
| ) | ||
| ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] | ||
| attn_ranks = [i for i in range(attn_size)] | ||
|
|
||
| default_pg_switcher = DefaultProcessGroupSwitcher( | ||
| _get_default_group(), afd_pg | ||
| ) | ||
| with default_pg_switcher: | ||
| sub_group_ranks = [] | ||
| for i in range(len(ffn_ranks)): | ||
| ranks = list([attn_ranks[i], ffn_ranks[i]]) | ||
| sub_group_ranks.append(ranks) | ||
| self.process_group = init_model_parallel_group( | ||
| sub_group_ranks, self.rank, backend=backend, group_name="ae" | ||
| ) | ||
|
|
||
| logger.info("p2p connector initialized") | ||
|
|
||
| self._initialized = True | ||
|
|
||
| def is_initialized(self) -> bool: | ||
| """Check if the connector is initialized and ready to use. | ||
|
|
||
| Returns: | ||
| bool: True if the connector is initialized, False otherwise. | ||
| """ | ||
| return self._initialized | ||
|
|
||
| def send_attn_output( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| metadata: "AFDConnectorMetadata", | ||
| ) -> Any: | ||
| """ | ||
| This method will be called by the ATTN side. | ||
|
|
||
|
|
||
| * To send the intermediate tensors generated by ATTN instances to FFN. | ||
| """ | ||
|
|
||
| intermediate_tensors = IntermediateTensors( | ||
| { | ||
| "hidden_states": hidden_states, | ||
| } | ||
| ) | ||
| try: | ||
| self.process_group.send_tensor_dict( | ||
| intermediate_tensors.tensors, | ||
| all_gather_group=None, | ||
| ) | ||
| dst = ( | ||
| self.process_group.rank_in_group + 1 | ||
| ) % self.process_group.world_size | ||
| self.process_group.send_object(metadata, dst) | ||
| except Exception as e: | ||
| raise RuntimeError(f"Communication error: {e}") | ||
|
|
||
| def recv_attn_output( | ||
| self, | ||
| timeout_ms: Optional[int] = None, | ||
| ) -> tuple[torch.Tensor, "AFDConnectorMetadata"]: | ||
| """ | ||
| This method will be called by the FFN side. | ||
|
|
||
|
|
||
| * To receive the intermediate tensors from ATTN. | ||
| * And (Maybe) dispatch them from the receiver to other GPUs. | ||
| """ | ||
| intermediate_tensors = self.process_group.recv_tensor_dict( | ||
| all_gather_group=None, | ||
| ) | ||
| src = ( | ||
| self.process_group.rank_in_group - 1 | ||
| ) % self.process_group.world_size | ||
| metadata = self.process_group.recv_object(src) | ||
| return intermediate_tensors["hidden_states"], metadata | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # attn <- ffn | ||
| # ------------------------------------------------------------------------- | ||
| def send_ffn_output( | ||
| self, | ||
| ffn_output: torch.Tensor, | ||
| metadata: "AFDConnectorMetadata", | ||
| ) -> None: | ||
| """ | ||
| This method will be called by the FFN side. | ||
|
|
||
|
|
||
| * To send the intermediate tensors generated by FFN instances back to | ||
| the sender (this should be the same GPU as it comes from) | ||
| """ | ||
| intermediate_tensors = IntermediateTensors( | ||
| { | ||
| "hidden_states": ffn_output, | ||
| } | ||
| ) | ||
| self.process_group.send_tensor_dict( | ||
| intermediate_tensors.tensors, | ||
| ) | ||
| dst = ( | ||
| self.process_group.rank_in_group + 1 | ||
| ) % self.process_group.world_size | ||
|
|
||
| self.process_group.send_object(metadata, dst) | ||
|
|
||
| def recv_ffn_output( | ||
| self, | ||
| handle: Any, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| This method will be called by the ATTN side. | ||
|
|
||
|
|
||
| * To receive the MOE output intermediate tensors. | ||
| * And (Maybe) dispatch them from the receiver to other GPUs. | ||
| (this should be the same GPU as it comes from) | ||
| """ | ||
| intermediate_tensors = self.process_group.recv_tensor_dict( | ||
| all_gather_group=None, | ||
| ) | ||
| src = ( | ||
| self.process_group.rank_in_group - 1 | ||
| ) % self.process_group.world_size | ||
|
|
||
| self.process_group.recv_object(src) | ||
| return intermediate_tensors["hidden_states"] | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be range(attn_size, ffn_size + attn_size)