Skip to content

Commit 300315a

Browse files
committed
orion's review
1 parent 2850f50 commit 300315a

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

torch_sim/state.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,6 @@ def pbc(self) -> torch.Tensor:
109109

110110
def __post_init__(self) -> None:
111111
"""Initialize the SimState and validate the arguments."""
112-
# if devices aren't all the same, raise an error, in a clean way
113-
devices = {
114-
attr: getattr(self, attr).device
115-
for attr in ("positions", "masses", "cell", "atomic_numbers")
116-
}
117-
if len(set(devices.values())) > 1:
118-
raise ValueError("All tensors must be on the same device")
119-
120112
# Check that positions, masses and atomic numbers have compatible shapes
121113
shapes = [
122114
getattr(self, attr).shape[0]
@@ -132,9 +124,7 @@ def __post_init__(self) -> None:
132124
if isinstance(self.pbc, bool):
133125
self.pbc = [self.pbc] * 3
134126
if not isinstance(self.pbc, torch.Tensor):
135-
self.pbc = torch.tensor(
136-
self.pbc, dtype=torch.bool, device=self.positions.device
137-
)
127+
self.pbc = torch.tensor(self.pbc, dtype=torch.bool, device=self.device)
138128

139129
initial_system_idx = self.system_idx
140130
if initial_system_idx is None:
@@ -157,6 +147,21 @@ def __post_init__(self) -> None:
157147
f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}"
158148
)
159149

150+
# if devices aren't all the same, raise an error, in a clean way
151+
devices = {
152+
attr: getattr(self, attr).device
153+
for attr in (
154+
"positions",
155+
"masses",
156+
"cell",
157+
"atomic_numbers",
158+
"pbc",
159+
"system_idx",
160+
)
161+
}
162+
if len(set(devices.values())) > 1:
163+
raise ValueError("All tensors must be on the same device")
164+
160165
@property
161166
def wrap_positions(self) -> torch.Tensor:
162167
"""Atomic positions wrapped according to periodic boundary conditions if pbc=True,

0 commit comments

Comments
 (0)