Skip to content

Commit 8eb51f7

Browse files
committed
FEAT: added support for mixed types dataframes in H5 files
1 parent 8c595ea commit 8eb51f7

File tree

1 file changed

+66
-3
lines changed

1 file changed

+66
-3
lines changed

larray_editor/arrayadapter.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2444,15 +2444,59 @@ class PyTablesPandasFrameAdapter(AbstractColumnarAdapter):
24442444
def __init__(self, data, attributes):
24452445
super().__init__(data=data, attributes=attributes)
24462446
attrs = data._v_attrs
2447-
assert hasattr(attrs, 'nblocks') and attrs.nblocks == 1, "not implemented for nblocks > 1"
2447+
assert hasattr(attrs, 'nblocks')
24482448
assert hasattr(attrs, 'axis0_variety') and attrs.axis0_variety in {'regular', 'multi'}
24492449
assert hasattr(attrs, 'axis1_variety') and attrs.axis1_variety in {'regular', 'multi'}
24502450
self._axis0_variety = attrs.axis0_variety
24512451
self._axis1_variety = attrs.axis1_variety
24522452
self._encoding = getattr(attrs, 'encoding', None)
2453+
nblocks = attrs.nblocks
2454+
self._block_values_nodes = [data._f_get_child(f'block{i}_values')
2455+
for i in range(nblocks)]
2456+
assert not (nblocks > 1 and attrs.axis0_variety == 'multi'), \
2457+
("loading mixed type DataFrames with a multi-index in columns "
2458+
"from HDF5 is not implemented yet")
2459+
2460+
# data.block0_values.shape[0] is not always correct (if multiblocks)
2461+
if attrs.axis1_variety == 'multi':
2462+
self._num_rows = data.axis1_label0.shape[0]
2463+
else:
2464+
self._num_rows = data.axis1.shape[0]
2465+
2466+
if nblocks > 1:
2467+
import tables
2468+
2469+
axis_node = data._f_get_child(f'axis0')
2470+
col_names = axis_node.read().tolist()
2471+
# {col_idx: (block_idx, idx_in_block)}
2472+
column_source = {}
2473+
cached_string_blocks = {}
2474+
for block_idx in range(nblocks):
2475+
block_values_node = data._f_get_child(f'block{block_idx}_values')
2476+
block_items = data._f_get_child(f'block{block_idx}_items').read()
2477+
2478+
if isinstance(block_values_node, tables.VLArray):
2479+
# This is very unfortunate but we cannot slice those blocks
2480+
# on disk because they are stored as a single blob
2481+
# We load the full block and kept it cached in
2482+
# memory so that we do not reload the whole block on each
2483+
# scroll
2484+
block_values = block_values_node.read()[0]
2485+
cached_string_blocks[block_idx] = block_values
2486+
for idx_in_block, col_name in enumerate(block_items):
2487+
col_idx = col_names.index(col_name)
2488+
column_source[col_idx] = (block_idx, idx_in_block)
2489+
self._cached_string_blocks = cached_string_blocks
2490+
self._column_source = column_source
2491+
self._num_columns = len(column_source)
2492+
else:
2493+
self._cached_string_blocks = None
2494+
self._column_source = None
2495+
self._num_columns = data.block0_values.shape[1]
2496+
24532497

24542498
def shape2d(self):
2455-
return self.data.block0_values.shape
2499+
return self._num_rows, self._num_columns
24562500

24572501
def _get_axis_names(self, axis_num: int) -> list[str]:
24582502
group = self.data
@@ -2515,7 +2559,25 @@ def get_vlabels_values(self, start, stop):
25152559
return self._get_axis_labels(1, start, stop).transpose()
25162560

25172561
def get_values(self, h_start, v_start, h_stop, v_stop):
2518-
return self.data.block0_values[v_start:v_stop, h_start:h_stop]
2562+
data = self.data
2563+
attrs = data._v_attrs
2564+
if attrs.nblocks == 1:
2565+
return data.block0_values[v_start:v_stop, h_start:h_stop]
2566+
else:
2567+
import tables
2568+
block_nodes = self._block_values_nodes
2569+
# TODO: for performance, we should probably read all columns from
2570+
# the same block at once
2571+
np_columns = []
2572+
for col_idx in range(h_start, h_stop):
2573+
block_idx, idx_in_block = self._column_source[col_idx]
2574+
block_node = block_nodes[block_idx]
2575+
if isinstance(block_node, tables.VLArray):
2576+
block_node = self._cached_string_blocks[block_idx]
2577+
chunk = block_node[v_start:v_stop, idx_in_block]
2578+
np_columns.append(chunk)
2579+
2580+
return np.stack(np_columns, axis=1, dtype=object)
25192581

25202582

25212583
@adapter_for('tables.Group')
@@ -2586,6 +2648,7 @@ def get_values(self, h_start, v_start, h_stop, v_stop):
25862648

25872649

25882650
@path_adapter_for('.h5', 'tables')
2651+
@path_adapter_for('.hdf', 'tables')
25892652
class H5PathAdapter(PyTablesFileAdapter):
25902653
@classmethod
25912654
def open(cls, fpath):

0 commit comments

Comments
 (0)