Skip to content

Commit a3eb995

Browse files
committed
fix #842 : included scalars when dumping or loading a Session object (hdf5 + pickle formats)
1 parent 4a2bd7b commit a3eb995

File tree

6 files changed

+248
-115
lines changed

6 files changed

+248
-115
lines changed

doc/source/changes/version_0_33.rst.inc

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ New features
4949
Miscellaneous improvements
5050
^^^^^^^^^^^^^^^^^^^^^^^^^^
5151

52-
* improved something.
52+
* scalar objects (i.e of type int, float, bool, string, date, time or datetime) belonging to a session
53+
are now also saved and loaded when using the HDF5 or pickle format (closes :issue:`842`).
5354

5455

5556
Fixes

larray/core/session.py

+170-87
Large diffs are not rendered by default.

larray/inout/common.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
from __future__ import absolute_import, print_function
22

33
import os
4+
from datetime import date, time, datetime
45
from collections import OrderedDict
56

7+
from larray.util.compat import bytes, unicode
8+
from larray.core.axis import Axis
9+
from larray.core.group import Group
610
from larray.core.array import Array
711

812

13+
# all formats
14+
_supported_larray_types = (Axis, Group, Array)
15+
16+
# only for HDF5 and pickle formats
17+
# support list, tuple and dict?
18+
# replace unicode by str when Python 2.7 will no longer be supported
19+
_supported_scalars_types = (int, float, bool, bytes, unicode, date, time, datetime)
20+
_supported_types = _supported_larray_types + _supported_scalars_types
21+
_supported_typenames = {cls.__name__ for cls in _supported_types}
22+
23+
924
def _get_index_col(nb_axes=None, index_col=None, wide=True):
1025
if not wide:
1126
if nb_axes is not None or index_col is not None:

larray/inout/hdf.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44

55
import numpy as np
6+
import pandas as pd
67
from pandas import HDFStore
78

89
from larray.core.array import Array
@@ -12,21 +13,27 @@
1213
from larray.core.metadata import Metadata
1314
from larray.util.misc import LHDFStore
1415
from larray.inout.session import register_file_handler
15-
from larray.inout.common import FileHandler
16+
from larray.inout.common import FileHandler, _supported_typenames, _supported_scalars_types
1617
from larray.inout.pandas import df_asarray
1718
from larray.example import get_example_filepath
1819

1920

