Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 5, 2024
1 parent fac9909 commit 61eb457
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
38 changes: 29 additions & 9 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def _make_slice_or_index_from_list_or_tuple(indices):
return indices


def _tuple_with_slices(t):
"""
Replace all integers in a tuple with slices, so we preserve the dimensionality.
"""

result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t)
changes = [j for (j, i) in enumerate(t) if isinstance(i, int)]

return result, changes


class Dataset:
arguments = {}

Expand Down Expand Up @@ -354,8 +365,8 @@ def _getitem_extended(self, index):
Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves.
"""

if not isinstance(index, tuple):
return self[index]
assert False, index


shape = self.data.shape

Expand All @@ -379,7 +390,7 @@ def _unwind(self, index, rest, shape, axis, axes):
if isinstance(index, (list, tuple)):
axes.append(axis) # Dimension of the concatenation
for i in index:
yield from self._unwind(i, rest, shape, axis, axes)
yield from self._unwind((slice(i, i + 1),), rest, shape, axis, axes)
return

if len(rest) == 0:
Expand Down Expand Up @@ -771,20 +782,29 @@ def check_same_variables(self, d1, d2):
def __len__(self):
return len(self.datasets[0])

def _get_tuple(self, n):
def _get_tuple(self, index):
assert len(index) > 1, index

selected_variables = index[1]

index, changed = _tuple_with_slices(index)

index = list(index)
index[1] = slice(None)
index = tuple(index)

p = (n[0], slice(None), n[2:])
result = [d[p] for d in self.datasets]
result = [d[index] for d in self.datasets]

print([type(s) for s in p], p)
print(self.shape, [r.shape for r in result])
print(self.shape, [r.shape for r in result], selected_variables, changed)
result = np.stack(result)
print(result.shape)

raise NotImplementedError()

result = np.concatenate(result)
# result = np.stack(result)

return result[n[1]]
return result[index[1]]

def _get_slice(self, s):
return np.stack([self[i] for i in range(*s.indices(self._len))])
Expand Down
13 changes: 9 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,16 @@ def slices(ds, start=None, end=None, step=None):
t[0:10, 0:3, 0]
t[:, :, :]

# t[:,(1,3),:]
# t[:,(1,3)]
# t[:, (1, 3), :]
# t[:, (1, 3)]

if ds.shape[2] > 1: # Ensemble dimension
t[0:10, :, (0, 1)]
t[0]
t[0, :]
t[0, 0, :]
t[0, 0, 0, :]

# if ds.shape[2] > 1: # Ensemble dimension
# t[0:10, :, (0, 1)]


def make_row(args, ensemble=False, grid=False):
Expand Down

0 comments on commit 61eb457

Please sign in to comment.