Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c3ab5db
swap vesin for torch cell list
abhijeetgangan Nov 20, 2025
5cbe090
correct shape for pbc tensor
abhijeetgangan Nov 20, 2025
45634e4
use linked cell for nequip
abhijeetgangan Nov 20, 2025
aa2d0f2
use linked cell for sevennet
abhijeetgangan Nov 20, 2025
95a60d9
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Nov 22, 2025
e1a9736
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Nov 22, 2025
b8f33d1
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Nov 23, 2025
7f510c3
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Nov 26, 2025
d06551e
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Nov 27, 2025
4bf3385
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Dec 2, 2025
4186552
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Dec 10, 2025
12ed04e
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Dec 10, 2025
967e147
Refactor to match the torch_nl api
abhijeetgangan Dec 10, 2025
6ade410
Merge branch 'ag/batch_nl_api' of https://github.com/TorchSim/torch-s…
abhijeetgangan Dec 10, 2025
f911f8c
fix nl for classical potentials
abhijeetgangan Dec 10, 2025
b93a416
fix nl for graphpes
abhijeetgangan Dec 10, 2025
08a9f67
add cuda batch nl
abhijeetgangan Dec 11, 2025
58ae254
make sure the pbc and cell tensor is contiguous
abhijeetgangan Dec 11, 2025
532647f
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Dec 11, 2025
44efe9e
Minor cleanup
abhijeetgangan Dec 11, 2025
4ed6d3d
fix cell shape
abhijeetgangan Dec 11, 2025
bc79385
Merge branch 'main' into ag/batch_nl_api
abhijeetgangan Dec 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ classifiers = [
requires-python = ">=3.12"
dependencies = [
"h5py>=3.12.1",
"nvalchemi-toolkit-ops",
"numpy>=1.26,<3",
"tables>=3.10.2",
"torch>=2",
"tqdm>=4.67",
"vesin-torch>=0.4.0, <0.5.0",
"vesin>=0.4.0, <0.5.0",
]

[project.optional-dependencies]
Expand Down
229 changes: 159 additions & 70 deletions tests/test_neighbors.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion torch_sim/models/graphpes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra
# model's cutoff value. To ensure no strange edge effects whereby
# edges that are exactly `cutoff` long are included/excluded,
# we bump cutoff + 1e-5 up slightly
nl, shifts = torchsim_nl(R, cell, state.pbc, cutoff + 1e-5)

# Create system_idx for this single system (all atoms belong to system 0)
system_idx_single = torch.zeros(R.shape[0], dtype=torch.long, device=R.device)
nl, _system_mapping, shifts = torchsim_nl(
R, cell, state.pbc, cutoff + 1e-5, system_idx_single
)

atomic_graph = AtomicGraph(
Z=Z.long(),
Expand Down
20 changes: 15 additions & 5 deletions torch_sim/models/lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,27 @@ def unbatched_forward(
pbc = state.pbc

if self.use_neighbor_list:
# Get neighbor list using vesin_nl_ts
mapping, shifts = torchsim_nl(
# Get neighbor list using torchsim_nl
# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
sort_id=False,
system_idx=system_idx,
)
# Get displacements using neighbor list
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts
positions=positions,
cell=cell,
pbc=pbc,
pairs=(mapping[0], mapping[1]),
shifts=shifts_idx,
)
else:
# Get all pairwise displacements
Expand Down
45 changes: 13 additions & 32 deletions torch_sim/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
indicating which system each atom belongs to. If not provided with
atomic_numbers, all atoms are assumed to be in the same system.
neighbor_list_fn (Callable): Function to compute neighbor lists.
Defaults to vesin_nl_ts.
Defaults to torch_nl_linked_cell.
compute_forces (bool): Whether to compute forces. Defaults to True.
compute_stress (bool): Whether to compute stress. Defaults to True.
enable_cueq (bool): Whether to enable CuEq acceleration. Defaults to False.
Expand Down Expand Up @@ -298,37 +298,18 @@ def forward( # noqa: C901
):
self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx)

# Process each system's neighbor list separately
edge_indices = []
shifts_list = []
unit_shifts_list = []
offset = 0

# TODO (AG): Currently doesn't work for batched neighbor lists
for sys_idx in range(self.n_systems):
system_mask = sim_state.system_idx == sys_idx
# Calculate neighbor list for this system
edge_idx, shifts_idx = self.neighbor_list_fn(
positions=sim_state.positions[system_mask],
cell=sim_state.row_vector_cell[sys_idx],
pbc=sim_state.pbc,
cutoff=self.r_max,
)

# Adjust indices for the system
edge_idx = edge_idx + offset
shifts = torch.mm(shifts_idx, sim_state.row_vector_cell[sys_idx])

edge_indices.append(edge_idx)
unit_shifts_list.append(shifts_idx)
shifts_list.append(shifts)

offset += len(sim_state.positions[system_mask])

# Combine all neighbor lists
edge_index = torch.cat(edge_indices, dim=1)
unit_shifts = torch.cat(unit_shifts_list, dim=0)
shifts = torch.cat(shifts_list, dim=0)
# Batched neighbor list using linked-cell algorithm
edge_index, mapping_system, unit_shifts = self.neighbor_list_fn(
sim_state.positions,
sim_state.row_vector_cell,
sim_state.pbc,
self.r_max,
sim_state.system_idx,
)
# Convert unit cell shift indices to Cartesian shifts
shifts = ts.transforms.compute_cell_shifts(
sim_state.row_vector_cell, unit_shifts, mapping_system
)

# Build data dict for MACE model
data_dict = dict(
Expand Down
15 changes: 11 additions & 4 deletions torch_sim/models/morse.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,26 @@ def unbatched_forward(
pbc = sim_state.pbc

if self.use_neighbor_list:
mapping, shifts = torchsim_nl(
# Ensure system_idx exists (create if None for single system)
system_idx = (
sim_state.system_idx
if sim_state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
sort_id=False,
system_idx=system_idx,
)
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
cell=cell,
pbc=pbc,
pairs=mapping,
shifts=shifts,
pairs=(mapping[0], mapping[1]),
shifts=shifts_idx,
)
else:
dr_vec, distances = transforms.get_pair_displacements(
Expand Down
41 changes: 9 additions & 32 deletions torch_sim/models/nequip_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class NequIPFrameworkModel(ModelInterface):
device (torch.device | None): Device to run calculations on.
Defaults to CUDA if available, otherwise CPU.
neighbor_list_fn (Callable): Function to compute neighbor lists.
Defaults to vesin_nl_ts.
Defaults to torch_nl_linked_cell.
atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms].
If provided at initialization, cannot be provided again during forward pass.
system_idx (torch.Tensor | None): Batch indices with shape [n_atoms] indicating
Expand Down Expand Up @@ -304,37 +304,14 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: #
):
self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx)

# Process each system's neighbor list separately
edge_indices = []
shifts_list = []
unit_shifts_list = []
offset = 0

# TODO (AG): Currently doesn't work for batched neighbor lists
for sys_idx in range(self.n_systems):
system_idx_mask = sim_state.system_idx == sys_idx
# Calculate neighbor list for this system
edge_idx, shifts_idx = self.neighbor_list_fn(
positions=sim_state.positions[system_idx_mask],
cell=sim_state.row_vector_cell[sys_idx],
pbc=sim_state.pbc,
cutoff=self.r_max,
)

# Adjust indices for the batch
edge_idx = edge_idx + offset
shifts = torch.mm(shifts_idx, sim_state.row_vector_cell[sys_idx])

edge_indices.append(edge_idx)
unit_shifts_list.append(shifts_idx)
shifts_list.append(shifts)

offset += len(sim_state.positions[system_idx_mask])

# Combine all neighbor lists
edge_index = torch.cat(edge_indices, dim=1)
unit_shifts = torch.cat(unit_shifts_list, dim=0)
shifts = torch.cat(shifts_list, dim=0)
# Batched neighbor list using linked-cell algorithm
edge_index, _mapping_system, unit_shifts = self.neighbor_list_fn(
sim_state.positions,
sim_state.row_vector_cell,
sim_state.pbc,
self.r_max,
sim_state.system_idx,
)
atomic_types = ChemicalSpeciesToAtomTypeMapper(self.type_names)(
sim_state.atomic_numbers
)
Expand Down
22 changes: 14 additions & 8 deletions torch_sim/models/particle_life.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,27 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
cell = cell.squeeze(0) # Squeeze the first dimension

if self.use_neighbor_list:
# Get neighbor list using wrapping_nl
mapping, shifts = torchsim_nl(
# Get neighbor list using torchsim_nl
# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
cell=cell,
pbc=pbc,
cutoff=float(self.cutoff),
sort_id=False,
cutoff=self.cutoff,
system_idx=system_idx,
)
# Get displacements using neighbor list
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
cell=cell,
pbc=pbc,
pairs=mapping,
shifts=shifts,
pairs=(mapping[0], mapping[1]),
shifts=shifts_idx,
)
else:
# Get all pairwise displacements
Expand All @@ -180,7 +186,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
mask = distances < self.cutoff
# Get valid pairs - match neighbor list convention for pair order
i, j = torch.where(mask)
mapping = torch.stack([j, i]) # Changed from [j, i] to [i, j]
mapping = torch.stack([j, i])
# Get valid displacements and distances
dr_vec = dr_vec[mask]
distances = distances[mask]
Expand Down
50 changes: 33 additions & 17 deletions torch_sim/models/sevennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24'
(OMat24).
neighbor_list_fn (Callable): Neighbor list function to use.
Default is vesin_nl_ts.
Default is torch_nl_linked_cell.
device (torch.device | str | None): Device to run the model on
dtype (torch.dtype): Data type for computation

