Skip to content

Commit aa8c235

Browse files
committed
added some comment
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent 9c3b006 commit aa8c235

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

tests/lora/test_lora.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def test_single_lora_spmd():
8888
# ensure_model_parallel_initialized(1, 1)
8989

9090
# num_devices = jax.local_device_count() # why does this line cause hanging.
91-
num_devices = 4
91+
# To test SPMD multi-chip case, only num_device=2 works for this model Qwen2.5-3B-Instruct.
92+
# This is because this model has kv_head=2. https://github.com/vllm-project/tpu_commons/blob/a489e59c5b3a4d5c28e93775d5323970eecd66c9/tpu_commons/layers/jax/attention_interface.py#L275 here we shard the num_kv_heads. Only 2 can divide the num_kv_heads in this case.
93+
num_devices = 2
9294
print(f'xw32 using TP={num_devices}')
9395
llm = setup_vllm(1, num_devices)
9496

tpu_commons/models/jax/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def sharded_ragged_paged_attention(
2121
v_scale: float | None = None,
2222
):
2323
"""Shards along KV heads."""
24+
# nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128)
2425
qkv_spec = P(None, "model", None)
2526
kv_cache_spec = P(None, None, "model")
2627
in_specs = (
@@ -86,6 +87,7 @@ def attention(
8687
md = attention_metadata
8788

8889
# (T, N, H)
90+
# nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128)
8991
output, kv_cache = sharded_ragged_paged_attention(
9092
head_dim_original**-0.5, mesh, attention_chunk_size, q_scale, k_scale,
9193
v_scale)(

0 commit comments

Comments
 (0)