Skip to content

Conversation

coreyjadams
Copy link
Collaborator

@coreyjadams coreyjadams commented Oct 3, 2025

PhysicsNeMo Pull Request

This PR implements adjustments to transolver to enable better, and more efficient, domain parallelism. Batch size > 1 locally will not be supported for domain parallelism.

The key detail of these changes are to rearrange the data shapes. Transolver takes inputs of shape [Batch, Tokens, Features]. Previously, this was being converted to shapes of [Batch, N_head, N_Tokens, Head_dim]. This pr essentialy replaces that with [Batch, N_tokens, N_Head, Head_dim] everywhere relevant. The motivation for this is data contiguousness: keeping the rearrange on the tensor dims with dim > sharded_dim is a 0-communication op, and everything stays nicely sharded.

Note as well, the attention mechanism is actually not parallel here (not that it impacts performance). Check out the Transolver++ paper for details of their algorithm, but it's actually really easy to implement with ShardTensor - it just works.

https://arxiv.org/pdf/2502.02414v1

The idea is, the slice tokens are computed from a matmul of two pieces: normed_weights, and fx, and the matmul contracts over the tokens dimension (which is the sharded dimension). The output is a Partial placement, meaning each rank has a properly-shaped replica of the tokens but need to be all-reduced.

So, right before the attention layer, I implement a redistribute to a Replicate() placement which triggers the differentiable all reduce. After the attention, the multiplication of out_slice_tokens and slice_weights produces a properly sharded outputs from the PhysicsAttention. It's equivalent to this section of the upstream code (from the ++ model):
https://github.com/thuml/Transolver_plus/blob/main/models/Transolver_plus.py#L71-L73

(That code has two reductions; we also have two, computing the normed_weights will automatically all-reduce)

There is a boost to the attention path, too, in ShardTensor: it was computing ring Attention on replicated tensors which is slow (and wrong!) and now it is using a local attention on replicated tensors.

With these adjustments, I can efficiently scale the transolver model to about 3.6M input points on 8xA100.

Transolver ++

With these changes, we're actually only a hair's breadth away from having Transolver++. I have the remaining updates on a separate branch where I want to validate training accuracy.

Transformer Engine

Is not supported with domain parallelism. The ops go through a different dispatch path - we need to implement that at some point, but it's not here.

Tests

I added a new section of multi-gpu tests that are validating transolver end to end with shard tensor inputs. It checks the outputs have the right shape, and the outputs are sharded properly.

Changelog

I haven't updated it yet, I would like to bring Transolver ++ in, in it's entirety, and do it in one shot.

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Greptile Overview

Updated On: 2025-10-21 00:04:50 UTC

Greptile Summary

This PR optimizes Transolver for domain parallelism by restructuring internal tensor dimensions from [Batch, N_head, N_Tokens, Head_dim] to [Batch, N_tokens, N_Head, Head_dim]. This reordering maintains data contiguity on sharded dimensions, enabling zero-communication rearrange operations during distributed execution. The Physics Attention layer now performs a strategic Partial → Replicate redistribution before attention to trigger automatic all-reduce on slice tokens, matching the Transolver++ algorithm. A fast-path optimization was added to skip ring attention when tensors are replicated. These changes enable efficient scaling to 3.6M input points on 8xA100 GPUs. Note that batch size > 1 locally is not supported for domain parallelism, and Transformer Engine remains incompatible with this approach.

Changed Files
Filename Score Overview
test/models/test_transolver.py 5/5 Removed debug print statements from transolver2d forward test
test/distributed/shard_tensor/models/transolver.py 4/5 Added new multi-GPU tests validating domain-parallel execution with ShardTensor for structured and irregular geometries; reshape operations on lines 72-73 may need scrutiny for sharding semantics
physicsnemo/models/transolver/Physics_Attention.py 0/5 Unable to assess - file content not provided in change summary
physicsnemo/distributed/shard_utils/attention_patches.py 5/5 Added replicated-tensor fast-path in scaled dot product attention to bypass ring attention for efficiency

