@@ -518,6 +518,20 @@ def write_arrays(
518518
519519 self .flush ()
520520
521+ def write_global_array (self , name : str , array : np .ndarray | torch .Tensor ) -> None :
522+ """Write a global array to the trajectory file.
523+
524+ This function is used to write a global array to the trajectory file.
525+ """
526+ if isinstance (array , torch .Tensor ):
527+ array = array .cpu ().detach ().numpy ()
528+
529+ steps = [0 ]
530+ if name not in self .array_registry :
531+ self ._initialize_array (name , array )
532+ self ._validate_array (name , array , steps )
533+ self ._serialize_array (name , array , steps )
534+
521535 def _initialize_array (self , name : str , array : np .ndarray ) -> None :
522536 """Initialize a single array and add it to the registry.
523537
@@ -643,15 +657,10 @@ def get_array(
643657 if name not in self .array_registry :
644658 raise ValueError (f"Array { name } not found in registry" )
645659
646- data = self ._file .root .data .__getitem__ (name ).read (
660+ return self ._file .root .data .__getitem__ (name ).read (
647661 start = start , stop = stop , step = step
648662 )
649663
650- if name == "pbc" :
651- return np .squeeze (data , axis = 0 )
652-
653- return data
654-
655664 def get_steps (
656665 self ,
657666 name : str ,
@@ -788,7 +797,8 @@ def write_state( # noqa: C901
788797 self .write_arrays ({"atomic_numbers" : state [0 ].atomic_numbers }, 0 )
789798
790799 if "pbc" not in self .array_registry :
791- self .write_arrays ({"pbc" : state [0 ].pbc }, 0 )
800+ print ("not in array registry" )
801+ self .write_global_array ("pbc" , state [0 ].pbc )
792802
793803 # Write all arrays to file
794804 self .write_arrays (data , steps )
@@ -830,15 +840,17 @@ def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]:
830840 arrays ["positions" ] = self .get_array ("positions" , start = frame , stop = frame + 1 )[0 ]
831841
832842 def return_prop (self : Self , prop : str , frame : int ) -> np .ndarray :
843+ if prop == "pbc" :
844+ return self .get_array (prop , start = 0 , stop = 3 )
833845 if getattr (self ._file .root .data , prop ).shape [0 ] > 1 : # Variable prop
834846 start , stop = frame , frame + 1
835847 else : # Static prop
836848 start , stop = 0 , 1
837- return self .get_array (prop , start = start , stop = stop )
849+ return self .get_array (prop , start = start , stop = stop )[ 0 ]
838850
839- arrays ["cell" ] = np .expand_dims (return_prop (self , "cell" , frame ), axis = 0 )[ 0 ]
840- arrays ["atomic_numbers" ] = return_prop (self , "atomic_numbers" , frame )[ 0 ]
841- arrays ["masses" ] = return_prop (self , "masses" , frame )[ 0 ]
851+ arrays ["cell" ] = np .expand_dims (return_prop (self , "cell" , frame ), axis = 0 )
852+ arrays ["atomic_numbers" ] = return_prop (self , "atomic_numbers" , frame )
853+ arrays ["masses" ] = return_prop (self , "masses" , frame )
842854 arrays ["pbc" ] = return_prop (self , "pbc" , frame )
843855
844856 return arrays
0 commit comments