@@ -109,16 +109,6 @@ def pbc(self) -> torch.Tensor:
109109
110110 def __post_init__ (self ) -> None :
111111 """Initialize the SimState and validate the arguments."""
112- if isinstance (self .pbc , bool ):
113- self .pbc = [self .pbc ] * 3
114- if not isinstance (self .pbc , torch .Tensor ):
115- self .pbc = torch .tensor (
116- self .pbc , dtype = torch .bool , device = self .positions .device
117- )
118-
119- # Validate and process the state after initialization.
120- # data validation and fill system_idx
121- # should make pbc a tensor here
122112 # if devices aren't all the same, raise an error, in a clean way
123113 devices = {
124114 attr : getattr (self , attr ).device
@@ -139,6 +129,13 @@ def __post_init__(self) -> None:
139129 f"masses { shapes [1 ]} , atomic_numbers { shapes [2 ]} "
140130 )
141131
132+ if isinstance (self .pbc , bool ):
133+ self .pbc = [self .pbc ] * 3
134+ if not isinstance (self .pbc , torch .Tensor ):
135+ self .pbc = torch .tensor (
136+ self .pbc , dtype = torch .bool , device = self .positions .device
137+ )
138+
142139 initial_system_idx = self .system_idx
143140 if initial_system_idx is None :
144141 self .system_idx = torch .zeros (
0 commit comments