Skip to content

Commit fcf1d1a

Browse files
committed
cleanup state
1 parent 6d518ff commit fcf1d1a

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

examples/scripts/7_Others/7.3_Batched_neighbor_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
cutoff = torch.tensor(4.0, dtype=pos.dtype)
1919
self_interaction = False
2020

21+
# Ensure pbc has the correct shape [n_systems, 3]
2122
pbc_tensor = torch.tensor(pbc).repeat(state.n_systems, 1)
2223

2324
mapping, mapping_system, shifts_idx = torch_nl_linked_cell(

torch_sim/state.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,6 @@ def pbc(self) -> torch.Tensor:
109109

110110
def __post_init__(self) -> None:
111111
"""Initialize the SimState and validate the arguments."""
112-
if isinstance(self.pbc, bool):
113-
self.pbc = [self.pbc] * 3
114-
if not isinstance(self.pbc, torch.Tensor):
115-
self.pbc = torch.tensor(
116-
self.pbc, dtype=torch.bool, device=self.positions.device
117-
)
118-
119-
# Validate and process the state after initialization.
120-
# data validation and fill system_idx
121-
# should make pbc a tensor here
122112
# if devices aren't all the same, raise an error, in a clean way
123113
devices = {
124114
attr: getattr(self, attr).device
@@ -139,6 +129,13 @@ def __post_init__(self) -> None:
139129
f"masses {shapes[1]}, atomic_numbers {shapes[2]}"
140130
)
141131

132+
if isinstance(self.pbc, bool):
133+
self.pbc = [self.pbc] * 3
134+
if not isinstance(self.pbc, torch.Tensor):
135+
self.pbc = torch.tensor(
136+
self.pbc, dtype=torch.bool, device=self.positions.device
137+
)
138+
142139
initial_system_idx = self.system_idx
143140
if initial_system_idx is None:
144141
self.system_idx = torch.zeros(

0 commit comments

Comments
 (0)