Skip to content

Commit 014add7

Browse files
committed
added some comment
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent 9c3b006 commit 014add7

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

tests/lora/test_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_single_lora_spmd():
8888
# ensure_model_parallel_initialized(1, 1)
8989

9090
# num_devices = jax.local_device_count() # why does this line cause hanging.
91-
num_devices = 4
91+
num_devices = 8
9292
print(f'xw32 using TP={num_devices}')
9393
llm = setup_vllm(1, num_devices)
9494

tpu_commons/models/jax/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def sharded_ragged_paged_attention(
2121
v_scale: float | None = None,
2222
):
2323
"""Shards along KV heads."""
24+
# nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128)
2425
qkv_spec = P(None, "model", None)
2526
kv_cache_spec = P(None, None, "model")
2627
in_specs = (
@@ -86,6 +87,7 @@ def attention(
8687
md = attention_metadata
8788

8889
# (T, N, H)
90+
# nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128)
8991
output, kv_cache = sharded_ragged_paged_attention(
9092
head_dim_original**-0.5, mesh, attention_chunk_size, q_scale, k_scale,
9193
v_scale)(

0 commit comments

Comments
 (0)