Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 0fd9c4f

Browse files
committed
better indexing
1 parent 453766f commit 0fd9c4f

File tree

3 files changed

+103
-67
lines changed

3 files changed

+103
-67
lines changed

ecml_tools/data.py

+70-45
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,36 @@
2020

2121
import ecml_tools
2222

23-
from .indexing import apply_index_to_slices_changes, index_to_slices, length_to_slices
23+
from .indexing import (
24+
apply_index_to_slices_changes,
25+
index_to_slices,
26+
length_to_slices,
27+
update_tuple,
28+
)
2429

2530
LOG = logging.getLogger(__name__)
2631

2732
__all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"]
2833

2934
DEBUG_ZARR_LOADING = int(os.environ.get("DEBUG_ZARR_LOADING", "0"))
3035

36+
DEPTH = 0
37+
38+
39+
def debug_indexing(method):
40+
def wrapper(self, index):
41+
global DEPTH
42+
if isinstance(index, tuple):
43+
print(" " * DEPTH, "->", self, method.__name__, index)
44+
DEPTH += 1
45+
result = method(self, index)
46+
DEPTH -= 1
47+
if isinstance(index, tuple):
48+
print(" " * DEPTH, "<-", self, method.__name__, result.shape)
49+
return result
50+
51+
return wrapper
52+
3153

3254
def debug_zarr_loading(on_off):
3355
global DEBUG_ZARR_LOADING
@@ -192,6 +214,7 @@ def metadata_specific(self, **kwargs):
192214
def __repr__(self):
193215
return self.__class__.__name__ + "()"
194216

217+
@debug_indexing
195218
def _get_tuple(self, n):
196219
raise NotImplementedError(
197220
f"Tuple not supported: {n} (class {self.__class__.__name__})"
@@ -344,6 +367,7 @@ def __init__(self, path):
344367
def __len__(self):
345368
return self.data.shape[0]
346369

370+
@debug_indexing
347371
def __getitem__(self, n):
348372
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
349373
return self._getitem_extended(n)
@@ -638,6 +662,7 @@ class Concat(Combined):
638662
def __len__(self):
639663
return sum(len(i) for i in self.datasets)
640664

665+
@debug_indexing
641666
def _get_tuple(self, index):
642667
index, changes = index_to_slices(index, self.shape)
643668
result = []
@@ -661,6 +686,7 @@ def _get_tuple(self, index):
661686

662687
return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes)
663688

689+
@debug_indexing
664690
def __getitem__(self, n):
665691
if isinstance(n, tuple):
666692
return self._get_tuple(n)
@@ -675,6 +701,7 @@ def __getitem__(self, n):
675701
k += 1
676702
return self.datasets[k][n]
677703

704+
@debug_indexing
678705
def _get_slice(self, s):
679706
result = []
680707

@@ -742,24 +769,30 @@ def shape(self):
742769
assert False not in result, result
743770
return result
744771

772+
@debug_indexing
745773
def _get_tuple(self, index):
774+
print(index, self.shape)
746775
index, changes = index_to_slices(index, self.shape)
747-
selected = index[self.axis]
748776
lengths = [d.shape[self.axis] for d in self.datasets]
749-
slices = length_to_slices(selected, lengths)
750-
print("per_dataset_index", slices)
777+
slices = length_to_slices(index[self.axis], lengths)
751778

752-
result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
779+
print("SLICES", slices, self.axis, index, lengths)
780+
before = index[: self.axis]
753781

754-
x = tuple([slice(None)] * self.axis + [selected])
782+
result = [
783+
d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None
784+
]
785+
print([d.shape for d in result])
786+
result = np.concatenate(result, axis=self.axis)
787+
print(result.shape)
755788

756-
return apply_index_to_slices_changes(
757-
np.concatenate(result, axis=self.axis)[x], changes
758-
)
789+
return apply_index_to_slices_changes(result, changes)
759790

791+
@debug_indexing
760792
def _get_slice(self, s):
761793
return np.stack([self[i] for i in range(*s.indices(self._len))])
762794

795+
@debug_indexing
763796
def __getitem__(self, n):
764797
if isinstance(n, tuple):
765798
return self._get_tuple(n)
@@ -810,42 +843,22 @@ def check_same_variables(self, d1, d2):
810843
def __len__(self):
811844
return len(self.datasets[0])
812845

846+
@debug_indexing
813847
def _get_tuple(self, index):
814-
print("Join._get_tuple", index)
815-
assert len(index) > 1, index
816-
817848
index, changes = index_to_slices(index, self.shape)
818-
819-
selected_variables = index[1]
820-
821-
index = list(index)
822-
index[1] = slice(None)
823-
index = tuple(index)
824-
print("Join._get_tuple", index)
849+
index, previous = update_tuple(index, 1, slice(None))
825850

826851
# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
827852
result = [d[index] for d in self.datasets]
828853

829-
print(
830-
"Join._get_tuple",
831-
self.shape,
832-
[r.shape for r in result],
833-
selected_variables,
834-
changes,
835-
)
836854
result = np.concatenate(result, axis=1)
837-
print("Join._get_tuple", result.shape)
838-
839-
# raise NotImplementedError()
840-
841-
# result = np.concatenate(result)
842-
# result = np.stack(result)
843-
844-
return apply_index_to_slices_changes(result[:, selected_variables], changes)
855+
return apply_index_to_slices_changes(result[:, previous], changes)
845856

857+
@debug_indexing
846858
def _get_slice(self, s):
847859
return np.stack([self[i] for i in range(*s.indices(self._len))])
848860

861+
@debug_indexing
849862
def __getitem__(self, n):
850863
if isinstance(n, tuple):
851864
return self._get_tuple(n)
@@ -931,10 +944,14 @@ def __init__(self, dataset, indices):
931944

932945
self.dataset = dataset
933946
self.indices = list(indices)
947+
self.slice = _make_slice_or_index_from_list_or_tuple(self.indices)
948+
assert isinstance(self.slice, slice)
949+
print("SUBSET", self.slice)
934950

935951
# Forward other properties to the super dataset
936952
super().__init__(dataset)
937953

954+
@debug_indexing
938955
def __getitem__(self, n):
939956
if isinstance(n, tuple):
940957
return self._get_tuple(n)
@@ -945,25 +962,22 @@ def __getitem__(self, n):
945962
n = self.indices[n]
946963
return self.dataset[n]
947964

965+
@debug_indexing
948966
def _get_slice(self, s):
949967
# TODO: check if the indices can be simplified to a slice
950968
# the time checking maybe be longer than the time saved
951969
# using a slice
952970
indices = [self.indices[i] for i in range(*s.indices(self._len))]
953971
return np.stack([self.dataset[i] for i in indices])
954972

973+
@debug_indexing
955974
def _get_tuple(self, n):
956-
first, rest = n[0], n[1:]
957-
958-
if isinstance(first, int):
959-
return self.dataset[(self.indices[first],) + rest]
960-
961-
if isinstance(first, slice):
962-
indices = tuple(self.indices[i] for i in range(*first.indices(self._len)))
963-
indices = _make_slice_or_index_from_list_or_tuple(indices)
964-
return self.dataset[(indices,) + rest]
965-
966-
raise NotImplementedError(f"Only int and slice supported not {type(first)}")
975+
index, changes = index_to_slices(n, self.shape)
976+
index, previous = update_tuple(index, 0, self.slice)
977+
result = self.dataset[index]
978+
result = result[previous]
979+
result = apply_index_to_slices_changes(result, changes)
980+
return result
967981

968982
def __len__(self):
969983
return len(self.indices)
@@ -1003,6 +1017,17 @@ def __init__(self, dataset, indices):
10031017
# Forward other properties to the main dataset
10041018
super().__init__(dataset)
10051019

1020+
@debug_indexing
1021+
def _get_tuple(self, index):
1022+
index, changes = index_to_slices(index, self.shape)
1023+
index, previous = update_tuple(index, 1, slice(None))
1024+
result = self.dataset[index]
1025+
result = result[:, self.indices]
1026+
result = result[:, previous]
1027+
result = apply_index_to_slices_changes(result, changes)
1028+
return result
1029+
1030+
@debug_indexing
10061031
def __getitem__(self, n):
10071032
if isinstance(n, tuple):
10081033
return self._get_tuple(n)

ecml_tools/indexing.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,29 @@ def length_to_slices(index, lengths):
8282
"""
8383
total = sum(lengths)
8484
start, stop, step = index.indices(total)
85-
print(start, stop, step)
8685

87-
# TODO: combine loops
88-
p = []
86+
result = []
87+
8988
pos = 0
9089
for length in lengths:
9190
end = pos + length
92-
p.append((pos, end))
93-
pos = end
9491

95-
result = []
92+
b = max(pos, start)
93+
e = min(end, stop)
9694

97-
for i, (s, e) in enumerate(p):
98-
pos = s
99-
if s % step:
100-
s = s + step - s % step
101-
assert s % step == 0
102-
assert s >= pos
103-
if max(s, start) <= min(e, stop):
104-
result.append((i, slice(s - pos, e - pos, step)))
105-
else:
106-
result.append(None)
95+
p = None
96+
if b <= e:
97+
if (b - start) % step != 0:
98+
b = b + step - (b - start) % step
99+
b -= pos
100+
e -= pos
101+
102+
if 0 <= b < e:
103+
p = slice(b, e, step)
104+
105+
result.append(p)
106+
107+
pos = end
107108

108109
return result
109110

tests/test_data.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,16 @@ def __init__(self, ds):
179179
self.np = ds[:] # Numpy array
180180

181181
assert self.ds.shape == self.np.shape
182+
assert (self.ds == self.np).all()
182183

183184
def __getitem__(self, index):
184-
assert (self.ds[index] == self.np[index]).all()
185+
if self.ds[index] is None:
186+
assert False, (self.ds, index)
187+
188+
if not (self.ds[index] == self.np[index]).all():
189+
# print("DS", self.ds[index])
190+
# print("NP", self.np[index])
191+
assert (self.ds[index] == self.np[index]).all()
185192

186193

187194
def indexing(ds):
@@ -310,7 +317,7 @@ def test_concat():
310317
def test_join_1():
311318
ds = open_dataset(
312319
"test-2021-2021-6h-o96-abcd",
313-
"test-2021-2021-6h-o96-efg",
320+
"test-2021-2021-6h-o96-efgh",
314321
)
315322

316323
assert isinstance(ds, Join)
@@ -330,6 +337,7 @@ def test_join_1():
330337
_(date, "e"),
331338
_(date, "f"),
332339
_(date, "g"),
340+
_(date, "h"),
333341
]
334342
)
335343
assert (row == expect).all()
@@ -338,7 +346,7 @@ def test_join_1():
338346

339347
assert (ds.dates == np.array(dates, dtype="datetime64")).all()
340348

341-
assert ds.variables == ["a", "b", "c", "d", "e", "f", "g"]
349+
assert ds.variables == ["a", "b", "c", "d", "e", "f", "g", "h"]
342350
assert ds.name_to_index == {
343351
"a": 0,
344352
"b": 1,
@@ -347,9 +355,10 @@ def test_join_1():
347355
"e": 4,
348356
"f": 5,
349357
"g": 6,
358+
"h": 7,
350359
}
351360

352-
assert ds.shape == (365 * 4, 7, 1, VALUES)
361+
assert ds.shape == (365 * 4, 8, 1, VALUES)
353362

354363
same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd")
355364
slices(ds)
@@ -1275,6 +1284,7 @@ def test_ensemble_1():
12751284

12761285
dates = []
12771286
date = datetime.datetime(2021, 1, 1)
1287+
indexing(ds)
12781288

12791289
for row in ds:
12801290
expect = make_row(
@@ -1299,7 +1309,7 @@ def test_ensemble_1():
12991309
assert ds.shape == (365 * 4, 4, 11, VALUES)
13001310
# same_stats(ds, open_dataset("test-2021-2021-6h-o96-abcd"), "abcd")
13011311
slices(ds)
1302-
indexing(ds)
1312+
13031313
metadata(ds)
13041314

13051315

@@ -1432,4 +1442,4 @@ def test_statistics():
14321442

14331443

14341444
if __name__ == "__main__":
1435-
test_constructor_5()
1445+
test_ensemble_1()

0 commit comments

Comments
 (0)