Skip to content

Enable DoMINO parallelization via ShardTensor #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 190 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
ede3f3a
adding multiscale feature to model
Feb 6, 2025
3ce5bfa
bug fixes in model
Feb 9, 2025
844b09a
removing non-dim scaling from domino datapipe
Feb 10, 2025
abc1688
adding tests for multiscale training and inference script
Feb 10, 2025
39814ca
minor fix for surface training blow up
Feb 10, 2025
6fad1f2
fixing bug in inference
Feb 11, 2025
cea5a3f
surface volume radii
Feb 19, 2025
07f8baa
Profiling (#787)
coreyjadams Feb 20, 2025
e735f11
Enable Domain Parallelism with ShardTensor (#784)
coreyjadams Feb 20, 2025
48f2b21
name change
ktangsali Feb 21, 2025
6d7af37
name change docs
ktangsali Feb 21, 2025
ad9aa79
refactoring
Feb 21, 2025
e1a9ce8
This commit addresses two issues:
coreyjadams Feb 21, 2025
88065d3
hyper tuning
Feb 24, 2025
b0b4bab
Merge branch 'enh-domino-multiscale' of https://github.com/RishikeshR…
Feb 24, 2025
b24d9cd
fixing duplication in model
Feb 24, 2025
44b3bb8
Minor fixes and updates to the profiling utility.
coreyjadams Feb 25, 2025
c9ecd53
Add functionality to distributed manager to provide mesh-wide groups.
coreyjadams Feb 25, 2025
dfac94e
Performance enhancements to shard tensor. Not fully optimized yet bu…
coreyjadams Feb 26, 2025
0ce7735
model improvement
Feb 27, 2025
b2067a2
Hot fix - interface for mesh names was incorrect.
coreyjadams Feb 28, 2025
19238db
Small updates to ShardTensor and redistribution methods.
coreyjadams Feb 28, 2025
42af72a
This commit improves the functionality, readability, and maintainabil…
coreyjadams Mar 11, 2025
bfcc018
Add support for a select group of conv_transpose shapes where kernel …
coreyjadams Mar 11, 2025
db3ff9f
Enable group normalization with shard tensor.
coreyjadams Mar 11, 2025
48dab53
Add attention mechanism (scaled_dot_product_attention) to supported S…
coreyjadams Mar 12, 2025
9f83671
Add average pooling fucntionality for select shapes.
coreyjadams Mar 12, 2025
90a2ec0
Enable pooling, normalization, and attention patches when registering…
coreyjadams Mar 12, 2025
942ec2a
Remove printouts ...
coreyjadams Mar 12, 2025
61ae414
Merge branch 'main' into shardTensorFeature
coreyjadams Mar 31, 2025
9f982a0
Merge branch modulus:main into shardTensorFeature
coreyjadams Mar 31, 2025
67bc29d
Merge branch 'NVIDIA:main' into shardTensorFeature
coreyjadams Apr 1, 2025
e67ec56
Merge branch 'NVIDIA:main' into shardTensorFeature
coreyjadams Apr 2, 2025
17e3ffc
This commit addresses issues that arose in the merge of my feature br…
coreyjadams Apr 2, 2025
018e5d9
Add a sharding propagation for aten.select.int.
coreyjadams Apr 2, 2025
4ad0dae
Reorganize the halo and ring message passing to be easier to follow a…
coreyjadams Apr 2, 2025
be2ef41
Merge branch 'shardTensorFeature' of github.com:coreyjadams/physicsne…
coreyjadams Apr 2, 2025
1268634
This commit adds support for Max Pooling, Unpooling via nearest neigh…
coreyjadams Apr 2, 2025
b0a82af
This commit adds tests for RingBallQuery (which is ball query on shar…
coreyjadams Apr 4, 2025
6573fc4
make sure that convolutions and ball query compute output shapes and …
coreyjadams Apr 4, 2025
6bd2791
Add profiling hooks to convolution wrapper and halo padding.
coreyjadams Apr 7, 2025
c8d07a3
adding resampling and geo encoding type
Apr 8, 2025
aab6a71
Merge branch 'NVIDIA:main' into shardTensorFeature
coreyjadams Apr 8, 2025
b6711fd
Disable the `length` variables in BallQuery. They are unused, but st…
coreyjadams Apr 8, 2025
0d9be0d
This commit applies some reorganizations to the ball query layer to e…
coreyjadams Apr 10, 2025
ddb2a00
Merge branch 'NVIDIA:main' into domino_perf
coreyjadams Apr 10, 2025
45c5f08
Caching Mechanism For DoMINO Training (#805)
Mx7f Apr 14, 2025
46ba853
Add the Datacenter use case (#783)
derek-wistron Feb 11, 2025
d889256
Update README.md (#769)
eltociear Feb 11, 2025
a1d3377
Update README.md (#780)
ram-cherukuri Feb 11, 2025
8446c8f
Merge dlwp-healpix updates from modulus-uw (#785)
daviddpruitt Feb 12, 2025
33b0772
Update Dockerfile (#791)
ktangsali Feb 12, 2025
19c0595
Add random walk noise and kinematic mask (#786)
Alexey-Kamenev Feb 12, 2025
4a99170
Adds new Modulus devs to /blossom-ci authorized users (#792)
peterdsharpe Feb 13, 2025
9e0942c
Dockerfile changes to handle onnxruntime dependency (#793)
ktangsali Feb 13, 2025
7ff65c4
Fix NCCL_ASYNC_ERROR_HANDLING deprecation warning (#711)
simonbyrne Mar 17, 2025
0f21709
Profiling (#787)
coreyjadams Feb 20, 2025
13e8838
Enable Domain Parallelism with ShardTensor (#784)
coreyjadams Feb 20, 2025
7446701
name change
ktangsali Feb 21, 2025
65d30dd
name change docs
ktangsali Feb 21, 2025
25d7b69
These two files should not be included in the release. They are gene…
coreyjadams Feb 21, 2025
2e9b69c
RC fixes 1
ktangsali Feb 27, 2025
401b5d9
L-MGN: improve inference
Alexey-Kamenev Feb 27, 2025
b4df071
Remove obsolete config
Alexey-Kamenev Mar 5, 2025
4af6e1c
Docs fixes
ktangsali Mar 7, 2025
d4f07a0
Readme updates
ktangsali Mar 12, 2025
c33909c
Add notice about the rename
ktangsali Mar 13, 2025
03509be
Profiler Fixes. Duplicate of #172
coreyjadams Mar 13, 2025
f5bec4a
backward compatibility fix with old modulus namespace
loliverhennigh Mar 14, 2025
b84044d
Add custom installation of pyspng for arm
ktangsali Mar 14, 2025
71cb4b2
post release updates to version, add migration guide to readme and up…
ktangsali Mar 18, 2025
45181e9
Post rename updates (#816)
ktangsali Mar 19, 2025
b29363f
Initial ReGen model release (#810)
pzharrington Mar 20, 2025
6f44efe
Bug entry point (#818)
loliverhennigh Mar 25, 2025
12beb10
Address pytorch versioning issues. (#820)
coreyjadams Mar 25, 2025
8af07d5
1.0.1 rc rebase (#829)
ktangsali Mar 26, 2025
848e07f
Comment warnings setting (#830)
NickGeneva Mar 31, 2025
f2d034b
Update pyproject.toml links (#832)
coreyjadams Mar 31, 2025
a4f0369
Update README.md reference link (#821)
wangguan1995 Mar 31, 2025
43ab827
Update README.md (#833)
ktangsali Apr 1, 2025
43162ac
Dockerfile Fixes (#835)
NickGeneva Apr 2, 2025
370ccbe
MSC Checkpointing Changes (#789)
chris-hawes Apr 2, 2025
11d6e96
Fixes DeprecationWarning introduced in setuptools>=77 (#837)
peterdsharpe Apr 8, 2025
f253c73
Cordiff usability and performance enhancements for custom dataset tra…
CharlelieLrt Apr 8, 2025
4bf3a06
Update from_checkpoint docs (#843)
pzharrington Apr 9, 2025
5f58986
resolving merge conflicts
Apr 14, 2025
22c769a
fixing minor issues
Apr 14, 2025
35686da
resolving conflicts
Apr 14, 2025
0ec101a
fixing conflicts
Apr 14, 2025
daec55c
Optimizations and efficiency improvements in the domino datapipe. Hi…
coreyjadams Apr 14, 2025
9e93054
Merge remote-tracking branch 'upstream/domino' into domino_perf
coreyjadams Apr 15, 2025
13daff1
Remove obsolete and unused dataclasses - it's a flat config heirarchy…
coreyjadams Apr 15, 2025
610bc2a
This commit enables reading the old-style pickled files by default. …
coreyjadams Apr 16, 2025
473f7d2
Provide more robust reading of pickled files.
coreyjadams Apr 16, 2025
685e79c
fixing bugs
Apr 16, 2025
33bf72c
Fix several small bugs: the dataloader sometimes implicitly uses cupy…
coreyjadams Apr 16, 2025
3309bc3
Fix issue if using CPU data loading.
coreyjadams Apr 16, 2025
2c94eb1
Ensure all gpu preprocessing is directed to the proper device
coreyjadams Apr 16, 2025
0d1412e
Ensure that the dataloader doesn't waste GPU memory. Previously, loa…
coreyjadams Apr 16, 2025
10c42b9
Enable zarr readers. Use file path to toggle which type of file to r…
coreyjadams Apr 17, 2025
e869c8e
Improve logging and track memory leak. Enable zarr.
coreyjadams Apr 17, 2025
b07a897
Add GPU monitoring to the training script, and recreate the knn class…
coreyjadams Apr 17, 2025
21d4689
Update README.md
RishikeshRanade Apr 17, 2025
60068d2
Merge branch 'domino' into domino_perf
coreyjadams Apr 17, 2025
4279259
Merge branch 'domino_perf' into shardTensorFeature
coreyjadams Apr 17, 2025
6fcdb07
Enforce the determinism request in the domino pipeline.
coreyjadams Apr 17, 2025
dac4734
This commit makes an improvement to the zarr reading: reads are now _…
coreyjadams Apr 17, 2025
c05f235
Put ALL zarr chunk reads into futures and thread the IO.
coreyjadams Apr 17, 2025
439deac
Introduce a Sharded data pipeline for DoMINO. This class is construc…
coreyjadams Apr 18, 2025
275d7c5
Merge branch 'domino_perf' into shardTensorFeature
coreyjadams Apr 18, 2025
294636d
Domino perf (#848)
coreyjadams Apr 21, 2025
166272a
bug fix - validation step commented out
Apr 21, 2025
9b91893
Update ball query module to call to the functional interface to leverage
coreyjadams Apr 21, 2025
5e45e4a
Merge remote-tracking branch 'upstream/domino' into shardTensorFeature
coreyjadams Apr 21, 2025
65d5fd9
minor fixes to train.py
Apr 21, 2025
03f2288
Fix CUPY float/int datatype casting. (#852)
coreyjadams Apr 22, 2025
207a578
This commit creates alternative versions of the domino loss functions…
coreyjadams Apr 23, 2025
b9236f7
Remove older loss functions and consolidate script.
coreyjadams Apr 23, 2025
1edf788
Merge branch 'domino_loss_fn' into shardTensorFeature
coreyjadams Apr 23, 2025
3960424
Merge loss function updates.
coreyjadams Apr 23, 2025
a380cf0
Update model.py (#855)
coreyjadams Apr 25, 2025
88dd5de
modifying train.py
Apr 22, 2025
bafd631
minor fixes
Apr 25, 2025
5a52947
Domino Loss Functions (#853)
coreyjadams Apr 25, 2025
0f78214
fourier features to model params and cleanup
Apr 25, 2025
e73314e
modifying train.py
Apr 22, 2025
69b4dd4
minor fixes
Apr 25, 2025
173739e
merging changes in train.py
Apr 25, 2025
ec09d7c
resolving merge conflicts in train.py
Apr 25, 2025
5e52106
Merges `main` branch back into `domino` branch (#856)
peterdsharpe Apr 25, 2025
16deebb
DoMINO Model Refactor (#840)
peterdsharpe Apr 25, 2025
138adac
This commit address a bug in shard tensor: torch.tensor_split and
coreyjadams Apr 28, 2025
b2ab2c0
Ensure the backwards gradient computations uses consistent types.
coreyjadams Apr 28, 2025
3c21174
In ring calculations, the global rank was being used to compute sourc…
coreyjadams Apr 28, 2025
3075207
Implement sharded version of torch's index_select.
coreyjadams Apr 28, 2025
5efead5
Merge branch 'domino' into shardTensorFeature
coreyjadams Apr 28, 2025
5c9ece9
This commit enables the following pieces:
coreyjadams Apr 29, 2025
21b97f6
This commit handles some of the final updates required to enable full…
coreyjadams Apr 30, 2025
d0aa534
Add profiling hooks to the domino model.
coreyjadams Apr 30, 2025
19a025a
updating model and fixing bug in datapipe
May 6, 2025
084e96f
This is the last commit enabling sharding. The model is fully compat…
coreyjadams May 6, 2025
2867344
Update the domino readme to include information on domain parallelism.
coreyjadams May 6, 2025
08b3352
Explicit Warp device allocation for SDF and Ball Query (#876)
mnabian May 7, 2025
4983fb8
A few fixes for the domino pipeline. (#863)
coreyjadams May 9, 2025
63b4161
Add first draft of domain parallelism detailed tutorial.
coreyjadams May 9, 2025
b4046d4
Update two pieces of shard tensor:
coreyjadams May 9, 2025
cd8ec87
Add annotations and docstrings to sharded reduction operators
coreyjadams May 12, 2025
620e184
Domino merge from `main` (#888)
coreyjadams May 12, 2025
1b728b8
Remove `wrapt` usage from all but one patch (and it's coming next.)
coreyjadams May 13, 2025
d53d21e
Add tests to verify that shard tensor operations do not trigger on to…
coreyjadams May 13, 2025
1e543a5
Automatically enable all shard tensor execution paths, now that they
coreyjadams May 13, 2025
e71fa1e
Add first draft of tutorial for extending shard tensor.
coreyjadams May 13, 2025
8201bba
Update tests to accomodate new domino model. Minor tweaks to domino …
coreyjadams May 14, 2025
064d29a
Merge branch 'main' into domino
coreyjadams May 14, 2025
e3a40dd
fixing minor bugs
May 14, 2025
02276bc
Exclusively fix linting errors. (#895)
coreyjadams May 14, 2025
8e7c774
Merge branch 'main' into domino
coreyjadams May 14, 2025
5c7fe0b
Domino datapipe test (#896)
coreyjadams May 14, 2025
0ebd3a2
Merge branch 'main' into domino
coreyjadams May 14, 2025
0be2cd5
Merge branch 'main' into domino
coreyjadams May 15, 2025
42f7611
Fix ruff error.
coreyjadams May 15, 2025
93cb7d7
Remove numpy conversion since sdf now returns a numpy array directly
coreyjadams May 15, 2025
289cd49
Enable cupy usage in computing scaling factors.
coreyjadams May 15, 2025
e0ca512
Merge remote-tracking branch 'upstream/domino' into shardTensorFeature
coreyjadams May 15, 2025
500c29e
Add a test on consequtive reductions, which was really failing when d…
coreyjadams May 15, 2025
2e01406
Update training scripts slightly: move nvinit from pynvml to after di…
coreyjadams May 15, 2025
407e3a5
Add more context to the tutorial on implementing custom domain parall…
coreyjadams May 15, 2025
cbe79e4
Ensure sharded ops are mapped, and update domain parallel tutorial
coreyjadams May 16, 2025
d087552
Merge branch 'main' into shardTensorFeature
coreyjadams May 19, 2025
8770c08
Slight reorganization of scripts in the tutorials to make testing and…
coreyjadams May 19, 2025
272508c
Fix typos and links and minor details in domain parallel tutorials.
coreyjadams May 19, 2025
160770a
Remove file that should not be present
coreyjadams May 19, 2025
f026476
Remove file that should not be present (2)
coreyjadams May 19, 2025
efe773e
Update ahmed_body.ipynb
coreyjadams May 19, 2025
412674b
Update profiling.rst
coreyjadams May 19, 2025
2c4d9bf
Update test.py
coreyjadams May 19, 2025
1f1649b
Ensure download links are not messed up ...
coreyjadams May 19, 2025
aa130e9
Update download_dataset.sh
coreyjadams May 19, 2025
f4b173c
Fix typo:p primative fixed to primitive.
coreyjadams May 20, 2025
675d18b
Increase verbosity of comments in tutorial scripts
coreyjadams May 20, 2025
78e079a
Update examples/cfd/external_aerodynamics/domino/README.md
coreyjadams May 20, 2025
927d32f
Fix bug in shard tensor dispatch.
coreyjadams May 21, 2025
911df60
Merge branch 'shardTensorFeature' of github.com:coreyjadams/physicsne…
coreyjadams May 21, 2025
966ac40
Ensure domino returns properly after removing comparison tools.
coreyjadams May 21, 2025
5be00bc
Fix failing test in CI...hopefully.
coreyjadams May 21, 2025
92647a2
Resolve comments from PR Review.
coreyjadams May 21, 2025
a06273b
Fix broken import in tests
coreyjadams May 21, 2025
ff7a967
Missing multigpu marker
coreyjadams May 21, 2025
b1dfea3
Fix broken imports for shard tensor dispatch bindings
coreyjadams May 21, 2025
1aa482b
tensors need to be contiguous for NCCL
coreyjadams May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ NVIDIA PhysicsNeMo Core (Latest Release)
tutorials/simple_training_example.rst
tutorials/simple_logging_and_checkpointing.rst
tutorials/profiling.rst
tutorials/fsdp_and_shard_tensor.rst
tutorials/domain_parallelism_entry_point.rst

.. toctree::
:maxdepth: 2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import time

# This time, let's make two moderately large tensors since we'll have to, at least briefly,
# construct a tensor of their point-by-point difference.
N_points_to_search = 234_567
N_target_points = 12_345
num_neighbors = 17


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# We'll make these 3D tensors to represent 3D points
a = torch.randn(N_points_to_search, 3, device=device)
b = torch.randn(N_target_points, 3, device=device)

def knn(x, y, n):
# Return the n nearest neighbors in x for each point in y.
# Returns the

# First, compute the pairwise difference between all points in x and y.
displacement_vec = x[None, :, :] - y[:, None, :]

# Use the norm to compute the distance:
distance = torch.norm(displacement_vec, dim=2)

distances, indices = torch.topk(distance, k=n, dim=1, largest=False)

x_results = x[indices]
# distance = distances[indices]

return x_results, distances

y_neighbors_to_x, neighbor_disances = knn(a,b, num_neighbors)
print(y_neighbors_to_x.shape) # should be (N_target_points, num_neighbors, 3)
print(neighbor_disances.shape) # should be (N_target_points, num_neighbors)

# run a couple times to warmup:
for i in range(5):
_ = knn(a,b, num_neighbors)

# Optional: Benchmark it if you like:

# Measure execution time
torch.cuda.synchronize()
start_time = time.time()
for i in range(10):
_ = knn(a,b, num_neighbors)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time

print(f"Execution time for 10 runs: {elapsed_time:.4f} seconds")
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import torch
import torch.distributed as dist
from torch.overrides import handle_torch_function, has_torch_function
import time

from physicsnemo.distributed import DistributedManager, scatter_tensor, ShardTensor
from torch.distributed.tensor.placement_types import Shard, Replicate

from physicsnemo.distributed.shard_utils.ring import perform_ring_iteration, RingPassingConfig

# This time, let's make two moderately large tensors since we'll have to, at least briefly,
# construct a tensor of their point-by-point difference.
N_points_to_search = 234_567
N_target_points = 12_345
num_neighbors = 17

DistributedManager.initialize()
dm = DistributedManager()

device = dm.device



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# We'll make these 3D tensors to represent 3D points
a = torch.randn(N_points_to_search, 3, device=device)
b = torch.randn(N_target_points, 3, device=device)

def knn(x, y, n):
# This is to enable torch to track this knn function and route it correctly in ShardTensor:
if has_torch_function((x, y)):
return handle_torch_function(
knn, (x, y), x, y, n
)

# Return the n nearest neighbors in x for each point in y.

# First, compute the pairwise difference between all points in x and y.
displacement_vec = x[None, :, :] - y[:, None, :]

# Use the norm to compute the distance:
distance = torch.norm(displacement_vec, dim=2)

distances, indices = torch.topk(distance, k=n, dim=1, largest=False)

x_results = x[indices]

return x_results, distances

# Get the baseline result
y_neighbors_to_x, neighbor_disances = knn(a,b, num_neighbors)

if dm.rank == 0:

print(y_neighbors_to_x.shape) # should be (N_target_points, num_neighbors, 3)
print(neighbor_disances.shape) # should be (N_target_points, num_neighbors)

# DeviceMesh is a pytorch object - you can initialize it directly, or for added
# flexibility physicsnemo can infer up to one mesh dimension for you
# (as a -1, like in a tensor.reshape() call...)
mesh = dm.initialize_mesh(mesh_shape = [-1,], mesh_dim_names = ["domain"])
# Shard(i) indicates we want the final tensor to be sharded along the tensor dimension i
# But the placements is a tuple or list, indicating the desired placement along the mesh.
placements = (Shard(0),)
# This function will distribute the tensor from global_src to the specified mesh,
# using the input placements.
# Note that in multi-level parallelism, the source is the _global_ rank not the mesh group rank.
a_sharded = scatter_tensor(tensor = a, global_src = 0, mesh = mesh, placements = placements)
b_sharded = scatter_tensor(tensor = b, global_src = 0, mesh = mesh, placements = placements)


def knn_ring(func, types, args, kwargs):
# Wrapper to intercept knn and compute it in a ring.
# Never fully realizes the distance product.

def extract_args(x, y, n, *args, **kwargs):
return x, y, n
x, y, n = extract_args(*args, **kwargs)


# Each tensor has a _spec attribute, which contains information about the tensor's placement
# and the devices it lives on:
x_spec = x._spec
y_spec = y._spec

# ** In general ** you want to do some checking on the placements, since each
# point cloud might be sharded differently. By construction, I know they're both
# sharded along the points axis here (and not, say, replicated).

if not x_spec.mesh == y_spec.mesh:
raise NotImplementedError("Tensors must be sharded on the same mesh")

mesh = x_spec.mesh
local_group = mesh.get_group(0)
local_size = dist.get_world_size(group=local_group)
mesh_rank = mesh.get_local_rank()

# x and y are both sharded - and since we're returning the nearest
# neighbors to x, let's make sure the output keeps that sharding too.

# One memory-efficient way to do this is with with a ring computation.
# We'll compute the knn on the local tensors, get the distances and outputs,
# then shuffle the y shards along the mesh.

# we'll need to sort the results and make sure we have just the top-k,
# which is a little extra computation.

# Physics nemo has a ring passing utility we can use.
ring_config = RingPassingConfig(
mesh_dim = 0,
mesh_size = local_size,
ring_direction = "forward",
communication_method = "p2p"
)

local_x, local_y = x.to_local(), y.to_local()
current_dists = None
current_topk_y = None

x_sharding_shapes = x._spec.sharding_shapes()[0]


for i in range(local_size):
source_rank = (mesh_rank - i) % local_size

# For point clouds, we need to pass the size of the incoming shard.
next_source_rank = (source_rank - 1) % local_size
recv_shape = x_sharding_shapes[next_source_rank]
if i != local_size - 1:
# Don't do a ring on the last iteration.
next_local_x = perform_ring_iteration(
local_x,
mesh,
ring_config,
recv_shape=recv_shape,
)

# Compute the knn on the local tensors:
local_x_results, local_distances = func(local_x, local_y, n)


if current_dists is None:
current_dists = local_distances
current_topk_y = local_x_results
else:
# Combine with the topk so far:
current_dists = torch.cat([current_dists, local_distances], dim=1)
current_topk_y = torch.cat([current_topk_y, local_x_results], dim=1)
# And take the topk again:
current_dists, running_indexes = torch.topk(current_dists, k=n, dim=1, largest=False)

# This creates proper indexing to select specific elements along dim 1
current_topk_y = torch.gather(current_topk_y, 1,
running_indexes.unsqueeze(-1).expand(-1, -1, 3))



if i != local_size - 1:
# Don't do a ring on the last iteration.
local_x = next_local_x

# Finally, return the outputs as ShardTensors.
topk_y = ShardTensor.from_local(
current_topk_y,
device_mesh = mesh,
placements = y._spec.placements,
sharding_shapes = y._spec.sharding_shapes(),
)

distances = ShardTensor.from_local(
current_dists,
device_mesh = mesh,
placements = y._spec.placements,
sharding_shapes = y._spec.sharding_shapes(),
)

return topk_y, distances


ShardTensor.register_function_handler(knn, knn_ring)

# Get the sharded result
y_neighbors_to_x_sharded, neighbor_disances_sharded = knn(a_sharded,b_sharded, num_neighbors)

# Check for agreement:
y_neighbors_to_x_sharded = y_neighbors_to_x_sharded.full_tensor()
neighbor_disances_sharded = neighbor_disances_sharded.full_tensor()

if dm.rank == 0:
print(f"Neighbors agreement? {torch.allclose(y_neighbors_to_x, y_neighbors_to_x_sharded)}")
print(f"Distances agreement? {torch.allclose(neighbor_disances, neighbor_disances_sharded)}")


# run a couple times to warmup:
for i in range(5):
_ = knn(a_sharded,b_sharded, num_neighbors)
# Optional: Benchmark it if you like:

# Measure execution time
torch.cuda.synchronize()
start_time = time.time()
for i in range(10):
_ = knn(a_sharded,b_sharded, num_neighbors)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time

if dm.rank == 0:
print(f"Execution time for 10 runs: {elapsed_time:.4f} seconds")
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.distributed as dist
import time

from physicsnemo.distributed import DistributedManager, scatter_tensor, ShardTensor
from torch.distributed.tensor.placement_types import Shard, Replicate

from physicsnemo.distributed.shard_utils.ring import perform_ring_iteration, RingPassingConfig

# This time, let's make two moderately large tensors since we'll have to, at least briefly,
# construct a tensor of their point-by-point difference.
N_points_to_search = 234_567
N_target_points = 12_345
num_neighbors = 17

DistributedManager.initialize()
dm = DistributedManager()

# We'll make these 3D tensors to represent 3D points
a = torch.randn(N_points_to_search, 3, device=dm.device)
b = torch.randn(N_target_points, 3, device=dm.device)

def knn(x, y, n):
# Return the n nearest neighbors in x for each point in y.

# First, compute the pairwise difference between all points in x and y.
displacement_vec = x[None, :, :] - y[:, None, :]

# Use the norm to compute the distance:
distance = torch.norm(displacement_vec, dim=2)

distances, indices = torch.topk(distance, k=n, dim=1, largest=False)

x_results = x[indices]

return x_results, distances

# Get the baseline result
y_neighbors_to_x, neighbor_distances = knn(a, b, num_neighbors)

if dm.rank == 0:
print(y_neighbors_to_x.shape) # should be (N_target_points, num_neighbors, 3)
print(neighbor_distances.shape) # should be (N_target_points, num_neighbors)

# DeviceMesh is a pytorch object - you can initialize it directly, or for added
# flexibility physicsnemo can infer up to one mesh dimension for you
# (as a -1, like in a tensor.reshape() call...)
mesh = dm.initialize_mesh(mesh_shape = [-1,], mesh_dim_names = ["domain"])
# Shard(i) indicates we want the final tensor to be sharded along the tensor dimension i
# But the placements is a tuple or list, indicating the desired placement along the mesh.
placements = (Shard(0),)
# This function will distribute the tensor from global_src to the specified mesh,
# using the input placements.
# Note that in multi-level parallelism, the source is the _global_ rank not the mesh group rank.
a_sharded = scatter_tensor(tensor = a, global_src = 0, mesh = mesh, placements = placements)
b_sharded = scatter_tensor(tensor = b, global_src = 0, mesh = mesh, placements = placements)

# Get the sharded result
y_neighbors_to_x_sharded, neighbor_distances_sharded = knn(a_sharded, b_sharded, num_neighbors)

# Check for agreement:
y_neighbors_to_x_sharded = y_neighbors_to_x_sharded.full_tensor()
neighbor_distances_sharded = neighbor_distances_sharded.full_tensor()

if dm.rank == 0:
# Note - do the ``full_tensor`` call outside this if-block or it will hang!
print(f"Neighbors agreement? {torch.allclose(y_neighbors_to_x, y_neighbors_to_x_sharded)}")
print(f"Distances agreement? {torch.allclose(neighbor_distances, neighbor_distances_sharded)}")

# run a couple times to warmup:
for i in range(5):
_ = knn(a_sharded, b_sharded, num_neighbors)

# Optional: Benchmark it if you like:
# Measure execution time
torch.cuda.synchronize()
start_time = time.time()
for i in range(10):
_ = knn(a_sharded, b_sharded, num_neighbors)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time

if dm.rank == 0:
print(f"Execution time for 10 runs: {elapsed_time:.4f} seconds")

Loading