Skip to content

Commit 2850f50

Browse files
committed
Thomas' review
1 parent 5097cef commit 2850f50

File tree

6 files changed

+38
-32
lines changed

6 files changed

+38
-32
lines changed

tests/test_trajectory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_write_state_single(
9393
assert trajectory.get_array("positions").shape == (1, 10, 3)
9494
assert trajectory.get_array("atomic_numbers").shape == (1, 10)
9595
assert trajectory.get_array("cell").shape == (1, 3, 3)
96-
assert trajectory.get_array("pbc").shape == (1, 3)
96+
assert trajectory.get_array("pbc").shape == (3,)
9797

9898

9999
def test_write_state_multiple(
@@ -106,7 +106,7 @@ def test_write_state_multiple(
106106
assert trajectory.get_array("positions").shape == (2, 10, 3)
107107
assert trajectory.get_array("atomic_numbers").shape == (1, 10)
108108
assert trajectory.get_array("cell").shape == (2, 3, 3)
109-
assert trajectory.get_array("pbc").shape == (1, 3)
109+
assert trajectory.get_array("pbc").shape == (3,)
110110

111111

112112
def test_optional_arrays(trajectory: TorchSimTrajectory, random_state: MDState) -> None:

tests/test_transforms.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,10 +894,20 @@ def test_get_fractional_coordinates_batched() -> None:
894894
True,
895895
[[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]],
896896
),
897+
(
898+
[[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]],
899+
torch.eye(3, dtype=DTYPE) * 2.0,
900+
torch.tensor([True, False, True], dtype=torch.bool),
901+
[[0.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 0.2]],
902+
),
897903
],
898904
)
899905
def test_minimum_image_displacement(
900-
*, dr: list[list[float]], cell: torch.Tensor, pbc: bool, expected: list[list[float]]
906+
*,
907+
dr: list[list[float]],
908+
cell: torch.Tensor,
909+
pbc: bool | torch.Tensor,
910+
expected: list[list[float]],
901911
) -> None:
902912
"""Test minimum_image_displacement with various inputs.
903913

torch_sim/io.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,18 @@ def phonopy_to_state(
369369
torch.arange(len(phonopy_atoms_list), device=device), atoms_per_system
370370
)
371371

372+
"""
373+
NOTE: PhonopyAtoms does not have pbc attribute for Supercells assume True
374+
Verify consistent pbc
375+
if not all(all(at.pbc) == all(phonopy_atoms_lst[0].pbc) for at in phonopy_atoms_lst):
376+
raise ValueError("All systems must have the same periodic boundary conditions")
377+
"""
378+
372379
return ts.SimState(
373380
positions=positions,
374381
masses=masses,
375382
cell=cell,
376-
pbc=True, # phononpy always assumes periodic boundary conditions https://github.com/phonopy/phonopy/blob/develop/phonopy/structure/atoms.py#L140
383+
pbc=True,
377384
atomic_numbers=atomic_numbers,
378385
system_idx=system_idx,
379386
)

torch_sim/neighbors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,9 +783,9 @@ def torch_nl_linked_cell(
783783
positions (torch.Tensor [n_atom, 3]):
784784
A tensor containing the positions of atoms wrapped inside
785785
their respective unit cells.
786-
cell (torch.Tensor [3*num_systems, 3]): Unit cell vectors according to
786+
cell (torch.Tensor [3*n_systems, 3]): Unit cell vectors according to
787787
the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`.
788-
pbc (torch.Tensor [num_systems, 3] bool):
788+
pbc (torch.Tensor [n_systems, 3] bool):
789789
A tensor indicating the periodic boundary conditions to apply.
790790
Partial PBC are not supported yet.
791791
system_idx (torch.Tensor [n_atom,] torch.long):

torch_sim/trajectory.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -643,10 +643,15 @@ def get_array(
643643
if name not in self.array_registry:
644644
raise ValueError(f"Array {name} not found in registry")
645645

646-
return self._file.root.data.__getitem__(name).read(
646+
data = self._file.root.data.__getitem__(name).read(
647647
start=start, stop=stop, step=step
648648
)
649649

650+
if name == "pbc":
651+
return np.squeeze(data, axis=0)
652+
653+
return data
654+
650655
def get_steps(
651656
self,
652657
name: str,
@@ -829,11 +834,11 @@ def return_prop(self: Self, prop: str, frame: int) -> np.ndarray:
829834
start, stop = frame, frame + 1
830835
else: # Static prop
831836
start, stop = 0, 1
832-
return self.get_array(prop, start=start, stop=stop)[0]
837+
return self.get_array(prop, start=start, stop=stop)
833838

834-
arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)
835-
arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)
836-
arrays["masses"] = return_prop(self, "masses", frame)
839+
arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)[0]
840+
arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)[0]
841+
arrays["masses"] = return_prop(self, "masses", frame)[0]
837842
arrays["pbc"] = return_prop(self, "pbc", frame)
838843

839844
return arrays
@@ -926,14 +931,11 @@ def get_state(
926931
arrays = self._get_state_arrays(frame)
927932

928933
# Create SimState with required attributes
929-
pbc_tensor = torch.tensor(
930-
arrays["pbc"], device=device, dtype=torch.bool
931-
).squeeze()
932934
return SimState(
933935
positions=torch.tensor(arrays["positions"], device=device, dtype=dtype),
934936
masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype),
935937
cell=torch.tensor(arrays["cell"], device=device, dtype=dtype),
936-
pbc=pbc_tensor,
938+
pbc=torch.tensor(arrays["pbc"], device=device, dtype=torch.bool),
937939
atomic_numbers=torch.tensor(
938940
arrays["atomic_numbers"], device=device, dtype=torch.int
939941
),

torch_sim/transforms.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor:
113113

114114
@deprecated("Use wrap_positions instead")
115115
def pbc_wrap_general(
116-
positions: torch.Tensor,
117-
lattice_vectors: torch.Tensor,
118-
pbc: torch.Tensor | bool = True, # noqa: FBT002
116+
positions: torch.Tensor, lattice_vectors: torch.Tensor
119117
) -> torch.Tensor:
120118
"""Apply periodic boundary conditions using lattice
121119
vector transformation method.
@@ -131,16 +129,10 @@ def pbc_wrap_general(
131129
containing particle positions in real space.
132130
lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing
133131
lattice vectors as columns (A matrix in the equations).
134-
pbc (torch.Tensor | bool): Boolean tensor of shape (3,) or boolean indicating
135-
whether periodic boundary conditions are applied in each dimension.
136-
If a boolean is provided, all axes are assumed to have the same periodic
137-
boundary conditions.
138132
139133
Returns:
140134
torch.Tensor: Wrapped positions in real space with same shape as input positions.
141135
"""
142-
if isinstance(pbc, bool):
143-
pbc = torch.tensor([pbc] * 3)
144136
# Validate inputs
145137
if not torch.is_floating_point(positions) or not torch.is_floating_point(
146138
lattice_vectors
@@ -157,10 +149,7 @@ def pbc_wrap_general(
157149
frac_coords = positions @ torch.linalg.inv(lattice_vectors).T
158150

159151
# Wrap to reference cell [0,1) using modulo
160-
wrapped_frac = frac_coords.clone()
161-
wrapped_frac[:, pbc[0]] = frac_coords[:, pbc[0]] % 1.0
162-
wrapped_frac[:, pbc[1]] = frac_coords[:, pbc[1]] % 1.0
163-
wrapped_frac[:, pbc[2]] = frac_coords[:, pbc[2]] % 1.0
152+
wrapped_frac = frac_coords % 1.0
164153

165154
# Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row
166155
return wrapped_frac @ lattice_vectors.T
@@ -223,9 +212,7 @@ def pbc_wrap_batched(
223212

224213
# Wrap to reference cell [0,1) using modulo
225214
wrapped_frac = frac_coords.clone()
226-
wrapped_frac[:, pbc[0]] = frac_coords[:, pbc[0]] % 1.0
227-
wrapped_frac[:, pbc[1]] = frac_coords[:, pbc[1]] % 1.0
228-
wrapped_frac[:, pbc[2]] = frac_coords[:, pbc[2]] % 1.0
215+
wrapped_frac[:, pbc] = frac_coords[:, pbc] % 1.0
229216

230217
# Transform back to real space: r = A·f
231218
# Get the cell for each atom based on its system index
@@ -262,7 +249,7 @@ def minimum_image_displacement(
262249
dr_frac = torch.einsum("ij,...j->...i", cell_inv, dr)
263250

264251
# Apply minimum image convention
265-
dr_frac -= torch.round(dr_frac)
252+
dr_frac -= torch.where(pbc, torch.round(dr_frac), torch.zeros_like(dr_frac))
266253

267254
# Convert back to cartesian
268255
return torch.einsum("ij,...j->...i", cell, dr_frac)

0 commit comments

Comments
 (0)