Expand Down Expand Up @@ -191,27 +191,43 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:
# TODO: is this clone necessary?
sim_state = sim_state.clone()

# Batched neighbor list using linked-cell algorithm with row-vector cell
n_systems = sim_state.system_idx.max().item() + 1
edge_index, mapping_system, unit_shifts = self.neighbor_list_fn(
sim_state.positions,
sim_state.row_vector_cell,
sim_state.pbc,
self.cutoff,
sim_state.system_idx,
)

# Build per-system SevenNet AtomGraphData by slicing the global NL
n_atoms_per_system = sim_state.system_idx.bincount()
stride = torch.cat(
(
torch.tensor([0], device=self.device, dtype=torch.long),
n_atoms_per_system.cumsum(0),
)
)

data_list = []
for sys_idx in range(sim_state.system_idx.max().item() + 1):
system_mask = sim_state.system_idx == sys_idx
for sys_idx in range(n_systems):
sys_start = stride[sys_idx].item()
sys_end = stride[sys_idx + 1].item()

pos = sim_state.positions[system_mask]
# SevenNet uses row vector cell convention for neighbor list
pos = sim_state.positions[sys_start:sys_end]
row_vector_cell = sim_state.row_vector_cell[sys_idx]
pbc = sim_state.pbc
atomic_nums = sim_state.atomic_numbers[system_mask]

