From c3ab5dbc3c060ce10caca66ef05f463ce73ccc72 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Thu, 20 Nov 2025 09:32:51 -0800 Subject: [PATCH 01/11] swap vesin for torch cell list --- torch_sim/models/mace.py | 49 ++++++++++++---------------------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index be7b3914..71bb50fe 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -28,7 +28,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import vesin_nl_ts +from torch_sim.neighbors import torch_nl_linked_cell from torch_sim.typing import StateDict @@ -107,7 +107,7 @@ def __init__( *, device: torch.device | None = None, dtype: torch.dtype = torch.float64, - neighbor_list_fn: Callable = vesin_nl_ts, + neighbor_list_fn: Callable = torch_nl_linked_cell, compute_forces: bool = True, compute_stress: bool = True, enable_cueq: bool = False, @@ -133,7 +133,7 @@ def __init__( indicating which system each atom belongs to. If not provided with atomic_numbers, all atoms are assumed to be in the same system. neighbor_list_fn (Callable): Function to compute neighbor lists. - Defaults to vesin_nl_ts. + Defaults to torch_nl_linked_cell. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress. Defaults to True. enable_cueq (bool): Whether to enable CuEq acceleration. Defaults to False. @@ -298,37 +298,18 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # ): self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) - # Process each system's neighbor list separately - edge_indices = [] - shifts_list = [] - unit_shifts_list = [] - offset = 0 - - # TODO (AG): Currently doesn't work for batched neighbor lists - for sys_idx in range(self.n_systems): - system_mask = sim_state.system_idx == sys_idx - # Calculate neighbor list for this system - edge_idx, shifts_idx = self.neighbor_list_fn( - positions=sim_state.positions[system_mask], - cell=sim_state.row_vector_cell[sys_idx], - pbc=sim_state.pbc, - cutoff=self.r_max, - ) - - # Adjust indices for the system - edge_idx = edge_idx + offset - shifts = torch.mm(shifts_idx, sim_state.row_vector_cell[sys_idx]) - - edge_indices.append(edge_idx) - unit_shifts_list.append(shifts_idx) - shifts_list.append(shifts) - - offset += len(sim_state.positions[system_mask]) - - # Combine all neighbor lists - edge_index = torch.cat(edge_indices, dim=1) - unit_shifts = torch.cat(unit_shifts_list, dim=0) - shifts = torch.cat(shifts_list, dim=0) + # Batched neighbor list using linked-cell algorithm + edge_index, mapping_system, unit_shifts = torch_nl_linked_cell( + sim_state.positions, + sim_state.row_vector_cell, + sim_state.pbc, + self.r_max, + sim_state.system_idx, + ) + # Convert unit cell shift indices to Cartesian shifts + shifts = ts.transforms.compute_cell_shifts( + sim_state.row_vector_cell, unit_shifts, mapping_system + ) # Get model output out = self.model( From 5cbe09047d23c7a6e8e264b43b7538ae21b8faeb Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Thu, 20 Nov 2025 09:43:45 -0800 Subject: [PATCH 02/11] correct shape for pbc tensor --- torch_sim/models/mace.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 71bb50fe..5e1dfbd6 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -299,10 +299,15 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) # Batched neighbor list using linked-cell algorithm + pbc_tensor = ( + sim_state.pbc.repeat(self.n_systems, 1) + if sim_state.pbc.ndim == 1 + else sim_state.pbc + ) edge_index, mapping_system, unit_shifts = torch_nl_linked_cell( sim_state.positions, sim_state.row_vector_cell, - sim_state.pbc, + pbc_tensor, self.r_max, sim_state.system_idx, ) From 45634e4f6613828fcaad450d5a226626f7c8fb17 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Thu, 20 Nov 2025 09:55:43 -0800 Subject: [PATCH 03/11] use linked cell for nequip --- torch_sim/models/mace.py | 2 +- torch_sim/models/nequip_framework.py | 50 +++++++++------------------- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 5e1dfbd6..34b39e9b 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -304,7 +304,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # if sim_state.pbc.ndim == 1 else sim_state.pbc ) - edge_index, mapping_system, unit_shifts = torch_nl_linked_cell( + edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( sim_state.positions, sim_state.row_vector_cell, pbc_tensor, diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py index 89f1ce56..8096f019 100644 --- a/torch_sim/models/nequip_framework.py +++ b/torch_sim/models/nequip_framework.py @@ -25,7 +25,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import vesin_nl_ts +from torch_sim.neighbors import torch_nl_linked_cell from torch_sim.typing import StateDict @@ -150,7 +150,7 @@ class NequIPFrameworkModel(ModelInterface): device (torch.device | None): Device to run calculations on. Defaults to CUDA if available, otherwise CPU. neighbor_list_fn (Callable): Function to compute neighbor lists. - Defaults to vesin_nl_ts. + Defaults to torch_nl_linked_cell. atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms]. If provided at initialization, cannot be provided again during forward pass. system_idx (torch.Tensor | None): Batch indices with shape [n_atoms] indicating @@ -165,7 +165,7 @@ def __init__( r_max: float, type_names: list[str], device: torch.device | None = None, - neighbor_list_fn: Callable = vesin_nl_ts, + neighbor_list_fn: Callable = torch_nl_linked_cell, atomic_numbers: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> None: @@ -304,37 +304,19 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # ): self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) - # Process each system's neighbor list separately - edge_indices = [] - shifts_list = [] - unit_shifts_list = [] - offset = 0 - - # TODO (AG): Currently doesn't work for batched neighbor lists - for sys_idx in range(self.n_systems): - system_idx_mask = sim_state.system_idx == sys_idx - # Calculate neighbor list for this system - edge_idx, shifts_idx = self.neighbor_list_fn( - positions=sim_state.positions[system_idx_mask], - cell=sim_state.row_vector_cell[sys_idx], - pbc=sim_state.pbc, - cutoff=self.r_max, - ) - - # Adjust indices for the batch - edge_idx = edge_idx + offset - shifts = torch.mm(shifts_idx, sim_state.row_vector_cell[sys_idx]) - - edge_indices.append(edge_idx) - unit_shifts_list.append(shifts_idx) - shifts_list.append(shifts) - - offset += len(sim_state.positions[system_idx_mask]) - - # Combine all neighbor lists - edge_index = torch.cat(edge_indices, dim=1) - unit_shifts = torch.cat(unit_shifts_list, dim=0) - shifts = torch.cat(shifts_list, dim=0) + # Batched neighbor list using linked-cell algorithm (row-vector cell convention) + pbc_tensor = ( + sim_state.pbc.repeat(self.n_systems, 1) + if sim_state.pbc.ndim == 1 + else sim_state.pbc + ) + edge_index, _mapping_system, unit_shifts = self.neighbor_list_fn( + sim_state.positions, + sim_state.row_vector_cell, + pbc_tensor, + self.r_max, + sim_state.system_idx, + ) atomic_types = ChemicalSpeciesToAtomTypeMapper(self.type_names)( sim_state.atomic_numbers ) From aa2d0f2bb2005bb3e9f0864d3e3b0fcab48486db Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Thu, 20 Nov 2025 10:14:09 -0800 Subject: [PATCH 04/11] use linked cell for sevennet --- torch_sim/models/sevennet.py | 59 ++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 90c56892..23a12b8b 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -12,7 +12,7 @@ import torch_sim as ts from torch_sim.elastic import voigt_6_to_full_3x3_stress from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import vesin_nl_ts +from torch_sim.neighbors import torch_nl_linked_cell if TYPE_CHECKING: @@ -85,7 +85,7 @@ def __init__( model: AtomGraphSequential | str | Path, *, # force remaining arguments to be keyword-only modal: str | None = None, - neighbor_list_fn: Callable = vesin_nl_ts, + neighbor_list_fn: Callable = torch_nl_linked_cell, device: torch.device | str | None = None, dtype: torch.dtype = torch.float32, ) -> None: @@ -102,7 +102,7 @@ def __init__( for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24). neighbor_list_fn (Callable): Neighbor list function to use. - Default is vesin_nl_ts. + Default is torch_nl_linked_cell. device (torch.device | str | None): Device to run the model on dtype (torch.dtype): Data type for computation @@ -191,27 +191,48 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # TODO: is this clone necessary? sim_state = sim_state.clone() + # Batched neighbor list using linked-cell algorithm with row-vector cell + n_systems = sim_state.system_idx.max().item() + 1 + pbc_tensor = ( + sim_state.pbc.repeat(n_systems, 1) + if sim_state.pbc.ndim == 1 + else sim_state.pbc + ) + edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( + sim_state.positions, + sim_state.row_vector_cell, + pbc_tensor, + self.cutoff, + sim_state.system_idx, + ) + + # Build per-system SevenNet AtomGraphData by slicing the global NL + n_atoms_per_system = sim_state.system_idx.bincount() + stride = torch.cat( + ( + torch.tensor([0], device=self.device, dtype=torch.long), + n_atoms_per_system.cumsum(0), + ) + ) + data_list = [] - for sys_idx in range(sim_state.system_idx.max().item() + 1): - system_mask = sim_state.system_idx == sys_idx + for sys_idx in range(n_systems): + sys_start = stride[sys_idx].item() + sys_end = stride[sys_idx + 1].item() - pos = sim_state.positions[system_mask] - # SevenNet uses row vector cell convention for neighbor list + pos = sim_state.positions[sys_start:sys_end] row_vector_cell = sim_state.row_vector_cell[sys_idx] - pbc = sim_state.pbc - atomic_nums = sim_state.atomic_numbers[system_mask] - - edge_idx, shifts_idx = self.neighbor_list_fn( - positions=pos, - cell=row_vector_cell, - pbc=pbc, - cutoff=self.cutoff, - ) + atomic_nums = sim_state.atomic_numbers[sys_start:sys_end] + + mask = mapping_system == sys_idx + edge_idx_sys_global = edge_index[:, mask] + unit_shifts_sys = unit_shifts[mask] - shifts = torch.mm(shifts_idx, row_vector_cell) + # Convert global indices to local indices + edge_idx = edge_idx_sys_global - sys_start + shifts = torch.mm(unit_shifts_sys, row_vector_cell) edge_vec = pos[edge_idx[1]] - pos[edge_idx[0]] + shifts vol = torch.det(row_vector_cell) - # vol = vol if vol > 0.0 else torch.tensor(np.finfo(float).eps) data = { key.NODE_FEATURE: atomic_nums, @@ -220,7 +241,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: key.EDGE_IDX: edge_idx, key.EDGE_VEC: edge_vec, key.CELL: row_vector_cell, - key.CELL_SHIFT: shifts_idx, + key.CELL_SHIFT: unit_shifts_sys, key.CELL_VOLUME: vol, key.NUM_ATOMS: torch.tensor(len(atomic_nums), device=self.device), key.DATA_MODALITY: self.modal, From 967e147ce69a5acde56c8a7e2bc2aa4bb39a996d Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 10 Dec 2025 08:59:29 -0800 Subject: [PATCH 05/11] Refactor to match the torch_nl api --- tests/test_neighbors.py | 120 ++--- torch_sim/neighbors.py | 880 -------------------------------- torch_sim/neighbors/__init__.py | 117 +++++ torch_sim/neighbors/standard.py | 547 ++++++++++++++++++++ torch_sim/neighbors/torch_nl.py | 230 +++++++++ torch_sim/neighbors/vesin.py | 328 ++++++++++++ 6 files changed, 1284 insertions(+), 938 deletions(-) delete mode 100644 torch_sim/neighbors.py create mode 100644 torch_sim/neighbors/__init__.py create mode 100644 torch_sim/neighbors/standard.py create mode 100644 torch_sim/neighbors/torch_nl.py create mode 100644 torch_sim/neighbors/vesin.py diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 886f2181..fc20cecf 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -247,7 +247,7 @@ def test_neighbor_list_implementations( *, cutoff: float, atoms_list: str, - nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor]], + nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], request: pytest.FixtureRequest, ) -> None: """Check that different neighbor list implementations give the same results as ASE @@ -261,15 +261,20 @@ def test_neighbor_list_implementations( row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE) + # Create system_idx for single system (all atoms belong to system 0) + system_idx = torch.zeros(len(pos), dtype=torch.long, device=DEVICE) + # Get the neighbor list from the implementation being tested - mapping, shifts = nl_implementation( + mapping, _sys_map, shifts = nl_implementation( positions=pos, cell=row_vector_cell, pbc=pbc, cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), + system_idx=system_idx, ) # Calculate distances with cell shifts + # (shifts are now shift indices, same as shifts for single system) cell_shifts = torch.mm(shifts, row_vector_cell) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) dds = np.sort(dds.numpy()) @@ -305,7 +310,12 @@ def test_neighbor_list_implementations( @pytest.mark.parametrize("self_interaction", [True, False]) @pytest.mark.parametrize( "nl_implementation", - [neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell], + [ + neighbors.torch_nl_n2, + neighbors.torch_nl_linked_cell, + neighbors.standard_nl, + ] + + ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else []), ) def test_torch_nl_implementations( *, @@ -315,7 +325,11 @@ def test_torch_nl_implementations( molecule_atoms_set: list[Atoms], periodic_atoms_set: list[Atoms], ) -> None: - """Check that torch neighbor list implementations give the same results as ASE.""" + """Check that batched neighbor list implementations give the same results as ASE. + + This tests the native batched implementations (torch_nl_n2, torch_nl_linked_cell) + and the unified implementations (standard_nl, vesin_nl) in batched mode. + """ atoms_list = molecule_atoms_set + periodic_atoms_set # Convert to torch batch (concatenate all tensors) @@ -399,28 +413,19 @@ def test_standard_nl_edge_cases() -> None: pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0 cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE) + system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE) # Test different PBC combinations for pbc in (True, False): - mapping, _shifts = neighbors.standard_nl( + mapping, _sys_map, _shifts = neighbors.standard_nl( positions=pos, cell=cell, pbc=torch.tensor([pbc] * 3, device=DEVICE, dtype=DTYPE), cutoff=cutoff, + system_idx=system_idx, ) assert len(mapping[0]) > 0 # Should find neighbors - # Test sort_id - mapping, _shifts = neighbors.standard_nl( - positions=pos, - cell=cell, - pbc=torch.Tensor([True, True, True]), - cutoff=cutoff, - sort_id=True, - ) - # Check if indices are sorted - assert torch.all(mapping[0][1:] >= mapping[0][:-1]) - @pytest.mark.skipif(not neighbors.VESIN_AVAILABLE, reason="Vesin not available") def test_vesin_nl_edge_cases() -> None: @@ -428,6 +433,7 @@ def test_vesin_nl_edge_cases() -> None: pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0 cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE) + system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE) # Test both implementations for nl_fn in (neighbors.vesin_nl, neighbors.vesin_nl_ts): @@ -436,29 +442,22 @@ def test_vesin_nl_edge_cases() -> None: torch.Tensor([True, True, True]), torch.Tensor([False, False, False]), ): - mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=pbc, cutoff=cutoff) + mapping, _sys_map, _shifts = nl_fn( + positions=pos, cell=cell, pbc=pbc, cutoff=cutoff, system_idx=system_idx + ) assert len(mapping[0]) > 0 # Should find neighbors - # Test sort_id - mapping, _shifts = nl_fn( - positions=pos, - cell=cell, - pbc=torch.Tensor([True, True, True]), - cutoff=cutoff, - sort_id=True, - ) - # Check if indices are sorted - assert torch.all(mapping[0][1:] >= mapping[0][:-1]) - # Test different precisions if nl_fn == neighbors.vesin_nl: # vesin_nl_ts doesn't support float32 pos_f32 = pos.to(dtype=torch.float32) cell_f32 = cell.to(dtype=torch.float32) - mapping, _shifts = nl_fn( + system_idx_f32 = torch.zeros(2, dtype=torch.long, device=DEVICE) + mapping, _sys_map, _shifts = nl_fn( positions=pos_f32, cell=cell_f32, pbc=torch.Tensor([True, True, True]), cutoff=cutoff, + system_idx=system_idx_f32, ) assert len(mapping[0]) > 0 # Should find neighbors @@ -488,23 +487,26 @@ def test_torchsim_nl_consistency() -> None: cell = torch.eye(3, device=device, dtype=dtype) * 3.0 pbc = torch.tensor([False, False, False], device=device) cutoff = torch.tensor(1.5, device=device, dtype=dtype) + system_idx = torch.zeros(4, dtype=torch.long, device=device) # Test torchsim_nl against standard_nl - mapping_torchsim, shifts_torchsim = neighbors.torchsim_nl( - positions, cell, pbc, cutoff + mapping_torchsim, sys_map_ts, shifts_torchsim = neighbors.torchsim_nl( + positions, cell, pbc, cutoff, system_idx ) - mapping_standard, shifts_standard = neighbors.standard_nl( - positions, cell, pbc, cutoff + mapping_standard, sys_map_std, shifts_standard = neighbors.standard_nl( + positions, cell, pbc, cutoff, system_idx ) # torchsim_nl should always give consistent shape with standard_nl assert mapping_torchsim.shape == mapping_standard.shape assert shifts_torchsim.shape == shifts_standard.shape + assert sys_map_ts.shape == sys_map_std.shape # When vesin is unavailable, torchsim_nl should match standard_nl exactly if not neighbors.VESIN_AVAILABLE: torch.testing.assert_close(mapping_torchsim, mapping_standard) torch.testing.assert_close(shifts_torchsim, shifts_standard) + torch.testing.assert_close(sys_map_ts, sys_map_std) @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available for testing") @@ -521,12 +523,16 @@ def test_torchsim_nl_gpu() -> None: cell = torch.eye(3, device=device, dtype=dtype) * 3.0 pbc = torch.tensor([True, True, True], device=device) cutoff = torch.tensor(1.5, device=device, dtype=dtype) + system_idx = torch.zeros(2, dtype=torch.long, device=device) # Should work on GPU regardless of vesin availability - mapping, shifts = neighbors.torchsim_nl(positions, cell, pbc, cutoff) + mapping, sys_map, shifts = neighbors.torchsim_nl( + positions, cell, pbc, cutoff, system_idx + ) assert mapping.device.type == "cuda" assert shifts.device.type == "cuda" + assert sys_map.device.type == "cuda" assert mapping.shape[0] == 2 # (2, num_neighbors) @@ -551,24 +557,26 @@ def test_torchsim_nl_fallback_when_vesin_unavailable( cell = torch.eye(3, device=device, dtype=dtype) * 3.0 pbc = torch.tensor([False, False, False], device=device) cutoff = torch.tensor(1.5, device=device, dtype=dtype) + system_idx = torch.zeros(4, dtype=torch.long, device=device) # Monkeypatch VESIN_AVAILABLE to False to simulate vesin not being installed monkeypatch.setattr(neighbors, "VESIN_AVAILABLE", False) # Call torchsim_nl with mocked unavailable vesin - mapping_torchsim, shifts_torchsim = neighbors.torchsim_nl( - positions, cell, pbc, cutoff + mapping_torchsim, sys_map_ts, shifts_torchsim = neighbors.torchsim_nl( + positions, cell, pbc, cutoff, system_idx ) # Call standard_nl directly for comparison - mapping_standard, shifts_standard = neighbors.standard_nl( - positions, cell, pbc, cutoff + mapping_standard, sys_map_std, shifts_standard = neighbors.standard_nl( + positions, cell, pbc, cutoff, system_idx ) # When VESIN_AVAILABLE is False, torchsim_nl should use standard_nl # and produce identical results torch.testing.assert_close(mapping_torchsim, mapping_standard) torch.testing.assert_close(shifts_torchsim, shifts_standard) + torch.testing.assert_close(sys_map_ts, sys_map_std) def test_strict_nl_edge_cases() -> None: @@ -624,7 +632,10 @@ def test_neighbor_lists_time_and_memory() -> None: ] if neighbors.VESIN_AVAILABLE: nl_implementations.extend( - [neighbors.vesin_nl_ts, cast("Callable[..., Any]", neighbors.vesin_nl)] + [ + neighbors.vesin_nl_ts, + cast("Callable[..., Any]", neighbors.vesin_nl), + ] ) for nl_fn in nl_implementations: @@ -639,25 +650,18 @@ def test_neighbor_lists_time_and_memory() -> None: # Time the execution start_time = time.perf_counter() - if nl_fn in (neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell): - system_idx = torch.zeros(n_atoms, dtype=torch.long, device=DEVICE) - # Fix pbc tensor shape - pbc = torch.tensor([[True, True, True]], device=DEVICE) - _mapping, _mapping_system, _shifts_idx = nl_fn( - positions=pos, - cell=cell, - pbc=pbc, - cutoff=cutoff, - system_idx=system_idx, - self_interaction=False, - ) - else: - _mapping, _shifts = nl_fn( - positions=pos, - cell=cell, - pbc=torch.Tensor([True, True, True]), - cutoff=cutoff, - ) + # All neighbor list functions now use the unified API with system_idx + system_idx = torch.zeros(n_atoms, dtype=torch.long, device=DEVICE) + # Fix pbc tensor shape + pbc = torch.tensor([[True, True, True]], device=DEVICE) + _mapping, _mapping_system, _shifts_idx = nl_fn( + positions=pos, + cell=cell, + pbc=pbc, + cutoff=cutoff, + system_idx=system_idx, + self_interaction=False, + ) end_time = time.perf_counter() execution_time = end_time - start_time diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py deleted file mode 100644 index eb90ec92..00000000 --- a/torch_sim/neighbors.py +++ /dev/null @@ -1,880 +0,0 @@ -"""Utilities for neighbor list calculations.""" - -import torch - - -# Make vesin optional - fall back to pure PyTorch implementation if unavailable -try: - from vesin import NeighborList as VesinNeighborList - from vesin.torch import NeighborList as VesinNeighborListTorch - - VESIN_AVAILABLE = True -except ImportError: - VESIN_AVAILABLE = False - VesinNeighborList = None - VesinNeighborListTorch = None - -import torch_sim.math as fm -from torch_sim import transforms - - -@torch.jit.script -def primitive_neighbor_list( # noqa: C901, PLR0915 - quantities: str, - pbc: torch.Tensor, - cell: torch.Tensor, - positions: torch.Tensor, - cutoff: torch.Tensor, - device: torch.device, - dtype: torch.dtype, - self_interaction: bool = False, # noqa: FBT001, FBT002 - use_scaled_positions: bool = False, # noqa: FBT001, FBT002 - max_n_bins: int = int(1e6), -) -> list[torch.Tensor]: - """Compute a neighbor list for an atomic configuration. - - ASE periodic neighbor list implementation - Atoms outside periodic boundaries are mapped into the unit cell. Atoms - outside non-periodic boundaries are included in the neighbor list - but complexity of neighbor list search for those can become n^2. - The neighbor list is sorted by first atom index 'i', but not by second - atom index 'j'. - - Args: - quantities: Quantities to compute by the neighbor list algorithm. Each character - in this string defines a quantity. They are returned in a tuple of - the same order. Possible quantities are - * 'i' : first atom index - * 'j' : second atom index - * 'd' : absolute distance - * 'D' : distance vector - * 'S' : shift vector (number of cell boundaries crossed by the bond - between atom i and j). With the shift vector S, the - distances D between atoms can be computed from: - D = positions[j]-positions[i]+S.dot(cell) - pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in - each axis. - cell: Unit cell vectors according to the row vector convention, i.e. - `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - positions: Atomic positions. Anything that can be converted to an ndarray of - shape (n, 3) will do: [(x1,y1,z1), (x2,y2,z2), ...]. If - use_scaled_positions is set to true, this must be scaled positions. - cutoff: Cutoff for neighbor search. It can be: - * A single float: This is a global cutoff for all elements. - * A dictionary: This specifies cutoff values for element - pairs. Specification accepts element numbers of symbols. - Example: {(1, 6): 1.1, (1, 1): 1.0, ('C', 'C'): 1.85} - * A list/array with a per atom value: This specifies the radius of - an atomic sphere for each atoms. If spheres overlap, atoms are - within each others neighborhood. - See :func:`~ase.neighborlist.natural_cutoffs` - for an example on how to get such a list. - device: PyTorch device to use for computations - dtype: PyTorch data type to use - self_interaction: Return the atom itself as its own neighbor if set to true. - Default: False - use_scaled_positions: If set to true, positions are expected to be - scaled positions. - max_n_bins: Maximum number of bins used in neighbor search. This is used to limit - the maximum amount of memory required by the neighbor list. - - Returns: - list[torch.Tensor]: One tensor for each item in `quantities`. Indices in `i` - are returned in ascending order 0..len(a)-1, but the order of (i,j) - pairs is not guaranteed. - - References: - - This code is modified version of the github gist - https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 - """ - # Naming conventions: Suffixes indicate the dimension of an array. The - # following convention is used here: - # c: Cartesian index, can have values 0, 1, 2 - # i: Global atom index, can have values 0..len(a)-1 - # xyz: Bin index, three values identifying x-, y- and z-component of a - # spatial bin that is used to make neighbor search O(n) - # b: Linearized version of the 'xyz' bin index - # a: Bin-local atom index, i.e. index identifying an atom *within* a - # bin - # p: Pair index, can have value 0 or 1 - # n: (Linear) neighbor index - - if len(positions) == 0: - raise RuntimeError("No atoms provided") - - # Compute reciprocal lattice vectors. - recip_cell = torch.linalg.pinv(cell).T - b1_c, b2_c, b3_c = recip_cell[0], recip_cell[1], recip_cell[2] - - # Compute distances of cell faces. - l1 = torch.linalg.norm(b1_c) - l2 = torch.linalg.norm(b2_c) - l3 = torch.linalg.norm(b3_c) - pytorch_scalar_1 = torch.as_tensor(1.0, device=device, dtype=dtype) - face_dist_c = torch.hstack( - [ - 1 / l1 if l1 > 0 else pytorch_scalar_1, - 1 / l2 if l2 > 0 else pytorch_scalar_1, - 1 / l3 if l3 > 0 else pytorch_scalar_1, - ] - ) - if face_dist_c.shape != (3,): - raise ValueError(f"face_dist_c.shape={face_dist_c.shape} != (3,)") - - # we don't handle other fancier cutoffs - max_cutoff: torch.Tensor = cutoff - - # We use a minimum bin size of 3 A - bin_size = torch.maximum(max_cutoff, torch.tensor(3.0, device=device, dtype=dtype)) - # Compute number of bins such that a sphere of radius cutoff fits into - # eight neighboring bins. - n_bins_c = torch.maximum( - (face_dist_c / bin_size).to(dtype=torch.long, device=device), - torch.ones(3, dtype=torch.long, device=device), - ) - n_bins = torch.prod(n_bins_c) - # Make sure we limit the amount of memory used by the explicit bins. - while n_bins > max_n_bins: - n_bins_c = torch.maximum( - n_bins_c // 2, torch.ones(3, dtype=torch.long, device=device) - ) - n_bins = torch.prod(n_bins_c) - - # Compute over how many bins we need to loop in the neighbor list search. - neigh_search = torch.ceil(bin_size * n_bins_c / face_dist_c).to( - dtype=torch.long, device=device - ) - neigh_search_x, neigh_search_y, neigh_search_z = ( - neigh_search[0], - neigh_search[1], - neigh_search[2], - ) - - # If we only have a single bin and the system is not periodic, then we - # do not need to search neighboring bins - pytorch_scalar_int_0 = torch.as_tensor(0, dtype=torch.long, device=device) - neigh_search_x = ( - pytorch_scalar_int_0 if n_bins_c[0] == 1 and not pbc[0] else neigh_search_x - ) - neigh_search_y = ( - pytorch_scalar_int_0 if n_bins_c[1] == 1 and not pbc[1] else neigh_search_y - ) - neigh_search_z = ( - pytorch_scalar_int_0 if n_bins_c[2] == 1 and not pbc[2] else neigh_search_z - ) - - # Sort atoms into bins. - if not any(pbc): - scaled_positions_ic = positions - elif use_scaled_positions: - scaled_positions_ic = positions - positions = torch.dot(scaled_positions_ic, cell) - else: - scaled_positions_ic = torch.linalg.solve(cell.T, positions.T).T - - bin_index_ic = torch.floor(scaled_positions_ic * n_bins_c).to( - dtype=torch.long, device=device - ) - cell_shift_ic = torch.zeros_like(bin_index_ic, device=device) - - for c in range(3): - if pbc[c]: - # (Note: torch.divmod does not exist in older numpy versions) - cell_shift_ic[:, c], bin_index_ic[:, c] = fm.torch_divmod( - bin_index_ic[:, c], n_bins_c[c] - ) - else: - bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1) - - # Convert Cartesian bin index to unique scalar bin index. - bin_index_i = bin_index_ic[:, 0] + n_bins_c[0] * ( - bin_index_ic[:, 1] + n_bins_c[1] * bin_index_ic[:, 2] - ) - - # atom_i contains atom index in new sort order. - atom_i = torch.argsort(bin_index_i) - bin_index_i = bin_index_i[atom_i] - - # Find max number of atoms per bin - max_n_atoms_per_bin = torch.bincount(bin_index_i).max() - - # Sort atoms into bins: atoms_in_bin_ba contains for each bin (identified - # by its scalar bin index) a list of atoms inside that bin. This list is - # homogeneous, i.e. has the same size *max_n_atoms_per_bin* for all bins. - # The list is padded with -1 values. - atoms_in_bin_ba = -torch.ones( - n_bins.item(), max_n_atoms_per_bin.item(), dtype=torch.long, device=device - ) - for bin_cnt in range(int(max_n_atoms_per_bin.item())): - # Create a mask array that identifies the first atom of each bin. - mask = torch.cat( - ( - torch.ones(1, dtype=torch.bool, device=device), - bin_index_i[:-1] != bin_index_i[1:], - ), - dim=0, - ) - # Assign all first atoms. - atoms_in_bin_ba[bin_index_i[mask], bin_cnt] = atom_i[mask] - - # Remove atoms that we just sorted into atoms_in_bin_ba. The next - # "first" atom will be the second and so on. - mask = torch.logical_not(mask) - atom_i = atom_i[mask] - bin_index_i = bin_index_i[mask] - - # Make sure that all atoms have been sorted into bins. - if len(atom_i) != 0: - raise ValueError(f"len(atom_i)={len(atom_i)} != 0") - if len(bin_index_i) != 0: - raise ValueError(f"len(bin_index_i)={len(bin_index_i)} != 0") - - # Now we construct neighbor pairs by pairing up all atoms within a bin or - # between bin and neighboring bin. atom_pairs_pn is a helper buffer that - # contains all potential pairs of atoms between two bins, i.e. it is a list - # of length max_n_atoms_per_bin**2. - # atom_pairs_pn_np = np.indices( - # (max_n_atoms_per_bin, max_n_atoms_per_bin), dtype=int - # ).reshape(2, -1) - atom_pairs_pn = torch.cartesian_prod( - torch.arange(max_n_atoms_per_bin, device=device), - torch.arange(max_n_atoms_per_bin, device=device), - ) - atom_pairs_pn = atom_pairs_pn.T.reshape(2, -1) - - # Initialized empty neighbor list buffers. - first_at_neigh_tuple_nn = [] - second_at_neigh_tuple_nn = [] - cell_shift_vector_x_n = [] - cell_shift_vector_y_n = [] - cell_shift_vector_z_n = [] - - # This is the main neighbor list search. We loop over neighboring bins and - # then construct all possible pairs of atoms between two bins, assuming - # that each bin contains exactly max_n_atoms_per_bin atoms. We then throw - # out pairs involving pad atoms with atom index -1 below. - binz_xyz, biny_xyz, binx_xyz = torch.meshgrid( - torch.arange(n_bins_c[2], device=device), - torch.arange(n_bins_c[1], device=device), - torch.arange(n_bins_c[0], device=device), - indexing="ij", - ) - # The memory layout of binx_xyz, biny_xyz, binz_xyz is such that computing - # the respective bin index leads to a linearly increasing consecutive list. - # The following assert statement succeeds: - # b_b = (binx_xyz + n_bins_c[0] * (biny_xyz + n_bins_c[1] * - # binz_xyz)).ravel() - # assert (b_b == torch.arange(torch.prod(n_bins_c))).all() - - # First atoms in pair. - _first_at_neigh_tuple_n = atoms_in_bin_ba[:, atom_pairs_pn[0]] - for dz in range(-int(neigh_search_z.item()), int(neigh_search_z.item()) + 1): - for dy in range(-int(neigh_search_y.item()), int(neigh_search_y.item()) + 1): - for dx in range(-int(neigh_search_x.item()), int(neigh_search_x.item()) + 1): - # Bin index of neighboring bin and shift vector. - shiftx_xyz, neighbinx_xyz = fm.torch_divmod(binx_xyz + dx, n_bins_c[0]) - shifty_xyz, neighbiny_xyz = fm.torch_divmod(biny_xyz + dy, n_bins_c[1]) - shiftz_xyz, neighbinz_xyz = fm.torch_divmod(binz_xyz + dz, n_bins_c[2]) - neighbin_b = ( - neighbinx_xyz - + n_bins_c[0] * (neighbiny_xyz + n_bins_c[1] * neighbinz_xyz) - ).ravel() - - # Second atom in pair. - _second_at_neigh_tuple_n = atoms_in_bin_ba[neighbin_b][ - :, atom_pairs_pn[1] - ] - - # Shift vectors. - # TODO: was np.resize: - # _cell_shift_vector_x_n_np = np.resize( - # shiftx_xyz.reshape(-1, 1).numpy(), - # (int(max_n_atoms_per_bin.item() ** 2), shiftx_xyz.numel()), - # ).T - # _cell_shift_vector_y_n_np = np.resize( - # shifty_xyz.reshape(-1, 1).numpy(), - # (int(max_n_atoms_per_bin.item() ** 2), shifty_xyz.numel()), - # ).T - # _cell_shift_vector_z_n_np = np.resize( - # shiftz_xyz.reshape(-1, 1).numpy(), - # (int(max_n_atoms_per_bin.item() ** 2), shiftz_xyz.numel()), - # ).T - # this basically just tiles shiftx_xyz.reshape(-1, 1) n times - _cell_shift_vector_x_n = shiftx_xyz.reshape(-1, 1).repeat( - (1, int(max_n_atoms_per_bin.item() ** 2)) - ) - # assert _cell_shift_vector_x_n.shape == _cell_shift_vector_x_n_np.shape - # assert np.allclose( - # _cell_shift_vector_x_n.numpy(), _cell_shift_vector_x_n_np - # ) - _cell_shift_vector_y_n = shifty_xyz.reshape(-1, 1).repeat( - (1, int(max_n_atoms_per_bin.item() ** 2)) - ) - # assert _cell_shift_vector_y_n.shape == _cell_shift_vector_y_n_np.shape - # assert np.allclose( - # _cell_shift_vector_y_n.numpy(), _cell_shift_vector_y_n_np - # ) - _cell_shift_vector_z_n = shiftz_xyz.reshape(-1, 1).repeat( - (1, int(max_n_atoms_per_bin.item() ** 2)) - ) - # assert _cell_shift_vector_z_n.shape == _cell_shift_vector_z_n_np.shape - # assert np.allclose( - # _cell_shift_vector_z_n.numpy(), _cell_shift_vector_z_n_np - # ) - - # We have created too many pairs because we assumed each bin - # has exactly max_n_atoms_per_bin atoms. Remove all superfluous - # pairs. Those are pairs that involve an atom with index -1. - mask = torch.logical_and( - _first_at_neigh_tuple_n != -1, _second_at_neigh_tuple_n != -1 - ) - if mask.sum() > 0: - first_at_neigh_tuple_nn += [_first_at_neigh_tuple_n[mask]] - second_at_neigh_tuple_nn += [_second_at_neigh_tuple_n[mask]] - cell_shift_vector_x_n += [_cell_shift_vector_x_n[mask]] - cell_shift_vector_y_n += [_cell_shift_vector_y_n[mask]] - cell_shift_vector_z_n += [_cell_shift_vector_z_n[mask]] - - # Flatten overall neighbor list. - first_at_neigh_tuple_n = torch.cat(first_at_neigh_tuple_nn) - second_at_neigh_tuple_n = torch.cat(second_at_neigh_tuple_nn) - cell_shift_vector_n = torch.vstack( - [ - torch.cat(cell_shift_vector_x_n), - torch.cat(cell_shift_vector_y_n), - torch.cat(cell_shift_vector_z_n), - ] - ).T - - # Add global cell shift to shift vectors - cell_shift_vector_n += ( - cell_shift_ic[first_at_neigh_tuple_n] - cell_shift_ic[second_at_neigh_tuple_n] - ) - - # Remove all self-pairs that do not cross the cell boundary. - if not self_interaction: - m = torch.logical_not( - torch.logical_and( - first_at_neigh_tuple_n == second_at_neigh_tuple_n, - (cell_shift_vector_n == 0).all(dim=1), - ) - ) - first_at_neigh_tuple_n = first_at_neigh_tuple_n[m] - second_at_neigh_tuple_n = second_at_neigh_tuple_n[m] - cell_shift_vector_n = cell_shift_vector_n[m] - - # For non-periodic directions, remove any bonds that cross the domain - # boundary. - for c in range(3): - if not pbc[c]: - m = cell_shift_vector_n[:, c] == 0 - first_at_neigh_tuple_n = first_at_neigh_tuple_n[m] - second_at_neigh_tuple_n = second_at_neigh_tuple_n[m] - cell_shift_vector_n = cell_shift_vector_n[m] - - # Sort neighbor list. - bin_cnt = torch.argsort(first_at_neigh_tuple_n) - first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt] - second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt] - cell_shift_vector_n = cell_shift_vector_n[bin_cnt] - - # Compute distance vectors. - # TODO: Use .T? - distance_vector_nc = ( - positions[second_at_neigh_tuple_n] - - positions[first_at_neigh_tuple_n] - + cell_shift_vector_n.to(cell.dtype).matmul(cell) - ) - abs_distance_vector_n = torch.sqrt( - torch.sum(distance_vector_nc * distance_vector_nc, dim=1) - ) - - # We have still created too many pairs. Only keep those with distance - # smaller than max_cutoff. - mask = abs_distance_vector_n < max_cutoff - first_at_neigh_tuple_n = first_at_neigh_tuple_n[mask] - second_at_neigh_tuple_n = second_at_neigh_tuple_n[mask] - cell_shift_vector_n = cell_shift_vector_n[mask] - distance_vector_nc = distance_vector_nc[mask] - abs_distance_vector_n = abs_distance_vector_n[mask] - - # Assemble return tuple. - ret_vals = [] - for quant in quantities: - if quant == "i": - ret_vals += [first_at_neigh_tuple_n] - elif quant == "j": - ret_vals += [second_at_neigh_tuple_n] - elif quant == "D": - ret_vals += [distance_vector_nc] - elif quant == "d": - ret_vals += [abs_distance_vector_n] - elif quant == "S": - ret_vals += [cell_shift_vector_n] - else: - raise ValueError("Unsupported quantity specified.") - - return ret_vals - - -@torch.jit.script -def standard_nl( - positions: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - cutoff: torch.Tensor, - sort_id: bool = False, # noqa: FBT001, FBT002 -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute neighbor lists using primitive neighbor list algorithm. - - This function provides a standardized interface for computing neighbor lists - in atomic systems, wrapping the more general primitive_neighbor_list implementation. - It handles both periodic and non-periodic boundary conditions and returns - neighbor pairs along with their periodic shifts. - - The function follows ASE's neighbor list conventions (see ASE: - https://gitlab.com/ase/ase/-/blob/master/ase/neighborlist.py?ref_type=heads#L152 - but provides a simplified interface focused on the most common use case of - getting neighbor pairs and shifts. - - Key Features: - - Handles both periodic and non-periodic systems - - Returns both neighbor indices and shift vectors for periodic systems - - Optional sorting of neighbors by first index for better memory access patterns - - Fully compatible with PyTorch's automatic differentiation - - Args: - positions: Atomic positions tensor of shape (num_atoms, 3) - cell: Unit cell vectors according to the row vector convention, i.e. - `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in - each axis. - cutoff: Maximum distance for considering atoms as neighbors - sort_id: If True, sort neighbors by first atom index for better memory - access patterns - - Returns: - tuple containing: - - mapping: Tensor of shape (2, num_neighbors) containing pairs of - atom indices that are neighbors. Each column (i,j) represents a - neighbor pair. - - shifts: Tensor of shape (num_neighbors, 3) containing the periodic - shift vectors needed to get the correct periodic image for each - neighbor pair. - - Example: - >>> # Get neighbors for a periodic system - >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) - >>> cell = torch.eye(3) * 10.0 - >>> mapping, shifts = standard_nl(positions, cell, True, 1.5) - >>> print(mapping) # Shows pairs of neighboring atoms - >>> print(shifts) # Shows corresponding periodic shifts - - Notes: - - The function uses primitive_neighbor_list internally but provides a simpler - interface - - For non-periodic systems, shifts will be zero vectors - - The neighbor list includes both (i,j) and (j,i) pairs for complete force - computation - - Memory usage scales with system size and number of neighbors per atom - - References: - - https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 - """ - device = positions.device - dtype = positions.dtype - i, j, S = primitive_neighbor_list( - quantities="ijS", - positions=positions, - cell=cell, - pbc=pbc, - cutoff=cutoff, - device=device, - dtype=dtype, - self_interaction=False, - use_scaled_positions=False, - max_n_bins=torch.tensor(1e6, dtype=torch.int64, device=device), - ) - - mapping = torch.stack((i, j), dim=0) - mapping = mapping.to(dtype=torch.long) - shifts = S.to(dtype=dtype) - - if sort_id: - idx = torch.argsort(mapping[0]) - mapping = mapping[:, idx] - shifts = shifts[idx, :] - - return mapping, shifts - - -if VESIN_AVAILABLE: - - @torch.jit.script - def vesin_nl_ts( - positions: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - cutoff: torch.Tensor, - sort_id: bool = False, # noqa: FBT001, FBT002 - ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute neighbor lists using TorchScript-compatible Vesin. - - This function provides a TorchScript-compatible interface to the Vesin - neighbor list algorithm using VesinNeighborListTorch. It handles both - periodic and non-periodic systems and returns neighbor pairs along with - their periodic shifts. - - Args: - positions: Atomic positions tensor of shape (num_atoms, 3) - cell: Unit cell vectors according to the row vector convention, i.e. - `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in - each axis. - cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors - sort_id: If True, sort neighbors by first atom index for better memory - access patterns - - Returns: - tuple containing: - - mapping: Tensor of shape (2, num_neighbors) containing pairs of - atom indices that are neighbors. Each column (i,j) represents a - neighbor pair. - - shifts: Tensor of shape (num_neighbors, 3) containing the periodic - shift vectors needed to get the correct periodic image for each - neighbor pair. - - Notes: - - Uses VesinNeighborListTorch for TorchScript compatibility - - Requires CPU tensors in float64 precision internally - - Returns tensors on the same device as input with original precision - - For non-periodic systems, shifts will be zero vectors - - The neighbor list includes both (i,j) and (j,i) pairs - - References: - https://github.com/Luthaf/vesin - """ - device = positions.device - dtype = positions.dtype - - neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) - - # Convert tensors to CPU and float64 properly - positions_cpu = positions.cpu().to(dtype=torch.float64) - cell_cpu = cell.cpu().to(dtype=torch.float64) - periodic_cpu = pbc.to(dtype=torch.bool).cpu() - - # Only works on CPU and requires float64 - i, j, S = neighbor_list_fn.compute( - points=positions_cpu, - box=cell_cpu, - periodic=periodic_cpu, - quantities="ijS", - ) - - mapping = torch.stack((i, j), dim=0) - mapping = mapping.to(dtype=torch.long, device=device) - shifts = S.to(dtype=dtype, device=device) - - if sort_id: - idx = torch.argsort(mapping[0]) - mapping = mapping[:, idx] - shifts = shifts[idx, :] - - return mapping, shifts - - def vesin_nl( - positions: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - cutoff: float | torch.Tensor, - sort_id: bool = False, # noqa: FBT001, FBT002 - ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute neighbor lists using the standard Vesin implementation. - - This function provides an interface to the standard Vesin neighbor list - algorithm using VesinNeighborList. It handles both periodic and non-periodic - systems and returns neighbor pairs along with their periodic shifts. - - Args: - positions: Atomic positions tensor of shape (num_atoms, 3) - cell: Unit cell vectors according to the row vector convention, i.e. - `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in - each axis. - cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors - sort_id: If True, sort neighbors by first atom index for better memory - access patterns - - Returns: - tuple containing: - - mapping: Tensor of shape (2, num_neighbors) containing pairs of - atom indices that are neighbors. Each column (i,j) represents a - neighbor pair. - - shifts: Tensor of shape (num_neighbors, 3) containing the periodic - shift vectors needed to get the correct periodic image for each - neighbor pair. - - Notes: - - Uses standard VesinNeighborList implementation - - Requires CPU tensors in float64 precision internally - - Returns tensors on the same device as input with original precision - - For non-periodic systems (pbc=False), shifts will be zero vectors - - The neighbor list includes both (i,j) and (j,i) pairs - - Supports pre-sorting through the VesinNeighborList constructor - - References: - - https://github.com/Luthaf/vesin - """ - device = positions.device - dtype = positions.dtype - - neighbor_list_fn = VesinNeighborList( - (float(cutoff)), full_list=True, sorted=sort_id - ) - - # Convert tensors to CPU and float64 without gradients - positions_cpu = positions.detach().cpu().to(dtype=torch.float64) - cell_cpu = cell.detach().cpu().to(dtype=torch.float64) - periodic_cpu = pbc.detach().to(dtype=torch.bool).cpu() - - # Only works on CPU and returns numpy arrays - i, j, S = neighbor_list_fn.compute( - points=positions_cpu, - box=cell_cpu, - periodic=periodic_cpu, - quantities="ijS", - ) - i, j = ( - torch.tensor(i, dtype=torch.long, device=device), - torch.tensor(j, dtype=torch.long, device=device), - ) - mapping = torch.stack((i, j), dim=0) - shifts = torch.tensor(S, dtype=dtype, device=device) - - return mapping, shifts - - -def torchsim_nl( - positions: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - cutoff: torch.Tensor, - sort_id: bool = False, # noqa: FBT001, FBT002 -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute neighbor lists with automatic fallback for AMD ROCm compatibility. - - This function automatically selects the best available neighbor list implementation. - When vesin is available, it uses vesin_nl_ts for optimal performance. When vesin - is not available (e.g., on AMD ROCm systems), it falls back to standard_nl. - - Args: - positions: Atomic positions tensor of shape (num_atoms, 3) - cell: Unit cell vectors according to the row vector convention, i.e. - `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in - each axis. - cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors - sort_id: If True, sort neighbors by first atom index for better memory - access patterns - - Returns: - tuple containing: - - mapping: Tensor of shape (2, num_neighbors) containing pairs of - atom indices that are neighbors. Each column (i,j) represents a - neighbor pair. - - shifts: Tensor of shape (num_neighbors, 3) containing the periodic - shift vectors needed to get the correct periodic image for each - neighbor pair. - - Notes: - - Automatically uses vesin_nl_ts when vesin is available - - Falls back to standard_nl when vesin is unavailable (AMD ROCm) - - Fallback works on NVIDIA CUDA, AMD ROCm, and CPU - - For non-periodic systems (pbc=False), shifts will be zero vectors - - The neighbor list includes both (i,j) and (j,i) pairs - """ - if not VESIN_AVAILABLE: - return standard_nl(positions, cell, pbc, cutoff, sort_id) - - return vesin_nl_ts(positions, cell, pbc, cutoff, sort_id) - - -def strict_nl( - cutoff: float, - positions: torch.Tensor, - cell: torch.Tensor, - mapping: torch.Tensor, - system_mapping: torch.Tensor, - shifts_idx: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Apply a strict cutoff to the neighbor list defined in the mapping. - - This function filters the neighbor list based on a specified cutoff distance. - It computes the squared distances between pairs of positions and retains only - those pairs that are within the cutoff distance. The function also accounts - for periodic boundary conditions by applying cell shifts when necessary. - - Args: - cutoff (float): - The maximum distance for considering two atoms as neighbors. This value - is used to filter the neighbor pairs based on their distances. - positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing - the positions of the atoms. - cell (torch.Tensor): Unit cell vectors according to the row vector convention, - i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - mapping (torch.Tensor): - A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` - for which to compute distances. - system_mapping (torch.Tensor): - A tensor that maps the shifts to the corresponding cells, used in conjunction - with `shifts_idx` to compute the correct periodic shifts. - shifts_idx (torch.Tensor): - A tensor of shape (n_shifts, 3) representing the indices for shifts to apply - to the distances for periodic boundary conditions. - - Returns: - tuple: - A tuple containing: - - mapping (torch.Tensor): A filtered tensor of shape (2, n_filtered_pairs) - with pairs of indices that are within the cutoff distance. - - mapping_system (torch.Tensor): A tensor of shape (n_filtered_pairs,) - that maps the filtered pairs to their corresponding systems. - - shifts_idx (torch.Tensor): A tensor of shape (n_filtered_pairs, 3) - containing the periodic shift indices for the filtered pairs. - - Notes: - - The function computes the squared distances to avoid the computational cost - of taking square roots, which is unnecessary for comparison. - - If no cell shifts are needed (i.e., for non-periodic systems), the function - directly computes the squared distances between the positions. - - References: - - https://github.com/felixmusil/torch_nl - """ - cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping) - if cell_shifts is None: - d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) - else: - d2 = ( - (positions[mapping[0]] - positions[mapping[1]] - cell_shifts) - .square() - .sum(dim=1) - ) - - mask = d2 < cutoff * cutoff - mapping = mapping[:, mask] - mapping_system = system_mapping[mask] - shifts_idx = shifts_idx[mask] - return mapping, mapping_system, shifts_idx - - -@torch.jit.script -def torch_nl_n2( - positions: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - cutoff: torch.Tensor, - system_idx: torch.Tensor, - self_interaction: bool = False, # noqa: FBT001, FBT002 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute the neighbor list for a set of atomic structures using a - naive neighbor search before applying a strict `cutoff`. - The atomic positions `pos` should be wrapped inside their respective unit cells. - - Args: - cutoff (float): - The cutoff radius used for the neighbor search. - positions (torch.Tensor [n_atom, 3]): A tensor containing the positions - of atoms wrapped inside their respective unit cells. - cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to - the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc (torch.Tensor [n_structure, 3] bool): - A tensor indicating the periodic boundary conditions to apply. - Partial PBC are not supported yet. - system_idx (torch.Tensor [n_atom,] torch.long): - A tensor containing the index of the structure to which each atom belongs. - self_interaction (bool, optional): - A flag to indicate whether to keep the center atoms as their own neighbors. - Default is False. - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mapping (torch.Tensor [2, n_neighbors]): - A tensor containing the indices of the neighbor list for the given - positions array. `mapping[0]` corresponds to the central atom indices, - and `mapping[1]` corresponds to the neighbor atom indices. - system_mapping (torch.Tensor [n_neighbors]): - A tensor mapping the neighbor atoms to their respective structures. - shifts_idx (torch.Tensor [n_neighbors, 3]): - A tensor containing the cell shift indices used to reconstruct the - neighbor atom positions. - - References: - - https://github.com/felixmusil/torch_nl - """ - n_atoms = torch.bincount(system_idx) - mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( - positions, cell, pbc, cutoff.item(), n_atoms, self_interaction - ) - mapping, mapping_system, shifts_idx = strict_nl( - cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx - ) - return mapping, mapping_system, shifts_idx - - -@torch.jit.script -def torch_nl_linked_cell( - positions: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - cutoff: torch.Tensor, - system_idx: torch.Tensor, - self_interaction: bool = False, # noqa: FBT001, FBT002 (*, not compatible with torch.jit.script) -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute the neighbor list for a set of atomic structures using the linked - cell algorithm before applying a strict `cutoff`. The atoms positions `pos` - should be wrapped inside their respective unit cells. - - Args: - cutoff (float): - The cutoff radius used for the neighbor search. - positions (torch.Tensor [n_atom, 3]): - A tensor containing the positions of atoms wrapped inside - their respective unit cells. - cell (torch.Tensor [3*n_systems, 3]): Unit cell vectors according to - the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc (torch.Tensor [n_systems, 3] bool): - A tensor indicating the periodic boundary conditions to apply. - Partial PBC are not supported yet. - system_idx (torch.Tensor [n_atom,] torch.long): - A tensor containing the index of the structure to which each atom belongs. - self_interaction (bool, optional): - A flag to indicate whether to keep the center atoms as their own neighbors. - Default is False. - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - A tuple containing: - - mapping (torch.Tensor [2, n_neighbors]): - A tensor containing the indices of the neighbor list for the given - positions array. `mapping[0]` corresponds to the central atom - indices, and `mapping[1]` corresponds to the neighbor atom indices. - - system_mapping (torch.Tensor [n_neighbors]): - A tensor mapping the neighbor atoms to their respective structures. - - shifts_idx (torch.Tensor [n_neighbors, 3]): - A tensor containing the cell shift indices used to reconstruct the - neighbor atom positions. - - References: - - https://github.com/felixmusil/torch_nl - """ - n_atoms = torch.bincount(system_idx) - mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( - positions, cell, pbc, cutoff.item(), n_atoms, self_interaction - ) - - mapping, mapping_system, shifts_idx = strict_nl( - cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx - ) - return mapping, mapping_system, shifts_idx diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py new file mode 100644 index 00000000..d41aff84 --- /dev/null +++ b/torch_sim/neighbors/__init__.py @@ -0,0 +1,117 @@ +"""Neighbor list implementations for torch-sim. + +This module provides multiple neighbor list implementations with automatic +fallback based on available dependencies. The API supports both single-system +and batched (multi-system) calculations. + +Available Implementations: + - Primitive: Pure PyTorch implementation (always available) + - Vesin: High-performance neighbor lists (optional, requires vesin package) + - Batched: Optimized for multiple systems (torch_nl_n2, torch_nl_linked_cell) + +Default Neighbor Lists: + The module automatically selects the best available implementation: + - For single systems: vesin_nl (if available) or standard_nl (fallback) + - For batched systems: torch_nl_linked_cell (always available) +""" + +import torch + +from torch_sim.neighbors.standard import primitive_neighbor_list, standard_nl +from torch_sim.neighbors.torch_nl import strict_nl, torch_nl_linked_cell, torch_nl_n2 + + +# Try to import Vesin implementations +try: + from torch_sim.neighbors.vesin import ( + VESIN_AVAILABLE, + VesinNeighborList, + VesinNeighborListTorch, + vesin_nl, + vesin_nl_ts, + ) +except ImportError: + VESIN_AVAILABLE = False + VesinNeighborList = None # type: ignore[assignment,misc] + VesinNeighborListTorch = None # type: ignore[assignment,misc] + vesin_nl = None # type: ignore[assignment] + vesin_nl_ts = None # type: ignore[assignment] + +# Set default neighbor list based on what's available +if VESIN_AVAILABLE: + default_nl = vesin_nl + default_nl_ts = vesin_nl_ts +else: + default_nl = standard_nl + default_nl_ts = standard_nl + +# For batched calculations, always use linked cell as default +default_batched_nl = torch_nl_linked_cell + + +def torchsim_nl( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute neighbor lists with automatic fallback for AMD ROCm compatibility. + + This function automatically selects the best available neighbor list implementation. + When vesin is available, it uses vesin_nl_ts for optimal performance. When vesin + is not available (e.g., on AMD ROCm systems), it falls back to standard_nl. + + Args: + positions: Atomic positions tensor [n_atoms, 3] + cell: Unit cell vectors [3*n_systems, 3] (row vector convention) + pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors + system_idx: Tensor [n_atoms] indicating which system each atom belongs to. + For single system, use torch.zeros(n_atoms, dtype=torch.long) + self_interaction: If True, include self-pairs. Default: False + + Returns: + tuple containing: + - mapping: Tensor [2, num_neighbors] - pairs of atom indices + - system_mapping: Tensor [num_neighbors] - system assignment for each pair + - shifts_idx: Tensor [num_neighbors, 3] - periodic shift indices + + Notes: + - Automatically uses vesin_nl_ts when vesin is available + - Falls back to standard_nl when vesin is unavailable (AMD ROCm) + - Fallback works on NVIDIA CUDA, AMD ROCm, and CPU + - For non-periodic systems (pbc=False), shifts will be zero vectors + - The neighbor list includes both (i,j) and (j,i) pairs + """ + if not VESIN_AVAILABLE: + return torch_nl_linked_cell( + positions, cell, pbc, cutoff, system_idx, self_interaction + ) + + return vesin_nl_ts(positions, cell, pbc, cutoff, system_idx, self_interaction) + + +__all__ = [ + # Availability + "VESIN_AVAILABLE", + "VesinNeighborList", + "VesinNeighborListTorch", + # Defaults + "default_batched_nl", + "default_nl", + "default_nl_ts", + # Core implementations + "primitive_neighbor_list", + "standard_nl", + # Utilities + "strict_nl", + # Batched implementations + "torch_nl_linked_cell", + "torch_nl_n2", + "torchsim_nl", + # Vesin implementations + "vesin_nl", + "vesin_nl_ts", +] diff --git a/torch_sim/neighbors/standard.py b/torch_sim/neighbors/standard.py new file mode 100644 index 00000000..574d0387 --- /dev/null +++ b/torch_sim/neighbors/standard.py @@ -0,0 +1,547 @@ +"""Pure PyTorch neighbor list implementation. + +This module provides a native PyTorch implementation of neighbor list calculation +that works on any device (CPU, CUDA, ROCm) without external dependencies. +""" + +import torch + +import torch_sim.math as fm + + +@torch.jit.script +def primitive_neighbor_list( # noqa: C901, PLR0915 + quantities: str, + pbc: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + cutoff: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + self_interaction: bool = False, # noqa: FBT001, FBT002 + use_scaled_positions: bool = False, # noqa: FBT001, FBT002 + max_n_bins: int = int(1e6), +) -> list[torch.Tensor]: + """Compute a neighbor list for an atomic configuration. + + ASE periodic neighbor list implementation + Atoms outside periodic boundaries are mapped into the unit cell. Atoms + outside non-periodic boundaries are included in the neighbor list + but complexity of neighbor list search for those can become n^2. + The neighbor list is sorted by first atom index 'i', but not by second + atom index 'j'. + + Args: + quantities: Quantities to compute by the neighbor list algorithm. Each character + in this string defines a quantity. They are returned in a tuple of + the same order. Possible quantities are + * 'i' : first atom index + * 'j' : second atom index + * 'd' : absolute distance + * 'D' : distance vector + * 'S' : shift vector (number of cell boundaries crossed by the bond + between atom i and j). With the shift vector S, the + distances D between atoms can be computed from: + D = positions[j]-positions[i]+S.dot(cell) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. + cell: Unit cell vectors according to the row vector convention, i.e. + `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + positions: Atomic positions. Anything that can be converted to an ndarray of + shape (n, 3) will do: [(x1,y1,z1), (x2,y2,z2), ...]. If + use_scaled_positions is set to true, this must be scaled positions. + cutoff: Cutoff for neighbor search. It can be: + * A single float: This is a global cutoff for all elements. + * A dictionary: This specifies cutoff values for element + pairs. Specification accepts element numbers of symbols. + Example: {(1, 6): 1.1, (1, 1): 1.0, ('C', 'C'): 1.85} + * A list/array with a per atom value: This specifies the radius of + an atomic sphere for each atoms. If spheres overlap, atoms are + within each others neighborhood. + See :func:`~ase.neighborlist.natural_cutoffs` + for an example on how to get such a list. + device: PyTorch device to use for computations + dtype: PyTorch data type to use + self_interaction: Return the atom itself as its own neighbor if set to true. + Default: False + use_scaled_positions: If set to true, positions are expected to be + scaled positions. + max_n_bins: Maximum number of bins used in neighbor search. This is used to limit + the maximum amount of memory required by the neighbor list. + + Returns: + list[torch.Tensor]: One tensor for each item in `quantities`. Indices in `i` + are returned in ascending order 0..len(a)-1, but the order of (i,j) + pairs is not guaranteed. + + References: + - This code is modified version of the github gist + https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 + """ + # Naming conventions: Suffixes indicate the dimension of an array. The + # following convention is used here: + # c: Cartesian index, can have values 0, 1, 2 + # i: Global atom index, can have values 0..len(a)-1 + # xyz: Bin index, three values identifying x-, y- and z-component of a + # spatial bin that is used to make neighbor search O(n) + # b: Linearized version of the 'xyz' bin index + # a: Bin-local atom index, i.e. index identifying an atom *within* a + # bin + # p: Pair index, can have value 0 or 1 + # n: (Linear) neighbor index + + if len(positions) == 0: + raise RuntimeError("No atoms provided") + + # Compute reciprocal lattice vectors. + recip_cell = torch.linalg.pinv(cell).T + b1_c, b2_c, b3_c = recip_cell[0], recip_cell[1], recip_cell[2] + + # Compute distances of cell faces. + l1 = torch.linalg.norm(b1_c) + l2 = torch.linalg.norm(b2_c) + l3 = torch.linalg.norm(b3_c) + pytorch_scalar_1 = torch.as_tensor(1.0, device=device, dtype=dtype) + face_dist_c = torch.hstack( + [ + 1 / l1 if l1 > 0 else pytorch_scalar_1, + 1 / l2 if l2 > 0 else pytorch_scalar_1, + 1 / l3 if l3 > 0 else pytorch_scalar_1, + ] + ) + if face_dist_c.shape != (3,): + raise ValueError(f"face_dist_c.shape={face_dist_c.shape} != (3,)") + + # we don't handle other fancier cutoffs + max_cutoff: torch.Tensor = cutoff + + # We use a minimum bin size of 3 A + bin_size = torch.maximum(max_cutoff, torch.tensor(3.0, device=device, dtype=dtype)) + # Compute number of bins such that a sphere of radius cutoff fits into + # eight neighboring bins. + n_bins_c = torch.maximum( + (face_dist_c / bin_size).to(dtype=torch.long, device=device), + torch.ones(3, dtype=torch.long, device=device), + ) + n_bins = torch.prod(n_bins_c) + # Make sure we limit the amount of memory used by the explicit bins. + while n_bins > max_n_bins: + n_bins_c = torch.maximum( + n_bins_c // 2, torch.ones(3, dtype=torch.long, device=device) + ) + n_bins = torch.prod(n_bins_c) + + # Compute over how many bins we need to loop in the neighbor list search. + neigh_search = torch.ceil(bin_size * n_bins_c / face_dist_c).to( + dtype=torch.long, device=device + ) + neigh_search_x, neigh_search_y, neigh_search_z = ( + neigh_search[0], + neigh_search[1], + neigh_search[2], + ) + + # If we only have a single bin and the system is not periodic, then we + # do not need to search neighboring bins + pytorch_scalar_int_0 = torch.as_tensor(0, dtype=torch.long, device=device) + neigh_search_x = ( + pytorch_scalar_int_0 if n_bins_c[0] == 1 and not pbc[0] else neigh_search_x + ) + neigh_search_y = ( + pytorch_scalar_int_0 if n_bins_c[1] == 1 and not pbc[1] else neigh_search_y + ) + neigh_search_z = ( + pytorch_scalar_int_0 if n_bins_c[2] == 1 and not pbc[2] else neigh_search_z + ) + + # Sort atoms into bins. + if not any(pbc): + scaled_positions_ic = positions + elif use_scaled_positions: + scaled_positions_ic = positions + positions = torch.dot(scaled_positions_ic, cell) + else: + scaled_positions_ic = torch.linalg.solve(cell.T, positions.T).T + + bin_index_ic = torch.floor(scaled_positions_ic * n_bins_c).to( + dtype=torch.long, device=device + ) + cell_shift_ic = torch.zeros_like(bin_index_ic, device=device) + + for c in range(3): + if pbc[c]: + # (Note: torch.divmod does not exist in older numpy versions) + cell_shift_ic[:, c], bin_index_ic[:, c] = fm.torch_divmod( + bin_index_ic[:, c], n_bins_c[c] + ) + else: + bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1) + + # Convert Cartesian bin index to unique scalar bin index. + bin_index_i = bin_index_ic[:, 0] + n_bins_c[0] * ( + bin_index_ic[:, 1] + n_bins_c[1] * bin_index_ic[:, 2] + ) + + # atom_i contains atom index in new sort order. + atom_i = torch.argsort(bin_index_i) + bin_index_i = bin_index_i[atom_i] + + # Find max number of atoms per bin + max_n_atoms_per_bin = torch.bincount(bin_index_i).max() + + # Sort atoms into bins: atoms_in_bin_ba contains for each bin (identified + # by its scalar bin index) a list of atoms inside that bin. This list is + # homogeneous, i.e. has the same size *max_n_atoms_per_bin* for all bins. + # The list is padded with -1 values. + atoms_in_bin_ba = -torch.ones( + n_bins.item(), max_n_atoms_per_bin.item(), dtype=torch.long, device=device + ) + for bin_cnt in range(int(max_n_atoms_per_bin.item())): + # Create a mask array that identifies the first atom of each bin. + mask = torch.cat( + ( + torch.ones(1, dtype=torch.bool, device=device), + bin_index_i[:-1] != bin_index_i[1:], + ), + dim=0, + ) + # Assign all first atoms. + atoms_in_bin_ba[bin_index_i[mask], bin_cnt] = atom_i[mask] + + # Remove atoms that we just sorted into atoms_in_bin_ba. The next + # "first" atom will be the second and so on. + mask = torch.logical_not(mask) + atom_i = atom_i[mask] + bin_index_i = bin_index_i[mask] + + # Make sure that all atoms have been sorted into bins. + if len(atom_i) != 0: + raise ValueError(f"len(atom_i)={len(atom_i)} != 0") + if len(bin_index_i) != 0: + raise ValueError(f"len(bin_index_i)={len(bin_index_i)} != 0") + + # Now we construct neighbor pairs by pairing up all atoms within a bin or + # between bin and neighboring bin. atom_pairs_pn is a helper buffer that + # contains all potential pairs of atoms between two bins, i.e. it is a list + # of length max_n_atoms_per_bin**2. + # atom_pairs_pn_np = np.indices( + # (max_n_atoms_per_bin, max_n_atoms_per_bin), dtype=int + # ).reshape(2, -1) + atom_pairs_pn = torch.cartesian_prod( + torch.arange(max_n_atoms_per_bin, device=device), + torch.arange(max_n_atoms_per_bin, device=device), + ) + atom_pairs_pn = atom_pairs_pn.T.reshape(2, -1) + + # Initialized empty neighbor list buffers. + first_at_neigh_tuple_nn = [] + second_at_neigh_tuple_nn = [] + cell_shift_vector_x_n = [] + cell_shift_vector_y_n = [] + cell_shift_vector_z_n = [] + + # This is the main neighbor list search. We loop over neighboring bins and + # then construct all possible pairs of atoms between two bins, assuming + # that each bin contains exactly max_n_atoms_per_bin atoms. We then throw + # out pairs involving pad atoms with atom index -1 below. + binz_xyz, biny_xyz, binx_xyz = torch.meshgrid( + torch.arange(n_bins_c[2], device=device), + torch.arange(n_bins_c[1], device=device), + torch.arange(n_bins_c[0], device=device), + indexing="ij", + ) + # The memory layout of binx_xyz, biny_xyz, binz_xyz is such that computing + # the respective bin index leads to a linearly increasing consecutive list. + # The following assert statement succeeds: + # b_b = (binx_xyz + n_bins_c[0] * (biny_xyz + n_bins_c[1] * + # binz_xyz)).ravel() + # assert (b_b == torch.arange(torch.prod(n_bins_c))).all() + + # First atoms in pair. + _first_at_neigh_tuple_n = atoms_in_bin_ba[:, atom_pairs_pn[0]] + for dz in range(-int(neigh_search_z.item()), int(neigh_search_z.item()) + 1): + for dy in range(-int(neigh_search_y.item()), int(neigh_search_y.item()) + 1): + for dx in range(-int(neigh_search_x.item()), int(neigh_search_x.item()) + 1): + # Bin index of neighboring bin and shift vector. + shiftx_xyz, neighbinx_xyz = fm.torch_divmod(binx_xyz + dx, n_bins_c[0]) + shifty_xyz, neighbiny_xyz = fm.torch_divmod(biny_xyz + dy, n_bins_c[1]) + shiftz_xyz, neighbinz_xyz = fm.torch_divmod(binz_xyz + dz, n_bins_c[2]) + neighbin_b = ( + neighbinx_xyz + + n_bins_c[0] * (neighbiny_xyz + n_bins_c[1] * neighbinz_xyz) + ).ravel() + + # Second atom in pair. + _second_at_neigh_tuple_n = atoms_in_bin_ba[neighbin_b][ + :, atom_pairs_pn[1] + ] + + # Shift vectors. + # TODO: was np.resize: + # _cell_shift_vector_x_n_np = np.resize( + # shiftx_xyz.reshape(-1, 1).numpy(), + # (int(max_n_atoms_per_bin.item() ** 2), shiftx_xyz.numel()), + # ).T + # _cell_shift_vector_y_n_np = np.resize( + # shifty_xyz.reshape(-1, 1).numpy(), + # (int(max_n_atoms_per_bin.item() ** 2), shifty_xyz.numel()), + # ).T + # _cell_shift_vector_z_n_np = np.resize( + # shiftz_xyz.reshape(-1, 1).numpy(), + # (int(max_n_atoms_per_bin.item() ** 2), shiftz_xyz.numel()), + # ).T + # this basically just tiles shiftx_xyz.reshape(-1, 1) n times + _cell_shift_vector_x_n = shiftx_xyz.reshape(-1, 1).repeat( + (1, int(max_n_atoms_per_bin.item() ** 2)) + ) + # assert _cell_shift_vector_x_n.shape == _cell_shift_vector_x_n_np.shape + # assert np.allclose( + # _cell_shift_vector_x_n.numpy(), _cell_shift_vector_x_n_np + # ) + _cell_shift_vector_y_n = shifty_xyz.reshape(-1, 1).repeat( + (1, int(max_n_atoms_per_bin.item() ** 2)) + ) + # assert _cell_shift_vector_y_n.shape == _cell_shift_vector_y_n_np.shape + # assert np.allclose( + # _cell_shift_vector_y_n.numpy(), _cell_shift_vector_y_n_np + # ) + _cell_shift_vector_z_n = shiftz_xyz.reshape(-1, 1).repeat( + (1, int(max_n_atoms_per_bin.item() ** 2)) + ) + # assert _cell_shift_vector_z_n.shape == _cell_shift_vector_z_n_np.shape + # assert np.allclose( + # _cell_shift_vector_z_n.numpy(), _cell_shift_vector_z_n_np + # ) + + # We have created too many pairs because we assumed each bin + # has exactly max_n_atoms_per_bin atoms. Remove all superfluous + # pairs. Those are pairs that involve an atom with index -1. + mask = torch.logical_and( + _first_at_neigh_tuple_n != -1, _second_at_neigh_tuple_n != -1 + ) + if mask.sum() > 0: + first_at_neigh_tuple_nn += [_first_at_neigh_tuple_n[mask]] + second_at_neigh_tuple_nn += [_second_at_neigh_tuple_n[mask]] + cell_shift_vector_x_n += [_cell_shift_vector_x_n[mask]] + cell_shift_vector_y_n += [_cell_shift_vector_y_n[mask]] + cell_shift_vector_z_n += [_cell_shift_vector_z_n[mask]] + + # Flatten overall neighbor list. + first_at_neigh_tuple_n = torch.cat(first_at_neigh_tuple_nn) + second_at_neigh_tuple_n = torch.cat(second_at_neigh_tuple_nn) + cell_shift_vector_n = torch.vstack( + [ + torch.cat(cell_shift_vector_x_n), + torch.cat(cell_shift_vector_y_n), + torch.cat(cell_shift_vector_z_n), + ] + ).T + + # Add global cell shift to shift vectors + cell_shift_vector_n += ( + cell_shift_ic[first_at_neigh_tuple_n] - cell_shift_ic[second_at_neigh_tuple_n] + ) + + # Remove all self-pairs that do not cross the cell boundary. + if not self_interaction: + m = torch.logical_not( + torch.logical_and( + first_at_neigh_tuple_n == second_at_neigh_tuple_n, + (cell_shift_vector_n == 0).all(dim=1), + ) + ) + first_at_neigh_tuple_n = first_at_neigh_tuple_n[m] + second_at_neigh_tuple_n = second_at_neigh_tuple_n[m] + cell_shift_vector_n = cell_shift_vector_n[m] + + # For non-periodic directions, remove any bonds that cross the domain + # boundary. + for c in range(3): + if not pbc[c]: + m = cell_shift_vector_n[:, c] == 0 + first_at_neigh_tuple_n = first_at_neigh_tuple_n[m] + second_at_neigh_tuple_n = second_at_neigh_tuple_n[m] + cell_shift_vector_n = cell_shift_vector_n[m] + + # Sort neighbor list. + bin_cnt = torch.argsort(first_at_neigh_tuple_n) + first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt] + second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt] + cell_shift_vector_n = cell_shift_vector_n[bin_cnt] + + # Compute distance vectors. + # TODO: Use .T? + distance_vector_nc = ( + positions[second_at_neigh_tuple_n] + - positions[first_at_neigh_tuple_n] + + cell_shift_vector_n.to(cell.dtype).matmul(cell) + ) + abs_distance_vector_n = torch.sqrt( + torch.sum(distance_vector_nc * distance_vector_nc, dim=1) + ) + + # We have still created too many pairs. Only keep those with distance + # smaller than max_cutoff. + mask = abs_distance_vector_n < max_cutoff + first_at_neigh_tuple_n = first_at_neigh_tuple_n[mask] + second_at_neigh_tuple_n = second_at_neigh_tuple_n[mask] + cell_shift_vector_n = cell_shift_vector_n[mask] + distance_vector_nc = distance_vector_nc[mask] + abs_distance_vector_n = abs_distance_vector_n[mask] + + # Assemble return tuple. + ret_vals = [] + for quant in quantities: + if quant == "i": + ret_vals += [first_at_neigh_tuple_n] + elif quant == "j": + ret_vals += [second_at_neigh_tuple_n] + elif quant == "D": + ret_vals += [distance_vector_nc] + elif quant == "d": + ret_vals += [abs_distance_vector_n] + elif quant == "S": + ret_vals += [cell_shift_vector_n] + else: + raise ValueError("Unsupported quantity specified.") + + return ret_vals + + +def standard_nl( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute neighbor lists using primitive neighbor list algorithm. + + This function provides a standardized interface for computing neighbor lists + in atomic systems. It handles both single systems and batched (multi-system) + calculations with a unified API. + + Key Features: + - Unified API for single and batched systems + - Supports both periodic and non-periodic boundary conditions + - Returns neighbor indices, system mapping, and shift vectors + - Fully compatible with PyTorch's automatic differentiation + - Consistent with torch_nl_n2 and torch_nl_linked_cell API + + Args: + positions: Atomic positions tensor of shape [n_atoms, 3] + cell: Unit cell vectors [3*n_systems, 3] (row vector convention) + pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cutoff: Maximum distance for considering atoms as neighbors + system_idx: Tensor [n_atoms] indicating which system each atom belongs to. + For single system, use torch.zeros(n_atoms, dtype=torch.long) + self_interaction: If True, include self-pairs. Default: False + + Returns: + tuple containing: + - mapping: Tensor [2, num_neighbors] - pairs of atom indices + - system_mapping: Tensor [num_neighbors] - system assignment for each pair + - shifts_idx: Tensor [num_neighbors, 3] - periodic shift indices + + Example: + >>> # Single system (all atoms belong to system 0) + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + >>> cell = torch.eye(3) * 10.0 + >>> pbc = torch.tensor([True, True, True]) + >>> cutoff = torch.tensor(1.5) + >>> system_idx = torch.zeros(2, dtype=torch.long) + >>> mapping, sys_map, shifts = standard_nl( + ... positions, cell, pbc, cutoff, system_idx + ... ) + + >>> # Batched systems + >>> positions = torch.randn(20, 3) # 20 atoms total + >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # 2 systems + >>> system_idx = torch.cat([torch.zeros(10), torch.ones(10)]).long() + >>> mapping, sys_map, shifts = standard_nl( + ... positions, cell, pbc, cutoff, system_idx + ... ) + + References: + - https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 + """ + device = positions.device + dtype = positions.dtype + n_systems = system_idx.max().item() + 1 + + # Handle PBC: reshape if needed + if pbc.ndim == 1: + if pbc.shape[0] == 3: + # Single PBC for all systems + pbc_per_system = pbc.unsqueeze(0).expand(n_systems, -1) + elif pbc.shape[0] == n_systems * 3: + # Flat concatenated PBC, reshape to [n_systems, 3] + pbc_per_system = pbc.reshape(n_systems, 3) + else: + raise ValueError(f"Unexpected PBC shape: {pbc.shape}") + else: + # Already [n_systems, 3] + pbc_per_system = pbc + + # Process each system's neighbor list separately + edge_indices = [] + shifts_idx_list = [] + system_mapping_list = [] + offset = 0 + + for sys_idx in range(n_systems): + system_mask = system_idx == sys_idx + n_atoms_in_system = system_mask.sum().item() + + if n_atoms_in_system == 0: + continue + + # Get the cell for this system + cell_sys = cell if cell.shape[0] == 3 else cell[sys_idx * 3 : (sys_idx + 1) * 3] + + # Calculate neighbor list for this system using primitive_neighbor_list + # Ensure tensors are contiguous for TorchScript + positions_sys = positions[system_mask].contiguous() + cell_sys = cell_sys.contiguous() + pbc_sys = pbc_per_system[sys_idx].contiguous() + + i, j, S = primitive_neighbor_list( + quantities="ijS", + positions=positions_sys, + cell=cell_sys, + pbc=pbc_sys, + cutoff=cutoff, + device=device, + dtype=dtype, + self_interaction=self_interaction, + use_scaled_positions=False, + max_n_bins=int(1e6), + ) + + edge_idx = torch.stack((i, j), dim=0).to(dtype=torch.long) + shifts = S.to(dtype=dtype) + + # Adjust indices for the global atom indexing + edge_idx = edge_idx + offset + + edge_indices.append(edge_idx) + shifts_idx_list.append(shifts) + system_mapping_list.append( + torch.full((edge_idx.shape[1],), sys_idx, dtype=torch.long, device=device) + ) + + offset += n_atoms_in_system + + # Combine all neighbor lists + if len(edge_indices) == 0: + # No neighbors found + mapping = torch.zeros((2, 0), dtype=torch.long, device=device) + system_mapping = torch.zeros(0, dtype=torch.long, device=device) + shifts_idx = torch.zeros((0, 3), dtype=dtype, device=device) + else: + mapping = torch.cat(edge_indices, dim=1) + shifts_idx = torch.cat(shifts_idx_list, dim=0) + system_mapping = torch.cat(system_mapping_list, dim=0) + + return mapping, system_mapping, shifts_idx diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py new file mode 100644 index 00000000..8e111c39 --- /dev/null +++ b/torch_sim/neighbors/torch_nl.py @@ -0,0 +1,230 @@ +"""Batched neighbor list implementations for multiple systems. + +This module provides neighbor list calculations optimized for batched processing +of multiple atomic systems simultaneously. These implementations are designed for +use with multiple systems that may have different numbers of atoms. + +The API follows the batched convention used in MACE and other models: +- Requires system_idx to identify which system each atom belongs to +- Returns (mapping, system_mapping, shifts_idx) tuples +- mapping: [2, n_neighbors] - pairs of atom indices +- system_mapping: [n_neighbors] - which system each neighbor pair belongs to +- shifts_idx: [n_neighbors, 3] - periodic shift indices +""" + +import torch + +from torch_sim import transforms + + +def strict_nl( + cutoff: float, + positions: torch.Tensor, + cell: torch.Tensor, + mapping: torch.Tensor, + system_mapping: torch.Tensor, + shifts_idx: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply a strict cutoff to the neighbor list defined in the mapping. + + This function filters the neighbor list based on a specified cutoff distance. + It computes the squared distances between pairs of positions and retains only + those pairs that are within the cutoff distance. The function also accounts + for periodic boundary conditions by applying cell shifts when necessary. + + Args: + cutoff (float): + The maximum distance for considering two atoms as neighbors. This value + is used to filter the neighbor pairs based on their distances. + positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing + the positions of the atoms. + cell (torch.Tensor): Unit cell vectors according to the row vector convention, + i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + mapping (torch.Tensor): + A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` + for which to compute distances. + system_mapping (torch.Tensor): + A tensor that maps the shifts to the corresponding cells, used in conjunction + with `shifts_idx` to compute the correct periodic shifts. + shifts_idx (torch.Tensor): + A tensor of shape (n_shifts, 3) representing the indices for shifts to apply + to the distances for periodic boundary conditions. + + Returns: + tuple: + A tuple containing: + - mapping (torch.Tensor): A filtered tensor of shape (2, n_filtered_pairs) + with pairs of indices that are within the cutoff distance. + - mapping_system (torch.Tensor): A tensor of shape (n_filtered_pairs,) + that maps the filtered pairs to their corresponding systems. + - shifts_idx (torch.Tensor): A tensor of shape (n_filtered_pairs, 3) + containing the periodic shift indices for the filtered pairs. + + Notes: + - The function computes the squared distances to avoid the computational cost + of taking square roots, which is unnecessary for comparison. + - If no cell shifts are needed (i.e., for non-periodic systems), the function + directly computes the squared distances between the positions. + + References: + - https://github.com/felixmusil/torch_nl + """ + cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping) + if cell_shifts is None: + d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) + else: + d2 = ( + (positions[mapping[0]] - positions[mapping[1]] - cell_shifts) + .square() + .sum(dim=1) + ) + + mask = d2 < cutoff * cutoff + mapping = mapping[:, mask] + mapping_system = system_mapping[mask] + shifts_idx = shifts_idx[mask] + return mapping, mapping_system, shifts_idx + + +@torch.jit.script +def torch_nl_n2( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the neighbor list for a set of atomic structures using a + naive neighbor search before applying a strict `cutoff`. + + The atomic positions `pos` should be wrapped inside their respective unit cells. + + This implementation uses a naive O(N²) neighbor search which can be slow for + large systems but is simple and works reliably for small to medium systems. + + Args: + positions (torch.Tensor [n_atom, 3]): A tensor containing the positions + of atoms wrapped inside their respective unit cells. + cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to + the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + pbc (torch.Tensor [n_structure, 3] bool): + A tensor indicating the periodic boundary conditions to apply. + Partial PBC are not supported yet. + cutoff (torch.Tensor): + The cutoff radius used for the neighbor search. + system_idx (torch.Tensor [n_atom,] torch.long): + A tensor containing the index of the structure to which each atom belongs. + self_interaction (bool, optional): + A flag to indicate whether to keep the center atoms as their own neighbors. + Default is False. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mapping (torch.Tensor [2, n_neighbors]): + A tensor containing the indices of the neighbor list for the given + positions array. `mapping[0]` corresponds to the central atom indices, + and `mapping[1]` corresponds to the neighbor atom indices. + system_mapping (torch.Tensor [n_neighbors]): + A tensor mapping the neighbor atoms to their respective structures. + shifts_idx (torch.Tensor [n_neighbors, 3]): + A tensor containing the cell shift indices used to reconstruct the + neighbor atom positions. + + Example: + >>> # Create a batched system with 2 structures + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]]) + >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # Two cells + >>> pbc = torch.tensor([[True, True, True], [True, True, True]]) + >>> cutoff = torch.tensor(2.0) + >>> # First 2 atoms in system 0, last in system 1 + >>> system_idx = torch.tensor([0, 0, 1]) + >>> mapping, sys_map, shifts = torch_nl_n2( + ... positions, cell, pbc, cutoff, system_idx + ... ) + + References: + - https://github.com/felixmusil/torch_nl + """ + n_atoms = torch.bincount(system_idx) + mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( + positions, cell, pbc, cutoff.item(), n_atoms, self_interaction + ) + mapping, mapping_system, shifts_idx = strict_nl( + cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx + ) + return mapping, mapping_system, shifts_idx + + +@torch.jit.script +def torch_nl_linked_cell( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 (*, not compatible with torch.jit.script) +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the neighbor list for a set of atomic structures using the linked + cell algorithm before applying a strict `cutoff`. + + The atoms positions `pos` should be wrapped inside their respective unit cells. + + This is the recommended default for batched neighbor list calculations as it + provides good performance for systems of various sizes using the linked cell + algorithm which has O(N) complexity. + + Args: + positions (torch.Tensor [n_atom, 3]): + A tensor containing the positions of atoms wrapped inside + their respective unit cells. + cell (torch.Tensor [3*n_systems, 3]): Unit cell vectors according to + the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + pbc (torch.Tensor [n_systems, 3] bool): + A tensor indicating the periodic boundary conditions to apply. + Partial PBC are not supported yet. + cutoff (torch.Tensor): + The cutoff radius used for the neighbor search. + system_idx (torch.Tensor [n_atom,] torch.long): + A tensor containing the index of the structure to which each atom belongs. + self_interaction (bool, optional): + A flag to indicate whether to keep the center atoms as their own neighbors. + Default is False. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A tuple containing: + - mapping (torch.Tensor [2, n_neighbors]): + A tensor containing the indices of the neighbor list for the given + positions array. `mapping[0]` corresponds to the central atom + indices, and `mapping[1]` corresponds to the neighbor atom indices. + - system_mapping (torch.Tensor [n_neighbors]): + A tensor mapping the neighbor atoms to their respective structures. + - shifts_idx (torch.Tensor [n_neighbors, 3]): + A tensor containing the cell shift indices used to reconstruct the + neighbor atom positions. + + Example: + >>> # Create a batched system with 2 structures + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]]) + >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # Two cells + >>> pbc = torch.tensor([[True, True, True], [True, True, True]]) + >>> cutoff = torch.tensor(2.0) + >>> # First 2 atoms in system 0, last in system 1 + >>> system_idx = torch.tensor([0, 0, 1]) + >>> mapping, sys_map, shifts = torch_nl_linked_cell( + ... positions, cell, pbc, cutoff, system_idx + ... ) + + References: + - https://github.com/felixmusil/torch_nl + """ + n_atoms = torch.bincount(system_idx) + mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( + positions, cell, pbc, cutoff.item(), n_atoms, self_interaction + ) + + mapping, mapping_system, shifts_idx = strict_nl( + cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx + ) + return mapping, mapping_system, shifts_idx diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py new file mode 100644 index 00000000..d45a41a1 --- /dev/null +++ b/torch_sim/neighbors/vesin.py @@ -0,0 +1,328 @@ +"""Vesin-based neighbor list implementations. + +This module provides high-performance neighbor list calculations using the +Vesin library. It includes both TorchScript-compatible and standard implementations. + +Vesin is available at: https://github.com/Luthaf/vesin +""" + +import torch + + +try: + from vesin import NeighborList as VesinNeighborList + from vesin.torch import NeighborList as VesinNeighborListTorch + + VESIN_AVAILABLE = True +except ImportError: + VESIN_AVAILABLE = False + VesinNeighborList = None # type: ignore[assignment, misc] + VesinNeighborListTorch = None # type: ignore[assignment, misc] + +__all__ = [ + "VESIN_AVAILABLE", + "VesinNeighborList", + "VesinNeighborListTorch", + "vesin_nl", + "vesin_nl_ts", +] + + +if VESIN_AVAILABLE: + + def vesin_nl_ts( # noqa: PLR0915 + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute neighbor lists using TorchScript-compatible Vesin. + + This function provides a TorchScript-compatible interface to the Vesin + neighbor list algorithm using VesinNeighborListTorch. It handles both + single systems and batched (multi-system) calculations with a unified API. + + Args: + positions: Atomic positions tensor [n_atoms, 3] + cell: Unit cell vectors [3*n_systems, 3] (row vector convention) + pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors + system_idx: Tensor [n_atoms] indicating which system each atom belongs to. + For single system, use torch.zeros(n_atoms, dtype=torch.long) + self_interaction: If True, include self-pairs. Default: False + + Returns: + tuple containing: + - mapping: Tensor [2, num_neighbors] - pairs of atom indices + - system_mapping: Tensor [num_neighbors] - system assignment for each pair + - shifts_idx: Tensor [num_neighbors, 3] - periodic shift indices + + Example: + >>> # Single system + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + >>> system_idx = torch.zeros(2, dtype=torch.long) + >>> mapping, sys_map, shifts = vesin_nl_ts( + ... positions, cell, pbc, cutoff, system_idx + ... ) + + Notes: + - Uses VesinNeighborListTorch for TorchScript compatibility + - Requires CPU tensors in float64 precision internally + - Returns tensors on the same device as input with original precision + - For non-periodic systems, shifts will be zero vectors + - The neighbor list includes both (i,j) and (j,i) pairs + + References: + https://github.com/Luthaf/vesin + """ + device = positions.device + dtype = positions.dtype + n_systems = system_idx.max().item() + 1 + + # Handle PBC: reshape if needed + if pbc.ndim == 1: + if pbc.shape[0] == 3: + # Single PBC for all systems + pbc_per_system = pbc.unsqueeze(0).expand(n_systems, -1) + elif pbc.shape[0] == n_systems * 3: + # Flat concatenated PBC, reshape to [n_systems, 3] + pbc_per_system = pbc.reshape(n_systems, 3) + else: + raise ValueError(f"Unexpected PBC shape: {pbc.shape}") + else: + # Already [n_systems, 3] + pbc_per_system = pbc + + # Process each system's neighbor list separately + edge_indices = [] + shifts_idx_list = [] + system_mapping_list = [] + offset = 0 + + for sys_idx in range(n_systems): + system_mask = system_idx == sys_idx + n_atoms_in_system = system_mask.sum().item() + + if n_atoms_in_system == 0: + continue + + # Calculate neighbor list for this system + neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) + + # Get the cell for this system + cell_sys = ( + cell if cell.shape[0] == 3 else cell[sys_idx * 3 : (sys_idx + 1) * 3] + ) + + # Convert tensors to CPU and float64 properly + positions_cpu = positions[system_mask].cpu().to(dtype=torch.float64) + cell_cpu = cell_sys.cpu().to(dtype=torch.float64) + periodic_cpu = pbc_per_system[sys_idx].to(dtype=torch.bool).cpu() + + # Only works on CPU and requires float64 + i, j, S = neighbor_list_fn.compute( + points=positions_cpu, + box=cell_cpu, + periodic=periodic_cpu, + quantities="ijS", + ) + + edge_idx = torch.stack((i, j), dim=0).to(dtype=torch.long, device=device) + shifts = S.to(dtype=dtype, device=device) + + # Adjust indices for the global atom indexing + edge_idx = edge_idx + offset + + edge_indices.append(edge_idx) + shifts_idx_list.append(shifts) + system_mapping_list.append( + torch.full((edge_idx.shape[1],), sys_idx, dtype=torch.long, device=device) + ) + + offset += n_atoms_in_system + + # Combine all neighbor lists + if len(edge_indices) == 0: + # No neighbors found + mapping = torch.zeros((2, 0), dtype=torch.long, device=device) + system_mapping = torch.zeros(0, dtype=torch.long, device=device) + shifts_idx = torch.zeros((0, 3), dtype=dtype, device=device) + else: + mapping = torch.cat(edge_indices, dim=1) + shifts_idx = torch.cat(shifts_idx_list, dim=0) + system_mapping = torch.cat(system_mapping_list, dim=0) + + # Add self-interactions if requested + if self_interaction: + n_atoms = positions.shape[0] + self_pairs = torch.arange(n_atoms, device=device, dtype=torch.long) + self_mapping = torch.stack([self_pairs, self_pairs], dim=0) + self_shifts = torch.zeros((n_atoms, 3), dtype=dtype, device=device) + self_sys_mapping = system_idx + + mapping = torch.cat([mapping, self_mapping], dim=1) + shifts_idx = torch.cat([shifts_idx, self_shifts], dim=0) + system_mapping = torch.cat([system_mapping, self_sys_mapping], dim=0) + + return mapping, system_mapping, shifts_idx + + def vesin_nl( # noqa: PLR0915 + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: float | torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute neighbor lists using the standard Vesin implementation. + + This function provides an interface to the standard Vesin neighbor list + algorithm using VesinNeighborList. It handles both single systems and + batched (multi-system) calculations with a unified API. + + Args: + positions: Atomic positions tensor [n_atoms, 3] + cell: Unit cell vectors [3*n_systems, 3] (row vector convention) + pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cutoff: Maximum distance for considering atoms as neighbors + system_idx: Tensor [n_atoms] indicating which system each atom belongs to. + For single system, use torch.zeros(n_atoms, dtype=torch.long) + self_interaction: If True, include self-pairs. Default: False + + Returns: + tuple containing: + - mapping: Tensor [2, num_neighbors] - pairs of atom indices + - system_mapping: Tensor [num_neighbors] - system assignment for each pair + - shifts_idx: Tensor [num_neighbors, 3] - periodic shift indices + + Example: + >>> # Single system + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + >>> system_idx = torch.zeros(2, dtype=torch.long) + >>> mapping, sys_map, shifts = vesin_nl( + ... positions, cell, pbc, cutoff, system_idx + ... ) + + Notes: + - Uses standard VesinNeighborList implementation + - Requires CPU tensors in float64 precision internally + - Returns tensors on the same device as input with original precision + - For non-periodic systems, shifts will be zero vectors + - The neighbor list includes both (i,j) and (j,i) pairs + + References: + - https://github.com/Luthaf/vesin + """ + device = positions.device + dtype = positions.dtype + n_systems = system_idx.max().item() + 1 + + # Handle PBC: reshape if needed + if pbc.ndim == 1: + if pbc.shape[0] == 3: + # Single PBC for all systems + pbc_per_system = pbc.unsqueeze(0).expand(n_systems, -1) + elif pbc.shape[0] == n_systems * 3: + # Flat concatenated PBC, reshape to [n_systems, 3] + pbc_per_system = pbc.reshape(n_systems, 3) + else: + raise ValueError(f"Unexpected PBC shape: {pbc.shape}") + else: + # Already [n_systems, 3] + pbc_per_system = pbc + + # Process each system's neighbor list separately + edge_indices = [] + shifts_idx_list = [] + system_mapping_list = [] + offset = 0 + + for sys_idx in range(n_systems): + system_mask = system_idx == sys_idx + n_atoms_in_system = system_mask.sum().item() + + if n_atoms_in_system == 0: + continue + + # Get the cell for this system + cell_sys = ( + cell if cell.shape[0] == 3 else cell[sys_idx * 3 : (sys_idx + 1) * 3] + ) + + # Calculate neighbor list for this system + neighbor_list_fn = VesinNeighborList( + (float(cutoff)), full_list=True, sorted=False + ) + + # Convert tensors to CPU and float64 without gradients + positions_cpu = positions[system_mask].detach().cpu().to(dtype=torch.float64) + cell_cpu = cell_sys.detach().cpu().to(dtype=torch.float64) + periodic_cpu = pbc_per_system[sys_idx].detach().to(dtype=torch.bool).cpu() + + # Only works on CPU and returns numpy arrays + i, j, S = neighbor_list_fn.compute( + points=positions_cpu, + box=cell_cpu, + periodic=periodic_cpu, + quantities="ijS", + ) + i, j = ( + torch.tensor(i, dtype=torch.long, device=device), + torch.tensor(j, dtype=torch.long, device=device), + ) + edge_idx = torch.stack((i, j), dim=0) + shifts = torch.tensor(S, dtype=dtype, device=device) + + # Adjust indices for the global atom indexing + edge_idx = edge_idx + offset + + edge_indices.append(edge_idx) + shifts_idx_list.append(shifts) + system_mapping_list.append( + torch.full((edge_idx.shape[1],), sys_idx, dtype=torch.long, device=device) + ) + + offset += n_atoms_in_system + + # Combine all neighbor lists + if len(edge_indices) == 0: + # No neighbors found + mapping = torch.zeros((2, 0), dtype=torch.long, device=device) + system_mapping = torch.zeros(0, dtype=torch.long, device=device) + shifts_idx = torch.zeros((0, 3), dtype=dtype, device=device) + else: + mapping = torch.cat(edge_indices, dim=1) + shifts_idx = torch.cat(shifts_idx_list, dim=0) + system_mapping = torch.cat(system_mapping_list, dim=0) + + # Add self-interactions if requested + if self_interaction: + n_atoms = positions.shape[0] + self_pairs = torch.arange(n_atoms, device=device, dtype=torch.long) + self_mapping = torch.stack([self_pairs, self_pairs], dim=0) + self_shifts = torch.zeros((n_atoms, 3), dtype=dtype, device=device) + self_sys_mapping = system_idx + + mapping = torch.cat([mapping, self_mapping], dim=1) + shifts_idx = torch.cat([shifts_idx, self_shifts], dim=0) + system_mapping = torch.cat([system_mapping, self_sys_mapping], dim=0) + + return mapping, system_mapping, shifts_idx + +else: + # Provide stub functions that raise informative errors + def vesin_nl_ts( # type: ignore[misc] + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stub function when Vesin is not available.""" + raise ImportError("Vesin is not installed. Install it with: pip install vesin") + + def vesin_nl( # type: ignore[misc] + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stub function when Vesin is not available.""" + raise ImportError("Vesin is not installed. Install it with: pip install vesin") From f911f8c087ba430e055865f232a2ef79d979afaf Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 10 Dec 2025 10:16:20 -0800 Subject: [PATCH 06/11] fix nl for classical potentials --- torch_sim/models/lennard_jones.py | 20 ++++++++++++++----- torch_sim/models/morse.py | 15 +++++++++++---- torch_sim/models/particle_life.py | 22 +++++++++++++-------- torch_sim/models/soft_sphere.py | 32 ++++++++++++++++++++----------- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 888798a1..22d1c088 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -261,17 +261,27 @@ def unbatched_forward( pbc = state.pbc if self.use_neighbor_list: - # Get neighbor list using vesin_nl_ts - mapping, shifts = torchsim_nl( + # Get neighbor list using torchsim_nl + # Ensure system_idx exists (create if None for single system) + system_idx = ( + state.system_idx + if state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + mapping, system_mapping, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=pbc, cutoff=self.cutoff, - sort_id=False, + system_idx=system_idx, ) - # Get displacements using neighbor list + # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( - positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts + positions=positions, + cell=cell, + pbc=pbc, + pairs=(mapping[0], mapping[1]), + shifts=shifts_idx, ) else: # Get all pairwise displacements diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 046c3415..513c3ed6 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -267,19 +267,26 @@ def unbatched_forward( pbc = sim_state.pbc if self.use_neighbor_list: - mapping, shifts = torchsim_nl( + # Ensure system_idx exists (create if None for single system) + system_idx = ( + sim_state.system_idx + if sim_state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + mapping, system_mapping, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=pbc, cutoff=self.cutoff, - sort_id=False, + system_idx=system_idx, ) + # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( positions=positions, cell=cell, pbc=pbc, - pairs=mapping, - shifts=shifts, + pairs=(mapping[0], mapping[1]), + shifts=shifts_idx, ) else: dr_vec, distances = transforms.get_pair_displacements( diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index dd478299..270f79f3 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -150,21 +150,27 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: cell = cell.squeeze(0) # Squeeze the first dimension if self.use_neighbor_list: - # Get neighbor list using wrapping_nl - mapping, shifts = torchsim_nl( + # Get neighbor list using torchsim_nl + # Ensure system_idx exists (create if None for single system) + system_idx = ( + state.system_idx + if state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + mapping, system_mapping, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=pbc, - cutoff=float(self.cutoff), - sort_id=False, + cutoff=self.cutoff, + system_idx=system_idx, ) - # Get displacements using neighbor list + # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( positions=positions, cell=cell, pbc=pbc, - pairs=mapping, - shifts=shifts, + pairs=(mapping[0], mapping[1]), + shifts=shifts_idx, ) else: # Get all pairwise displacements @@ -180,7 +186,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: mask = distances < self.cutoff # Get valid pairs - match neighbor list convention for pair order i, j = torch.where(mask) - mapping = torch.stack([j, i]) # Changed from [j, i] to [i, j] + mapping = torch.stack([j, i]) # Get valid displacements and distances dr_vec = dr_vec[mask] distances = distances[mask] diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 79e70a38..6dc9d1ce 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -285,21 +285,27 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: pbc = state.pbc if self.use_neighbor_list: - # Get neighbor list using vesin_nl_ts - mapping, shifts = torchsim_nl( + # Get neighbor list using torchsim_nl + # Ensure system_idx exists (create if None for single system) + system_idx = ( + state.system_idx + if state.system_idx is not None + else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + ) + mapping, system_mapping, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=pbc, cutoff=self.cutoff, - sort_id=False, + system_idx=system_idx, ) - # Get displacements between neighbor pairs + # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( positions=positions, cell=cell, pbc=pbc, - pairs=mapping, - shifts=shifts, + pairs=(mapping[0], mapping[1]), + shifts=shifts_idx, ) else: @@ -710,20 +716,24 @@ def unbatched_forward( # noqa: PLR0915 # Compute neighbor list or full distance matrix if self.use_neighbor_list: # Get neighbor list for efficient computation - mapping, shifts = torchsim_nl( + # Ensure system_idx exists (create if None for single system) + system_idx = torch.zeros( + positions.shape[0], dtype=torch.long, device=self.device + ) + mapping, system_mapping, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=self.pbc, cutoff=self.cutoff, - sort_id=False, + system_idx=system_idx, ) - # Get displacements between neighbor pairs + # Pass shifts_idx directly - get_pair_displacements will convert them dr_vec, distances = transforms.get_pair_displacements( positions=positions, cell=cell, pbc=self.pbc, - pairs=mapping, - shifts=shifts, + pairs=(mapping[0], mapping[1]), + shifts=shifts_idx, ) else: From b93a41656e50fbf6feb65ecd3660a5fcf8e2c56c Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 10 Dec 2025 10:57:06 -0800 Subject: [PATCH 07/11] fix nl for graphpes --- torch_sim/models/graphpes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 4ce37120..e380cab4 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -75,7 +75,12 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra # model's cutoff value. To ensure no strange edge effects whereby # edges that are exactly `cutoff` long are included/excluded, # we bump cutoff + 1e-5 up slightly - nl, shifts = torchsim_nl(R, cell, state.pbc, cutoff + 1e-5) + + # Create system_idx for this single system (all atoms belong to system 0) + system_idx_single = torch.zeros(R.shape[0], dtype=torch.long, device=R.device) + nl, _system_mapping, shifts = torchsim_nl( + R, cell, state.pbc, cutoff + 1e-5, system_idx_single + ) atomic_graph = AtomicGraph( Z=Z.long(), From 08a9f676714c035c356c375e24038b8cb1e31da3 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 10 Dec 2025 20:12:12 -0800 Subject: [PATCH 08/11] add cuda batch nl --- pyproject.toml | 3 +- tests/test_neighbors.py | 115 ++++++++++++++++++--- torch_sim/models/mace.py | 4 +- torch_sim/models/nequip_framework.py | 4 +- torch_sim/models/sevennet.py | 4 +- torch_sim/neighbors/__init__.py | 60 ++++++----- torch_sim/neighbors/alchemiops.py | 145 +++++++++++++++++++++++++++ 7 files changed, 285 insertions(+), 50 deletions(-) create mode 100644 torch_sim/neighbors/alchemiops.py diff --git a/pyproject.toml b/pyproject.toml index 14bf4d0a..8529a316 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,12 +28,11 @@ classifiers = [ requires-python = ">=3.12" dependencies = [ "h5py>=3.12.1", + "nvalchemi-toolkit-ops", "numpy>=1.26,<3", "tables>=3.10.2", "torch>=2", "tqdm>=4.67", - "vesin-torch>=0.4.0, <0.5.0", - "vesin>=0.4.0, <0.5.0", ] [project.optional-dependencies] diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index fc20cecf..ce0adcb1 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -315,7 +315,8 @@ def test_neighbor_list_implementations( neighbors.torch_nl_linked_cell, neighbors.standard_nl, ] - + ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else []), + + ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else []) + + ([neighbors.alchemiops_nl_n2] if neighbors.ALCHEMIOPS_AVAILABLE else []), ) def test_torch_nl_implementations( *, @@ -463,8 +464,10 @@ def test_vesin_nl_edge_cases() -> None: def test_torchsim_nl_availability() -> None: - """Test that VESIN_AVAILABLE flag is correctly set.""" + """Test that availability flags are correctly set.""" assert isinstance(neighbors.VESIN_AVAILABLE, bool) + assert isinstance(neighbors.ALCHEMIOPS_AVAILABLE, bool) + if neighbors.VESIN_AVAILABLE: assert neighbors.VesinNeighborList is not None assert neighbors.VesinNeighborListTorch is not None @@ -472,6 +475,79 @@ def test_torchsim_nl_availability() -> None: assert neighbors.VesinNeighborList is None assert neighbors.VesinNeighborListTorch is None + if neighbors.ALCHEMIOPS_AVAILABLE: + assert neighbors.alchemiops_nl_n2 is not None + else: + assert neighbors.alchemiops_nl_n2 is None + + +@pytest.mark.skipif( + not neighbors.ALCHEMIOPS_AVAILABLE or not torch.cuda.is_available(), + reason="Alchemiops requires CUDA", +) +def test_alchemiops_nl_edge_cases() -> None: + """Test edge cases for alchemiops_nl_n2 implementation (CUDA only).""" + device = torch.device("cuda") + dtype = torch.float32 + + pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=device, dtype=dtype) + cell = torch.eye(3, device=device, dtype=dtype) * 2.0 + cutoff = torch.tensor(1.5, device=device, dtype=dtype) + system_idx = torch.zeros(2, dtype=torch.long, device=device) + + # Test alchemiops_nl_n2 + for pbc in ( + torch.tensor([True, True, True], device=device), + torch.tensor([False, False, False], device=device), + ): + mapping, sys_map, _shifts = neighbors.alchemiops_nl_n2( + positions=pos, + cell=cell, + pbc=pbc, + cutoff=cutoff, + system_idx=system_idx, + ) + assert len(mapping[0]) > 0 # Should find neighbors + assert (sys_map == 0).all() # All in system 0 + + +def test_fallback_when_alchemiops_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that torch-sim works correctly without alchemiops (CI compatibility).""" + # This test ensures CI works even if alchemiops fails to import + # torchsim_nl should fall back to pure PyTorch implementations + device = torch.device("cpu") + dtype = torch.float32 + + positions = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], + device=device, + dtype=dtype, + ) + cell = torch.eye(3, device=device, dtype=dtype) * 3.0 + pbc = torch.tensor([False, False, False], device=device) + cutoff = torch.tensor(1.5, device=device, dtype=dtype) + system_idx = torch.zeros(4, dtype=torch.long, device=device) + + # Use monkeypatch to temporarily disable alchemiops + monkeypatch.setattr(neighbors, "ALCHEMIOPS_AVAILABLE", False) + + # torchsim_nl should always work (with fallback) + mapping, sys_map, _shifts = neighbors.torchsim_nl( + positions, cell, pbc, cutoff, system_idx + ) + + # Should find neighbors + assert mapping.shape[0] == 2 + assert mapping.shape[1] > 0 + assert sys_map.shape[0] == mapping.shape[1] + + # default_batched_nl should always be available + assert neighbors.default_batched_nl is not None + mapping2, _sys_map2, _shifts2 = neighbors.default_batched_nl( + positions, cell, pbc, cutoff, system_idx + ) + assert mapping2.shape[1] > 0 + def test_torchsim_nl_consistency() -> None: """Test that torchsim_nl produces consistent results.""" @@ -512,6 +588,9 @@ def test_torchsim_nl_consistency() -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available for testing") def test_torchsim_nl_gpu() -> None: """Test that torchsim_nl works on GPU (CUDA/ROCm).""" + torch.cuda.empty_cache() + torch.cuda.synchronize() + device = torch.device("cuda") dtype = torch.float32 @@ -525,7 +604,7 @@ def test_torchsim_nl_gpu() -> None: cutoff = torch.tensor(1.5, device=device, dtype=dtype) system_idx = torch.zeros(2, dtype=torch.long, device=device) - # Should work on GPU regardless of vesin availability + # Should work on GPU regardless of implementation availability mapping, sys_map, shifts = neighbors.torchsim_nl( positions, cell, pbc, cutoff, system_idx ) @@ -535,15 +614,18 @@ def test_torchsim_nl_gpu() -> None: assert sys_map.device.type == "cuda" assert mapping.shape[0] == 2 # (2, num_neighbors) + # Cleanup + torch.cuda.empty_cache() + def test_torchsim_nl_fallback_when_vesin_unavailable( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Test that torchsim_nl falls back to standard_nl when vesin is unavailable. + """Test that torchsim_nl falls back to torch_nl when alchemiops/vesin unavailable. - This test simulates the case where vesin is not installed by monkeypatching - VESIN_AVAILABLE to False. This ensures the fallback logic is tested even in - CI environments where vesin is actually installed. + This test simulates the case where alchemiops and vesin are not available by + monkeypatching their availability flags to False. This ensures the fallback logic + is tested even in environments where they are actually installed. """ device = torch.device("cpu") dtype = torch.float32 @@ -559,24 +641,25 @@ def test_torchsim_nl_fallback_when_vesin_unavailable( cutoff = torch.tensor(1.5, device=device, dtype=dtype) system_idx = torch.zeros(4, dtype=torch.long, device=device) - # Monkeypatch VESIN_AVAILABLE to False to simulate vesin not being installed + # Monkeypatch both availability flags to False monkeypatch.setattr(neighbors, "VESIN_AVAILABLE", False) + monkeypatch.setattr(neighbors, "ALCHEMIOPS_AVAILABLE", False) - # Call torchsim_nl with mocked unavailable vesin + # Call torchsim_nl with mocked unavailable implementations mapping_torchsim, sys_map_ts, shifts_torchsim = neighbors.torchsim_nl( positions, cell, pbc, cutoff, system_idx ) - # Call standard_nl directly for comparison - mapping_standard, sys_map_std, shifts_standard = neighbors.standard_nl( + # Call torch_nl_linked_cell directly for comparison + mapping_expected, sys_map_exp, shifts_expected = neighbors.torch_nl_linked_cell( positions, cell, pbc, cutoff, system_idx ) - # When VESIN_AVAILABLE is False, torchsim_nl should use standard_nl + # When both are unavailable, torchsim_nl should use torch_nl_linked_cell # and produce identical results - torch.testing.assert_close(mapping_torchsim, mapping_standard) - torch.testing.assert_close(shifts_torchsim, shifts_standard) - torch.testing.assert_close(sys_map_ts, sys_map_std) + torch.testing.assert_close(mapping_torchsim, mapping_expected) + torch.testing.assert_close(shifts_torchsim, shifts_expected) + torch.testing.assert_close(sys_map_ts, sys_map_exp) def test_strict_nl_edge_cases() -> None: @@ -637,6 +720,8 @@ def test_neighbor_lists_time_and_memory() -> None: cast("Callable[..., Any]", neighbors.vesin_nl), ] ) + if neighbors.ALCHEMIOPS_AVAILABLE and DEVICE.type == "cuda": + nl_implementations.append(neighbors.alchemiops_nl_n2) for nl_fn in nl_implementations: # Get initial memory usage diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 22aa1c0f..401243cd 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -28,7 +28,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torch_nl_linked_cell +from torch_sim.neighbors import torchsim_nl from torch_sim.typing import StateDict @@ -107,7 +107,7 @@ def __init__( *, device: torch.device | None = None, dtype: torch.dtype = torch.float64, - neighbor_list_fn: Callable = torch_nl_linked_cell, + neighbor_list_fn: Callable = torchsim_nl, compute_forces: bool = True, compute_stress: bool = True, enable_cueq: bool = False, diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py index 8096f019..6266016a 100644 --- a/torch_sim/models/nequip_framework.py +++ b/torch_sim/models/nequip_framework.py @@ -25,7 +25,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torch_nl_linked_cell +from torch_sim.neighbors import torchsim_nl from torch_sim.typing import StateDict @@ -165,7 +165,7 @@ def __init__( r_max: float, type_names: list[str], device: torch.device | None = None, - neighbor_list_fn: Callable = torch_nl_linked_cell, + neighbor_list_fn: Callable = torchsim_nl, atomic_numbers: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, ) -> None: diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 23a12b8b..52d5393b 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -12,7 +12,7 @@ import torch_sim as ts from torch_sim.elastic import voigt_6_to_full_3x3_stress from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torch_nl_linked_cell +from torch_sim.neighbors import torchsim_nl if TYPE_CHECKING: @@ -85,7 +85,7 @@ def __init__( model: AtomGraphSequential | str | Path, *, # force remaining arguments to be keyword-only modal: str | None = None, - neighbor_list_fn: Callable = torch_nl_linked_cell, + neighbor_list_fn: Callable = torchsim_nl, device: torch.device | str | None = None, dtype: torch.dtype = torch.float32, ) -> None: diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index d41aff84..3e685c8f 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -21,6 +21,13 @@ from torch_sim.neighbors.torch_nl import strict_nl, torch_nl_linked_cell, torch_nl_n2 +# Try to import Alchemiops implementations (NVIDIA CUDA acceleration) +try: + from torch_sim.neighbors.alchemiops import ALCHEMIOPS_AVAILABLE, alchemiops_nl_n2 +except ImportError: + ALCHEMIOPS_AVAILABLE = False + alchemiops_nl_n2 = None # type: ignore[assignment] + # Try to import Vesin implementations try: from torch_sim.neighbors.vesin import ( @@ -37,16 +44,16 @@ vesin_nl = None # type: ignore[assignment] vesin_nl_ts = None # type: ignore[assignment] -# Set default neighbor list based on what's available -if VESIN_AVAILABLE: - default_nl = vesin_nl - default_nl_ts = vesin_nl_ts +# Set default neighbor list based on what's available (priority order) +if ALCHEMIOPS_AVAILABLE: + # Alchemiops is fastest on NVIDIA GPUs + default_batched_nl = alchemiops_nl_n2 +elif VESIN_AVAILABLE: + # Vesin is good fallback + default_batched_nl = vesin_nl_ts # Still use native for batched else: - default_nl = standard_nl - default_nl_ts = standard_nl - -# For batched calculations, always use linked cell as default -default_batched_nl = torch_nl_linked_cell + # Pure PyTorch fallback + default_batched_nl = torch_nl_linked_cell def torchsim_nl( @@ -57,11 +64,13 @@ def torchsim_nl( system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute neighbor lists with automatic fallback for AMD ROCm compatibility. + """Compute neighbor lists with automatic selection of best available implementation. - This function automatically selects the best available neighbor list implementation. - When vesin is available, it uses vesin_nl_ts for optimal performance. When vesin - is not available (e.g., on AMD ROCm systems), it falls back to standard_nl. + This function automatically selects the best available neighbor list implementation + based on what's installed. Priority order: + 1. Alchemiops (NVIDIA CUDA optimized) if available + 2. Vesin (fast, cross-platform) if available + 3. torch_nl_linked_cell (pure PyTorch fallback) Args: positions: Atomic positions tensor [n_atoms, 3] @@ -79,39 +88,36 @@ def torchsim_nl( - shifts_idx: Tensor [num_neighbors, 3] - periodic shift indices Notes: - - Automatically uses vesin_nl_ts when vesin is available - - Falls back to standard_nl when vesin is unavailable (AMD ROCm) + - Automatically uses best available implementation + - Priority: Alchemiops > Vesin > torch_nl_linked_cell - Fallback works on NVIDIA CUDA, AMD ROCm, and CPU - For non-periodic systems (pbc=False), shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs """ - if not VESIN_AVAILABLE: - return torch_nl_linked_cell( + if ALCHEMIOPS_AVAILABLE: + return alchemiops_nl_n2( positions, cell, pbc, cutoff, system_idx, self_interaction ) - - return vesin_nl_ts(positions, cell, pbc, cutoff, system_idx, self_interaction) + if VESIN_AVAILABLE: + return vesin_nl_ts(positions, cell, pbc, cutoff, system_idx, self_interaction) + return torch_nl_linked_cell( + positions, cell, pbc, cutoff, system_idx, self_interaction + ) __all__ = [ - # Availability + "ALCHEMIOPS_AVAILABLE", "VESIN_AVAILABLE", "VesinNeighborList", "VesinNeighborListTorch", - # Defaults + "alchemiops_nl_n2", "default_batched_nl", - "default_nl", - "default_nl_ts", - # Core implementations "primitive_neighbor_list", "standard_nl", - # Utilities "strict_nl", - # Batched implementations "torch_nl_linked_cell", "torch_nl_n2", "torchsim_nl", - # Vesin implementations "vesin_nl", "vesin_nl_ts", ] diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py new file mode 100644 index 00000000..10c34faf --- /dev/null +++ b/torch_sim/neighbors/alchemiops.py @@ -0,0 +1,145 @@ +"""Alchemiops-based neighbor list implementations. + +This module provides high-performance CUDA-accelerated neighbor list calculations +using the nvalchemiops library. Uses the naive N^2 implementation for reliability. + +nvalchemiops is available at: https://github.com/NVIDIA/nvalchemiops +""" + +import warnings + +import torch + + +try: + from nvalchemiops.neighborlist import ( + batch_naive_neighbor_list as _batch_naive_neighbor_list, + ) + from nvalchemiops.neighborlist.neighbor_utils import estimate_max_neighbors + + ALCHEMIOPS_AVAILABLE = True +except ImportError: + ALCHEMIOPS_AVAILABLE = False + _batch_naive_neighbor_list = None # type: ignore[assignment] + estimate_max_neighbors = None # type: ignore[assignment] + +__all__ = [ + "ALCHEMIOPS_AVAILABLE", + "alchemiops_nl_n2", +] + + +def _prepare_inputs(cell: torch.Tensor, pbc: torch.Tensor, system_idx: torch.Tensor): # noqa: ANN202 + """Prepare cell and PBC tensors for alchemiops functions.""" + n_systems = system_idx.max().item() + 1 + + # Reshape cell: [3*n_systems, 3] or [3, 3] -> [n_systems, 3, 3] + if cell.ndim == 2: + cell_reshaped = ( + cell.unsqueeze(0) if cell.shape[0] == 3 else cell.reshape(n_systems, 3, 3) + ) + else: + cell_reshaped = cell + + # Reshape PBC: various formats -> [n_systems, 3] + if pbc.ndim == 1: + pbc_reshaped = ( + pbc.unsqueeze(0).expand(n_systems, -1) + if pbc.shape[0] == 3 + else pbc.reshape(n_systems, 3) + ) + else: + pbc_reshaped = pbc + + return cell_reshaped, pbc_reshaped.to(torch.bool), n_systems + + +if ALCHEMIOPS_AVAILABLE: + + def alchemiops_nl_n2( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute neighbor lists using Alchemiops naive N^2 algorithm. + + Args: + positions: Atomic positions tensor [n_atoms, 3] + cell: Unit cell vectors [3*n_systems, 3] (row vector convention) + pbc: Boolean tensor [3] or [n_systems, 3] for PBC + cutoff: Maximum distance (scalar tensor) + system_idx: Tensor [n_atoms] indicating system assignment + self_interaction: If True, include self-pairs + + Returns: + (mapping, system_mapping, shifts_idx) + """ + r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff + cell_reshaped, pbc_bool, _n_systems = _prepare_inputs(cell, pbc, system_idx) + + # Call alchemiops neighbor list + res = _batch_naive_neighbor_list( + positions=positions, + cutoff=r_max, + batch_idx=system_idx.to(torch.int32), + cell=cell_reshaped, + pbc=pbc_bool, + return_neighbor_list=True, + ) + + # Parse results: (neighbor_list, neighbor_ptr[, neighbor_list_shifts]) + if len(res) == 3: # type: ignore[arg-type] + mapping, _, shifts_idx = res # type: ignore[misc] + else: + mapping, _ = res # type: ignore[misc] + shifts_idx = torch.zeros( + (mapping.shape[1], 3), dtype=positions.dtype, device=positions.device + ) + + # Convert dtypes + mapping = mapping.to(dtype=torch.long) + # Convert shifts_idx to floating point to match cell dtype (for einsum) + shifts_idx = shifts_idx.to(dtype=cell.dtype) + + # Create system_mapping + system_mapping = system_idx[mapping[0]] + + # Alchemiops does NOT include self-interactions by default + # Add them only if requested + if self_interaction: + n_atoms = positions.shape[0] + self_pairs = torch.arange(n_atoms, device=positions.device, dtype=torch.long) + self_mapping = torch.stack([self_pairs, self_pairs], dim=0) + # Self-shifts should match shifts_idx dtype + self_shifts = torch.zeros( + (n_atoms, 3), dtype=cell.dtype, device=positions.device + ) + + mapping = torch.cat([mapping, self_mapping], dim=1) + shifts_idx = torch.cat([shifts_idx, self_shifts], dim=0) + system_mapping = torch.cat([system_mapping, system_idx], dim=0) + + # Check if neighbors exceed estimate + max_neighbors_estimate = estimate_max_neighbors(r_max) + if mapping.shape[1] > max_neighbors_estimate: + warnings.warn( + f"Number of neighbors {mapping.shape[1]} exceeds estimated max " + f"{max_neighbors_estimate} for cutoff {r_max}.", + UserWarning, + stacklevel=2, + ) + return mapping, system_mapping, shifts_idx + +else: + # Provide stub function that raises informative error + def alchemiops_nl_n2( # type: ignore[misc] + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stub function when nvalchemiops is not available.""" + raise ImportError( + "nvalchemiops is not installed. Install it with: pip install nvalchemiops" + ) From 58ae2549a2a3a9cfdf1dad39799acb13e37ddd2c Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 10 Dec 2025 20:22:01 -0800 Subject: [PATCH 09/11] make sure the pbc and cell tensor is contiguous --- torch_sim/neighbors/alchemiops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py index 10c34faf..d0d510db 100644 --- a/torch_sim/neighbors/alchemiops.py +++ b/torch_sim/neighbors/alchemiops.py @@ -30,7 +30,10 @@ def _prepare_inputs(cell: torch.Tensor, pbc: torch.Tensor, system_idx: torch.Tensor): # noqa: ANN202 - """Prepare cell and PBC tensors for alchemiops functions.""" + """Prepare cell and PBC tensors for alchemiops functions. + + Ensures tensors are properly shaped and contiguous for Warp backend. + """ n_systems = system_idx.max().item() + 1 # Reshape cell: [3*n_systems, 3] or [3, 3] -> [n_systems, 3, 3] @@ -51,6 +54,10 @@ def _prepare_inputs(cell: torch.Tensor, pbc: torch.Tensor, system_idx: torch.Ten else: pbc_reshaped = pbc + # Ensure tensors are contiguous for Warp backend + cell_reshaped = cell_reshaped.contiguous() + pbc_reshaped = pbc_reshaped.contiguous() + return cell_reshaped, pbc_reshaped.to(torch.bool), n_systems From 44efe9e77707f2072c2874076824315f43d6185a Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Thu, 11 Dec 2025 14:18:13 -0800 Subject: [PATCH 10/11] Minor cleanup --- torch_sim/models/mace.py | 7 +-- torch_sim/models/nequip_framework.py | 9 +--- torch_sim/models/sevennet.py | 7 +-- torch_sim/neighbors/__init__.py | 39 ++++++++++++++-- torch_sim/neighbors/alchemiops.py | 64 +++++--------------------- torch_sim/neighbors/standard.py | 45 ++++-------------- torch_sim/neighbors/torch_nl.py | 36 ++++++++++++--- torch_sim/neighbors/vesin.py | 69 ++++++++++------------------ 8 files changed, 113 insertions(+), 163 deletions(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 401243cd..65a616f4 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -299,15 +299,10 @@ def forward( # noqa: C901 self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) # Batched neighbor list using linked-cell algorithm - pbc_tensor = ( - sim_state.pbc.repeat(self.n_systems, 1) - if sim_state.pbc.ndim == 1 - else sim_state.pbc - ) edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( sim_state.positions, sim_state.row_vector_cell, - pbc_tensor, + sim_state.pbc, self.r_max, sim_state.system_idx, ) diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py index 6266016a..775abf16 100644 --- a/torch_sim/models/nequip_framework.py +++ b/torch_sim/models/nequip_framework.py @@ -304,16 +304,11 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # ): self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) - # Batched neighbor list using linked-cell algorithm (row-vector cell convention) - pbc_tensor = ( - sim_state.pbc.repeat(self.n_systems, 1) - if sim_state.pbc.ndim == 1 - else sim_state.pbc - ) + # Batched neighbor list using linked-cell algorithm edge_index, _mapping_system, unit_shifts = self.neighbor_list_fn( sim_state.positions, sim_state.row_vector_cell, - pbc_tensor, + sim_state.pbc, self.r_max, sim_state.system_idx, ) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 52d5393b..81b76820 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -193,15 +193,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # Batched neighbor list using linked-cell algorithm with row-vector cell n_systems = sim_state.system_idx.max().item() + 1 - pbc_tensor = ( - sim_state.pbc.repeat(n_systems, 1) - if sim_state.pbc.ndim == 1 - else sim_state.pbc - ) edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( sim_state.positions, sim_state.row_vector_cell, - pbc_tensor, + sim_state.pbc, self.cutoff, sim_state.system_idx, ) diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index 3e685c8f..2d4b8335 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -21,6 +21,36 @@ from torch_sim.neighbors.torch_nl import strict_nl, torch_nl_linked_cell, torch_nl_n2 +def _normalize_inputs( + cell: torch.Tensor, pbc: torch.Tensor, n_systems: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Normalize cell and PBC tensors to standard batch format. + + Handles multiple input formats: + - cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3] + - pbc: [3], [n_systems, 3], or [n_systems*3] + + Returns: + (cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3]) + Both tensors are guaranteed to be contiguous. + """ + # Normalize cell + if cell.ndim == 2: + if cell.shape[0] == 3: + cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() + else: + cell = cell.reshape(n_systems, 3, 3) + + # Normalize PBC + if pbc.ndim == 1: + if pbc.shape[0] == 3: + pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() + else: + pbc = pbc.reshape(n_systems, 3) + + return cell, pbc + + # Try to import Alchemiops implementations (NVIDIA CUDA acceleration) try: from torch_sim.neighbors.alchemiops import ALCHEMIOPS_AVAILABLE, alchemiops_nl_n2 @@ -74,11 +104,10 @@ def torchsim_nl( Args: positions: Atomic positions tensor [n_atoms, 3] - cell: Unit cell vectors [3*n_systems, 3] (row vector convention) - pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] + pbc: Boolean tensor [n_systems, 3] or [3] cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors - system_idx: Tensor [n_atoms] indicating which system each atom belongs to. - For single system, use torch.zeros(n_atoms, dtype=torch.long) + system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False Returns: @@ -93,6 +122,8 @@ def torchsim_nl( - Fallback works on NVIDIA CUDA, AMD ROCm, and CPU - For non-periodic systems (pbc=False), shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs + - Accepts both single-system [3, 3] or batched [n_systems, 3, 3] cell formats + - Accepts both single [3] or batched [n_systems, 3] PBC formats """ if ALCHEMIOPS_AVAILABLE: return alchemiops_nl_n2( diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py index d0d510db..59d1ba8d 100644 --- a/torch_sim/neighbors/alchemiops.py +++ b/torch_sim/neighbors/alchemiops.py @@ -6,21 +6,17 @@ nvalchemiops is available at: https://github.com/NVIDIA/nvalchemiops """ -import warnings - import torch try: - from nvalchemiops.neighborlist import ( - batch_naive_neighbor_list as _batch_naive_neighbor_list, - ) + from nvalchemiops.neighborlist import batch_naive_neighbor_list from nvalchemiops.neighborlist.neighbor_utils import estimate_max_neighbors ALCHEMIOPS_AVAILABLE = True except ImportError: ALCHEMIOPS_AVAILABLE = False - _batch_naive_neighbor_list = None # type: ignore[assignment] + batch_naive_neighbor_list = None # type: ignore[assignment] estimate_max_neighbors = None # type: ignore[assignment] __all__ = [ @@ -29,38 +25,6 @@ ] -def _prepare_inputs(cell: torch.Tensor, pbc: torch.Tensor, system_idx: torch.Tensor): # noqa: ANN202 - """Prepare cell and PBC tensors for alchemiops functions. - - Ensures tensors are properly shaped and contiguous for Warp backend. - """ - n_systems = system_idx.max().item() + 1 - - # Reshape cell: [3*n_systems, 3] or [3, 3] -> [n_systems, 3, 3] - if cell.ndim == 2: - cell_reshaped = ( - cell.unsqueeze(0) if cell.shape[0] == 3 else cell.reshape(n_systems, 3, 3) - ) - else: - cell_reshaped = cell - - # Reshape PBC: various formats -> [n_systems, 3] - if pbc.ndim == 1: - pbc_reshaped = ( - pbc.unsqueeze(0).expand(n_systems, -1) - if pbc.shape[0] == 3 - else pbc.reshape(n_systems, 3) - ) - else: - pbc_reshaped = pbc - - # Ensure tensors are contiguous for Warp backend - cell_reshaped = cell_reshaped.contiguous() - pbc_reshaped = pbc_reshaped.contiguous() - - return cell_reshaped, pbc_reshaped.to(torch.bool), n_systems - - if ALCHEMIOPS_AVAILABLE: def alchemiops_nl_n2( @@ -75,8 +39,8 @@ def alchemiops_nl_n2( Args: positions: Atomic positions tensor [n_atoms, 3] - cell: Unit cell vectors [3*n_systems, 3] (row vector convention) - pbc: Boolean tensor [3] or [n_systems, 3] for PBC + cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] + pbc: Boolean tensor [n_systems, 3] or [3] cutoff: Maximum distance (scalar tensor) system_idx: Tensor [n_atoms] indicating system assignment self_interaction: If True, include self-pairs @@ -84,16 +48,19 @@ def alchemiops_nl_n2( Returns: (mapping, system_mapping, shifts_idx) """ + from torch_sim.neighbors import _normalize_inputs + r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff - cell_reshaped, pbc_bool, _n_systems = _prepare_inputs(cell, pbc, system_idx) + n_systems = system_idx.max().item() + 1 + cell, pbc = _normalize_inputs(cell, pbc, n_systems) # Call alchemiops neighbor list - res = _batch_naive_neighbor_list( + res = batch_naive_neighbor_list( positions=positions, cutoff=r_max, batch_idx=system_idx.to(torch.int32), - cell=cell_reshaped, - pbc=pbc_bool, + cell=cell, + pbc=pbc.to(torch.bool), return_neighbor_list=True, ) @@ -129,15 +96,6 @@ def alchemiops_nl_n2( shifts_idx = torch.cat([shifts_idx, self_shifts], dim=0) system_mapping = torch.cat([system_mapping, system_idx], dim=0) - # Check if neighbors exceed estimate - max_neighbors_estimate = estimate_max_neighbors(r_max) - if mapping.shape[1] > max_neighbors_estimate: - warnings.warn( - f"Number of neighbors {mapping.shape[1]} exceeds estimated max " - f"{max_neighbors_estimate} for cutoff {r_max}.", - UserWarning, - stacklevel=2, - ) return mapping, system_mapping, shifts_idx else: diff --git a/torch_sim/neighbors/standard.py b/torch_sim/neighbors/standard.py index 574d0387..826ba752 100644 --- a/torch_sim/neighbors/standard.py +++ b/torch_sim/neighbors/standard.py @@ -418,24 +418,12 @@ def standard_nl( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute neighbor lists using primitive neighbor list algorithm. - This function provides a standardized interface for computing neighbor lists - in atomic systems. It handles both single systems and batched (multi-system) - calculations with a unified API. - - Key Features: - - Unified API for single and batched systems - - Supports both periodic and non-periodic boundary conditions - - Returns neighbor indices, system mapping, and shift vectors - - Fully compatible with PyTorch's automatic differentiation - - Consistent with torch_nl_n2 and torch_nl_linked_cell API - Args: - positions: Atomic positions tensor of shape [n_atoms, 3] - cell: Unit cell vectors [3*n_systems, 3] (row vector convention) - pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + positions: Atomic positions tensor [n_atoms, 3] + cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] + pbc: Boolean tensor [n_systems, 3] or [3] cutoff: Maximum distance for considering atoms as neighbors - system_idx: Tensor [n_atoms] indicating which system each atom belongs to. - For single system, use torch.zeros(n_atoms, dtype=torch.long) + system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False Returns: @@ -466,23 +454,12 @@ def standard_nl( References: - https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 """ + from torch_sim.neighbors import _normalize_inputs + device = positions.device dtype = positions.dtype n_systems = system_idx.max().item() + 1 - - # Handle PBC: reshape if needed - if pbc.ndim == 1: - if pbc.shape[0] == 3: - # Single PBC for all systems - pbc_per_system = pbc.unsqueeze(0).expand(n_systems, -1) - elif pbc.shape[0] == n_systems * 3: - # Flat concatenated PBC, reshape to [n_systems, 3] - pbc_per_system = pbc.reshape(n_systems, 3) - else: - raise ValueError(f"Unexpected PBC shape: {pbc.shape}") - else: - # Already [n_systems, 3] - pbc_per_system = pbc + cell, pbc = _normalize_inputs(cell, pbc, n_systems) # Process each system's neighbor list separately edge_indices = [] @@ -498,13 +475,11 @@ def standard_nl( continue # Get the cell for this system - cell_sys = cell if cell.shape[0] == 3 else cell[sys_idx * 3 : (sys_idx + 1) * 3] + cell_sys = cell[sys_idx] # Calculate neighbor list for this system using primitive_neighbor_list - # Ensure tensors are contiguous for TorchScript - positions_sys = positions[system_mask].contiguous() - cell_sys = cell_sys.contiguous() - pbc_sys = pbc_per_system[sys_idx].contiguous() + positions_sys = positions[system_mask] + pbc_sys = pbc[sys_idx] i, j, S = primitive_neighbor_list( quantities="ijS", diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py index 8e111c39..c8a6562e 100644 --- a/torch_sim/neighbors/torch_nl.py +++ b/torch_sim/neighbors/torch_nl.py @@ -17,6 +17,28 @@ from torch_sim import transforms +@torch.jit.script +def _normalize_inputs_jit( + cell: torch.Tensor, pbc: torch.Tensor, n_systems: int +) -> tuple[torch.Tensor, torch.Tensor]: + """JIT-compatible input normalization for torch_nl functions.""" + # Normalize cell + if cell.ndim == 2: + if cell.shape[0] == 3: + cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() + else: + cell = cell.reshape(n_systems, 3, 3) + + # Normalize PBC + if pbc.ndim == 1: + if pbc.shape[0] == 3: + pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() + else: + pbc = pbc.reshape(n_systems, 3) + + return cell, pbc + + def strict_nl( cutoff: float, positions: torch.Tensor, @@ -106,11 +128,9 @@ def torch_nl_n2( Args: positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. - cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to - the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc (torch.Tensor [n_structure, 3] bool): + cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. + pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. - Partial PBC are not supported yet. cutoff (torch.Tensor): The cutoff radius used for the neighbor search. system_idx (torch.Tensor [n_atom,] torch.long): @@ -146,6 +166,8 @@ def torch_nl_n2( References: - https://github.com/felixmusil/torch_nl """ + n_systems = system_idx.max().item() + 1 + cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( positions, cell, pbc, cutoff.item(), n_atoms, self_interaction @@ -178,11 +200,9 @@ def torch_nl_linked_cell( positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. - cell (torch.Tensor [3*n_systems, 3]): Unit cell vectors according to - the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. - Partial PBC are not supported yet. cutoff (torch.Tensor): The cutoff radius used for the neighbor search. system_idx (torch.Tensor [n_atom,] torch.long): @@ -219,6 +239,8 @@ def torch_nl_linked_cell( References: - https://github.com/felixmusil/torch_nl """ + n_systems = system_idx.max().item() + 1 + cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( positions, cell, pbc, cutoff.item(), n_atoms, self_interaction diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index d45a41a1..f301a0ca 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -41,16 +41,14 @@ def vesin_nl_ts( # noqa: PLR0915 """Compute neighbor lists using TorchScript-compatible Vesin. This function provides a TorchScript-compatible interface to the Vesin - neighbor list algorithm using VesinNeighborListTorch. It handles both - single systems and batched (multi-system) calculations with a unified API. + neighbor list algorithm using VesinNeighborListTorch. Args: positions: Atomic positions tensor [n_atoms, 3] - cell: Unit cell vectors [3*n_systems, 3] (row vector convention) - pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] + pbc: Boolean tensor [n_systems, 3] or [3] cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors - system_idx: Tensor [n_atoms] indicating which system each atom belongs to. - For single system, use torch.zeros(n_atoms, dtype=torch.long) + system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False Returns: @@ -81,19 +79,17 @@ def vesin_nl_ts( # noqa: PLR0915 dtype = positions.dtype n_systems = system_idx.max().item() + 1 - # Handle PBC: reshape if needed + # Normalize inputs to batch format + if cell.ndim == 2: + if cell.shape[0] == 3: + cell = cell.unsqueeze(0).expand(n_systems, -1, -1) + else: + cell = cell.reshape(n_systems, 3, 3) if pbc.ndim == 1: if pbc.shape[0] == 3: - # Single PBC for all systems - pbc_per_system = pbc.unsqueeze(0).expand(n_systems, -1) - elif pbc.shape[0] == n_systems * 3: - # Flat concatenated PBC, reshape to [n_systems, 3] - pbc_per_system = pbc.reshape(n_systems, 3) + pbc = pbc.unsqueeze(0).expand(n_systems, -1) else: - raise ValueError(f"Unexpected PBC shape: {pbc.shape}") - else: - # Already [n_systems, 3] - pbc_per_system = pbc + pbc = pbc.reshape(n_systems, 3) # Process each system's neighbor list separately edge_indices = [] @@ -112,14 +108,12 @@ def vesin_nl_ts( # noqa: PLR0915 neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) # Get the cell for this system - cell_sys = ( - cell if cell.shape[0] == 3 else cell[sys_idx * 3 : (sys_idx + 1) * 3] - ) + cell_sys = cell[sys_idx] # Convert tensors to CPU and float64 properly positions_cpu = positions[system_mask].cpu().to(dtype=torch.float64) cell_cpu = cell_sys.cpu().to(dtype=torch.float64) - periodic_cpu = pbc_per_system[sys_idx].to(dtype=torch.bool).cpu() + periodic_cpu = pbc[sys_idx].to(dtype=torch.bool).cpu() # Only works on CPU and requires float64 i, j, S = neighbor_list_fn.compute( @@ -168,7 +162,7 @@ def vesin_nl_ts( # noqa: PLR0915 return mapping, system_mapping, shifts_idx - def vesin_nl( # noqa: PLR0915 + def vesin_nl( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, @@ -179,16 +173,14 @@ def vesin_nl( # noqa: PLR0915 """Compute neighbor lists using the standard Vesin implementation. This function provides an interface to the standard Vesin neighbor list - algorithm using VesinNeighborList. It handles both single systems and - batched (multi-system) calculations with a unified API. + algorithm using VesinNeighborList. Args: positions: Atomic positions tensor [n_atoms, 3] - cell: Unit cell vectors [3*n_systems, 3] (row vector convention) - pbc: Boolean tensor [3] or [n_systems, 3] for periodic boundary conditions + cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] + pbc: Boolean tensor [n_systems, 3] or [3] cutoff: Maximum distance for considering atoms as neighbors - system_idx: Tensor [n_atoms] indicating which system each atom belongs to. - For single system, use torch.zeros(n_atoms, dtype=torch.long) + system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False Returns: @@ -215,23 +207,12 @@ def vesin_nl( # noqa: PLR0915 References: - https://github.com/Luthaf/vesin """ + from torch_sim.neighbors import _normalize_inputs + device = positions.device dtype = positions.dtype n_systems = system_idx.max().item() + 1 - - # Handle PBC: reshape if needed - if pbc.ndim == 1: - if pbc.shape[0] == 3: - # Single PBC for all systems - pbc_per_system = pbc.unsqueeze(0).expand(n_systems, -1) - elif pbc.shape[0] == n_systems * 3: - # Flat concatenated PBC, reshape to [n_systems, 3] - pbc_per_system = pbc.reshape(n_systems, 3) - else: - raise ValueError(f"Unexpected PBC shape: {pbc.shape}") - else: - # Already [n_systems, 3] - pbc_per_system = pbc + cell, pbc = _normalize_inputs(cell, pbc, n_systems) # Process each system's neighbor list separately edge_indices = [] @@ -247,9 +228,7 @@ def vesin_nl( # noqa: PLR0915 continue # Get the cell for this system - cell_sys = ( - cell if cell.shape[0] == 3 else cell[sys_idx * 3 : (sys_idx + 1) * 3] - ) + cell_sys = cell[sys_idx] # Calculate neighbor list for this system neighbor_list_fn = VesinNeighborList( @@ -259,7 +238,7 @@ def vesin_nl( # noqa: PLR0915 # Convert tensors to CPU and float64 without gradients positions_cpu = positions[system_mask].detach().cpu().to(dtype=torch.float64) cell_cpu = cell_sys.detach().cpu().to(dtype=torch.float64) - periodic_cpu = pbc_per_system[sys_idx].detach().to(dtype=torch.bool).cpu() + periodic_cpu = pbc[sys_idx].detach().to(dtype=torch.bool).cpu() # Only works on CPU and returns numpy arrays i, j, S = neighbor_list_fn.compute( From 4ed6d3d38b811069f1daac4798b3962bc7ddabde Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Thu, 11 Dec 2025 14:29:14 -0800 Subject: [PATCH 11/11] fix cell shape --- torch_sim/neighbors/__init__.py | 8 ++++++-- torch_sim/neighbors/torch_nl.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index 2d4b8335..915beaee 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -39,14 +39,18 @@ def _normalize_inputs( if cell.shape[0] == 3: cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() else: - cell = cell.reshape(n_systems, 3, 3) + cell = cell.reshape(n_systems, 3, 3).contiguous() + else: + cell = cell.contiguous() # Normalize PBC if pbc.ndim == 1: if pbc.shape[0] == 3: pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() else: - pbc = pbc.reshape(n_systems, 3) + pbc = pbc.reshape(n_systems, 3).contiguous() + else: + pbc = pbc.contiguous() return cell, pbc diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py index c8a6562e..e8ddec4f 100644 --- a/torch_sim/neighbors/torch_nl.py +++ b/torch_sim/neighbors/torch_nl.py @@ -27,14 +27,18 @@ def _normalize_inputs_jit( if cell.shape[0] == 3: cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() else: - cell = cell.reshape(n_systems, 3, 3) + cell = cell.reshape(n_systems, 3, 3).contiguous() + else: + cell = cell.contiguous() # Normalize PBC if pbc.ndim == 1: if pbc.shape[0] == 3: pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() else: - pbc = pbc.reshape(n_systems, 3) + pbc = pbc.reshape(n_systems, 3).contiguous() + else: + pbc = pbc.contiguous() return cell, pbc