Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit b3973e7

Browse files
committed
Merge branch 'develop' of https://github.com/ecmwf-lab/ecml-tools into ens
2 parents bdf2bfb + bb9164a commit b3973e7

File tree

4 files changed

+570
-272
lines changed

4 files changed

+570
-272
lines changed

ecml_tools/data.py

+123-18
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,45 @@
2020

2121
import ecml_tools
2222

23+
from .indexing import (
24+
apply_index_to_slices_changes,
25+
index_to_slices,
26+
length_to_slices,
27+
update_tuple,
28+
)
29+
2330
LOG = logging.getLogger(__name__)
2431

2532
__all__ = ["open_dataset", "open_zarr", "debug_zarr_loading"]
2633

2734
DEBUG_ZARR_LOADING = int(os.environ.get("DEBUG_ZARR_LOADING", "0"))
2835

36+
DEPTH = 0
37+
38+
39+
def _debug_indexing(method):
40+
def wrapper(self, index):
41+
global DEPTH
42+
if isinstance(index, tuple):
43+
print(" " * DEPTH, "->", self, method.__name__, index)
44+
DEPTH += 1
45+
result = method(self, index)
46+
DEPTH -= 1
47+
if isinstance(index, tuple):
48+
print(" " * DEPTH, "<-", self, method.__name__, result.shape)
49+
return result
50+
51+
return wrapper
52+
53+
54+
if True:
55+
56+
def debug_indexing(x):
57+
return x
58+
59+
else:
60+
debug_indexing = _debug_indexing
61+
2962

3063
def debug_zarr_loading(on_off):
3164
global DEBUG_ZARR_LOADING
@@ -190,11 +223,18 @@ def metadata_specific(self, **kwargs):
190223
def __repr__(self):
191224
return self.__class__.__name__ + "()"
192225

226+
@debug_indexing
193227
def _get_tuple(self, n):
194-
raise NotImplementedError(f"Tuple not supported: {n} (class {self.__class__.__name__})")
228+
raise NotImplementedError(
229+
f"Tuple not supported: {n} (class {self.__class__.__name__})"
230+
)
195231

196232

197233
class Source:
234+
"""
235+
Class used to follow the provenance of a data point.
236+
"""
237+
198238
def __init__(self, dataset, index, source=None, info=None):
199239
self.dataset = dataset
200240
self.index = index
@@ -340,6 +380,7 @@ def __init__(self, path):
340380
def __len__(self):
341381
return self.data.shape[0]
342382

