@@ -109,14 +109,6 @@ def pbc(self) -> torch.Tensor:
109109
110110 def __post_init__ (self ) -> None :
111111 """Initialize the SimState and validate the arguments."""
112- # if devices aren't all the same, raise an error, in a clean way
113- devices = {
114- attr : getattr (self , attr ).device
115- for attr in ("positions" , "masses" , "cell" , "atomic_numbers" )
116- }
117- if len (set (devices .values ())) > 1 :
118- raise ValueError ("All tensors must be on the same device" )
119-
120112 # Check that positions, masses and atomic numbers have compatible shapes
121113 shapes = [
122114 getattr (self , attr ).shape [0 ]
@@ -132,9 +124,7 @@ def __post_init__(self) -> None:
132124 if isinstance (self .pbc , bool ):
133125 self .pbc = [self .pbc ] * 3
134126 if not isinstance (self .pbc , torch .Tensor ):
135- self .pbc = torch .tensor (
136- self .pbc , dtype = torch .bool , device = self .positions .device
137- )
127+ self .pbc = torch .tensor (self .pbc , dtype = torch .bool , device = self .device )
138128
139129 initial_system_idx = self .system_idx
140130 if initial_system_idx is None :
@@ -157,6 +147,21 @@ def __post_init__(self) -> None:
157147 f"Cell must have shape (n_systems, 3, 3), got { self .cell .shape } "
158148 )
159149
150+ # if devices aren't all the same, raise an error, in a clean way
151+ devices = {
152+ attr : getattr (self , attr ).device
153+ for attr in (
154+ "positions" ,
155+ "masses" ,
156+ "cell" ,
157+ "atomic_numbers" ,
158+ "pbc" ,
159+ "system_idx" ,
160+ )
161+ }
162+ if len (set (devices .values ())) > 1 :
163+ raise ValueError ("All tensors must be on the same device" )
164+
160165 @property
161166 def wrap_positions (self ) -> torch .Tensor :
162167 """Atomic positions wrapped according to periodic boundary conditions if pbc=True,
0 commit comments