Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 6, 2024
1 parent dfca19f commit bb9164a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 60 deletions.
7 changes: 0 additions & 7 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,21 +784,14 @@ def shape(self):

@debug_indexing
def _get_tuple(self, index):
print(index, self.shape)
index, changes = index_to_slices(index, self.shape)
lengths = [d.shape[self.axis] for d in self.datasets]
slices = length_to_slices(index[self.axis], lengths)

print("SLICES", slices, self.axis, index, lengths)
before = index[: self.axis]

result = [
d[before + (i,)] for (d, i) in zip(self.datasets, slices) if i is not None
]
print([d.shape for d in result])
result = np.concatenate(result, axis=self.axis)
print(result.shape)

return apply_index_to_slices_changes(result, changes)

@debug_indexing
Expand Down
81 changes: 28 additions & 53 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def slices(ds, start=None, end=None, step=None):


def make_row(*args, ensemble=False, grid=False):
assert not isinstance(args[0], (list, tuple))
# assert not isinstance(args[0], (list, tuple))
if grid:

def _(x):
Expand Down Expand Up @@ -298,17 +298,17 @@ def run(
metadata(self.ds)


def simple_row(date, vars):
values = [_(date, v) for v in vars]
return make_row(*values)


def test_simple():
test = DatasetTester("test-2021-2022-6h-o96-abcd")
test.run(
expected_class=Zarr,
expected_length=365 * 2 * 4,
date_to_row=lambda date: make_row(
_(date, "a"),
_(date, "b"),
_(date, "c"),
_(date, "d"),
),
date_to_row=lambda date: simple_row(date, "abcd"),
start_date=datetime.datetime(2021, 1, 1),
time_increment=datetime.timedelta(hours=6),
excepted_shape=(365 * 2 * 4, 4, 1, VALUES),
Expand All @@ -335,9 +335,7 @@ def test_concat():
excepted_shape=(365 * 3 * 4, 4, 1, VALUES),
excepted_variables=["a", "b", "c", "d"],
excepted_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3},
date_to_row=lambda date: make_row(
_(date, "a"), _(date, "b"), _(date, "c"), _(date, "d")
),
date_to_row=lambda date: simple_row(date, "abcd"),
statistics_reference_dataset="test-2021-2022-6h-o96-abcd",
statistics_reference_variables="abcd",
)
Expand Down Expand Up @@ -368,66 +366,43 @@ def test_join_1():
"g": 6,
"h": 7,
},
date_to_row=lambda date: make_row(
_(date, "a"),
_(date, "b"),
_(date, "c"),
_(date, "d"),
_(date, "e"),
_(date, "f"),
_(date, "g"),
_(date, "h"),
),
date_to_row=lambda date: simple_row(date, "abcdefgh"),
# TODO: test second stats
statistics_reference_dataset="test-2021-2021-6h-o96-abcd",
statistics_reference_variables="abcd",
)


def test_join_2():
ds = open_dataset(
"test-2021-2021-6h-o96-abcd-1",
"test-2021-2021-6h-o96-bdef-2",
test = DatasetTester(
[
"test-2021-2021-6h-o96-abcd-1",
"test-2021-2021-6h-o96-bdef-2",
]
)

assert isinstance(ds, Select)
assert len(ds) == 365 * 4
assert len([row for row in ds]) == len(ds)

dates = []
date = datetime.datetime(2021, 1, 1)

for row in ds:
expect = make_row(
test.run(
expected_class=Select,
expected_length=365 * 4,
start_date=datetime.datetime(2021, 1, 1),
time_increment=datetime.timedelta(hours=6),
excepted_shape=(365 * 4, 6, 1, VALUES),
excepted_variables=["a", "b", "c", "d", "e", "f"],
excepted_name_to_index={"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5},
date_to_row=lambda date: make_row(
_(date, "a", 1),
_(date, "b", 2),
_(date, "c", 1),
_(date, "d", 2),
_(date, "e", 2),
_(date, "f", 2),
)
assert (row == expect).all()
dates.append(date)
date += datetime.timedelta(hours=6)

assert (ds.dates == np.array(dates, dtype="datetime64")).all()

assert ds.variables == ["a", "b", "c", "d", "e", "f"]
assert ds.name_to_index == {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5}

assert ds.shape == (365 * 4, 6, 1, VALUES)

same_stats(
ds,
open_dataset(
),
statistics_reference_dataset=[
"test-2021-2021-6h-o96-ac-1",
"test-2021-2021-6h-o96-bdef-2",
),
"abcdef",
],
statistics_reference_variables="abcdef",
)
slices(ds)
indexing(ds)
metadata(ds)


def test_join_3():
Expand Down Expand Up @@ -1432,4 +1407,4 @@ def test_statistics():


if __name__ == "__main__":
test_simple()
test_ensemble_1()

0 comments on commit bb9164a

Please sign in to comment.