diff --git a/physicsnemo/distributed/shard_utils/attention_patches.py b/physicsnemo/distributed/shard_utils/attention_patches.py index 9aeab05d38..ea92a95b6a 100644 --- a/physicsnemo/distributed/shard_utils/attention_patches.py +++ b/physicsnemo/distributed/shard_utils/attention_patches.py @@ -649,7 +649,43 @@ def sdpa_wrapper(func: Callable, types: Any, args: tuple, kwargs: dict) -> Shard if q._spec.mesh.ndim != 1: raise MissingShardPatch("q must be on a 1D mesh") - return ring_sdpa(q, k, v, attn_mask, **kwargs) + # This is to implement sequence-parallel attention. + # Make sure the shardings are all the same: + if not (q._spec.placements[0] == k._spec.placements[0] == v._spec.placements[0]): + raise MissingShardPatch("q, k, and v must all be on the same placement") + + # Make sure the attention mask, if provided, has the same placement as q, k, and v + if attn_mask is not None and hasattr(attn_mask, "_spec"): + if attn_mask._spec.placements[0] != q._spec.placements[0]: + raise MissingShardPatch( + "attn_mask must have the same placement as q, k, and v" + ) + + # if the placements are replicated (which is what we expect in transolver's + # Physics Attention) + # then just run locally and convert the output back to a replicated tensor: + + if v._spec.placements[0].is_replicate(): + local_q = q.to_local() + local_k = k.to_local() + local_v = v.to_local() + if attn_mask is not None: + local_attn_mask = attn_mask.to_local() + else: + local_attn_mask = None + local_output = torch.nn.functional.scaled_dot_product_attention( + local_q, local_k, local_v, attn_mask=local_attn_mask, **kwargs + ) + + output = ShardTensor.from_local( + local_output, + q._spec.mesh, + q._spec.placements, + # We don't have to worry about sharding shapes here since it's not sharded ... + ) + return output + else: + return ring_sdpa(q, k, v, attn_mask, **kwargs) def repackage_sdpa_args( diff --git a/physicsnemo/models/transolver/Physics_Attention.py b/physicsnemo/models/transolver/Physics_Attention.py index ca87ec4a2f..2c793aaf59 100644 --- a/physicsnemo/models/transolver/Physics_Attention.py +++ b/physicsnemo/models/transolver/Physics_Attention.py @@ -36,6 +36,10 @@ import torch.nn as nn import transformer_engine.pytorch as te # noqa: F401 from einops import rearrange +from torch.autograd.profiler import record_function +from torch.distributed.tensor.placement_types import Replicate + +from physicsnemo.distributed import ShardTensor class PhysicsAttentionBase(nn.Module, ABC): @@ -65,7 +69,7 @@ def __init__(self, dim, heads, dim_head, dropout, slice_num, use_te): self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) - self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + self.temperature = nn.Parameter(torch.ones([1, 1, heads, 1]) * 0.5) self.use_te = use_te if self.use_te: @@ -108,17 +112,22 @@ def compute_slices_from_projections( """ Compute slice weights and slice tokens from input projections and latent features. + In a domain-parallel setting, this function will do an implicit allreduce. + When we sum over the slice_weights over a sharded dimension + and use the output, it will resolve Partial->Replicated placement (aka + allreduce) implicitly. + Args: slice_projections (torch.Tensor): - The projected input tensor of shape [Batch, N_heads, N_tokens, Slice_num], + The projected input tensor of shape [Batch, N_tokens, N_heads, Slice_num], representing the projection of each token onto each slice for each attention head. fx (torch.Tensor): - The latent feature tensor of shape [Batch, N_heads, N_tokens, Head_dim], + The latent feature tensor of shape [Batch, N_tokens, N_heads, Head_dim], representing the learned states to be aggregated by the slice weights. Returns: tuple[torch.Tensor, torch.Tensor]: - - slice_weights: Tensor of shape [Batch, N_heads, N_tokens, Slice_num], + - slice_weights: Tensor of shape [Batch, N_tokens, N_heads, Slice_num], representing the normalized weights for each slice per token and head. - slice_token: Tensor of shape [Batch, N_heads, Slice_num, Head_dim], representing the aggregated latent features for each slice, head, and batch. @@ -129,28 +138,48 @@ def compute_slices_from_projections( - The aggregated features are normalized by the sum of weights for numerical stability. """ - # Project the latent space vectors on to the weight computation space, - # and compute a temperature adjusted softmax. - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) # [Batch, N_heads, N_tokens, Slice_num] + with record_function("compute_slices_from_projections"): + # Project the latent space vectors on to the weight computation space, + # and compute a temperature adjusted softmax. + clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) # [Batch, N_tokens, N_heads, Slice_num] + + # Cast to the computation type (since the parameter is probably fp32) + slice_weights = slice_weights.to(slice_projections.dtype) - # Cast to the computation type (since the parameter is probably fp32) - slice_weights = slice_weights.to(slice_projections.dtype) + # This does the projection of the latent space fx by the weights: - # This does the projection of the latent space fx by the weights: + # Computing the slice tokens is a matmul followed by a normalization. + # It can, unfortunately, overflow in reduced precision, so normalize first: + slice_norm = slice_weights.sum(1) # [Batch, N_heads, Slice_num] + # Sharded note: slice_norm will be a partial sum at this point. + # That's because the we're summing over the tokens, which are distributed + normed_weights = slice_weights / (slice_norm[:, None, :, :] + 1e-2) + # Normed weights has shape + # (batch, n_tokens, n_heads, slice_num) - # Computing the slice tokens is a matmul followed by a normalization. - # It can, unfortunately, overflow in reduced precision, so normalize first: - slice_norm = slice_weights.sum(2) # [Batch, N_heads, Slice_num] - normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) - slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) + # Sharded note: normed_weights will resolve the partial slice_norm + # and the output normed_weights will be sharded. + # fx has shape (Batch, n_tokens, n_heads, head_dim) + # This matmul needs to contract over the tokens + # This should produce an output with shape + # [Batch, N_heads, Slice_num, Head_dim] - # Return the original weights, not the normed weights: - return slice_weights, slice_token + # Like the weight norm, this sum is a **partial** sum since we are summing + # over the tokens + + slice_token = torch.matmul( + normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) + ) + + # Return the original weights, not the normed weights: + + return slice_weights, slice_token def compute_slice_attention_te(self, slice_tokens: torch.Tensor) -> torch.Tensor: """ @@ -171,18 +200,33 @@ def compute_slice_attention_te(self, slice_tokens: torch.Tensor) -> torch.Tensor def compute_slice_attention_sdpa(self, slice_tokens: torch.Tensor) -> torch.Tensor: """ Torch SDPA implementation of slice attention + + Args: + slice_tokens (torch.Tensor): + The slice tokens tensor of shape [Batch, N_heads, Slice_num, Head_dim]. + + Returns: + torch.Tensor: + The output tensor of shape [Batch, N_heads, Slice_num, Head_dim]. """ + with record_function("compute_slice_attention_sdpa"): + # In this case we're using ShardTensor, ensure slice_token is *replicated* - qkv = self.qkv_project(slice_tokens) - qkv = rearrange(qkv, " b h s (t d) -> t b h s d", t=3, d=self.dim_head) + qkv = self.qkv_project(slice_tokens) - q_slice_token, k_slice_token, v_slice_token = qkv.unbind(0) + qkv = rearrange(qkv, " b h s (t d) -> b h s t d", t=3, d=self.dim_head) - out_slice_token3 = torch.nn.functional.scaled_dot_product_attention( - q_slice_token, k_slice_token, v_slice_token, is_causal=False - ) + if isinstance(qkv, ShardTensor): + # This will be a differentiable allreduce + qkv = qkv.redistribute(placements=[Replicate()]) + + q_slice_token, k_slice_token, v_slice_token = qkv.unbind(3) + + out_slice_token = torch.nn.functional.scaled_dot_product_attention( + q_slice_token, k_slice_token, v_slice_token, is_causal=False + ) - return out_slice_token3 + return out_slice_token def project_attention_outputs( self, out_slice_token: torch.Tensor, slice_weights: torch.Tensor @@ -190,12 +234,15 @@ def project_attention_outputs( """ Project the attended slice tokens back onto the original token space. + Note that in the distributed case, this will have a replicated and + sharded inputs. Slice tokens will be replicated, and slice weights will be sharded. + Args: out_slice_token (torch.Tensor): The output tensor from the attention mechanism over slices, of shape [Batch, N_heads, Slice_num, Head_dim]. slice_weights (torch.Tensor): - The slice weights tensor of shape [Batch, N_heads, N_tokens, Slice_num], + The slice weights tensor of shape [Batch, N_tokens, N_heads, Slice_num], representing the contribution of each slice to each token. Returns: @@ -207,11 +254,21 @@ def project_attention_outputs( - The function projects the attended slice tokens back to the token space using the slice weights. - The output is reshaped to concatenate all attention heads for each token. """ + with record_function("project_attention_outputs"): + # Slice weights has shape (Batch, n_tokens, n_heads, slice_num) + # Out slice tokens has shape (Batch, n_heads, slice_num, head_dim) + # The output of this function needs to have shape + # (Batch, n_tokens, n_channels) == (Batch, n_tokens, n_heads * head_dim) + # Note that tokens may be sharded, in which case slice_weights + # is a sharded tensor and out_slice_token is a replicated tensor + + out_x = torch.einsum("bths,bhsd->bthd", slice_weights, out_slice_token) + + # Condense the last two dimensions: + out_x = rearrange(out_x, "b t h d -> b t (h d)") - out_x = torch.matmul(slice_weights, out_slice_token) - out_x = rearrange(out_x, "b h n d -> b n (h d)") - out_x = self.out_linear(out_x) - return self.out_dropout(out_x) + out_x = self.out_linear(out_x) + return self.out_dropout(out_x) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -220,35 +277,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Input x should have shape of [Batch, N_tokens, N_Channels] ([B, N, C]) """ - # Project the inputs onto learned spaces: - x_mid, fx_mid = self.project_input_onto_slices(x) + with record_function("forward"): + # Project the inputs onto learned spaces: + x_mid, fx_mid = self.project_input_onto_slices(x) - # Perform the linear projection of learned latent space onto slices: - slice_projections = self.in_project_slice(x_mid) + # Perform the linear projection of learned latent space onto slices: - # Slice projections has shape [B, N_head, N_tokens, Head_dim], but head_dim may have changed! + slice_projections = self.in_project_slice(x_mid) - # Use the slice projections and learned spaces to compute the slices, and their weights: - slice_weights, slice_tokens = self.compute_slices_from_projections( - slice_projections, fx_mid - ) - # slice_weights has shape [Batch, N_heads, N_tokens, Slice_num] - # slice_tokens has shape [Batch, N_heads, N_tokens, head_dim] + # Slice projections has shape [B, N_tokens, N_head, Head_dim], but head_dim may have changed! - # Apply attention to the slice tokens - if self.use_te: - out_slice_token = self.compute_slice_attention_te(slice_tokens) - else: - out_slice_token = self.compute_slice_attention_sdpa(slice_tokens) + # Use the slice projections and learned spaces to compute the slices, and their weights: + slice_weights, slice_tokens = self.compute_slices_from_projections( + slice_projections, fx_mid + ) + # slice_weights has shape [Batch, N_tokens, N_heads, Slice_num] + # slice_tokens has shape [Batch, N_tokens, N_heads, head_dim] + + # Apply attention to the slice tokens + if self.use_te: + out_slice_token = self.compute_slice_attention_te(slice_tokens) + else: + out_slice_token = self.compute_slice_attention_sdpa(slice_tokens) - # Shape unchanged + # Shape unchanged - # Deslice: - outputs = self.project_attention_outputs(out_slice_token, slice_weights) + # Deslice: + outputs = self.project_attention_outputs(out_slice_token, slice_weights) - # Outputs now has the same shape as the original input x + # Outputs now has the same shape as the original input x - return outputs + return outputs class PhysicsAttentionIrregularMesh(PhysicsAttentionBase): @@ -271,12 +330,19 @@ def __init__( def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: """ Project the input onto the slice space. + + Args: + x (torch.Tensor): The input tensor of shape [Batch, N_tokens, N_Channels] + + Returns: + tuple[torch.Tensor, torch.Tensor]: The projected x and fx tensors of shape [Batch, N_tokens, N_Channels], [Batch, N_tokens, N_heads, Head_dim] + """ - fx_mid = rearrange( - self.in_project_fx(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head - ) + fx = self.in_project_fx(x) + fx_mid = rearrange(fx, "B N (h d) -> B N h d", h=self.heads, d=self.dim_head) + x_mid = rearrange( - self.in_project_x(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head + self.in_project_x(x), "B N (h d) -> B N h d", h=self.heads, d=self.dim_head ) return x_mid, fx_mid @@ -285,6 +351,8 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): """ Specialization for 2d image-like meshes + + Only implements the projection onto the slice space. """ def __init__( @@ -309,7 +377,11 @@ def __init__( def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: # Rearrange the input tokens back to an image shape: - x = rearrange(x, "b (h w) c -> b c h w", h=self.H, w=self.W) + b = x.shape[0] + c = x.shape[-1] + + x = x.view(b, self.H, self.W, c) + x = x.permute(0, 3, 1, 2) # Apply the projections, here they are convolutions in 2D: input_projected_fx = self.in_project_fx(x) @@ -318,13 +390,13 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: # Next, re-reshape the projections into token-like shapes: input_projected_fx = rearrange( input_projected_fx, - "b (n_heads head_dim) h w -> b n_heads (h w) head_dim", + "b (n_heads head_dim) h w -> b (h w) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) input_projected_x = rearrange( input_projected_x, - "b (n_heads head_dim) h w -> b n_heads (h w) head_dim", + "b (n_heads head_dim) h w -> b (h w) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) @@ -336,6 +408,8 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): """ Specialization for 3D-image like meshes + + Only implements the projection onto the slice space. """ def __init__( @@ -362,9 +436,16 @@ def __init__( def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: """ Project the input onto the slice space. + + Input tensor has shape [Batch, N_tokens, N_Channels] """ - x = rearrange(x, "b (h w) c -> b c h w", h=self.H, w=self.W) + b = x.shape[0] + c = x.shape[-1] + + # x = rearrange(x, "b (h w d) c -> b c h w d", h=self.H, w=self.W, d=self.D) + x = x.view(b, self.H, self.W, self.D, c) + x = x.permute(0, 4, 1, 2, 3) # Apply the projections, here they are convolutions: input_projected_fx = self.in_project_fx(x) @@ -373,13 +454,13 @@ def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: # Next, re-reshape the projections into token-like shapes: input_projected_fx = rearrange( input_projected_fx, - "b (n_heads head_dim) h w -> b n_heads (h w) head_dim", + "b (n_heads head_dim) h w d-> b (h w d) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) input_projected_x = rearrange( input_projected_x, - "b (n_heads head_dim) h w -> b n_heads (h w) head_dim", + "b (n_heads head_dim) h w d -> b (h w d) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) diff --git a/test/distributed/shard_tensor/models/transolver.py b/test/distributed/shard_tensor/models/transolver.py new file mode 100644 index 0000000000..8e75f7dd00 --- /dev/null +++ b/test/distributed/shard_tensor/models/transolver.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import pytest +import torch +from torch.distributed.tensor import distribute_module +from torch.distributed.tensor.placement_types import Shard + +from physicsnemo.distributed import DistributedManager, scatter_tensor +from physicsnemo.models.transolver import Transolver + + +@pytest.mark.multigpu_static +@pytest.mark.parametrize("n_dims", [2, 3]) +def test_transolver_nd_distributed( + distributed_mesh, + n_dims, +): + """Test transolver 2D and 3D distributed forward pass""" + + dm = DistributedManager() + + spatial_dims = (128,) * n_dims + + # Construct transolver model + model = Transolver( + structured_shape=spatial_dims, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=3, + embedding_dim=5, + out_dim=2, + slice_num=32, + ref=1, + unified_pos=False, + use_te=False, + ).to(dm.device) + + # Create data: + image_embedding = torch.randn(1, *spatial_dims, 5).to(dm.device) + functional_input = torch.randn(1, *spatial_dims, 3).to(dm.device) + + # Scatter the data + placements = (Shard(1),) + + sharded_image_embedding = scatter_tensor( + image_embedding, 0, distributed_mesh, placements, requires_grad=False + ) + sharded_functional_input = scatter_tensor( + functional_input, 0, distributed_mesh, placements, requires_grad=False + ) + + sharded_image_embedding = sharded_image_embedding.reshape(1, -1, 5) + sharded_functional_input = sharded_functional_input.reshape(1, -1, 3) + + model = distribute_module(model, device_mesh=distributed_mesh) + + # Run model + output = model(sharded_image_embedding, sharded_functional_input) + + # Check output + assert output.shape == (1, numpy.prod(spatial_dims), 2) + + # Make sure the output is sharded, too: + assert output._spec.placements == sharded_image_embedding._spec.placements + + +@pytest.mark.multigpu_static +def test_transolver_irregular_distributed( + distributed_mesh, +): + """Test transolver irregular distributed forward pass""" + + dm = DistributedManager() + + spatial_dims = (16384,) + + # Construct transolver model + model = Transolver( + structured_shape=None, + n_layers=8, + n_hidden=64, + dropout=0, + n_head=4, + time_input=False, + act="gelu", + mlp_ratio=1, + functional_dim=3, + embedding_dim=5, + out_dim=2, + slice_num=32, + ref=1, + unified_pos=False, + use_te=False, + ).to(dm.device) + + # Create data: + image_embedding = torch.randn(1, *spatial_dims, 5).to(dm.device) + functional_input = torch.randn(1, *spatial_dims, 3).to(dm.device) + + # Scatter the data + placements = (Shard(1),) + + sharded_image_embedding = scatter_tensor( + image_embedding, 0, distributed_mesh, placements, requires_grad=False + ) + sharded_functional_input = scatter_tensor( + functional_input, 0, distributed_mesh, placements, requires_grad=False + ) + + # Distribute the model to DTensor: + model = distribute_module(model, device_mesh=distributed_mesh) + + # Run model + output = model(sharded_image_embedding, sharded_functional_input) + + # Check output + assert output.shape == (1, *spatial_dims, 2) + + # Make sure the output is sharded, too: + assert output._spec.placements == sharded_image_embedding._spec.placements diff --git a/test/models/test_transolver.py b/test/models/test_transolver.py index 66985bab9f..6040bd87a6 100644 --- a/test/models/test_transolver.py +++ b/test/models/test_transolver.py @@ -61,11 +61,6 @@ def test_transolver2d_forward(device): fx = torch.randn(bsize, 85 * 85, 1).to(device) embedding = torch.randn(bsize, 85, 85).to(device) - print(f"fx: {fx.shape}") - print(f"embedding: {embedding.shape}") - - print(f"output shape: {model(fx, embedding).shape}") - assert validate_forward_accuracy( model, (