diff --git a/tests/test_state.py b/tests/test_state.py index 426e3404..d489af39 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -650,3 +650,34 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: "New state doesn't have correct device!" ) assert si_sim_state is not new_state_gpu, "New state is not a different object!" + + +def test_state_set_cell(ti_sim_state: SimState) -> None: + """Test the set_cell method of SimState.""" + new_cell = ( + torch.diag_embed( + torch.tensor( + [3.0, 4.0, 5.0], device=ti_sim_state.device, dtype=ti_sim_state.dtype + ) + ) + @ ti_sim_state.cell + ) + ase_atoms = ti_sim_state.to_atoms()[0] + ti_sim_state.set_cell(new_cell, scale_atoms=True) + ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True) + assert torch.allclose( + ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions) + ) + + M = torch.tensor( + [[[1.0, 0.2, 0.0], [0.1, 1.5, 0.0], [0.0, 0.0, 2.0]]], + device=DEVICE, + dtype=ti_sim_state.dtype, + ) + new_cell = M @ ti_sim_state.cell + ase_atoms = ti_sim_state.to_atoms()[0] + ti_sim_state.set_cell(new_cell, scale_atoms=True) + ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True) + assert torch.allclose( + ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions) + ) diff --git a/torch_sim/state.py b/torch_sim/state.py index 0b2d6ef8..bb0bf1e7 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -251,6 +251,31 @@ def row_vector_cell(self, value: torch.Tensor) -> None: """ self.cell = value.mT + def set_cell( + self, + cell: torch.Tensor, + scale_atoms: bool = False, # noqa: FBT001, FBT002 + ) -> None: + """Set the unit cell of the system, optionally scaling atomic positions. + Torch version of ASE Atoms.set_cell. + + Args: + cell (torch.Tensor): New unit cell with shape (n_systems, 3, 3) + scale_atoms (bool, optional): Whether to scale atomic positions according to + the change in cell. Defaults to False. + """ + if cell.shape != self.cell.shape: + raise ValueError( + f"New cell must have shape {self.cell.shape}, got {cell.shape}" + ) + if scale_atoms: + M = torch.linalg.solve(self.cell.mT, cell.mT) + self.positions = torch.bmm( + self.positions.unsqueeze(1), M[self.system_idx] + ).squeeze(1) + + self.cell = cell + def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom accounting for constraints.