Skip to content

Commit 4375c2e

Browse files
committed
make convert_to_8bit_clip_percentile_normalization more robust
wip trajectory fix trajectory rm print
1 parent 300315a commit 4375c2e

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

tests/test_trajectory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def random_state() -> MDState:
3333
cell=torch.unsqueeze(torch.eye(3) * 10.0, 0),
3434
atomic_numbers=torch.ones(10, dtype=torch.int32),
3535
system_idx=torch.zeros(10, dtype=torch.int32),
36-
pbc=True,
36+
pbc=[True, True, False],
3737
)
3838

3939

@@ -473,6 +473,7 @@ def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> Non
473473
assert state.positions.dtype == expected_dtype
474474
assert state.cell.dtype == expected_dtype
475475
assert state.atomic_numbers.dtype == torch.int # Should always be int
476+
assert state.pbc.dtype == torch.bool # Should always be bool
476477

477478
# Test values (convert to CPU for comparison)
478479
np.testing.assert_allclose(state.positions, random_state.positions)
@@ -509,7 +510,7 @@ def test_write_ase_trajectory(
509510
np.testing.assert_allclose(
510511
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
511512
)
512-
np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy()[0])
513+
np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy())
513514

514515
# Clean up
515516
ase_traj.close()

torch_sim/trajectory.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,20 @@ def write_arrays(
518518

519519
self.flush()
520520

521+
def write_global_array(self, name: str, array: np.ndarray | torch.Tensor) -> None:
522+
"""Write a global array to the trajectory file.
523+
524+
This function is used to write a global array to the trajectory file.
525+
"""
526+
if isinstance(array, torch.Tensor):
527+
array = array.cpu().detach().numpy()
528+
529+
steps = [0]
530+
if name not in self.array_registry:
531+
self._initialize_array(name, array)
532+
self._validate_array(name, array, steps)
533+
self._serialize_array(name, array, steps)
534+
521535
def _initialize_array(self, name: str, array: np.ndarray) -> None:
522536
"""Initialize a single array and add it to the registry.
523537
@@ -643,15 +657,10 @@ def get_array(
643657
if name not in self.array_registry:
644658
raise ValueError(f"Array {name} not found in registry")
645659

646-
data = self._file.root.data.__getitem__(name).read(
660+
return self._file.root.data.__getitem__(name).read(
647661
start=start, stop=stop, step=step
648662
)
649663

650-
if name == "pbc":
651-
return np.squeeze(data, axis=0)
652-
653-
return data
654-
655664
def get_steps(
656665
self,
657666
name: str,
@@ -788,7 +797,8 @@ def write_state( # noqa: C901
788797
self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0)
789798

790799
if "pbc" not in self.array_registry:
791-
self.write_arrays({"pbc": state[0].pbc}, 0)
800+
print("not in array registry")
801+
self.write_global_array("pbc", state[0].pbc)
792802

793803
# Write all arrays to file
794804
self.write_arrays(data, steps)
@@ -830,15 +840,17 @@ def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]:
830840
arrays["positions"] = self.get_array("positions", start=frame, stop=frame + 1)[0]
831841

832842
def return_prop(self: Self, prop: str, frame: int) -> np.ndarray:
843+
if prop == "pbc":
844+
return self.get_array(prop, start=0, stop=3)
833845
if getattr(self._file.root.data, prop).shape[0] > 1: # Variable prop
834846
start, stop = frame, frame + 1
835847
else: # Static prop
836848
start, stop = 0, 1
837-
return self.get_array(prop, start=start, stop=stop)
849+
return self.get_array(prop, start=start, stop=stop)[0]
838850

839-
arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)[0]
840-
arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)[0]
841-
arrays["masses"] = return_prop(self, "masses", frame)[0]
851+
arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0)
852+
arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame)
853+
arrays["masses"] = return_prop(self, "masses", frame)
842854
arrays["pbc"] = return_prop(self, "pbc", frame)
843855

844856
return arrays

0 commit comments

Comments
 (0)