Skip to content

Commit

Permalink
add bytesio
Browse files Browse the repository at this point in the history
  • Loading branch information
szsdk committed Aug 20, 2024
1 parent f07007f commit 02919b1
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 19 deletions.
20 changes: 18 additions & 2 deletions emcfile/_pattern_sone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import logging
from collections.abc import Callable, Iterable, Mapping, Sequence
from pathlib import Path
Expand Down Expand Up @@ -228,7 +229,7 @@ def __getitem__(

def write(
self,
path: PATH_TYPE,
path: Union[PATH_TYPE, io.BytesIO],
*,
h5version: str = "2",
overwrite: bool = False,
Expand Down Expand Up @@ -365,6 +366,18 @@ def _write_bin(datas: Sequence[PatternsSOne], path: Path, overwrite: bool) -> No
getattr(data, g).tofile(fptr)


def _write_bytes(datas: Sequence[PatternsSOne], path: io.BytesIO) -> None:
num_data = np.sum([data.num_data for data in datas])
num_pix = datas[0].num_pix

header = np.zeros((256), dtype="i4")
header[:2] = [num_data, num_pix]
path.write(header.tobytes())
for g in PatternsSOne.ATTRS:
for data in datas:
path.write(getattr(data, g).tobytes())


def _write_h5_v2(
datas: Sequence[PatternsSOne],
path: H5Path,
Expand Down Expand Up @@ -403,14 +416,17 @@ def _write_h5_v2(

def write_patterns(
datas: Sequence[PatternsSOne],
path: PATH_TYPE,
path: Union[PATH_TYPE, io.BytesIO],
*,
h5version: str = "2",
overwrite: bool = False,
buffer_size: int = 1073741824, # 2 ** 30 bytes = 1 GB
compression: Union[None, int, str] = None,
) -> None:
# TODO: performance test
if isinstance(path, io.BytesIO):
return _write_bytes(datas, path)

f = make_path(path)
if isinstance(f, Path):
if f.suffix in [".emc", ".bin"]:
Expand Down
53 changes: 38 additions & 15 deletions emcfile/_pattern_sone_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from collections.abc import Sequence
from io import BufferedReader
from io import BufferedReader, BytesIO
from pathlib import Path
from typing import Any, TypeVar, Union, cast, overload

Expand Down Expand Up @@ -41,7 +41,7 @@ def concat_continous(a: npt.NDArray[Any]) -> npt.NDArray[Any]:


def read_indexed_array(
fin: BufferedReader,
fin: Union[BufferedReader, BytesIO],
idx_con: npt.NDArray["np.integer[T1]"],
arr_idx: npt.NDArray["np.integer[T1]"],
e0: int,
Expand All @@ -51,9 +51,9 @@ def read_indexed_array(
e = arr_idx[e]
s = arr_idx[s]
fin.seek(I4 * (int(s) - e0), os.SEEK_CUR)
return np.fromfile(fin, count=int(e - s), dtype=np.int32), int(e) - int(
arr_idx[-1]
)
return np.frombuffer(
fin.read(int(e - s) * I4), count=int(e - s), dtype=np.int32
), int(e) - int(arr_idx[-1])

ans = []
for s, e in idx_con:
Expand All @@ -69,8 +69,7 @@ def read_indexed_array(


def read_patterns(
fn: Path,
fin: BufferedReader,
fin: Union[BufferedReader, BytesIO],
idx_con: npt.NDArray["np.integer[T1]"],
ones_idx: npt.NDArray["np.integer[T1]"],
multi_idx: npt.NDArray["np.integer[T1]"],
Expand All @@ -84,17 +83,15 @@ def read_patterns(
if fin.read(1):
total = seek_start + place_ones.nbytes + place_multi.nbytes + count_multi.nbytes
_log.error(
"START: %d, place_ones: %d, place_multi: %d, count_multi: %d, total=%d;"
"filesize = %d; e0: %d",
"START: %d, place_ones: %d, place_multi: %d, count_multi: %d, total=%d; e0: %d",
seek_start,
place_ones.nbytes,
place_multi.nbytes,
count_multi.nbytes,
total,
fn.stat().st_size,
e0,
)
raise ValueError(f"Error when parsing {fn}")
raise ValueError("Error when parsing")
return place_ones.view("u4"), place_multi.view("u4"), count_multi


Expand Down Expand Up @@ -207,12 +204,40 @@ def _read_patterns(
) -> tuple[npt.NDArray[np.uint32], npt.NDArray[np.uint32], npt.NDArray[np.int32]]:
self.init_idx()
with self._fn.open("rb") as fin:
return read_patterns(self._fn, fin, idx_con, self.ones_idx, self.multi_idx)
return read_patterns(fin, idx_con, self.ones_idx, self.multi_idx)

def open(self) -> PatternsSOneEMCReadBuffer:
return PatternsSOneEMCReadBuffer(self._fn)


class _PatternsSOneBytes(PatternsSOneFile):
HEADER_BYTES = 1024

def __init__(self, fn: BytesIO):
self._fn = fn
self._fn.seek(0)
self.num_data = np.frombuffer(self._fn.read(4), dtype=np.int32, count=1)[0]
self.num_pix = np.frombuffer(self._fn.read(4), dtype=np.int32, count=1)[0]
self.ndim = 2
self.shape = (self.num_data, self.num_pix)
self._init_idx = False

def _read_ones_multi(self) -> tuple[npt.NDArray[np.uint32], npt.NDArray[np.uint32]]:
self._fn.seek(1024)
return np.frombuffer(
self._fn.read(I4 * self.num_data), dtype=np.int32, count=self.num_data
), np.frombuffer(
self._fn.read(I4 * self.num_data), dtype=np.int32, count=self.num_data
)

def _read_patterns(
self, idx_con: npt.NDArray["np.integer[T1]"]
) -> tuple[npt.NDArray[np.uint32], npt.NDArray[np.uint32], npt.NDArray[np.int32]]:
self.init_idx()
self._fn.seek(0)
return read_patterns(self._fn, idx_con, self.ones_idx, self.multi_idx)


class PatternsSOneEMCReadBuffer(PatternsSOneEMC):
def __init__(self, fn: "str | Path"):
super().__init__(fn)
Expand All @@ -238,9 +263,7 @@ def _read_patterns(
self, idx_con: npt.NDArray["np.integer[T1]"]
) -> tuple[npt.NDArray[np.uint32], npt.NDArray[np.uint32], npt.NDArray[np.int32]]:
self.init_idx()
return read_patterns(
self._fn, self._file_handle, idx_con, self.ones_idx, self.multi_idx
)
return read_patterns(self._file_handle, idx_con, self.ones_idx, self.multi_idx)


def read_indexed_array_h5(
Expand Down
7 changes: 5 additions & 2 deletions emcfile/_patterns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
from collections.abc import Sequence
from pathlib import Path
from typing import Optional, TypeVar, cast
Expand All @@ -11,7 +12,7 @@
from ._h5helper import PATH_TYPE, H5Path
from ._misc import divide_range
from ._pattern_sone import SPARSE_PATTERN, PatternsSOne, _full, _ones, _zeros
from ._pattern_sone_file import file_patterns
from ._pattern_sone_file import _PatternsSOneBytes, file_patterns

__all__ = ["patterns"]

Expand Down Expand Up @@ -69,6 +70,7 @@ def _from_sparse_patterns(src: Sequence[SPARSE_PATTERN]) -> PatternsSOne:

def patterns(
src: "PATH_TYPE"
"| io.BytesIO"
"| npt.NDArray[np.integer[T1]]"
"| spmatrix"
"| int"
Expand Down Expand Up @@ -124,7 +126,8 @@ def patterns(
- Another `PatternsSOne` object: The function returns a subset or a copy of the input object,
starting from index `start` and ending at index `end`.
"""

if isinstance(src, io.BytesIO):
return _PatternsSOneBytes(src)[start:end]
if isinstance(src, (str, Path, H5Path)):
return file_patterns(src)[start:end]
if isinstance(src, PatternsSOne):
Expand Down
7 changes: 7 additions & 0 deletions emcfile/tests/test_patterns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import gc
import io
import itertools
import logging
import tempfile
Expand Down Expand Up @@ -208,6 +209,12 @@ def test_fileio(suffix, kargs, big_data):
assert d_read == big_data[start:end]


def test_bytesio(small_data, tmp_path: Path):
bio = io.BytesIO()
small_data.write(bio)
assert ef.patterns(bio) == small_data


def gen_write_patterns():
data = ef.patterns(np.random.randint(0, 10, size=(16, 256)))
for i in 2 ** np.arange(0, 10, 2):
Expand Down

0 comments on commit 02919b1

Please sign in to comment.