Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion physicsnemo/distributed/shard_utils/attention_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,43 @@ def sdpa_wrapper(func: Callable, types: Any, args: tuple, kwargs: dict) -> Shard
if q._spec.mesh.ndim != 1:
raise MissingShardPatch("q must be on a 1D mesh")

return ring_sdpa(q, k, v, attn_mask, **kwargs)
# This is to implement sequence-parallel attention.
# Make sure the shardings are all the same:
if not (q._spec.placements[0] == k._spec.placements[0] == v._spec.placements[0]):
raise MissingShardPatch("q, k, and v must all be on the same placement")

# Make sure the attention mask, if provided, has the same placement as q, k, and v
if attn_mask is not None and hasattr(attn_mask, "_spec"):
if attn_mask._spec.placements[0] != q._spec.placements[0]:
raise MissingShardPatch(
"attn_mask must have the same placement as q, k, and v"
)

# if the placements are replicated (which is what we expect in transolver's
# Physics Attention)
# then just run locally and convert the output back to a replicated tensor:

if v._spec.placements[0].is_replicate():
local_q = q.to_local()
local_k = k.to_local()
local_v = v.to_local()
if attn_mask is not None:
local_attn_mask = attn_mask.to_local()
else:
local_attn_mask = None
local_output = torch.nn.functional.scaled_dot_product_attention(
local_q, local_k, local_v, attn_mask=local_attn_mask, **kwargs
)

output = ShardTensor.from_local(
local_output,
q._spec.mesh,
q._spec.placements,
# We don't have to worry about sharding shapes here since it's not sharded ...
)
return output
else:
return ring_sdpa(q, k, v, attn_mask, **kwargs)


def repackage_sdpa_args(
Expand Down
Loading