diff --git a/torch_sim/state.py b/torch_sim/state.py index 813354fe..2cd5396e 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -167,8 +167,11 @@ def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3). """ - # TODO: implement a wrapping method - return self.positions + if not self.pbc: + return self.positions + return ts.transforms.pbc_wrap_batched( + self.positions, self.cell, self.system_idx, self.pbc + ) @property def device(self) -> torch.device: