Skip to content

Commit

Permalink
feat: avoid build csr in getitem
Browse files Browse the repository at this point in the history
  • Loading branch information
szsdk committed Feb 10, 2025
1 parent eb22bf6 commit d14f8ae
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
24 changes: 22 additions & 2 deletions src/emcfile/_pattern_sone.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ._h5helper import PATH_TYPE, H5Path, check_remove_groups, make_path
from ._misc import pretty_size
from ._utils import concat_continous

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -218,14 +219,33 @@ def __getitem__(
self, *index: Union[slice, npt.NDArray[np.bool_], npt.NDArray["np.integer[T1]"]]
) -> PatternsSOne: ...

def _get_subdataset0(self, i) -> PatternsSOne:
if len(i) == 0:
return _zeros((0, self.num_pix))
c = concat_continous(i)
multi_s = self.multi_idx[c]
return PatternsSOne(
num_pix=self.num_pix,
ones=self.ones[i],
place_ones=np.concat([self.place_ones[s:e] for s, e in self.ones_idx[c]]),
multi=self.multi[i],
place_multi=np.concat([self.place_multi[s:e] for s, e in multi_s]),
count_multi=np.concat([self.count_multi[s:e] for s, e in multi_s]),
)

def __getitem__(
self,
*index: "int | slice | npt.NDArray[np.bool_] | npt.NDArray[np.integer[T1]]",
) -> Union[npt.NDArray[np.int32], PatternsSOne]:
if len(index) == 1 and isinstance(index[0], (int, np.integer)):
return self._get_pattern(int(index[0]))
else:
return self._get_subdataset(index)
if len(index) == 1 and isinstance(index[0], np.ndarray):
if index[0].dtype == bool:
i = np.where(index[0])[0]
else:
i = index[0]
return self._get_subdataset0(i)
return self._get_subdataset(index)

def write(
self,
Expand Down
18 changes: 14 additions & 4 deletions tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_shape(big_data):
assert big_data.shape == (big_data.num_data, big_data.num_pix)


def test_getitem(big_data):
def test_getitem(big_data, big_dense):
for i in np.random.choice(big_data.num_data, 5):
assert np.sum(big_data[i] == 1) == big_data.ones[i]

Expand All @@ -126,11 +126,17 @@ def test_getitem(big_data):
t0 = time.time()
subdata1 = big_data[mask]
subdata2 = big_data[idx]
logging.info(f"Select dataset: {(time.time()-t0)/2}s")
logging.info(f"Select dataset: {(time.time() - t0) / 2}s")
assert subdata1 == subdata2
for _ in np.random.choice(subdata1.num_data, 5):
assert np.all(subdata1[_] == big_data[idx[_]])

for _ in range(10):
i = np.random.choice(
big_data.num_data, size=np.random.randint(big_data.num_data)
)
assert np.all(big_data[i].todense() == big_dense[i])


def test_concatenate(small_data, big_data):
patterns = [ef.patterns(big_data.num_pix)] + [
Expand Down Expand Up @@ -205,7 +211,7 @@ def test_fileio(suffix, kargs, big_data):
t0 = time.time()
d_read = ef.patterns(f.name, start=start, end=end)
logging.info(
f"Reading {d_read.num_data} patterns from h5 file(v1): {time.time()-t0}"
f"Reading {d_read.num_data} patterns from h5 file(v1): {time.time() - t0}"
)
assert d_read == big_data[start:end]

Expand Down Expand Up @@ -233,7 +239,7 @@ def test_write_patterns(suffix, data_list):
all_data = np.concatenate(data_list)
all_data.write(f1.name, overwrite=True)
t1 = time.time() - t
logging.info(f"speed[single]: {all_data.nbytes * 1e-9 /t1:.2f} GB/s")
logging.info(f"speed[single]: {all_data.nbytes * 1e-9 / t1:.2f} GB/s")

t = time.time()
ef.write_patterns(data_list, f0.name, buffer_size=2**12, overwrite=True)
Expand Down Expand Up @@ -354,3 +360,7 @@ def test_pattern_list(data_emc, data_h5):
np.testing.assert_equal(plst.ones, np.concatenate([p0.ones, p1.ones]))
plst2 = ef.PatternsSOneFileList([plst, p0])
assert plst2[: len(plst)][: len(p0)] == p0[:]


def test_aaa(small_data):
print(small_data)

0 comments on commit d14f8ae

Please sign in to comment.