Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 84d7f1f

Browse files
committed
statistics with missing data
1 parent fa749b8 commit 84d7f1f

File tree

14 files changed

+263
-453
lines changed

14 files changed

+263
-453
lines changed

ecml_tools/create/loaders.py

Lines changed: 68 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,12 @@
1919
from .check import DatasetName
2020
from .config import build_output, loader_config
2121
from .input import build_input
22-
from .statistics import (
23-
StatisticsRegistry,
24-
compute_aggregated_statistics,
25-
compute_statistics,
26-
)
22+
from .statistics import TempStatistics
2723
from .utils import (
2824
bytes,
2925
compute_directory_sizes,
3026
normalize_and_check_dates,
3127
progress_bar,
32-
to_datetime,
3328
)
3429
from .writer import CubesFilter, DataWriter
3530
from .zarr import ZarrBuiltRegistry, add_zarr_dataset
@@ -52,10 +47,7 @@ def __init__(self, *, path, print=print, **kwargs):
5247

5348
statistics_tmp = kwargs.get("statistics_tmp") or self.path + ".statistics"
5449

55-
self.statistics_registry = StatisticsRegistry(
56-
statistics_tmp,
57-
history_callback=self.registry.add_to_history,
58-
)
50+
self.statistics_registry = TempStatistics(statistics_tmp)
5951

6052
@classmethod
6153
def from_config(cls, *, config, path, print=print, **kwargs):
@@ -94,16 +86,17 @@ def read_dataset_metadata(self):
9486
ds = open_dataset(self.path)
9587
self.dataset_shape = ds.shape
9688
self.variables_names = ds.variables
89+
assert len(self.variables_names) == ds.shape[1], self.dataset_shape
90+
self.dates = ds.dates
9791

9892
z = zarr.open(self.path, "r")
99-
start = z.attrs.get("statistics_start_date")
100-
end = z.attrs.get("statistics_end_date")
101-
if start:
102-
start = to_datetime(start)
103-
if end:
104-
end = to_datetime(end)
105-
self._statistics_start_date_from_dataset = start
106-
self._statistics_end_date_from_dataset = end
93+
self.missing_dates = z.attrs.get("missing_dates")
94+
if self.missing_dates:
95+
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]
96+
assert type(self.missing_dates[0]) == type(self.dates[0]), (
97+
self.missing_dates[0],
98+
self.dates[0],
99+
)
107100

108101
@cached_property
109102
def registry(self):
@@ -283,10 +276,29 @@ def initialise(self, check_name=True):
283276
self.statistics_registry.create(exist_ok=False)
284277
self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version)
285278

279+
statistics_start, statistics_end = self._build_statistics_dates(
280+
self.main_config.output.get("statistics_start"),
281+
self.main_config.output.get("statistics_end"),
282+
)
283+
self.update_metadata(
284+
statistics_start_date=statistics_start,
285+
statistics_end_date=statistics_end,
286+
)
287+
print(f"Will compute statistics from {statistics_start} to {statistics_end}")
288+
286289
self.registry.add_to_history("init finished")
287290

288291
assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())
289292

293+
def _build_statistics_dates(self, start, end):
294+
ds = open_dataset(self.path)
295+
subset = ds.dates_interval_to_indices(start, end)
296+
start, end = ds.dates[subset[0]], ds.dates[subset[-1]]
297+
return (
298+
start.astype(datetime.datetime).isoformat(),
299+
end.astype(datetime.datetime).isoformat(),
300+
)
301+
290302

291303
class ContentLoader(Loader):
292304
def __init__(self, config, **kwargs):
@@ -340,24 +352,20 @@ def __init__(
340352
**kwargs,
341353
):
342354
super().__init__(**kwargs)
355+
assert statistics_start is None, statistics_start
356+
assert statistics_end is None, statistics_end
357+
343358
self.recompute = recompute
344359

345360
self._write_to_dataset = True
346361

347362
self.statistics_output = statistics_output
348-
if self.statistics_output:
349-
self._write_to_dataset = False
350363

351364
if config:
352365
self.main_config = loader_config(config)
353366

354-
self._statistics_start = statistics_start
355-
self._statistics_end = statistics_end
356-
357367
self.check_complete(force=force)
358-
359368
self.read_dataset_metadata()
360-
self.read_dataset_dates_metadata()
361369

362370
def run(self):
363371
# if requested, recompute statistics from data
@@ -366,19 +374,33 @@ def run(self):
366374
if self.recompute:
367375
self.recompute_temporary_statistics()
368376

369-
# compute the detailed statistics from temporary statistics directory
370-
detailed = self.get_detailed_stats()
377+
dates = [d for d in self.dates if d not in self.missing_dates]
371378

372-
if self._write_to_dataset:
373-
self.write_detailed_statistics(detailed)
379+
if self.missing_dates:
380+
assert type(self.missing_dates[0]) == type(dates[0]), (type(self.missing_dates[0]), type(dates[0]))
374381

375-
# compute the aggregated statistics from the detailed statistics
376-
# for the selected dates
377-
selected = {k: v[self.i_start : self.i_end + 1] for k, v in detailed.items()}
378-
stats = compute_aggregated_statistics(selected, self.variables_names)
382+
dates_computed = self.statistics_registry.dates_computed
383+
for d in dates:
384+
if d in self.missing_dates:
385+
assert d not in dates_computed, (d, date_computed)
386+
else:
387+
assert d in dates_computed, (d, dates_computed)
379388

380-
if self._write_to_dataset:
381-
self.write_aggregated_statistics(stats)
389+
z = zarr.open(self.path, mode="r")
390+
start = z.attrs.get("statistics_start_date")
391+
end = z.attrs.get("statistics_end_date")
392+
start = np.datetime64(start)
393+
end = np.datetime64(end)
394+
dates = [d for d in dates if d >= start and d <= end]
395+
assert type(start) == type(dates[0]), (type(start), type(dates[0]))
396+
397+
stats = self.statistics_registry.get_aggregated(dates, self.variables_names)
398+
399+
writer = {
400+
None: self.write_stats_to_dataset,
401+
"-": self.write_stats_to_stdout,
402+
}.get(self.statistics_output, self.write_stats_to_file)
403+
writer(stats)
382404

383405
def check_complete(self, force):
384406
if self._complete:
@@ -389,57 +411,12 @@ def check_complete(self, force):
389411
print(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")
390412
self._write_to_dataset = False
391413

392-
@property
393-
def statistics_start(self):
394-
user = self._statistics_start
395-
config = self.main_config.get("output", {}).get("statistics_start")
396-
dataset = self._statistics_start_date_from_dataset
397-
return user or config or dataset
398-
399-
@property
400-
def statistics_end(self):
401-
user = self._statistics_end
402-
config = self.main_config.get("output", {}).get("statistics_end")
403-
dataset = self._statistics_end_date_from_dataset
404-
return user or config or dataset
405-
406414
@property
407415
def _complete(self):
408416
return all(self.registry.get_flags(sync=False))
409417

410-
def read_dataset_dates_metadata(self):
411-
ds = open_dataset(self.path)
412-
subset = ds.dates_interval_to_indices(self.statistics_start, self.statistics_end)
413-
self.i_start = subset[0]
414-
self.i_end = subset[-1]
415-
self.date_start = ds.dates[subset[0]]
416-
self.date_end = ds.dates[subset[-1]]
417-
418-
# do not write statistics to dataset if dates do not match the ones in the dataset metadata
419-
start = self._statistics_start_date_from_dataset
420-
end = self._statistics_end_date_from_dataset
421-
422-
start_ok = start is None or to_datetime(self.date_start) == start
423-
end_ok = end is None or to_datetime(self.date_end) == end
424-
if not (start_ok and end_ok):
425-
print(
426-
f"Statistics start/end dates {self.date_start}/{self.date_end} "
427-
f"do not match dates in the dataset metadata {start}/{end}. "
428-
f"Will not write statistics to dataset."
429-
)
430-
self._write_to_dataset = False
431-
432-
def check():
433-
i_len = self.i_end + 1 - self.i_start
434-
self.print(f"Statistics computed on {i_len}/{len(ds.dates)} samples ")
435-
print(f"Requested ({i_len}): from {self.date_start} to {self.date_end}.")
436-
print(f"Available ({len(ds.dates)}): from {ds.dates[0]} to {ds.dates[-1]}.")
437-
if i_len < 1:
438-
raise ValueError("Cannot compute statistics on an empty interval.")
439-
440-
check()
441-
442418
def recompute_temporary_statistics(self):
419+
raise NotImplementedError("Untested code")
443420
self.statistics_registry.create(exist_ok=True)
444421

445422
self.print(
@@ -471,67 +448,21 @@ def recompute_temporary_statistics(self):
471448
self.statistics_registry[key] = detailed_stats
472449
self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config)
473450

474-
def get_detailed_stats(self):
475-
expected_shape = (self.dataset_shape[0], self.dataset_shape[1])
476-
try:
477-
return self.statistics_registry.as_detailed_stats(expected_shape)
478-
except self.statistics_registry.MissingDataException as e:
479-
missing_index = e.args[1]
480-
dates = open_dataset(self.path).dates
481-
missing_dates = dates[missing_index[0]]
482-
print(
483-
f"Missing dates: "
484-
f"{missing_dates[0]} ... {missing_dates[len(missing_dates)-1]} "
485-
f"({missing_dates.shape[0]} missing)"
486-
)
487-
raise
488-
489-
def write_detailed_statistics(self, detailed_stats):
490-
z = zarr.open(self.path)["_build"]
491-
for k, v in detailed_stats.items():
492-
if k == "variables_names":
493-
continue
494-
add_zarr_dataset(zarr_root=z, name=k, array=v)
495-
print("Wrote detailed statistics to zarr.")
496-
497-
def write_aggregated_statistics(self, stats):
498-
if self.statistics_output == "-":
499-
print(stats)
500-
return
501-
502-
if self.statistics_output:
503-
stats.save(self.statistics_output, provenance=dict(config=self.main_config))
504-
print(f"✅ Statistics written in {self.statistics_output}")
505-
return
506-
507-
if not self._write_to_dataset:
508-
return
451+
def write_stats_to_file(self, stats):
452+
stats.save(self.statistics_output, provenance=dict(config=self.main_config))
453+
print(f"✅ Statistics written in {self.statistics_output}")
454+
return
509455

510-
for k in [
511-
"mean",
512-
"stdev",
513-
"minimum",
514-
"maximum",
515-
"sums",
516-
"squares",
517-
"count",
518-
]:
456+
def write_stats_to_dataset(self, stats):
457+
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]:
519458
self._add_dataset(name=k, array=stats[k])
520459

521-
self.update_metadata(
522-
statistics_start_date=str(self.date_start),
523-
statistics_end_date=str(self.date_end),
524-
)
525-
526-
self.registry.add_to_history(
527-
"compute_statistics_end",
528-
start=str(self.date_start),
529-
end=str(self.date_end),
530-
i_start=self.i_start,
531-
i_end=self.i_end,
532-
)
460+
self.registry.add_to_history("compute_statistics_end")
533461
print(f"Wrote statistics in {self.path}")
534462

463+
def write_stats_to_stdout(self, stats):
464+
print(stats)
465+
535466

536467
class SizeLoader(Loader):
537468
def __init__(self, path, print):

0 commit comments

Comments
 (0)