-
Notifications
You must be signed in to change notification settings - Fork 460
Transolver domain parallel #1142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transolver domain parallel #1142
Conversation
…er ++ Need to fix 2D and 3D cases still and validate sharding.
tests to validate shard behavior in transolver
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general LGTM, just minor comments/clarifications. I think there is just one spot the docstring should be updated to match the new tensor shapes. The inline comments are helpful in explaining what is actually happening internally and why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
/blossom-ci |
/blossom-ci |
PhysicsNeMo Pull Request
This PR implements adjustments to transolver to enable better, and more efficient, domain parallelism. Batch size > 1 locally will not be supported for domain parallelism.
The key detail of these changes are to rearrange the data shapes. Transolver takes inputs of shape
[Batch, Tokens, Features]
. Previously, this was being converted to shapes of[Batch, N_head, N_Tokens, Head_dim]
. This pr essentialy replaces that with[Batch, N_tokens, N_Head, Head_dim]
everywhere relevant. The motivation for this is data contiguousness: keeping the rearrange on the tensor dims with dim > sharded_dim is a 0-communication op, and everything stays nicely sharded.Note as well, the attention mechanism is actually not parallel here (not that it impacts performance). Check out the Transolver++ paper for details of their algorithm, but it's actually really easy to implement with ShardTensor - it just works.
https://arxiv.org/pdf/2502.02414v1
The idea is, the slice tokens are computed from a matmul of two pieces:
normed_weights
, andfx
, and the matmul contracts over the tokens dimension (which is the sharded dimension). The output is aPartial
placement, meaning each rank has a properly-shaped replica of the tokens but need to be all-reduced.So, right before the attention layer, I implement a redistribute to a
Replicate()
placement which triggers the differentiable all reduce. After the attention, the multiplication ofout_slice_tokens
andslice_weights
produces a properly sharded outputs from the PhysicsAttention. It's equivalent to this section of the upstream code (from the ++ model):https://github.com/thuml/Transolver_plus/blob/main/models/Transolver_plus.py#L71-L73
(That code has two reductions; we also have two, computing the normed_weights will automatically all-reduce)
There is a boost to the attention path, too, in ShardTensor: it was computing ring Attention on replicated tensors which is slow (and wrong!) and now it is using a local attention on replicated tensors.
With these adjustments, I can efficiently scale the transolver model to about 3.6M input points on 8xA100.
Transolver ++
With these changes, we're actually only a hair's breadth away from having Transolver++. I have the remaining updates on a separate branch where I want to validate training accuracy.
Transformer Engine
Is not supported with domain parallelism. The ops go through a different dispatch path - we need to implement that at some point, but it's not here.
Tests
I added a new section of multi-gpu tests that are validating transolver end to end with shard tensor inputs. It checks the outputs have the right shape, and the outputs are sharded properly.
Changelog
I haven't updated it yet, I would like to bring Transolver ++ in, in it's entirety, and do it in one shot.
Description
Checklist
Dependencies
Greptile Overview
Updated On: 2025-10-21 00:04:50 UTC
Greptile Summary
This PR optimizes Transolver for domain parallelism by restructuring internal tensor dimensions from
[Batch, N_head, N_Tokens, Head_dim]
to[Batch, N_tokens, N_Head, Head_dim]
. This reordering maintains data contiguity on sharded dimensions, enabling zero-communication rearrange operations during distributed execution. The Physics Attention layer now performs a strategicPartial → Replicate
redistribution before attention to trigger automatic all-reduce on slice tokens, matching the Transolver++ algorithm. A fast-path optimization was added to skip ring attention when tensors are replicated. These changes enable efficient scaling to 3.6M input points on 8xA100 GPUs. Note that batch size > 1 locally is not supported for domain parallelism, and Transformer Engine remains incompatible with this approach.Changed Files
test/models/test_transolver.py
test/distributed/shard_tensor/models/transolver.py
physicsnemo/models/transolver/Physics_Attention.py
physicsnemo/distributed/shard_utils/attention_patches.py
Confidence score: 3/5
Physics_Attention.py
changes (content not provided), the correctness of reshape operations in test code (lines 72-73 intransolver.py
), and potential edge cases around the replicated-tensor fast-path that weren't explicitly tested in the provided changes. Lack of CHANGELOG update (acknowledged by author) also indicates incomplete review readiness.physicsnemo/models/transolver/Physics_Attention.py
(no content provided in summary) and the reshape operations intest/distributed/shard_tensor/models/transolver.py
lines 72-73; verify these maintain correct sharding semantics for multi-dimensional spatial domains.Sequence Diagram