@@ -616,11 +616,13 @@ def save(
616
616
else :
617
617
raise NotImplementedError (f'Unsupported type={ type (val )} ' )
618
618
619
- @staticmethod
619
+ @classmethod
620
620
def load (
621
- f , idx = None , keys_idx = None , keys = None , update_sub = True ,
621
+ cls , f , idx = None , keys_idx = None , keys = None , update_sub = True ,
622
622
verbose = False , rgb_to_float = False ):
623
- """Read an HDF5 file and return its content as a dictionary.
623
+ """Read an HDF5 file and return its content as a Data object.
624
+
625
+ NB: if relevant, a Batch object will be returned.
624
626
625
627
:param f: h5 file path of h5py.File or h5py.Group
626
628
:param idx: int, list, numpy.ndarray, torch.Tensor
@@ -644,17 +646,25 @@ def load(
644
646
"""
645
647
if not isinstance (f , (h5py .File , h5py .Group )):
646
648
with h5py .File (f , 'r' ) as file :
647
- out = Data .load (
649
+ out = cls .load (
648
650
file , idx = idx , keys_idx = keys_idx , keys = keys ,
649
651
update_sub = update_sub , verbose = verbose ,
650
652
rgb_to_float = rgb_to_float )
651
653
return out
652
654
655
+ # Check if the file actually corresponds to a Batch object
656
+ # rather than a simple Data object
657
+ if 'batch_item_0' in f .keys ():
658
+ return Batch .load (
659
+ f , idx = idx , keys_idx = keys_idx , keys = keys ,
660
+ update_sub = update_sub , verbose = verbose ,
661
+ rgb_to_float = rgb_to_float )
662
+
653
663
idx = tensor_idx (idx )
654
664
if idx .shape [0 ] == 0 :
655
665
keys_idx = []
656
666
elif keys_idx is None :
657
- keys_idx = list (set (f .keys ()) - set (Data ._NOT_INDEXABLE ))
667
+ keys_idx = list (set (f .keys ()) - set (cls ._NOT_INDEXABLE ))
658
668
if keys is None :
659
669
all_keys = list (f .keys ())
660
670
for k in ['_csr_' , '_cluster_' , '_obj_' ]:
@@ -685,7 +695,7 @@ def load(
685
695
elif k in keys :
686
696
d_dict [k ] = load_tensor (f [k ])
687
697
if verbose and k in d_dict .keys ():
688
- print (f'Data .load { k :<22} : { time () - start :0.5f} s' )
698
+ print (f'{ cls . __name__ } .load { k :<22} : { time () - start :0.5f} s' )
689
699
690
700
# Update the 'keys_idx' with newly-found 'csr_keys',
691
701
# 'cluster_keys', and 'obj_keys'
@@ -703,7 +713,7 @@ def load(
703
713
elif k in keys :
704
714
d_dict [k ] = load_csr_to_dense (f ['_csr_' ][k ], verbose = verbose )
705
715
if verbose and k in d_dict .keys ():
706
- print (f'Data .load { k :<22} : { time () - start :0.5f} s' )
716
+ print (f'{ cls . __name__ } .load { k :<22} : { time () - start :0.5f} s' )
707
717
708
718
# Special key '_cluster_' holds Cluster data
709
719
for k in cluster_keys :
@@ -717,7 +727,7 @@ def load(
717
727
f ['_cluster_' ][k ], update_sub = update_sub ,
718
728
verbose = verbose )[0 ]
719
729
if verbose and k in d_dict .keys ():
720
- print (f'Data .load { k :<22} : { time () - start :0.5f} s' )
730
+ print (f'{ cls . __name__ } .load { k :<22} : { time () - start :0.5f} s' )
721
731
722
732
# Special key '_obj_' holds InstanceData data
723
733
for k in obj_keys :
@@ -728,7 +738,7 @@ def load(
728
738
elif k in keys :
729
739
d_dict [k ] = InstanceData .load (f ['_obj_' ][k ], verbose = verbose )
730
740
if verbose and k in d_dict .keys ():
731
- print (f'Data .load { k :<22} : { time () - start :0.5f} s' )
741
+ print (f'{ cls . __name__ } .load { k :<22} : { time () - start :0.5f} s' )
732
742
733
743
# In case RGB is among the keys and is in integer type, convert
734
744
# to float
@@ -737,7 +747,7 @@ def load(
737
747
d_dict [k ] = to_float_rgb (d_dict [k ]) if rgb_to_float \
738
748
else to_byte_rgb (d_dict [k ])
739
749
740
- return Data (** d_dict )
750
+ return cls (** d_dict )
741
751
742
752
def estimate_instance_centroid (self , mode = 'iou' ):
743
753
"""Estimate the centroid position of each target instance
@@ -959,3 +969,99 @@ def get_example(self, idx):
959
969
self .obj = obj_bckp
960
970
961
971
return data
972
+
973
+ def save (
974
+ self ,
975
+ f ,
976
+ y_to_csr = True ,
977
+ pos_dtype = torch .float ,
978
+ fp_dtype = torch .float ):
979
+ """Save Batch to HDF5 file.
980
+
981
+ :param f: h5 file path of h5py.File or h5py.Group
982
+ :param y_to_csr: bool
983
+ Convert 'y' to CSR format before saving. Only applies if
984
+ 'y' is a 2D histogram
985
+ :param pos_dtype: torch dtype
986
+ Data type to which 'pos' should be cast before saving. The
987
+ reason for this separate treatment of 'pos' is that global
988
+ coordinates may be too large and casting to 'fp_dtype' may
989
+ result in hurtful precision loss
990
+ :param fp_dtype: torch dtype
991
+ Data type to which floating point tensors should be cast
992
+ before saving
993
+ :return:
994
+ """
995
+ # To facilitate Batch serialization, we store the Batch as a
996
+ # list of Data objects rather than a single Data object
997
+ data_list = self .to_data_list ()
998
+
999
+ if not isinstance (f , (h5py .File , h5py .Group )):
1000
+ with h5py .File (f , 'w' ) as file :
1001
+ self .save (
1002
+ file ,
1003
+ y_to_csr = y_to_csr ,
1004
+ pos_dtype = pos_dtype ,
1005
+ fp_dtype = fp_dtype )
1006
+ return
1007
+
1008
+ assert isinstance (f , (h5py .File , h5py .Group ))
1009
+
1010
+ # Save each individual Data object
1011
+ for i , data in enumerate (data_list ):
1012
+ g = f .create_group (f'batch_item_{ i } ' )
1013
+ data .save (
1014
+ g ,
1015
+ y_to_csr = y_to_csr ,
1016
+ pos_dtype = pos_dtype ,
1017
+ fp_dtype = fp_dtype )
1018
+
1019
+ @classmethod
1020
+ def load (
1021
+ cls , f , idx = None , keys_idx = None , keys = None , update_sub = True ,
1022
+ verbose = False , rgb_to_float = False ):
1023
+ """Read an HDF5 file and return its content as a Batch object.
1024
+
1025
+ :param f: h5 file path of h5py.File or h5py.Group
1026
+ :param idx: int, list, numpy.ndarray, torch.Tensor
1027
+ Used to select the elements in `keys_idx`. Supports fancy
1028
+ indexing
1029
+ :param keys_idx: List(str)
1030
+ Keys on which the indexing should be applied
1031
+ :param keys: List(str)
1032
+ Keys should be loaded from the file, ignoring the rest
1033
+ :param update_sub: bool
1034
+ If True, the point (i.e. subpoint) indices will also be
1035
+ updated to maintain dense indices. The output will then
1036
+ contain '(idx_sub, sub_super)' which can help apply these
1037
+ changes to maintain consistency with lower hierarchy levels
1038
+ of a NAG.
1039
+ :param verbose: bool
1040
+ :param rgb_to_float: bool
1041
+ If True and an integer 'rgb' or 'mean_rgb' attribute is
1042
+ loaded, it will be cast to float
1043
+ :return:
1044
+ """
1045
+ if not isinstance (f , (h5py .File , h5py .Group )):
1046
+ with h5py .File (f , 'r' ) as file :
1047
+ out = cls .load (
1048
+ file , idx = idx , keys_idx = keys_idx , keys = keys ,
1049
+ update_sub = update_sub , verbose = verbose ,
1050
+ rgb_to_float = rgb_to_float )
1051
+ return out
1052
+
1053
+ # Recover each individual Data object making up the Batch object
1054
+ data_list = []
1055
+ num_batch_items = len (f )
1056
+ for i in range (num_batch_items ):
1057
+ start = time ()
1058
+ data = Data .load (
1059
+ f [f'batch_item_{ i } ' ], idx = idx , keys_idx = keys_idx , keys = keys ,
1060
+ update_sub = update_sub , verbose = verbose ,
1061
+ rgb_to_float = rgb_to_float )
1062
+ data_list .append (data )
1063
+ if verbose :
1064
+ print (f'{ cls .__name__ } .load item-{ i :<15} : ' f'{ time () - start :0.3f} s\n ' )
1065
+
1066
+ # Return a Batch object
1067
+ return cls .from_data_list (data_list )
0 commit comments