21+
# for backward compatibility (larray < 0.29) but any object read from an hdf file should have
22+
# an attribute 'type'
23+
def _get_type_from_attrs(attrs):
24+
return attrs.type if 'type' in attrs else 'Array'
25+
26+
2027
def read_hdf(filepath_or_buffer, key, fill_value=nan, na=nan, sort_rows=False, sort_columns=False,
2128
name=None, **kwargs):
22-
r"""Reads an axis or group or array named key from a HDF5 file in filepath (path+name)
29+
r"""Reads a scalar or an axis or group or array named key from a HDF5 file in filepath (path+name)
2330
2431
Parameters
2532
----------
2633
filepath_or_buffer : str or pandas.HDFStore
2734
Path and name where the HDF5 file is stored or a HDFStore object.
2835
key : str or Group
29-
Name of the array.
36+
Name of the scalar or axis or group or array.
3037
fill_value : scalar or Array, optional
3138
Value used to fill cells corresponding to label combinations which are not present in the input.
3239
Defaults to NaN.
@@ -70,11 +77,14 @@ def read_hdf(filepath_or_buffer, key, fill_value=nan, na=nan, sort_rows=False, s
7077
key = _translate_group_key_hdf(key)
7178
res = None
7279
with LHDFStore(filepath_or_buffer) as store:
73-
pd_obj = store.get(key)
80+
try:
81+
pd_obj = store.get(key)
82+
except KeyError:
83+
filepath = filepath_or_buffer if isinstance(filepath_or_buffer, HDFStore) else store.filename
84+
raise KeyError('No item with name {} has been found in file {}'.format(key, filepath))
7485
attrs = store.get_storer(key).attrs
7586
writer = attrs.writer if 'writer' in attrs else None
76-
# for backward compatibility but any object read from an hdf file should have an attribute 'type'
77-
_type = attrs.type if 'type' in attrs else 'Array'
87+
_type = _get_type_from_attrs(attrs)
7888
_meta = attrs.metadata if 'metadata' in attrs else None
7989
if _type == 'Array':
8090
# cartesian product is not necessary if the array was written by LArray
@@ -110,6 +120,10 @@ def read_hdf(filepath_or_buffer, key, fill_value=nan, na=nan, sort_rows=False, s
110120
key = np.char.decode(key, 'utf-8')
111121
axis = read_hdf(filepath_or_buffer, attrs['axis_key'])
112122
res = LGroup(key=key, name=name, axis=axis)
123+
elif _type in _supported_typenames:
124+
res = pd_obj.values
125+
assert len(res) == 1
126+
res = res[0]
113127
return res
114128

115129

@@ -126,36 +140,37 @@ def _open_for_write(self):
126140

127141
def list_items(self):
128142
keys = [key.strip('/') for key in self.handle.keys()]
143+
items = [(key, _get_type_from_attrs(self.handle.get_storer(key).attrs)) for key in keys if '/' not in key]
144+
# ---- for backward compatibility (LArray < 0.33) ----
129145
# axes
130-
items = [(key.split('/')[-1], 'Axis') for key in keys if '__axes__' in key]
146+
items += [(key.split('/')[-1], 'Axis_Backward_Comp') for key in keys if '__axes__' in key]
131147
# groups
132-
items += [(key.split('/')[-1], 'Group') for key in keys if '__groups__' in key]
133-
# arrays
134-
items += [(key, 'Array') for key in keys if '/' not in key]
148+
items += [(key.split('/')[-1], 'Group_Backward_Comp') for key in keys if '__groups__' in key]
135149
return items
136150

137-
def _read_item(self, key, type, *args, **kwargs):
138-
if type == 'Array':
151+
def _read_item(self, key, typename, *args, **kwargs):
152+
if typename in _supported_typenames:
139153
hdf_key = '/' + key
140-
elif type == 'Axis':
154+
# ---- for backward compatibility (LArray < 0.33) ----
155+
elif typename == 'Axis_Backward_Comp':
141156
hdf_key = '__axes__/' + key
142-
elif type == 'Group':
157+
elif typename == 'Group_Backward_Comp':
143158
hdf_key = '__groups__/' + key
144159
else:
145160
raise TypeError()
146161
return read_hdf(self.handle, hdf_key, *args, **kwargs)
147162

148163
def _dump_item(self, key, value, *args, **kwargs):
149-
if isinstance(value, Array):
150-
hdf_key = '/' + key
151-
value.to_hdf(self.handle, hdf_key, *args, **kwargs)
152-
elif isinstance(value, Axis):
153-
hdf_key = '__axes__/' + key
164+
hdf_key = '/' + key
165+
if isinstance(value, (Array, Axis)):
154166
value.to_hdf(self.handle, hdf_key, *args, **kwargs)
155167
elif isinstance(value, Group):
156-
hdf_key = '__groups__/' + key
157-
hdf_axis_key = '__axes__/' + value.axis.name
168+
hdf_axis_key = '/' + value.axis.name
158169
value.to_hdf(self.handle, hdf_key, hdf_axis_key, *args, **kwargs)
170+
elif isinstance(value, _supported_scalars_types):
171+
s = pd.Series(data=value)
172+
self.handle.put(hdf_key, s)
173+
self.handle.get_storer(hdf_key).attrs.type = type(value).__name__
159174
else:
160175
raise TypeError()
161176

larray/inout/pickle.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from larray.core.metadata import Metadata
1010
from larray.util.compat import pickle
1111
from larray.inout.session import register_file_handler
12-
from larray.inout.common import FileHandler
12+
from larray.inout.common import FileHandler, _supported_types, _supported_typenames, _supported_scalars_types
1313

1414

1515
@register_file_handler('pickle', ['pkl', 'pickle'])
@@ -25,22 +25,25 @@ def _open_for_write(self):
2525
self.data = OrderedDict()
2626

2727
def list_items(self):
28+
# scalar
29+
items = [(key, type(value).__name__) for key, value in self.data.items()
30+
if isinstance(value, _supported_scalars_types)]
2831
# axes
29-
items = [(key, 'Axis') for key, value in self.data.items() if isinstance(value, Axis)]
32+
items += [(key, 'Axis') for key, value in self.data.items() if isinstance(value, Axis)]
3033
# groups
3134
items += [(key, 'Group') for key, value in self.data.items() if isinstance(value, Group)]
3235
# arrays
3336
items += [(key, 'Array') for key, value in self.data.items() if isinstance(value, Array)]
3437
return items
3538

36-
def _read_item(self, key, type, *args, **kwargs):
37-
if type in {'Array', 'Axis', 'Group'}:
39+
def _read_item(self, key, typename, *args, **kwargs):
40+
if typename in _supported_typenames:
3841
return self.data[key]
3942
else:
4043
raise TypeError()
4144

4245
def _dump_item(self, key, value, *args, **kwargs):
43-
if isinstance(value, (Array, Axis, Group)):
46+
if isinstance(value, _supported_types):
4447
self.data[key] = value
4548
else:
4649
raise TypeError()

larray/tests/test_session.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import os
44
import shutil
5+
from datetime import date, time, datetime
56

67
import numpy as np
78
import pandas as pd
89
import pytest
910

1011
from larray.tests.common import (assert_array_nan_equal, inputpath, tmp_path, meta,
1112
needs_xlwings, needs_pytables, needs_xlrd)
13+
from larray.inout.common import _supported_scalars_types
1214
from larray import (Session, Axis, Array, Group, isnan, zeros_like, ndtest, ones_like, ones, full,
1315
local_arrays, global_arrays, arrays)
1416
from larray.util.compat import pickle, PY2
@@ -178,7 +180,7 @@ def test_names(session):
178180
def _test_io(fpath, session, meta, engine):
179181
is_excel_or_csv = 'excel' in engine or 'csv' in engine
180182

181-
kind = Array if is_excel_or_csv else (Axis, Group, Array)
183+
kind = Array if is_excel_or_csv else (Axis, Group, Array) + _supported_scalars_types
182184
session = session.filter(kind=kind)
183185

184186
session.meta = meta
@@ -226,8 +228,21 @@ def _test_io(fpath, session, meta, engine):
226228
assert s.meta == meta
227229

228230

231+
def _add_scalars_to_session(s):
232+
# 's' for scalar
233+
s['s_int'] = 5
234+
s['s_float'] = 5.5
235+
s['s_bool'] = True
236+
s['s_str'] = 'string'
237+
s['s_date'] = date(2020, 1, 10)
238+
s['s_time'] = time(11, 23, 54)
239+
s['s_datetime'] = datetime(2020, 1, 10, 11, 23, 54)
240+
return s
241+
242+
229243
@needs_pytables
230244
def test_h5_io(tmpdir, session, meta):
245+
session = _add_scalars_to_session(session)
231246
fpath = tmp_path(tmpdir, 'test_session.h5')
232247
_test_io(fpath, session, meta, engine='pandas_hdf')
233248

@@ -276,6 +291,7 @@ def test_csv_io(tmpdir, session, meta):
276291

277292

278293
def test_pickle_io(tmpdir, session, meta):
294+
session = _add_scalars_to_session(session)
279295
fpath = tmp_path(tmpdir, 'test_session.pkl')
280296
_test_io(fpath, session, meta, engine='pickle')
281297

0 commit comments

Comments
 (0)