edge_idx, shifts_idx = self.neighbor_list_fn(
positions=pos,
cell=row_vector_cell,
pbc=pbc,
cutoff=self.cutoff,
)
atomic_nums = sim_state.atomic_numbers[sys_start:sys_end]

mask = mapping_system == sys_idx
edge_idx_sys_global = edge_index[:, mask]
unit_shifts_sys = unit_shifts[mask]

shifts = torch.mm(shifts_idx, row_vector_cell)
# Convert global indices to local indices
edge_idx = edge_idx_sys_global - sys_start
shifts = torch.mm(unit_shifts_sys, row_vector_cell)
edge_vec = pos[edge_idx[1]] - pos[edge_idx[0]] + shifts
vol = torch.det(row_vector_cell)
# vol = vol if vol > 0.0 else torch.tensor(np.finfo(float).eps)

data = {
key.NODE_FEATURE: atomic_nums,
Expand All @@ -220,7 +236,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:
key.EDGE_IDX: edge_idx,
key.EDGE_VEC: edge_vec,
key.CELL: row_vector_cell,
key.CELL_SHIFT: shifts_idx,
key.CELL_SHIFT: unit_shifts_sys,
key.CELL_VOLUME: vol,
key.NUM_ATOMS: torch.tensor(len(atomic_nums), device=self.device),
key.DATA_MODALITY: self.modal,
Expand Down
32 changes: 21 additions & 11 deletions torch_sim/models/soft_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,27 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
pbc = state.pbc

if self.use_neighbor_list:
# Get neighbor list using vesin_nl_ts
mapping, shifts = torchsim_nl(
# Get neighbor list using torchsim_nl
# Ensure system_idx exists (create if None for single system)
system_idx = (
state.system_idx
if state.system_idx is not None
else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
cell=cell,
pbc=pbc,
cutoff=self.cutoff,
sort_id=False,
system_idx=system_idx,
)
# Get displacements between neighbor pairs
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
cell=cell,
pbc=pbc,
pairs=mapping,
shifts=shifts,
pairs=(mapping[0], mapping[1]),
shifts=shifts_idx,
)

else:
Expand Down Expand Up @@ -710,20 +716,24 @@ def unbatched_forward( # noqa: PLR0915
# Compute neighbor list or full distance matrix
if self.use_neighbor_list:
# Get neighbor list for efficient computation
mapping, shifts = torchsim_nl(
# Ensure system_idx exists (create if None for single system)
system_idx = torch.zeros(
positions.shape[0], dtype=torch.long, device=self.device
)
mapping, system_mapping, shifts_idx = torchsim_nl(
positions=positions,
cell=cell,
pbc=self.pbc,
cutoff=self.cutoff,
sort_id=False,
system_idx=system_idx,
)
# Get displacements between neighbor pairs
# Pass shifts_idx directly - get_pair_displacements will convert them
dr_vec, distances = transforms.get_pair_displacements(
positions=positions,
cell=cell,
pbc=self.pbc,
pairs=mapping,
shifts=shifts,
pairs=(mapping[0], mapping[1]),
shifts=shifts_idx,
)

else:
Expand Down
Loading
Loading