From b4f0396afa27b8abcee3f29ebce8c5f6eb79d6ae Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Tue, 4 Nov 2025 16:19:50 -0800 Subject: [PATCH 1/5] v1 for pbc wrap only the pbc axis for each atom support init simstate with list of bools better test that uses itertools update vesin version fix pbc check in integrators fix more pytests more fixing more fixes VesinNeighborListTorch is slow fix metatensor for pbc more fixes wip fix pbc trajectory trajectory runs bump metatomic version or vesin will complain fix errors wip fix more trajectory issues make trajectory pass fix pbc in diff_sim fix ase atoms to state conversion for pbc fix neighbors.py assert the pymatgen pbc is valid proper conversion between pbc for atoms, and pymatgen do not pass in pbc to phononpy rm warning and add doc to github make consistent with prev implementation fix io tests lint minor simplification simplify test more simplification changes fix some tests satisfy prek but ruff check errors :/ we'll fix this later more cleanup rename renamove more diffs more changes wip fix rm init in md add type checking and fix pbc type in dataclass def cleanup state loosen test for nl --- .../7_Others/7.3_Batched_neighbor_list.py | 4 +- examples/tutorials/diff_sim.py | 17 +++++-- pyproject.toml | 6 +-- tests/models/test_soft_sphere.py | 4 +- tests/test_io.py | 10 ++-- tests/test_neighbors.py | 45 +++++++++++----- tests/test_trajectory.py | 10 ++-- tests/test_transforms.py | 4 +- torch_sim/integrators/md.py | 7 ++- torch_sim/integrators/npt.py | 8 +-- torch_sim/io.py | 26 +++++----- torch_sim/models/metatomic.py | 6 +-- torch_sim/models/nequip_framework.py | 6 +-- torch_sim/models/orb.py | 11 ++-- torch_sim/models/soft_sphere.py | 11 ++-- torch_sim/neighbors.py | 37 ++++++++------ torch_sim/state.py | 21 ++++++-- torch_sim/trajectory.py | 12 +++-- torch_sim/transforms.py | 51 +++++++++++++++---- 19 files changed, 184 insertions(+), 112 deletions(-) diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index 724231e3..1393b908 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -18,8 +18,8 @@ cutoff = torch.tensor(4.0, dtype=pos.dtype) self_interaction = False -# Fix: Ensure pbc has the correct shape [n_systems, 3] -pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) +# Ensure pbc has the correct shape [n_systems, 3] +pbc_tensor = torch.tensor(pbc).repeat(state.n_systems, 1) mapping, mapping_system, shifts_idx = torch_nl_linked_cell( pos, cell, pbc_tensor, cutoff, system_idx, self_interaction diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index 3835c926..461dcec4 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -117,7 +117,7 @@ class BaseState: positions: torch.Tensor cell: torch.Tensor - pbc: bool + pbc: torch.Tensor species: torch.Tensor @@ -133,14 +133,18 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments - pbc: bool = True, + pbc: torch.Tensor | bool = True, cutoff: float | None = None, ) -> None: """Initialize a soft sphere model for multi-component systems.""" super().__init__() self.device = device or torch.device("cpu") self.dtype = dtype - self.pbc = pbc + self.pbc = ( + pbc + if isinstance(pbc, torch.Tensor) + else torch.tensor([pbc] * 3, dtype=torch.bool) + ) # Store species list and determine number of unique species self.species = species @@ -382,7 +386,12 @@ def simulation( # Minimize to the nearest minimum. init_fn, apply_fn = gradient_descent(model, lr=0.1) - custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + custom_state = BaseState( + positions=R, + cell=cell, + species=species, + pbc=torch.tensor([True] * 3, dtype=torch.bool), + ) state = init_fn(custom_state) for _ in range(simulation_steps): state = apply_fn(state) diff --git a/pyproject.toml b/pyproject.toml index 850f35b6..996757d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,8 @@ dependencies = [ "tables>=3.10.2", "torch>=2", "tqdm>=4.67", - "vesin-torch>=0.3.7, <0.4.0", - "vesin>=0.3.7, <0.4.0", + "vesin-torch>=0.4.0, <0.5.0", + "vesin>=0.4.0, <0.5.0", ] [project.optional-dependencies] @@ -48,7 +48,7 @@ test = [ io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] mace = ["mace-torch>=0.3.12"] mattersim = ["mattersim>=0.1.2"] -metatomic = ["metatomic-torch>=0.1.1", "metatrain[pet]>=2025.7"] +metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.7"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index 99c2ed64..a07c8282 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -350,8 +350,8 @@ def test_multispecies_cutoff_default() -> None: @pytest.mark.parametrize( ("flag_name", "flag_value"), [ - ("pbc", True), - ("pbc", False), + ("pbc", torch.tensor([True, True, True])), + ("pbc", torch.tensor([False, False, False])), ("compute_forces", False), ("compute_stress", True), ("per_atom_energies", True), diff --git a/tests/test_io.py b/tests/test_io.py index 26e4a34c..a2c25ab4 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -46,7 +46,7 @@ def test_multiple_structures_to_state(si_structure: Structure) -> None: assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (16,) assert state.system_idx.shape == (16,) assert torch.all( @@ -64,7 +64,7 @@ def test_single_atoms_to_state(si_atoms: Atoms) -> None: assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (8,) assert state.system_idx.shape == (8,) assert torch.all(state.system_idx == 0) @@ -79,7 +79,7 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (16,) assert state.system_idx.shape == (16,) assert torch.all( @@ -171,7 +171,7 @@ def test_multiple_phonopy_to_state(si_phonopy_atoms: Any) -> None: assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (16,) assert state.system_idx.shape == (16,) assert torch.all( @@ -246,7 +246,7 @@ def test_state_round_trip( assert torch.allclose(sim_state.cell, round_trip_state.cell) assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) assert torch.all(sim_state.system_idx == round_trip_state.system_idx) - assert sim_state.pbc == round_trip_state.pbc + assert torch.equal(sim_state.pbc, round_trip_state.pbc) if isinstance(intermediate_format[0], Atoms): # TODO: masses round trip for pmg and phonopy masses is not exact diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 8a1e0f7a..72cdbf13 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -170,7 +170,7 @@ def test_primitive_neighbor_list( pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE) row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) - pbc = atoms.pbc.any() + pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE) # Get the neighbor list using the appropriate function (jitted or non-jitted) # Note: No self-interaction @@ -178,7 +178,7 @@ def test_primitive_neighbor_list( quantities="ijS", positions=pos, cell=row_vector_cell, - pbc=(pbc, pbc, pbc), + pbc=pbc, cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), device=DEVICE, dtype=DTYPE, @@ -258,7 +258,7 @@ def test_neighbor_list_implementations( # Convert to torch tensors pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE) row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) - pbc = atoms.pbc.any() + pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE) # Get the neighbor list from the implementation being tested mapping, shifts = nl_implementation( @@ -371,7 +371,7 @@ def test_primitive_neighbor_list_edge_cases() -> None: quantities="ijS", positions=pos, cell=cell, - pbc=pbc, + pbc=torch.tensor(pbc, device=DEVICE, dtype=DTYPE), cutoff=cutoff, device=DEVICE, dtype=DTYPE, @@ -383,7 +383,7 @@ def test_primitive_neighbor_list_edge_cases() -> None: quantities="ijS", positions=pos, cell=cell, - pbc=(True, True, True), + pbc=torch.Tensor([True, True, True]), cutoff=cutoff, device=DEVICE, dtype=DTYPE, @@ -404,7 +404,7 @@ def test_standard_nl_edge_cases() -> None: mapping, _shifts = neighbors.standard_nl( positions=pos, cell=cell, - pbc=pbc, + pbc=torch.tensor([pbc] * 3, device=DEVICE, dtype=DTYPE), cutoff=cutoff, ) assert len(mapping[0]) > 0 # Should find neighbors @@ -413,7 +413,7 @@ def test_standard_nl_edge_cases() -> None: mapping, _shifts = neighbors.standard_nl( positions=pos, cell=cell, - pbc=True, + pbc=torch.Tensor([True, True, True]), cutoff=cutoff, sort_id=True, ) @@ -430,13 +430,20 @@ def test_vesin_nl_edge_cases() -> None: # Test both implementations for nl_fn in (neighbors.vesin_nl, neighbors.vesin_nl_ts): # Test different PBC combinations - for pbc in (True, False): + for pbc in ( + torch.Tensor([True, True, True]), + torch.Tensor([False, False, False]), + ): mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=pbc, cutoff=cutoff) assert len(mapping[0]) > 0 # Should find neighbors # Test sort_id mapping, _shifts = nl_fn( - positions=pos, cell=cell, pbc=True, cutoff=cutoff, sort_id=True + positions=pos, + cell=cell, + pbc=torch.Tensor([True, True, True]), + cutoff=cutoff, + sort_id=True, ) # Check if indices are sorted assert torch.all(mapping[0][1:] >= mapping[0][:-1]) @@ -446,7 +453,10 @@ def test_vesin_nl_edge_cases() -> None: pos_f32 = pos.to(dtype=torch.float32) cell_f32 = cell.to(dtype=torch.float32) mapping, _shifts = nl_fn( - positions=pos_f32, cell=cell_f32, pbc=True, cutoff=cutoff + positions=pos_f32, + cell=cell_f32, + pbc=torch.Tensor([True, True, True]), + cutoff=cutoff, ) assert len(mapping[0]) > 0 # Should find neighbors @@ -528,7 +538,12 @@ def test_neighbor_lists_time_and_memory() -> None: self_interaction=False, ) else: - _mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) + _mapping, _shifts = nl_fn( + positions=pos, + cell=cell, + pbc=torch.Tensor([True, True, True]), + cutoff=cutoff, + ) end_time = time.perf_counter() execution_time = end_time - start_time @@ -551,4 +566,10 @@ def test_neighbor_lists_time_and_memory() -> None: assert cpu_memory_used < 5e8, ( f"{fn_name} used too much CPU memory: {cpu_memory_used / 1e6:.2f}MB" ) - assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s" + if nl_fn == neighbors.standard_nl: + # this function is just quite slow. So we have a higher tolerance. + # I tried removing "@jit.script" and it was still slow. + # (This nl function is just slow) + assert execution_time < 3, f"{fn_name} took too long: {execution_time}s" + else: + assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s" diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 215877c6..1ff750f1 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -93,7 +93,7 @@ def test_write_state_single( assert trajectory.get_array("positions").shape == (1, 10, 3) assert trajectory.get_array("atomic_numbers").shape == (1, 10) assert trajectory.get_array("cell").shape == (1, 3, 3) - assert trajectory.get_array("pbc").shape == (1,) + assert trajectory.get_array("pbc").shape == (1, 3) def test_write_state_multiple( @@ -106,7 +106,7 @@ def test_write_state_multiple( assert trajectory.get_array("positions").shape == (2, 10, 3) assert trajectory.get_array("atomic_numbers").shape == (1, 10) assert trajectory.get_array("cell").shape == (2, 3, 3) - assert trajectory.get_array("pbc").shape == (1,) + assert trajectory.get_array("pbc").shape == (1, 3) def test_optional_arrays(trajectory: TorchSimTrajectory, random_state: MDState) -> None: @@ -439,7 +439,7 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() ) - assert atoms.pbc.all() == random_state.pbc + np.testing.assert_array_equal(atoms.pbc, random_state.pbc.detach().cpu().numpy()) def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> None: @@ -478,7 +478,7 @@ def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> Non np.testing.assert_allclose(state.positions, random_state.positions) np.testing.assert_allclose(state.cell, random_state.cell) np.testing.assert_allclose(state.atomic_numbers, random_state.atomic_numbers) - assert state.pbc == random_state.pbc + assert torch.equal(state.pbc, random_state.pbc) def test_write_ase_trajectory( @@ -509,7 +509,7 @@ def test_write_ase_trajectory( np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() ) - assert atoms.pbc.all() == random_state.pbc + np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy()[0]) # Clean up ase_traj.close() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 067abad8..23b9dfd5 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,4 +1,6 @@ # ruff: noqa: PT011 +import itertools + import numpy as np import pytest import torch @@ -195,7 +197,7 @@ def test_pbc_wrap_general_batch() -> None: @pytest.mark.parametrize( - "pbc", [[True, True, True], [True, True, False], [False, False, False], True, False] + "pbc", [*list(itertools.product([False, True], repeat=3)), True, False] ) @pytest.mark.parametrize("pretty_translation", [True, False]) def test_wrap_positions_matches_ase( diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 92de1ccc..b03fb7a3 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -175,10 +175,13 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_positions = state.positions + state.velocities * dt - if state.pbc: + if state.pbc.any(): # Split positions and cells by system new_positions = transforms.pbc_wrap_batched( - new_positions, state.cell, state.system_idx + new_positions, + state.cell, + state.system_idx, + state.pbc, ) state.positions = new_positions diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4984baf4..db6e2b15 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -367,9 +367,9 @@ def _npt_langevin_position_step( state.positions = c_1 + c_2.unsqueeze(-1) * c_3 # Apply periodic boundary conditions if needed - if state.pbc: + if state.pbc.any(): state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx + state.positions, state.cell, state.system_idx, state.pbc ) return state @@ -1030,9 +1030,9 @@ def _npt_nose_hoover_exp_iL1( # noqa: N802 new_positions = state.positions + new_positions # Apply periodic boundary conditions if needed - if state.pbc: + if state.pbc.any(): return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.system_idx + new_positions, state.current_cell, state.system_idx, pbc=state.pbc ) return new_positions diff --git a/torch_sim/io.py b/torch_sim/io.py index a2081c20..796541b5 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -117,7 +117,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: # Create structure for this system struct = Structure( - lattice=Lattice(system_cell), + lattice=Lattice(system_cell, pbc=(state.pbc.tolist())), species=species, coords=system_positions, coords_are_cartesian=True, @@ -164,8 +164,11 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: # Convert atomic numbers to chemical symbols symbols = [chemical_symbols[z] for z in system_numbers] + + # Note: pbc is not used in the init since it's always assumed to be true + # https://github.com/phonopy/phonopy/blob/develop/phonopy/structure/atoms.py#L140 phonopy_atoms = PhonopyAtoms( - symbols=symbols, positions=system_positions, cell=system_cell, pbc=state.pbc + symbols=symbols, positions=system_positions, cell=system_cell ) phonopy_atoms_list.append(phonopy_atoms) @@ -225,14 +228,14 @@ def atoms_to_state( ) # Verify consistent pbc - if not all(all(at.pbc) == all(atoms_list[0].pbc) for at in atoms_list): + if not all(np.all(np.equal(at.pbc, atoms_list[0].pbc)) for at in atoms_list[1:]): raise ValueError("All systems must have the same periodic boundary conditions") return ts.SimState( positions=positions, masses=masses, cell=cell, - pbc=all(atoms_list[0].pbc), + pbc=atoms_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, ) @@ -294,11 +297,15 @@ def structures_to_state( torch.arange(len(struct_list), device=device), atoms_per_system ) + # Verify consistent pbc + if not all(tuple(s.pbc) == tuple(struct_list[0].pbc) for s in struct_list[1:]): + raise ValueError("All systems must have the same periodic boundary conditions") + return ts.SimState( positions=positions, masses=masses, cell=cell, - pbc=True, # Structures are always periodic + pbc=struct_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, ) @@ -362,18 +369,11 @@ def phonopy_to_state( torch.arange(len(phonopy_atoms_list), device=device), atoms_per_system ) - """ - NOTE: PhonopyAtoms does not have pbc attribute for Supercells assume True - Verify consistent pbc - if not all(all(at.pbc) == all(phonopy_atoms_lst[0].pbc) for at in phonopy_atoms_lst): - raise ValueError("All systems must have the same periodic boundary conditions") - """ - return ts.SimState( positions=positions, masses=masses, cell=cell, - pbc=True, + pbc=True, # phononpy always assumes periodic boundary conditions https://github.com/phonopy/phonopy/blob/develop/phonopy/structure/atoms.py#L140 atomic_numbers=atomic_numbers, system_idx=system_idx, ) diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 1e97cdda..2110f8c6 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -173,7 +173,6 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # atomic_nums = sim_state.atomic_numbers cell = sim_state.row_vector_cell positions = sim_state.positions - pbc = sim_state.pbc # Check dtype (metatomic models require a specific input dtype) if positions.dtype != self._dtype: @@ -196,9 +195,6 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # system_mask = sim_state.system_idx == sys_idx system_positions = positions[system_mask] system_cell = cell[sys_idx] - system_pbc = torch.tensor( - [pbc, pbc, pbc], device=self._device, dtype=torch.bool - ) system_atomic_numbers = atomic_nums[system_mask] # Create a System object for this system @@ -217,7 +213,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # positions=system_positions, types=system_atomic_numbers, cell=system_cell, - pbc=system_pbc, + pbc=sim_state.pbc, ) ) diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py index 916281a5..89f1ce56 100644 --- a/torch_sim/models/nequip_framework.py +++ b/torch_sim/models/nequip_framework.py @@ -345,11 +345,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # "cell": sim_state.row_vector_cell, "batch": sim_state.system_idx, "num_atoms": sim_state.system_idx.bincount(), - "pbc": torch.tensor( - [sim_state.pbc, sim_state.pbc, sim_state.pbc], - dtype=torch.bool, - device=self.device, - ), + "pbc": sim_state.pbc, "atomic_numbers": sim_state.atomic_numbers, "atom_types": atomic_types, "edge_index": edge_index, diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 7fda6c1f..16c469f1 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -147,9 +147,6 @@ def state_to_atom_graphs( # noqa: PLR0915 ) # Orb uses row vector cell convention for neighbor list atomic_numbers = state.atomic_numbers.long() - # Create PBC tensor based on state.pbc - pbc = torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool) - max_num_neighbors = max_num_neighbors or system_config.max_num_neighbors # Get atom embeddings for the model @@ -168,7 +165,7 @@ def state_to_atom_graphs( # noqa: PLR0915 atomic_numbers_embedding = atom_type_embedding.to(output_dtype) # Wrap positions into the central cell if needed - if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)): + if wrap and (torch.any(row_vector_cell != 0) and torch.any(state.pbc)): positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node) n_systems = state.system_idx.max().item() + 1 @@ -190,13 +187,13 @@ def state_to_atom_graphs( # noqa: PLR0915 atomic_numbers_per_system = atomic_numbers[system_mask] atomic_numbers_embedding_per_system = atomic_numbers_embedding[system_mask] cell_per_system = row_vector_cell[sys_idx] - pbc_per_system = pbc + pbc = state.pbc # Compute edges directly for this system edges, vectors, unit_shifts = feat_util.compute_pbc_radius_graph( positions=positions_per_system, cell=cell_per_system, - pbc=pbc_per_system, + pbc=pbc, radius=system_config.radius, max_number_neighbors=max_num_neighbors, edge_method=edge_method, @@ -230,7 +227,7 @@ def state_to_atom_graphs( # noqa: PLR0915 graph_feats = { "cell": cell_per_system, - "pbc": pbc_per_system, + "pbc": pbc, "lattice": lattice_per_system.to(device=positions_per_system.device), } diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index a9a1b97a..c641eebb 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -509,7 +509,7 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.float64, *, # Force keyword-only arguments - pbc: bool = True, + pbc: torch.Tensor | bool = True, compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, @@ -538,8 +538,9 @@ def __init__( device (torch.device | None): Device for computations. If None, uses CPU. Defaults to None. dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. - pbc (bool): Whether to use periodic boundary conditions. Defaults to - True. + pbc (torch.Tensor | bool): Boolean tensor of shape (3,) indicating periodic + boundary conditions in each axis. If None, all axes are assumed to be + periodic. Defaults to True. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress tensor. Defaults to False. per_atom_energies (bool): Whether to compute per-atom energy decomposition. @@ -597,7 +598,7 @@ def __init__( super().__init__() self._device = device or torch.device("cpu") self._dtype = dtype - self.pbc = pbc + self.pbc = torch.tensor([pbc] * 3) if isinstance(pbc, bool) else pbc self._compute_forces = compute_forces self._compute_stress = compute_stress self.per_atom_energies = per_atom_energies @@ -714,7 +715,7 @@ def unbatched_forward( # noqa: PLR0915 cell=cell, pbc=self.pbc, cutoff=self.cutoff, - sorti=False, + sort_id=False, ) # Get displacements between neighbor pairs dr_vec, distances = transforms.get_pair_displacements( diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 491d72cb..9c4cb0cf 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -11,7 +11,7 @@ @torch.jit.script def primitive_neighbor_list( # noqa: C901, PLR0915 quantities: str, - pbc: tuple[bool, bool, bool], + pbc: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, cutoff: torch.Tensor, @@ -42,8 +42,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 between atom i and j). With the shift vector S, the distances D between atoms can be computed from: D = positions[j]-positions[i]+S.dot(cell) - pbc: 3-tuple indicating giving periodic boundaries in the three Cartesian - directions. + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. positions: Atomic positions. Anything that can be converted to an ndarray of @@ -411,7 +411,7 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 def standard_nl( positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 + pbc: torch.Tensor, cutoff: torch.Tensor, sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -437,7 +437,8 @@ def standard_nl( positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Whether to use periodic boundary conditions (applied to all directions) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cutoff: Maximum distance for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns @@ -462,7 +463,7 @@ def standard_nl( Notes: - The function uses primitive_neighbor_list internally but provides a simpler interface - - For non-periodic systems (pbc=False), shifts will be zero vectors + - For non-periodic systems, shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs for complete force computation - Memory usage scales with system size and number of neighbors per atom @@ -476,7 +477,7 @@ def standard_nl( quantities="ijS", positions=positions, cell=cell, - pbc=(pbc, pbc, pbc), + pbc=pbc, cutoff=cutoff, device=device, dtype=dtype, @@ -501,7 +502,7 @@ def standard_nl( def vesin_nl_ts( positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 + pbc: torch.Tensor, cutoff: torch.Tensor, sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -515,7 +516,8 @@ def vesin_nl_ts( positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Whether to use periodic boundary conditions (applied to all directions) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns @@ -533,7 +535,7 @@ def vesin_nl_ts( - Uses VesinNeighborListTorch for TorchScript compatibility - Requires CPU tensors in float64 precision internally - Returns tensors on the same device as input with original precision - - For non-periodic systems (pbc=False), shifts will be zero vectors + - For non-periodic systems, shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs References: @@ -547,12 +549,13 @@ def vesin_nl_ts( # Convert tensors to CPU and float64 properly positions_cpu = positions.cpu().to(dtype=torch.float64) cell_cpu = cell.cpu().to(dtype=torch.float64) + periodic_cpu = pbc.to(dtype=torch.bool).cpu() # Only works on CPU and requires float64 i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, - periodic=pbc, + periodic=periodic_cpu, quantities="ijS", ) @@ -571,7 +574,7 @@ def vesin_nl_ts( def vesin_nl( positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 + pbc: torch.Tensor, cutoff: float | torch.Tensor, sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -585,7 +588,8 @@ def vesin_nl( positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Whether to use periodic boundary conditions (applied to all directions) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns @@ -618,12 +622,13 @@ def vesin_nl( # Convert tensors to CPU and float64 without gradients positions_cpu = positions.detach().cpu().to(dtype=torch.float64) cell_cpu = cell.detach().cpu().to(dtype=torch.float64) + periodic_cpu = pbc.detach().to(dtype=torch.bool).cpu() # Only works on CPU and returns numpy arrays i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, - periodic=pbc, + periodic=periodic_cpu, quantities="ijS", ) i, j = ( @@ -778,9 +783,9 @@ def torch_nl_linked_cell( positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. - cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to + cell (torch.Tensor [3*num_systems, 3]): Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc (torch.Tensor [n_structure, 3] bool): + pbc (torch.Tensor [num_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. system_idx (torch.Tensor [n_atom,] torch.long): diff --git a/torch_sim/state.py b/torch_sim/state.py index ccaffb27..4340a0a5 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -47,7 +47,9 @@ class SimState: stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` used by ASE. - pbc (bool): Boolean indicating whether to use periodic boundary conditions + pbc (bool | list[bool] | torch.Tensor): indicates periodic boundary + conditions in each axis. If a boolean is provided, all axes are + assumed to have the same periodic boundary conditions. atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. @@ -80,7 +82,7 @@ class SimState: positions: torch.Tensor masses: torch.Tensor cell: torch.Tensor - pbc: bool # TODO: do all calculators support mixed pbc? + pbc: torch.Tensor | list[bool] | bool atomic_numbers: torch.Tensor system_idx: torch.Tensor | None = field(default=None) @@ -91,6 +93,11 @@ def system_idx(self) -> torch.Tensor: """A getter for system_idx that tells type checkers it's always defined.""" return self.system_idx + @property + def pbc(self) -> torch.Tensor: + """A getter for pbc that tells type checkers it's always defined.""" + return self.pbc + _atom_attributes: ClassVar[set[str]] = { "positions", "masses", @@ -102,9 +109,6 @@ def system_idx(self) -> torch.Tensor: def __post_init__(self) -> None: """Initialize the SimState and validate the arguments.""" - # Validate and process the state after initialization. - # data validation and fill system_idx - # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way devices = { attr: getattr(self, attr).device @@ -125,6 +129,13 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) + if isinstance(self.pbc, bool): + self.pbc = [self.pbc] * 3 + if not isinstance(self.pbc, torch.Tensor): + self.pbc = torch.tensor( + self.pbc, dtype=torch.bool, device=self.positions.device + ) + initial_system_idx = self.system_idx if initial_system_idx is None: self.system_idx = torch.zeros( diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 3220b24b..c4009283 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -736,6 +736,7 @@ def write_state( # noqa: C901 if len(sub_states) != len(steps): raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}") + # Initialize data dictionary with required arrays data = { "positions": torch.stack([s.positions for s in state]), @@ -776,7 +777,7 @@ def write_state( # noqa: C901 self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0) if "pbc" not in self.array_registry: - self.write_arrays({"pbc": np.array(state[0].pbc)}, 0) + self.write_arrays({"pbc": state[0].pbc}, 0) # Write all arrays to file self.write_arrays(data, steps) @@ -887,13 +888,11 @@ def get_atoms(self, frame: int = -1) -> "Atoms": arrays = self._get_state_arrays(frame) - pbc = arrays.get("pbc", True) - return Atoms( numbers=np.ascontiguousarray(arrays["atomic_numbers"]), positions=np.ascontiguousarray(arrays["positions"]), cell=np.ascontiguousarray(arrays["cell"])[0], - pbc=pbc, + pbc=np.ascontiguousarray(arrays["pbc"]), ) def get_state( @@ -921,11 +920,14 @@ def get_state( arrays = self._get_state_arrays(frame) # Create SimState with required attributes + pbc_tensor = torch.tensor( + arrays["pbc"], device=device, dtype=torch.bool + ).squeeze() return SimState( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=bool(arrays.get("pbc", True)), + pbc=pbc_tensor, atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 53570a66..cd735122 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -113,7 +113,9 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor: @deprecated("Use wrap_positions instead") def pbc_wrap_general( - positions: torch.Tensor, lattice_vectors: torch.Tensor + positions: torch.Tensor, + lattice_vectors: torch.Tensor, + pbc: torch.Tensor | bool = True, # noqa: FBT002 ) -> torch.Tensor: """Apply periodic boundary conditions using lattice vector transformation method. @@ -129,10 +131,16 @@ def pbc_wrap_general( containing particle positions in real space. lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing lattice vectors as columns (A matrix in the equations). + pbc (torch.Tensor | bool): Boolean tensor of shape (3,) or boolean indicating + whether periodic boundary conditions are applied in each dimension. + If a boolean is provided, all axes are assumed to have the same periodic + boundary conditions. Returns: torch.Tensor: Wrapped positions in real space with same shape as input positions. """ + if isinstance(pbc, bool): + pbc = torch.tensor([pbc] * 3) # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point( lattice_vectors @@ -149,14 +157,20 @@ def pbc_wrap_general( frac_coords = positions @ torch.linalg.inv(lattice_vectors).T # Wrap to reference cell [0,1) using modulo - wrapped_frac = frac_coords % 1.0 + wrapped_frac = frac_coords.clone() + wrapped_frac[:, pbc[0]] = frac_coords[:, pbc[0]] % 1.0 + wrapped_frac[:, pbc[1]] = frac_coords[:, pbc[1]] % 1.0 + wrapped_frac[:, pbc[2]] = frac_coords[:, pbc[2]] % 1.0 # Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row return wrapped_frac @ lattice_vectors.T def pbc_wrap_batched( - positions: torch.Tensor, cell: torch.Tensor, system_idx: torch.Tensor + positions: torch.Tensor, + cell: torch.Tensor, + system_idx: torch.Tensor, + pbc: torch.Tensor | bool = True, # noqa: FBT002 ) -> torch.Tensor: """Apply periodic boundary conditions to batched systems. @@ -171,10 +185,16 @@ def pbc_wrap_batched( lattice vectors as column vectors. system_idx (torch.Tensor): Tensor of shape (n_atoms,) containing system indices for each atom. + pbc (torch.Tensor | bool): Tensor of shape (3,) containing boolean values + indicating whether periodic boundary conditions are applied in each dimension. + Can also be a bool. Defaults to True. Returns: torch.Tensor: Wrapped positions in real space with same shape as input positions. """ + if isinstance(pbc, bool): + pbc = torch.tensor([pbc, pbc, pbc], dtype=torch.bool, device=positions.device) + # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point(cell): raise TypeError("Positions and lattice vectors must be floating point tensors.") @@ -202,7 +222,10 @@ def pbc_wrap_batched( frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2) # Wrap to reference cell [0,1) using modulo - wrapped_frac = frac_coords % 1.0 + wrapped_frac = frac_coords.clone() + wrapped_frac[:, pbc[0]] = frac_coords[:, pbc[0]] % 1.0 + wrapped_frac[:, pbc[1]] = frac_coords[:, pbc[1]] % 1.0 + wrapped_frac[:, pbc[2]] = frac_coords[:, pbc[2]] % 1.0 # Transform back to real space: r = A·f # Get the cell for each atom based on its system index @@ -216,19 +239,22 @@ def minimum_image_displacement( *, dr: torch.Tensor, cell: torch.Tensor | None = None, - pbc: bool = True, + pbc: torch.Tensor | bool = True, ) -> torch.Tensor: """Apply minimum image convention to displacement vectors. Args: dr (torch.Tensor): Displacement vectors [N, 3] or [N, N, 3]. cell (Optional[torch.Tensor]): Unit cell matrix [3, 3]. - pbc (bool): Whether to apply periodic boundary conditions. + pbc (Optional[torch.Tensor]): Boolean tensor of shape (3,) indicating + periodic boundary conditions in each dimension. Returns: torch.Tensor: Minimum image displacement vectors with same shape as input. """ - if cell is None or not pbc: + if isinstance(pbc, bool): + pbc = torch.tensor([pbc] * 3, dtype=torch.bool, device=dr.device) + if cell is None or not pbc.any(): return dr # Convert to fractional coordinates @@ -246,7 +272,7 @@ def get_pair_displacements( *, positions: torch.Tensor, cell: torch.Tensor | None = None, - pbc: bool = True, + pbc: torch.Tensor | bool = True, pairs: tuple[torch.Tensor, torch.Tensor] | None = None, shifts: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -255,7 +281,8 @@ def get_pair_displacements( Args: positions (torch.Tensor): Atomic positions [N, 3]. cell (Optional[torch.Tensor]): Unit cell matrix [3, 3]. - pbc (bool): Whether to apply periodic boundary conditions. + pbc (Optional[torch.Tensor]): Boolean tensor of shape (3,) indicating + periodic boundary conditions in each dimension. pairs (Optional[Tuple[torch.Tensor, torch.Tensor]]): (i, j) indices for specific pairs to compute. shifts (Optional[torch.Tensor]): Shift vectors for periodic images [n_pairs, 3]. @@ -265,13 +292,15 @@ def get_pair_displacements( - Displacement vectors [n_pairs, 3]. - Distances [n_pairs]. """ + if isinstance(pbc, bool): + pbc = torch.tensor([pbc] * 3, dtype=torch.bool, device=positions.device) if pairs is None: # Create full distance matrix ri = positions.unsqueeze(0) # [1, N, 3] rj = positions.unsqueeze(1) # [N, 1, 3] dr = rj - ri # [N, N, 3] - if cell is not None and pbc: + if cell is not None and pbc.any(): dr = minimum_image_displacement(dr=dr, cell=cell, pbc=pbc) # Calculate distances @@ -287,7 +316,7 @@ def get_pair_displacements( i, j = pairs dr = positions[j] - positions[i] # [n_pairs, 3] - if cell is not None and pbc: + if cell is not None and pbc.any(): if shifts is not None: # Apply provided shifts dr = dr + torch.einsum("ij,kj->ki", cell, shifts) From 48920667f01ed747fcef3b1f4f6825d202663b32 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 7 Nov 2025 17:38:17 -0800 Subject: [PATCH 2/5] Thomas' review --- tests/test_trajectory.py | 4 ++-- tests/test_transforms.py | 12 +++++++++++- torch_sim/io.py | 9 ++++++++- torch_sim/neighbors.py | 4 ++-- torch_sim/trajectory.py | 20 +++++++++++--------- torch_sim/transforms.py | 21 ++++----------------- 6 files changed, 38 insertions(+), 32 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 1ff750f1..9c224705 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -93,7 +93,7 @@ def test_write_state_single( assert trajectory.get_array("positions").shape == (1, 10, 3) assert trajectory.get_array("atomic_numbers").shape == (1, 10) assert trajectory.get_array("cell").shape == (1, 3, 3) - assert trajectory.get_array("pbc").shape == (1, 3) + assert trajectory.get_array("pbc").shape == (3,) def test_write_state_multiple( @@ -106,7 +106,7 @@ def test_write_state_multiple( assert trajectory.get_array("positions").shape == (2, 10, 3) assert trajectory.get_array("atomic_numbers").shape == (1, 10) assert trajectory.get_array("cell").shape == (2, 3, 3) - assert trajectory.get_array("pbc").shape == (1, 3) + assert trajectory.get_array("pbc").shape == (3,) def test_optional_arrays(trajectory: TorchSimTrajectory, random_state: MDState) -> None: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 23b9dfd5..16565b73 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -894,10 +894,20 @@ def test_get_fractional_coordinates_batched() -> None: True, [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]], ), + ( + [[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]], + torch.eye(3, dtype=DTYPE) * 2.0, + torch.tensor([True, False, True], dtype=torch.bool), + [[0.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 0.2]], + ), ], ) def test_minimum_image_displacement( - *, dr: list[list[float]], cell: torch.Tensor, pbc: bool, expected: list[list[float]] + *, + dr: list[list[float]], + cell: torch.Tensor, + pbc: bool | torch.Tensor, + expected: list[list[float]], ) -> None: """Test minimum_image_displacement with various inputs. diff --git a/torch_sim/io.py b/torch_sim/io.py index 796541b5..253b29ed 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -369,11 +369,18 @@ def phonopy_to_state( torch.arange(len(phonopy_atoms_list), device=device), atoms_per_system ) + """ + NOTE: PhonopyAtoms does not have pbc attribute for Supercells assume True + Verify consistent pbc + if not all(all(at.pbc) == all(phonopy_atoms_lst[0].pbc) for at in phonopy_atoms_lst): + raise ValueError("All systems must have the same periodic boundary conditions") + """ + return ts.SimState( positions=positions, masses=masses, cell=cell, - pbc=True, # phononpy always assumes periodic boundary conditions https://github.com/phonopy/phonopy/blob/develop/phonopy/structure/atoms.py#L140 + pbc=True, atomic_numbers=atomic_numbers, system_idx=system_idx, ) diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 9c4cb0cf..5562a38a 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -783,9 +783,9 @@ def torch_nl_linked_cell( positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. - cell (torch.Tensor [3*num_systems, 3]): Unit cell vectors according to + cell (torch.Tensor [3*n_systems, 3]): Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc (torch.Tensor [num_systems, 3] bool): + pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. system_idx (torch.Tensor [n_atom,] torch.long): diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index c4009283..40c2b232 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -637,10 +637,15 @@ def get_array( if name not in self.array_registry: raise ValueError(f"Array {name} not found in registry") - return self._file.root.data.__getitem__(name).read( + data = self._file.root.data.__getitem__(name).read( start=start, stop=stop, step=step ) + if name == "pbc": + return np.squeeze(data, axis=0) + + return data + def get_steps( self, name: str, @@ -823,11 +828,11 @@ def return_prop(self: Self, prop: str, frame: int) -> np.ndarray: start, stop = frame, frame + 1 else: # Static prop start, stop = 0, 1 - return self.get_array(prop, start=start, stop=stop)[0] + return self.get_array(prop, start=start, stop=stop) - arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0) - arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame) - arrays["masses"] = return_prop(self, "masses", frame) + arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)[0] + arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)[0] + arrays["masses"] = return_prop(self, "masses", frame)[0] arrays["pbc"] = return_prop(self, "pbc", frame) return arrays @@ -920,14 +925,11 @@ def get_state( arrays = self._get_state_arrays(frame) # Create SimState with required attributes - pbc_tensor = torch.tensor( - arrays["pbc"], device=device, dtype=torch.bool - ).squeeze() return SimState( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=pbc_tensor, + pbc=torch.tensor(arrays["pbc"], device=device, dtype=torch.bool), atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index cd735122..2ab4ab2e 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -113,9 +113,7 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor: @deprecated("Use wrap_positions instead") def pbc_wrap_general( - positions: torch.Tensor, - lattice_vectors: torch.Tensor, - pbc: torch.Tensor | bool = True, # noqa: FBT002 + positions: torch.Tensor, lattice_vectors: torch.Tensor ) -> torch.Tensor: """Apply periodic boundary conditions using lattice vector transformation method. @@ -131,16 +129,10 @@ def pbc_wrap_general( containing particle positions in real space. lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing lattice vectors as columns (A matrix in the equations). - pbc (torch.Tensor | bool): Boolean tensor of shape (3,) or boolean indicating - whether periodic boundary conditions are applied in each dimension. - If a boolean is provided, all axes are assumed to have the same periodic - boundary conditions. Returns: torch.Tensor: Wrapped positions in real space with same shape as input positions. """ - if isinstance(pbc, bool): - pbc = torch.tensor([pbc] * 3) # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point( lattice_vectors @@ -157,10 +149,7 @@ def pbc_wrap_general( frac_coords = positions @ torch.linalg.inv(lattice_vectors).T # Wrap to reference cell [0,1) using modulo - wrapped_frac = frac_coords.clone() - wrapped_frac[:, pbc[0]] = frac_coords[:, pbc[0]] % 1.0 - wrapped_frac[:, pbc[1]] = frac_coords[:, pbc[1]] % 1.0 - wrapped_frac[:, pbc[2]] = frac_coords[:, pbc[2]] % 1.0 + wrapped_frac = frac_coords % 1.0 # Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row return wrapped_frac @ lattice_vectors.T @@ -223,9 +212,7 @@ def pbc_wrap_batched( # Wrap to reference cell [0,1) using modulo wrapped_frac = frac_coords.clone() - wrapped_frac[:, pbc[0]] = frac_coords[:, pbc[0]] % 1.0 - wrapped_frac[:, pbc[1]] = frac_coords[:, pbc[1]] % 1.0 - wrapped_frac[:, pbc[2]] = frac_coords[:, pbc[2]] % 1.0 + wrapped_frac[:, pbc] = frac_coords[:, pbc] % 1.0 # Transform back to real space: r = A·f # Get the cell for each atom based on its system index @@ -262,7 +249,7 @@ def minimum_image_displacement( dr_frac = torch.einsum("ij,...j->...i", cell_inv, dr) # Apply minimum image convention - dr_frac -= torch.round(dr_frac) + dr_frac -= torch.where(pbc, torch.round(dr_frac), torch.zeros_like(dr_frac)) # Convert back to cartesian return torch.einsum("ij,...j->...i", cell, dr_frac) From 994b3ae4b7b70b0e0d6e17253cfc4e7013d5f0ac Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 9 Nov 2025 16:40:17 -0800 Subject: [PATCH 3/5] orion's review --- torch_sim/state.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 4340a0a5..c59c2efd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -109,14 +109,6 @@ def pbc(self) -> torch.Tensor: def __post_init__(self) -> None: """Initialize the SimState and validate the arguments.""" - # if devices aren't all the same, raise an error, in a clean way - devices = { - attr: getattr(self, attr).device - for attr in ("positions", "masses", "cell", "atomic_numbers") - } - if len(set(devices.values())) > 1: - raise ValueError("All tensors must be on the same device") - # Check that positions, masses and atomic numbers have compatible shapes shapes = [ getattr(self, attr).shape[0] @@ -132,9 +124,7 @@ def __post_init__(self) -> None: if isinstance(self.pbc, bool): self.pbc = [self.pbc] * 3 if not isinstance(self.pbc, torch.Tensor): - self.pbc = torch.tensor( - self.pbc, dtype=torch.bool, device=self.positions.device - ) + self.pbc = torch.tensor(self.pbc, dtype=torch.bool, device=self.device) initial_system_idx = self.system_idx if initial_system_idx is None: @@ -157,6 +147,21 @@ def __post_init__(self) -> None: f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" ) + # if devices aren't all the same, raise an error, in a clean way + devices = { + attr: getattr(self, attr).device + for attr in ( + "positions", + "masses", + "cell", + "atomic_numbers", + "pbc", + "system_idx", + ) + } + if len(set(devices.values())) > 1: + raise ValueError("All tensors must be on the same device") + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, From e5ceb14705c183339624be536c1280e09e854378 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Mon, 10 Nov 2025 20:21:21 -0800 Subject: [PATCH 4/5] solve pbc trajectory issues --- tests/test_trajectory.py | 5 +++-- torch_sim/trajectory.py | 33 ++++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 9c224705..7f7252cf 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -33,7 +33,7 @@ def random_state() -> MDState: cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), atomic_numbers=torch.ones(10, dtype=torch.int32), system_idx=torch.zeros(10, dtype=torch.int32), - pbc=True, + pbc=[True, True, False], ) @@ -473,6 +473,7 @@ def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> Non assert state.positions.dtype == expected_dtype assert state.cell.dtype == expected_dtype assert state.atomic_numbers.dtype == torch.int # Should always be int + assert state.pbc.dtype == torch.bool # Should always be bool # Test values (convert to CPU for comparison) np.testing.assert_allclose(state.positions, random_state.positions) @@ -509,7 +510,7 @@ def test_write_ase_trajectory( np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() ) - np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy()[0]) + np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy()) # Clean up ase_traj.close() diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 40c2b232..ab841a85 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -512,6 +512,20 @@ def write_arrays( self.flush() + def write_global_array(self, name: str, array: np.ndarray | torch.Tensor) -> None: + """Write a global array to the trajectory file. + + This function is used to write a global array to the trajectory file. + """ + if isinstance(array, torch.Tensor): + array = array.cpu().detach().numpy() + + steps = [0] + if name not in self.array_registry: + self._initialize_array(name, array) + self._validate_array(name, array, steps) + self._serialize_array(name, array, steps) + def _initialize_array(self, name: str, array: np.ndarray) -> None: """Initialize a single array and add it to the registry. @@ -637,15 +651,10 @@ def get_array( if name not in self.array_registry: raise ValueError(f"Array {name} not found in registry") - data = self._file.root.data.__getitem__(name).read( + return self._file.root.data.__getitem__(name).read( start=start, stop=stop, step=step ) - if name == "pbc": - return np.squeeze(data, axis=0) - - return data - def get_steps( self, name: str, @@ -782,7 +791,7 @@ def write_state( # noqa: C901 self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0) if "pbc" not in self.array_registry: - self.write_arrays({"pbc": state[0].pbc}, 0) + self.write_global_array("pbc", state[0].pbc) # Write all arrays to file self.write_arrays(data, steps) @@ -824,15 +833,17 @@ def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]: arrays["positions"] = self.get_array("positions", start=frame, stop=frame + 1)[0] def return_prop(self: Self, prop: str, frame: int) -> np.ndarray: + if prop == "pbc": + return self.get_array(prop, start=0, stop=3) if getattr(self._file.root.data, prop).shape[0] > 1: # Variable prop start, stop = frame, frame + 1 else: # Static prop start, stop = 0, 1 - return self.get_array(prop, start=start, stop=stop) + return self.get_array(prop, start=start, stop=stop)[0] - arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)[0] - arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)[0] - arrays["masses"] = return_prop(self, "masses", frame)[0] + arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0) + arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame) + arrays["masses"] = return_prop(self, "masses", frame) arrays["pbc"] = return_prop(self, "pbc", frame) return arrays From ab97b2b3ff787623b429faef2417f35dcb30d582 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 13 Nov 2025 19:59:32 -0800 Subject: [PATCH 5/5] try fix orb issues --- torch_sim/models/orb.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 16c469f1..dd742c73 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -82,11 +82,12 @@ def cell_to_cellpar( x = torch.dot(cell[j], cell[k]) / ll angle = 180.0 / torch.pi * torch.arccos(x) else: - angle = 90.0 + angle = torch.tensor(90.0, dtype=cell.dtype, device=cell.device) angles.append(angle) + angles_tensor = torch.stack(angles) if radians: - angles = [angle * torch.pi / 180 for angle in angles] - return torch.concat((lengths, torch.stack(angles))) + angles_tensor = angles_tensor * torch.pi / 180.0 + return torch.concat((lengths, angles_tensor)) def state_to_atom_graphs( # noqa: PLR0915