Skip to content

fix: round-robin DP rank assignment for service discovery mode#181

Open
raviguptaamd wants to merge 2 commits into
vllm-project:mainfrom
raviguptaamd:ravgupta/discovery-dp-rank-roundrobin
Open

fix: round-robin DP rank assignment for service discovery mode#181
raviguptaamd wants to merge 2 commits into
vllm-project:mainfrom
raviguptaamd:ravgupta/discovery-dp-rank-roundrobin

Conversation

@raviguptaamd

Copy link
Copy Markdown

Summary

When using ZMQ service discovery with --intra-node-data-parallel-size > 1, worker addresses arrive without @rank suffixes. The router left prefill_dp_rank as None, so:

  • The prefill request lacked X-data-parallel-rank header
  • The decode kv_transfer_params omitted remote_dp_rank

This caused the MoRI-IO WRITE mode handshake to always target DP rank 0, deadlocking all other DP ranks.

Fix

Add an atomic round-robin counter (discovery_dp_rank_counter) to VllmPDRouter that assigns DP ranks when extract_base_http_and_dp_rank returns None and DP size > 1. The assigned rank flows into:

  1. Prefill request's X-data-parallel-rank header
  2. Decode request's remote_dp_rank in kv_transfer_params

This matches the behaviour of the toy proxy which already does round-robin DP rank assignment.

Changes

  • src/routers/http/vllm_pd_router.rs:
    • Add discovery_dp_rank_counter: AtomicUsize field to VllmPDRouter
    • Initialize to 0 in both constructor paths (discovery and non-discovery)
    • In process_vllm_two_stage_request_discovered(), generate round-robin DP rank when prefill_dp_rank.is_none() and intra_node_data_parallel_size > 1

Test plan

  • 1P/1D DeepSeek-V3-5layer on MI300X (8 DP ranks per node, MoRI EP, WRITE mode) — warmup and benchmark pass with correct per-rank MoRI-IO handshakes
  • Unit tests for round-robin assignment with DP size > 1
  • Verify no regression for non-discovery (static URL) mode where @rank suffixes are present

Builds on top of #157 (MoRI WRITE mode concurrent dispatch).

raviguptaamd and others added 2 commits June 2, 2026 06:28
When using ZMQ service discovery with intra_node_data_parallel_size > 1,
worker addresses arrive without @rank suffixes. The router left
prefill_dp_rank as None, so the prefill request lacked
X-data-parallel-rank and the decode kv_transfer_params omitted
remote_dp_rank. This caused the MoRI-IO handshake to always target
DP rank 0, deadlocking all other ranks.

Add an atomic round-robin counter (discovery_dp_rank_counter) that
assigns DP ranks when extract_base_http_and_dp_rank returns None and
DP size > 1. The assigned rank flows into the prefill header and the
decode's remote_dp_rank, matching the behaviour of the toy proxy.

Tested with 1P/1D DeepSeek-V3-5layer on MI300X (8 DP ranks per node).

Co-authored-by: Cursor <cursoragent@cursor.com>
…uting

Captures the uncommitted working-tree edits (from /tmp/router-inspect, the tree
that built the validated vllm-router-2p2d-dpfix binary on 2026-06-14) on top of
PR vllm-project#181 (11841c0). Without these, the decode->prefill notify targets a different
DP rank than the prefill was pinned to -> "remote blocks never arrived" deferred
write expiry under 2P2D EP=16 load (reproduced: clean vllm-project#181 wedges, this serves).
Adds remote_dp_rank_override flag + remote_dp_size so the MoRIIO connector's rank
arithmetic (blake2s(rid)%dp_size) matches the topology on both legs.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6409ac1409

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

"remote_block_ids": serde_json::Value::Null,
"transfer_id": transfer_id.unwrap_or(""),
"remote_dp_size": self.intra_node_data_parallel_size,
"remote_dp_size": self.effective_dp_size(),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep READ-mode DP size consistent on decode

When --moriio-dp-size is set and MoRI-IO registers READ mode, this sends the prefill leg remote_dp_size as the effective cross-pod size, but the READ decode path still overwrites remote_dp_size with self.intra_node_data_parallel_size in build_decode_kv_transfer_params. In a 2P2D setup such as moriio_dp_size=16 and intra_node_data_parallel_size=8, the two legs of the same transfer disagree about the DP world size, so the decode handshake only targets the local-size world. Use effective_dp_size() in the READ decode path as well.

Useful? React with 👍 / 👎.

Comment thread src/lib.rs
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
intra_node_data_parallel_size: self.intra_node_data_parallel_size,
moriio_dp_size: 0,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate MoRI-IO DP size through Python entrypoint

For routers started through the PyPI/Python entrypoint, Router.start() builds this RouterConfig, but the new value is hardcoded to 0 and the PyO3 constructor/RouterArgs do not expose a matching field. That means installed vllm-router users cannot configure the new cross-pod MoRI-IO DP size; Wide-EP deployments silently fall back to intra_node_data_parallel_size, so the new round-robin rank assignment and remote_dp_size stamping still use the wrong world size. Thread this option through the Python args/constructor instead of forcing 0.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant