diff --git a/src/emcfile/_pattern_sone.py b/src/emcfile/_pattern_sone.py index 635bcd0..554e87f 100644 --- a/src/emcfile/_pattern_sone.py +++ b/src/emcfile/_pattern_sone.py @@ -24,6 +24,7 @@ from ._h5helper import PATH_TYPE, H5Path, check_remove_groups, make_path from ._misc import pretty_size +from ._utils import concat_continous _log = logging.getLogger(__name__) @@ -218,14 +219,33 @@ def __getitem__( self, *index: Union[slice, npt.NDArray[np.bool_], npt.NDArray["np.integer[T1]"]] ) -> PatternsSOne: ... + def _get_subdataset0(self, i) -> PatternsSOne: + if len(i) == 0: + return _zeros((0, self.num_pix)) + c = concat_continous(i) + multi_s = self.multi_idx[c] + return PatternsSOne( + num_pix=self.num_pix, + ones=self.ones[i], + place_ones=np.concat([self.place_ones[s:e] for s, e in self.ones_idx[c]]), + multi=self.multi[i], + place_multi=np.concat([self.place_multi[s:e] for s, e in multi_s]), + count_multi=np.concat([self.count_multi[s:e] for s, e in multi_s]), + ) + def __getitem__( self, *index: "int | slice | npt.NDArray[np.bool_] | npt.NDArray[np.integer[T1]]", ) -> Union[npt.NDArray[np.int32], PatternsSOne]: if len(index) == 1 and isinstance(index[0], (int, np.integer)): return self._get_pattern(int(index[0])) - else: - return self._get_subdataset(index) + if len(index) == 1 and isinstance(index[0], np.ndarray): + if index[0].dtype == bool: + i = np.where(index[0])[0] + else: + i = index[0] + return self._get_subdataset0(i) + return self._get_subdataset(index) def write( self, diff --git a/tests/test_patterns.py b/tests/test_patterns.py index 1c00f21..3b33f16 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -117,7 +117,7 @@ def test_shape(big_data): assert big_data.shape == (big_data.num_data, big_data.num_pix) -def test_getitem(big_data): +def test_getitem(big_data, big_dense): for i in np.random.choice(big_data.num_data, 5): assert np.sum(big_data[i] == 1) == big_data.ones[i] @@ -126,11 +126,17 @@ def test_getitem(big_data): t0 = time.time() subdata1 = big_data[mask] subdata2 = big_data[idx] - logging.info(f"Select dataset: {(time.time()-t0)/2}s") + logging.info(f"Select dataset: {(time.time() - t0) / 2}s") assert subdata1 == subdata2 for _ in np.random.choice(subdata1.num_data, 5): assert np.all(subdata1[_] == big_data[idx[_]]) + for _ in range(10): + i = np.random.choice( + big_data.num_data, size=np.random.randint(big_data.num_data) + ) + assert np.all(big_data[i].todense() == big_dense[i]) + def test_concatenate(small_data, big_data): patterns = [ef.patterns(big_data.num_pix)] + [ @@ -205,7 +211,7 @@ def test_fileio(suffix, kargs, big_data): t0 = time.time() d_read = ef.patterns(f.name, start=start, end=end) logging.info( - f"Reading {d_read.num_data} patterns from h5 file(v1): {time.time()-t0}" + f"Reading {d_read.num_data} patterns from h5 file(v1): {time.time() - t0}" ) assert d_read == big_data[start:end] @@ -233,7 +239,7 @@ def test_write_patterns(suffix, data_list): all_data = np.concatenate(data_list) all_data.write(f1.name, overwrite=True) t1 = time.time() - t - logging.info(f"speed[single]: {all_data.nbytes * 1e-9 /t1:.2f} GB/s") + logging.info(f"speed[single]: {all_data.nbytes * 1e-9 / t1:.2f} GB/s") t = time.time() ef.write_patterns(data_list, f0.name, buffer_size=2**12, overwrite=True) @@ -354,3 +360,7 @@ def test_pattern_list(data_emc, data_h5): np.testing.assert_equal(plst.ones, np.concatenate([p0.ones, p1.ones])) plst2 = ef.PatternsSOneFileList([plst, p0]) assert plst2[: len(plst)][: len(p0)] == p0[:] + + +def test_aaa(small_data): + print(small_data)