@@ -742,6 +742,10 @@ def write_state( # noqa: C901
742742
743743 if len (sub_states ) != len (steps ):
744744 raise ValueError (f"{ len (sub_states )= } must match the { len (steps )= } " )
745+
746+ # Use the selected states for data serialization
747+ state = sub_states
748+
745749 # Initialize data dictionary with required arrays
746750 data = {
747751 "positions" : torch .stack ([s .positions for s in state ]),
@@ -781,8 +785,7 @@ def write_state( # noqa: C901
781785 # Save atomic numbers only for first frame
782786 self .write_arrays ({"atomic_numbers" : state [0 ].atomic_numbers }, 0 )
783787
784- if "pbc" not in self .array_registry :
785- self .write_arrays ({"pbc" : state [0 ].pbc }, [0 ])
788+ data ["pbc" ] = torch .stack ([s .pbc .reshape (- 1 ) for s in state ])
786789
787790 # Write all arrays to file
788791 self .write_arrays (data , steps )
@@ -833,7 +836,7 @@ def return_prop(self: Self, prop: str, frame: int) -> np.ndarray:
833836 arrays ["cell" ] = np .expand_dims (return_prop (self , "cell" , frame ), axis = 0 )
834837 arrays ["atomic_numbers" ] = return_prop (self , "atomic_numbers" , frame )
835838 arrays ["masses" ] = return_prop (self , "masses" , frame )
836- arrays ["pbc" ] = np . expand_dims ( return_prop (self , "pbc" , frame ), axis = 0 )
839+ arrays ["pbc" ] = return_prop (self , "pbc" , frame )
837840
838841 return arrays
839842
@@ -897,7 +900,7 @@ def get_atoms(self, frame: int = -1) -> "Atoms":
897900 numbers = np .ascontiguousarray (arrays ["atomic_numbers" ]),
898901 positions = np .ascontiguousarray (arrays ["positions" ]),
899902 cell = np .ascontiguousarray (arrays ["cell" ])[0 ],
900- pbc = np .ascontiguousarray (arrays ["pbc" ])[ 0 ] ,
903+ pbc = np .ascontiguousarray (arrays ["pbc" ]),
901904 )
902905
903906 def get_state (
0 commit comments