Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,34 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None:
"New state doesn't have correct device!"
)
assert si_sim_state is not new_state_gpu, "New state is not a different object!"


def test_state_set_cell(ti_sim_state: SimState) -> None:
"""Test the set_cell method of SimState."""
new_cell = (
torch.diag_embed(
torch.tensor(
[3.0, 4.0, 5.0], device=ti_sim_state.device, dtype=ti_sim_state.dtype
)
)
@ ti_sim_state.cell
)
ase_atoms = ti_sim_state.to_atoms()[0]
ti_sim_state.set_cell(new_cell, scale_atoms=True)
ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True)
assert torch.allclose(
ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions)
)

M = torch.tensor(
[[[1.0, 0.2, 0.0], [0.1, 1.5, 0.0], [0.0, 0.0, 2.0]]],
device=DEVICE,
dtype=ti_sim_state.dtype,
)
new_cell = M @ ti_sim_state.cell
ase_atoms = ti_sim_state.to_atoms()[0]
ti_sim_state.set_cell(new_cell, scale_atoms=True)
ase_atoms.set_cell(new_cell[0].T.cpu().numpy(), scale_atoms=True)
assert torch.allclose(
ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions)
)
25 changes: 25 additions & 0 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,31 @@ def row_vector_cell(self, value: torch.Tensor) -> None:
"""
self.cell = value.mT

def set_cell(
self,
cell: torch.Tensor,
scale_atoms: bool = False, # noqa: FBT001, FBT002
) -> None:
"""Set the unit cell of the system, optionally scaling atomic positions.
Torch version of ASE Atoms.set_cell.

Args:
cell (torch.Tensor): New unit cell with shape (n_systems, 3, 3)
scale_atoms (bool, optional): Whether to scale atomic positions according to
the change in cell. Defaults to False.
"""
if cell.shape != self.cell.shape:
raise ValueError(
f"New cell must have shape {self.cell.shape}, got {cell.shape}"
)
if scale_atoms:
M = torch.linalg.solve(self.cell.mT, cell.mT)
self.positions = torch.bmm(
self.positions.unsqueeze(1), M[self.system_idx]
).squeeze(1)

self.cell = cell

def get_number_of_degrees_of_freedom(self) -> torch.Tensor:
"""Calculate degrees of freedom accounting for constraints.

Expand Down