Confidence score: 3/5

  • This PR has specialized distributed computing logic that requires careful validation; tensor reshaping and placement transitions could introduce subtle correctness issues if sharding semantics aren't perfectly preserved.
  • Score reflects uncertainty about the Physics_Attention.py changes (content not provided), the correctness of reshape operations in test code (lines 72-73 in transolver.py), and potential edge cases around the replicated-tensor fast-path that weren't explicitly tested in the provided changes. Lack of CHANGELOG update (acknowledged by author) also indicates incomplete review readiness.
  • Pay close attention to physicsnemo/models/transolver/Physics_Attention.py (no content provided in summary) and the reshape operations in test/distributed/shard_tensor/models/transolver.py lines 72-73; verify these maintain correct sharding semantics for multi-dimensional spatial domains.

Sequence Diagram

sequenceDiagram
    participant User
    participant Transolver
    participant PhysicsAttention
    participant ShardTensor
    participant RingSDPA
    participant SDPA

    User->>Transolver: forward(embedding, functional_input)
    Note over Transolver: Input shape: [Batch, N_tokens, Features]
    
    Transolver->>ShardTensor: scatter_tensor(placements=Shard(1))
    Note over ShardTensor: Shard along tokens dimension
    
    Transolver->>PhysicsAttention: forward(sharded_input)
    Note over PhysicsAttention: Shape: [B, N_tokens, N_heads, Head_dim]<br/>(tokens dimension sharded)
    
    PhysicsAttention->>PhysicsAttention: project_input_onto_slices(x)
    Note over PhysicsAttention: Conv2D/Conv3D/Linear projection
    
    PhysicsAttention->>PhysicsAttention: compute_slices_from_projections(slice_proj, fx)
    Note over PhysicsAttention: matmul contracts over tokens<br/>produces Partial placement
    
    PhysicsAttention->>ShardTensor: redistribute(placements=[Replicate()])
    Note over ShardTensor: Differentiable all-reduce before attention
    
    alt use_te == True
        PhysicsAttention->>PhysicsAttention: compute_slice_attention_te(slice_tokens)
        Note over PhysicsAttention: TransformerEngine path<br/>NOT supported with domain parallel
    else use_te == False
        PhysicsAttention->>SDPA: scaled_dot_product_attention(q, k, v)
        Note over SDPA: Check if replicated or sharded
        
        alt placement.is_replicate()
            SDPA->>SDPA: Local attention on replicated tensors
            Note over SDPA: Fast path - no ring communication
        else placement.is_shard()
            SDPA->>RingSDPA: ring_sdpa(q, k, v)
            
            loop For each rank in mesh
                RingSDPA->>RingSDPA: efficient_attention(q, local_k, local_v)
                Note over RingSDPA: Compute on local block
                RingSDPA->>RingSDPA: stable_signed_accumulate(log_output, sign)
                Note over RingSDPA: Accumulate in log space
                RingSDPA->>RingSDPA: perform_ring_iteration(k, v)
                Note over RingSDPA: Send k,v to next rank
            end
            
            RingSDPA->>RingSDPA: exp(log_output - global_log_sumexp)
            Note over RingSDPA: Final normalization
        end
    end
    
    PhysicsAttention->>PhysicsAttention: project_attention_outputs(out_slice, weights)
    Note over PhysicsAttention: Replicated slice_tokens × sharded weights<br/>= sharded output
    
    PhysicsAttention-->>Transolver: sharded output [B, N_tokens, N_heads*Head_dim]
    Transolver-->>User: output (properly sharded)
Loading

@coreyjadams coreyjadams self-assigned this Oct 13, 2025
Copy link
Collaborator

@pzharrington pzharrington left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general LGTM, just minor comments/clarifications. I think there is just one spot the docstring should be updated to match the new tensor shapes. The inline comments are helpful in explaining what is actually happening internally and why.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

@coreyjadams coreyjadams merged commit ad35426 into NVIDIA:main Oct 21, 2025
1 check passed
@coreyjadams coreyjadams deleted the transolver_domain_parallel branch October 21, 2025 14:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants