Skip to content

Commit

Permalink
add: compress opts
Browse files Browse the repository at this point in the history
  • Loading branch information
szsdk committed Apr 15, 2024
1 parent cb4c597 commit e712671
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions emcfile/_h5helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,21 @@ def _write_single(
v: _T,
overwrite: bool,
verbose: bool,
compression: Optional[str],
compression_opts: Union[None, str, int],
) -> None:
if isinstance(v, np.ndarray):
if not isinstance(group, (h5py.File, h5py.Group)):
raise Exception(f"Cannot write type {type(v)} to {type(group)}")
if _check_exists(k, group, overwrite, verbose):
del group[k]
group.create_dataset(k, data=v)
group.create_dataset(
k, data=v, compression=compression, compression_opts=compression_opts
)
elif isinstance(v, dict):
if not isinstance(group, (h5py.File, h5py.Group)):
raise Exception(f"Cannot write type {type(v)} to {type(group)}")
_write_group(group, k, v, overwrite, verbose)
_write_group(group, k, v, overwrite, verbose, compression, compression_opts)
else:
group.attrs[k] = v

Expand All @@ -318,6 +322,8 @@ def _write_group(
obj: _T,
overwrite: bool,
verbose: bool,
compression: Optional[str],
compression_opts: Union[None, str, int],
) -> None:
if not isinstance(obj, dict):
raise Exception(f"Cannot write type {type(obj)}")
Expand All @@ -334,7 +340,7 @@ def _write_group(
if isinstance(g, h5py.Datatype):
raise NotImplementedError("The support for h5py.Datatype is not implemented.")
for k, v in obj.items():
_write_single(g, k, v, overwrite, verbose)
_write_single(g, k, v, overwrite, verbose, compression, compression_opts)
if obj_dot is not None:
obj["."] = obj_dot

Expand All @@ -344,6 +350,8 @@ def write_obj_h5(
obj: _T,
overwrite: bool = False,
verbose: bool = False,
compression: Optional[str] = None,
compression_opts: Union[None, str, int] = None,
) -> None:
"""Save dict `obj` to a `h5py.File`. The `np.ndarray` values are saved as
h5 datasets. Others are saved in attributes of h5 group `group_name`.
Expand All @@ -356,7 +364,7 @@ def write_obj_h5(
"""
f = h5path(fn)
with f.open("a") as fp:
_write_group(fp, f.gn, obj, overwrite, verbose)
_write_group(fp, f.gn, obj, overwrite, verbose, compression, compression_opts)


def _read_group(g: Union[h5py.File, h5py.Group]) -> dict[str, _T]:
Expand Down

0 comments on commit e712671

Please sign in to comment.