Skip to content

Commit 7b0d3e5

Browse files
committed
better support for Batch and NAGBatch serialization. The proper objects will be returned even if NAG.load() or Data.load() are called on files containing batched data
1 parent cc225ae commit 7b0d3e5

File tree

8 files changed

+189
-60
lines changed

8 files changed

+189
-60
lines changed

docs/data_structures.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Important specificities of our `Data` object are:
4848
`j` with `i<j`
4949
- `NAG.get_sampling()` produces indices for sampling the superpoints with
5050
certain constraints
51-
- `NAG.save()` and `NAG.load()` allow optimized, memory-friedly I/O operations
51+
- `NAG.save()` and `NAG.load()` allow optimized, memory-friendly I/O operations
5252
- `NAG.select()` indexes the nodes of a specified partition level à la numpy
5353
and updates the rest of the `NAG` structure accordingly
5454
- `NAG.show()` for interactive visualization (see

src/data/cluster.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def save(self, f, fp_dtype=torch.float):
136136
save_tensor(self.pointers, f, 'pointers', fp_dtype=fp_dtype)
137137
save_tensor(self.points, f, 'points', fp_dtype=fp_dtype)
138138

139-
@staticmethod
140-
def load(f, idx=None, update_sub=True, verbose=False):
139+
@classmethod
140+
def load(cls, f, idx=None, update_sub=True, verbose=False):
141141
"""Load Cluster from an HDF5 file. See `Cluster.save` for
142142
writing such file. Options allow reading only part of the
143143
clusters.
@@ -163,7 +163,7 @@ def load(f, idx=None, update_sub=True, verbose=False):
163163

164164
if not isinstance(f, (h5py.File, h5py.Group)):
165165
with h5py.File(f, 'r') as file:
166-
out = Cluster.load(
166+
out = cls.load(
167167
file, idx=idx, update_sub=update_sub, verbose=verbose)
168168
return out
169169

@@ -172,34 +172,34 @@ def load(f, idx=None, update_sub=True, verbose=False):
172172
start = time()
173173
idx = tensor_idx(idx)
174174
if verbose:
175-
print(f'Cluster.load tensor_idx : {time() - start:0.5f}s')
175+
print(f'{cls.__name__}.load tensor_idx : {time() - start:0.5f}s')
176176

177177
if idx is None or idx.shape[0] == 0:
178178
start = time()
179179
pointers = load_tensor(f['pointers'])
180180
points = load_tensor(f['points'])
181181
if verbose:
182-
print(f'Cluster.load read all : {time() - start:0.5f}s')
182+
print(f'{cls.__name__}.load read all : {time() - start:0.5f}s')
183183
start = time()
184-
out = Cluster(pointers, points), (None, None)
184+
out = cls(pointers, points), (None, None)
185185
if verbose:
186-
print(f'Cluster.load init : {time() - start:0.5f}s')
186+
print(f'{cls.__name__}.load init : {time() - start:0.5f}s')
187187
return out
188188

189189
# Read only pointers start and end indices based on idx
190190
start = time()
191191
ptr_start = load_tensor(f['pointers'], idx=idx)
192192
ptr_end = load_tensor(f['pointers'], idx=idx + 1)
193193
if verbose:
194-
print(f'Cluster.load read ptr : {time() - start:0.5f}s')
194+
print(f'{cls.__name__}.load read ptr : {time() - start:0.5f}s')
195195

196196
# Create the new pointers
197197
start = time()
198198
pointers = torch.cat([
199199
torch.zeros(1, dtype=ptr_start.dtype),
200200
torch.cumsum(ptr_end - ptr_start, 0)])
201201
if verbose:
202-
print(f'Cluster.load new pointers : {time() - start:0.5f}s')
202+
print(f'{cls.__name__}.load new pointers : {time() - start:0.5f}s')
203203

204204
# Create the indexing tensor to select and order values.
205205
# Simply, we could have used a list of slices, but we want to
@@ -212,19 +212,19 @@ def load(f, idx=None, update_sub=True, verbose=False):
212212
pointers[:-1]].repeat_interleave(sizes)
213213
val_idx += ptr_start.repeat_interleave(sizes)
214214
if verbose:
215-
print(f'Cluster.load val_idx : {time() - start:0.5f}s')
215+
print(f'{cls.__name__}.load val_idx : {time() - start:0.5f}s')
216216

217217
# Read the points, now we have computed the val_idx
218218
start = time()
219219
points = load_tensor(f['points'], idx=val_idx)
220220
if verbose:
221-
print(f'Cluster.load read points : {time() - start:0.5f}s')
221+
print(f'{cls.__name__}.load read points : {time() - start:0.5f}s')
222222

223223
# Build the Cluster object
224224
start = time()
225-
cluster = Cluster(pointers, points)
225+
cluster = cls(pointers, points)
226226
if verbose:
227-
print(f'Cluster.load init : {time() - start:0.5f}s')
227+
print(f'{cls.__name__}.load init : {time() - start:0.5f}s')
228228

229229
if not update_sub:
230230
return cluster, (None, None)
@@ -239,7 +239,7 @@ def load(f, idx=None, update_sub=True, verbose=False):
239239
idx_sub = cluster.points[perm]
240240
cluster.points = new_cluster_points
241241
if verbose:
242-
print(f'Cluster.load update_sub : {time() - start:0.5f}s')
242+
print(f'{cls.__name__}.load update_sub : {time() - start:0.5f}s')
243243

244244
# Selecting the subpoints with 'idx_sub' will not be
245245
# enough to maintain consistency with the current points. We
@@ -248,7 +248,7 @@ def load(f, idx=None, update_sub=True, verbose=False):
248248
start = time()
249249
sub_super = cluster.to_super_index()
250250
if verbose:
251-
print(f'Cluster.load super_index : {time() - start:0.5f}s')
251+
print(f'{cls.__name__}.load super_index : {time() - start:0.5f}s')
252252

253253
return cluster, (idx_sub, sub_super)
254254

src/data/csr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __getitem__(self, idx):
251251

252252
else:
253253
# Select the pointers and prepare the values indexing
254-
pointers, val_idx = CSRData.index_select_pointers(
254+
pointers, val_idx = self.__class__.index_select_pointers(
255255
self.pointers, idx)
256256
out.pointers = pointers
257257
out.values = [v[val_idx] for v in self.values]
@@ -348,8 +348,8 @@ def to(self, device, **kwargs):
348348
if self.__sizes__ is not None else None
349349
return out
350350

351-
@staticmethod
352-
def from_list(csr_list):
351+
@classmethod
352+
def from_list(cls, csr_list):
353353
assert isinstance(csr_list, list) and len(csr_list) > 0
354354
assert isinstance(csr_list[0], CSRData), \
355355
"All provided items must be CSRData objects."
@@ -392,7 +392,7 @@ def from_list(csr_list):
392392
for i in range(num_values):
393393
val_list = [csr.values[i] for csr in csr_list]
394394
if isinstance(csr_list[0].values[i], CSRData):
395-
val = CSRBatch.from_list(val_list)
395+
val = cls.from_list(val_list)
396396
elif is_index_value[i]:
397397
# "Index" values are stacked with updated indices.
398398
# For Clusters, this implies all point indices are

src/data/data.py

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -616,11 +616,13 @@ def save(
616616
else:
617617
raise NotImplementedError(f'Unsupported type={type(val)}')
618618

619-
@staticmethod
619+
@classmethod
620620
def load(
621-
f, idx=None, keys_idx=None, keys=None, update_sub=True,
621+
cls, f, idx=None, keys_idx=None, keys=None, update_sub=True,
622622
verbose=False, rgb_to_float=False):
623-
"""Read an HDF5 file and return its content as a dictionary.
623+
"""Read an HDF5 file and return its content as a Data object.
624+
625+
NB: if relevant, a Batch object will be returned.
624626
625627
:param f: h5 file path of h5py.File or h5py.Group
626628
:param idx: int, list, numpy.ndarray, torch.Tensor
@@ -644,17 +646,25 @@ def load(
644646
"""
645647
if not isinstance(f, (h5py.File, h5py.Group)):
646648
with h5py.File(f, 'r') as file:
647-
out = Data.load(
649+
out = cls.load(
648650
file, idx=idx, keys_idx=keys_idx, keys=keys,
649651
update_sub=update_sub, verbose=verbose,
650652
rgb_to_float=rgb_to_float)
651653
return out
652654

655+
# Check if the file actually corresponds to a Batch object
656+
# rather than a simple Data object
657+
if 'batch_item_0' in f.keys():
658+
return Batch.load(
659+
f, idx=idx, keys_idx=keys_idx, keys=keys,
660+
update_sub=update_sub, verbose=verbose,
661+
rgb_to_float=rgb_to_float)
662+
653663
idx = tensor_idx(idx)
654664
if idx.shape[0] == 0:
655665
keys_idx = []
656666
elif keys_idx is None:
657-
keys_idx = list(set(f.keys()) - set(Data._NOT_INDEXABLE))
667+
keys_idx = list(set(f.keys()) - set(cls._NOT_INDEXABLE))
658668
if keys is None:
659669
all_keys = list(f.keys())
660670
for k in ['_csr_', '_cluster_', '_obj_']:
@@ -685,7 +695,7 @@ def load(
685695
elif k in keys:
686696
d_dict[k] = load_tensor(f[k])
687697
if verbose and k in d_dict.keys():
688-
print(f'Data.load {k:<22}: {time() - start:0.5f}s')
698+
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s')
689699

690700
# Update the 'keys_idx' with newly-found 'csr_keys',
691701
# 'cluster_keys', and 'obj_keys'
@@ -703,7 +713,7 @@ def load(
703713
elif k in keys:
704714
d_dict[k] = load_csr_to_dense(f['_csr_'][k], verbose=verbose)
705715
if verbose and k in d_dict.keys():
706-
print(f'Data.load {k:<22}: {time() - start:0.5f}s')
716+
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s')
707717

708718
# Special key '_cluster_' holds Cluster data
709719
for k in cluster_keys:
@@ -717,7 +727,7 @@ def load(
717727
f['_cluster_'][k], update_sub=update_sub,
718728
verbose=verbose)[0]
719729
if verbose and k in d_dict.keys():
720-
print(f'Data.load {k:<22}: {time() - start:0.5f}s')
730+
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s')
721731

722732
# Special key '_obj_' holds InstanceData data
723733
for k in obj_keys:
@@ -728,7 +738,7 @@ def load(
728738
elif k in keys:
729739
d_dict[k] = InstanceData.load(f['_obj_'][k], verbose=verbose)
730740
if verbose and k in d_dict.keys():
731-
print(f'Data.load {k:<22}: {time() - start:0.5f}s')
741+
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s')
732742

733743
# In case RGB is among the keys and is in integer type, convert
734744
# to float
@@ -737,7 +747,7 @@ def load(
737747
d_dict[k] = to_float_rgb(d_dict[k]) if rgb_to_float \
738748
else to_byte_rgb(d_dict[k])
739749

740-
return Data(**d_dict)
750+
return cls(**d_dict)
741751

742752
def estimate_instance_centroid(self, mode='iou'):
743753
"""Estimate the centroid position of each target instance
@@ -959,3 +969,99 @@ def get_example(self, idx):
959969
self.obj = obj_bckp
960970

961971
return data
972+
973+
def save(
974+
self,
975+
f,
976+
y_to_csr=True,
977+
pos_dtype=torch.float,
978+
fp_dtype=torch.float):
979+
"""Save Batch to HDF5 file.
980+
981+
:param f: h5 file path of h5py.File or h5py.Group
982+
:param y_to_csr: bool
983+
Convert 'y' to CSR format before saving. Only applies if
984+
'y' is a 2D histogram
985+
:param pos_dtype: torch dtype
986+
Data type to which 'pos' should be cast before saving. The
987+
reason for this separate treatment of 'pos' is that global
988+
coordinates may be too large and casting to 'fp_dtype' may
989+
result in hurtful precision loss
990+
:param fp_dtype: torch dtype
991+
Data type to which floating point tensors should be cast
992+
before saving
993+
:return:
994+
"""
995+
# To facilitate Batch serialization, we store the Batch as a
996+
# list of Data objects rather than a single Data object
997+
data_list = self.to_data_list()
998+
999+
if not isinstance(f, (h5py.File, h5py.Group)):
1000+
with h5py.File(f, 'w') as file:
1001+
self.save(
1002+
file,
1003+
y_to_csr=y_to_csr,
1004+
pos_dtype=pos_dtype,
1005+
fp_dtype=fp_dtype)
1006+
return
1007+
1008+
assert isinstance(f, (h5py.File, h5py.Group))
1009+
1010+
# Save each individual Data object
1011+
for i, data in enumerate(data_list):
1012+
g = f.create_group(f'batch_item_{i}')
1013+
data.save(
1014+
g,
1015+
y_to_csr=y_to_csr,
1016+
pos_dtype=pos_dtype,
1017+
fp_dtype=fp_dtype)
1018+
1019+
@classmethod
1020+
def load(
1021+
cls, f, idx=None, keys_idx=None, keys=None, update_sub=True,
1022+
verbose=False, rgb_to_float=False):
1023+
"""Read an HDF5 file and return its content as a Batch object.
1024+
1025+
:param f: h5 file path of h5py.File or h5py.Group
1026+
:param idx: int, list, numpy.ndarray, torch.Tensor
1027+
Used to select the elements in `keys_idx`. Supports fancy
1028+
indexing
1029+
:param keys_idx: List(str)
1030+
Keys on which the indexing should be applied
1031+
:param keys: List(str)
1032+
Keys should be loaded from the file, ignoring the rest
1033+
:param update_sub: bool
1034+
If True, the point (i.e. subpoint) indices will also be
1035+
updated to maintain dense indices. The output will then
1036+
contain '(idx_sub, sub_super)' which can help apply these
1037+
changes to maintain consistency with lower hierarchy levels
1038+
of a NAG.
1039+
:param verbose: bool
1040+
:param rgb_to_float: bool
1041+
If True and an integer 'rgb' or 'mean_rgb' attribute is
1042+
loaded, it will be cast to float
1043+
:return:
1044+
"""
1045+
if not isinstance(f, (h5py.File, h5py.Group)):
1046+
with h5py.File(f, 'r') as file:
1047+
out = cls.load(
1048+
file, idx=idx, keys_idx=keys_idx, keys=keys,
1049+
update_sub=update_sub, verbose=verbose,
1050+
rgb_to_float=rgb_to_float)
1051+
return out
1052+
1053+
# Recover each individual Data object making up the Batch object
1054+
data_list = []
1055+
num_batch_items = len(f)
1056+
for i in range(num_batch_items):
1057+
start = time()
1058+
data = Data.load(
1059+
f[f'batch_item_{i}'], idx=idx, keys_idx=keys_idx, keys=keys,
1060+
update_sub=update_sub, verbose=verbose,
1061+
rgb_to_float=rgb_to_float)
1062+
data_list.append(data)
1063+
if verbose:
1064+
print(f'{cls.__name__}.load item-{i:<15} : 'f'{time() - start:0.3f}s\n')
1065+
1066+
# Return a Batch object
1067+
return cls.from_data_list(data_list)

0 commit comments

Comments
 (0)