383+
@debug_indexing
343384
def __getitem__(self, n):
344385
if isinstance(n, tuple) and any(not isinstance(i, (int, slice)) for i in n):
345386
return self._getitem_extended(n)
@@ -352,8 +393,7 @@ def _getitem_extended(self, index):
352393
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
353394
"""
354395

355-
if not isinstance(index, tuple):
356-
return self[index]
396+
assert False, index
357397

358398
shape = self.data.shape
359399

@@ -377,7 +417,7 @@ def _unwind(self, index, rest, shape, axis, axes):
377417
if isinstance(index, (list, tuple)):
378418
axes.append(axis) # Dimension of the concatenation
379419
for i in index:
380-
yield from self._unwind(i, rest, shape, axis, axes)
420+
yield from self._unwind((slice(i, i + 1),), rest, shape, axis, axes)
381421
return
382422

383423
if len(rest) == 0:
@@ -635,6 +675,31 @@ class Concat(Combined):
635675
def __len__(self):
636676
return sum(len(i) for i in self.datasets)
637677

678+
@debug_indexing
679+
def _get_tuple(self, index):
680+
index, changes = index_to_slices(index, self.shape)
681+
result = []
682+
683+
first, rest = index[0], index[1:]
684+
start, stop, step = first.start, first.stop, first.step
685+
686+
for d in self.datasets:
687+
length = d._len
688+
689+
result.append(d[(slice(start, stop, step),) + rest])
690+
691+
start -= length
692+
while start < 0:
693+
start += step
694+
695+
stop -= length
696+
697+
if start > stop:
698+
break
699+
700+
return apply_index_to_slices_changes(np.concatenate(result, axis=0), changes)
701+
702+
@debug_indexing
638703
def __getitem__(self, n):
639704
if isinstance(n, tuple):
640705
return self._get_tuple(n)
@@ -649,6 +714,7 @@ def __getitem__(self, n):
649714
k += 1
650715
return self.datasets[k][n]
651716

717+
@debug_indexing
652718
def _get_slice(self, s):
653719
result = []
654720

@@ -716,9 +782,23 @@ def shape(self):
716782
assert False not in result, result
717783
return result
718784

785+
@debug_indexing
786+
def _get_tuple(self, index):
787+
index, changes = index_to_slices(index, self.shape)
788+
lengths = [d.shape[self.axis] for d in self.datasets]
789+
slices = length_to_slices(index[self.axis], lengths)
790+
before = index[: self.axis]
791+
result = [
792+
d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None
793+
]
794+
result = np.concatenate(result, axis=self.axis)
795+
return apply_index_to_slices_changes(result, changes)
796+
797+
@debug_indexing
719798
def _get_slice(self, s):
720799
return np.stack([self[i] for i in range(*s.indices(self._len))])
721800

801+
@debug_indexing
722802
def __getitem__(self, n):
723803
if isinstance(n, tuple):
724804
return self._get_tuple(n)
@@ -769,9 +849,22 @@ def check_same_variables(self, d1, d2):
769849
def __len__(self):
770850
return len(self.datasets[0])
771851

852+
@debug_indexing
853+
def _get_tuple(self, index):
854+
index, changes = index_to_slices(index, self.shape)
855+
index, previous = update_tuple(index, 1, slice(None))
856+
857+
# TODO: optimize if index does not access all datasets, so we don't load chunks we don't need
858+
result = [d[index] for d in self.datasets]
859+
860+
result = np.concatenate(result, axis=1)
861+
return apply_index_to_slices_changes(result[:, previous], changes)
862+
863+
@debug_indexing
772864
def _get_slice(self, s):
773865
return np.stack([self[i] for i in range(*s.indices(self._len))])
774866

867+
@debug_indexing
775868
def __getitem__(self, n):
776869
if isinstance(n, tuple):
777870
return self._get_tuple(n)
@@ -857,10 +950,14 @@ def __init__(self, dataset, indices):
857950

858951
self.dataset = dataset
859952
self.indices = list(indices)
953+
self.slice = _make_slice_or_index_from_list_or_tuple(self.indices)
954+
assert isinstance(self.slice, slice)
955+
print("SUBSET", self.slice)
860956

861957
# Forward other properties to the super dataset
862958
super().__init__(dataset)
863959

960+
@debug_indexing
864961
def __getitem__(self, n):
865962
if isinstance(n, tuple):
866963
return self._get_tuple(n)
@@ -871,25 +968,22 @@ def __getitem__(self, n):
871968
n = self.indices[n]
872969
return self.dataset[n]
873970

971+
@debug_indexing
874972
def _get_slice(self, s):
875973
# TODO: check if the indices can be simplified to a slice
876974
# the time checking maybe be longer than the time saved
877975
# using a slice
878976
indices = [self.indices[i] for i in range(*s.indices(self._len))]
879977
return np.stack([self.dataset[i] for i in indices])
880978

979+
@debug_indexing
881980
def _get_tuple(self, n):
882-
first, rest = n[0], n[1:]
883-
884-
if isinstance(first, int):
885-
return self.dataset[(self.indices[first],) + rest]
886-
887-
if isinstance(first, slice):
888-
indices = tuple(self.indices[i] for i in range(*first.indices(self._len)))
889-
indices = _make_slice_or_index_from_list_or_tuple(indices)
890-
return self.dataset[(indices,) + rest]
891-
892-
raise NotImplementedError(f"Only int and slice supported not {type(first)}")
981+
index, changes = index_to_slices(n, self.shape)
982+
index, previous = update_tuple(index, 0, self.slice)
983+
result = self.dataset[index]
984+
result = result[previous]
985+
result = apply_index_to_slices_changes(result, changes)
986+
return result
893987

894988
def __len__(self):
895989
return len(self.indices)
@@ -929,12 +1023,23 @@ def __init__(self, dataset, indices):
9291023
# Forward other properties to the main dataset
9301024
super().__init__(dataset)
9311025

1026+
@debug_indexing
1027+
def _get_tuple(self, index):
1028+
index, changes = index_to_slices(index, self.shape)
1029+
index, previous = update_tuple(index, 1, slice(None))
1030+
result = self.dataset[index]
1031+
result = result[:, self.indices]
1032+
result = result[:, previous]
1033+
result = apply_index_to_slices_changes(result, changes)
1034+
return result
1035+
1036+
@debug_indexing
9321037
def __getitem__(self, n):
933-
# if isinstance(n, tuple):
934-
# return self._get_tuple(n)
1038+
if isinstance(n, tuple):
1039+
return self._get_tuple(n)
9351040

9361041
row = self.dataset[n]
937-
if isinstance(n, (slice, tuple)):
1042+
if isinstance(n, slice):
9381043
return row[:, self.indices]
9391044

9401045
return row[self.indices]

ecml_tools/indexing.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2+
# This software is licensed under the terms of the Apache Licence Version 2.0
3+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4+
# In applying this licence, ECMWF does not waive the privileges and immunities
5+
# granted to it by virtue of its status as an intergovernmental organisation
6+
# nor does it submit to any jurisdiction.
7+
8+
9+
import numpy as np
10+
11+
12+
def _tuple_with_slices(t, shape):
13+
"""
14+
Replace all integers in a tuple with slices, so we preserve the dimensionality.
15+
"""
16+
17+
result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t)
18+
changes = tuple(j for (j, i) in enumerate(t) if isinstance(i, int))
19+
result = tuple(slice(*s.indices(shape[i])) for (i, s) in enumerate(result))
20+
21+
return result, changes
22+
23+
24+
def _extend_shape(index, shape):
25+
if Ellipsis in index:
26+
if index.count(Ellipsis) > 1:
27+
raise IndexError("Only one Ellipsis is allowed")
28+
ellipsis_index = index.index(Ellipsis)
29+
index = list(index)
30+
index[ellipsis_index] = slice(None)
31+
while len(index) < len(shape):
32+
index.insert(ellipsis_index, slice(None))
33+
index = tuple(index)
34+
35+
while len(index) < len(shape):
36+
index = index + (slice(None),)
37+
38+
return index
39+
40+
41+
def _index_to_tuple(index, shape):
42+
if isinstance(index, int):
43+
return _extend_shape((index,), shape)
44+
if isinstance(index, slice):
45+
return _extend_shape((index,), shape)
46+
if isinstance(index, tuple):
47+
return _extend_shape(index, shape)
48+
if index is Ellipsis:
49+
return _extend_shape((Ellipsis,), shape)
50+
raise ValueError(f"Invalid index: {index}")
51+
52+
53+
def index_to_slices(index, shape):
54+
"""
55+
Convert an index to a tuple of slices, with the same dimensionality as the shape.
56+
"""
57+
return _tuple_with_slices(_index_to_tuple(index, shape), shape)
58+
59+
60+
def apply_index_to_slices_changes(result, changes):
61+
if changes:
62+
shape = result.shape
63+
for i in changes:
64+
assert shape[i] == 1, (i, changes, shape)
65+
result = np.squeeze(result, axis=changes)
66+
return result
67+
68+
69+
def update_tuple(t, index, value):
70+
"""
71+
Replace the elements of a tuple at the given index with a new value.
72+
"""
73+
t = list(t)
74+
prev = t[index]
75+
t[index] = value
76+
return tuple(t), prev
77+
78+
79+
def length_to_slices(index, lengths):
80+
"""
81+
Convert an index to a list of slices, given the lengths of the dimensions.
82+
"""
83+
total = sum(lengths)
84+
start, stop, step = index.indices(total)
85+
86+
result = []
87+
88+
pos = 0
89+
for length in lengths:
90+
end = pos + length
91+
92+
b = max(pos, start)
93+
e = min(end, stop)
94+
95+
p = None
96+
if b <= e:
97+
if (b - start) % step != 0:
98+
b = b + step - (b - start) % step
99+
b -= pos
100+
e -= pos
101+
102+
if 0 <= b < e:
103+
p = slice(b, e, step)
104+
105+
result.append(p)
106+
107+
pos = end
108+
109+
return result
110+
111+
112+
class IndexTester:
113+
def __init__(self, shape):
114+
self.shape = shape
115+
116+
def __getitem__(self, index):
117+
return index_to_slices(index, self.shape)
118+
119+
120+
if __name__ == "__main__":
121+
t = IndexTester((1000, 8, 10, 20000))
122+
i = t[0, 1, 2, 3]
123+
print(i)
124+
125+
# print(t[0])
126+
# print(t[0, 1, 2, 3])
127+
# print(t[0:10])
128+
# print(t[...])
129+
# print(t[:-1])

0 commit comments

Comments
 (0)