Skip to content

Commit 33b6c2f

Browse files
committed
trajectory runs
1 parent d75633d commit 33b6c2f

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

tests/test_trajectory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,9 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non
439439
np.testing.assert_allclose(
440440
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
441441
)
442-
assert atoms.pbc.all() == random_state.pbc
442+
np.testing.assert_array_equal(
443+
atoms.pbc, random_state.pbc.detach().cpu().numpy()
444+
)
443445

444446

445447
def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> None:
@@ -509,7 +511,9 @@ def test_write_ase_trajectory(
509511
np.testing.assert_allclose(
510512
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
511513
)
512-
assert atoms.pbc.all() == random_state.pbc
514+
np.testing.assert_array_equal(
515+
atoms.pbc, random_state.pbc.detach().cpu().numpy()
516+
)
513517

514518
# Clean up
515519
ase_traj.close()

torch_sim/integrators/md.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ class MDState(SimState):
4848
SimState._system_attributes | {"energy"} # noqa: SLF001
4949
)
5050

51+
def __post_init__(self) -> None:
52+
"""Ensure SimState initialization logic runs for MDState."""
53+
super().__init__(
54+
positions=self.positions,
55+
masses=self.masses,
56+
cell=self.cell,
57+
pbc=self.pbc,
58+
atomic_numbers=self.atomic_numbers,
59+
system_idx=self.system_idx,
60+
)
61+
5162
@property
5263
def velocities(self) -> torch.Tensor:
5364
"""Velocities calculated from momenta and masses with shape

torch_sim/trajectory.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,10 @@ def write_state( # noqa: C901
742742

743743
if len(sub_states) != len(steps):
744744
raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}")
745+
746+
# Use the selected states for data serialization
747+
state = sub_states
748+
745749
# Initialize data dictionary with required arrays
746750
data = {
747751
"positions": torch.stack([s.positions for s in state]),
@@ -781,8 +785,7 @@ def write_state( # noqa: C901
781785
# Save atomic numbers only for first frame
782786
self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0)
783787

784-
if "pbc" not in self.array_registry:
785-
self.write_arrays({"pbc": state[0].pbc}, [0])
788+
data["pbc"] = torch.stack([s.pbc.reshape(-1) for s in state])
786789

787790
# Write all arrays to file
788791
self.write_arrays(data, steps)
@@ -833,7 +836,7 @@ def return_prop(self: Self, prop: str, frame: int) -> np.ndarray:
833836
arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)
834837
arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)
835838
arrays["masses"] = return_prop(self, "masses", frame)
836-
arrays["pbc"] = np.expand_dims(return_prop(self, "pbc", frame), axis=0)
839+
arrays["pbc"] = return_prop(self, "pbc", frame)
837840

838841
return arrays
839842

@@ -897,7 +900,7 @@ def get_atoms(self, frame: int = -1) -> "Atoms":
897900
numbers=np.ascontiguousarray(arrays["atomic_numbers"]),
898901
positions=np.ascontiguousarray(arrays["positions"]),
899902
cell=np.ascontiguousarray(arrays["cell"])[0],
900-
pbc=np.ascontiguousarray(arrays["pbc"])[0],
903+
pbc=np.ascontiguousarray(arrays["pbc"]),
901904
)
902905

903906
def get_state(

0 commit comments

Comments
 (0)