19
19
from .check import DatasetName
20
20
from .config import build_output , loader_config
21
21
from .input import build_input
22
- from .statistics import (
23
- StatisticsRegistry ,
24
- compute_aggregated_statistics ,
25
- compute_statistics ,
26
- )
22
+ from .statistics import TempStatistics
27
23
from .utils import (
28
24
bytes ,
29
25
compute_directory_sizes ,
30
26
normalize_and_check_dates ,
31
27
progress_bar ,
32
- to_datetime ,
33
28
)
34
29
from .writer import CubesFilter , DataWriter
35
30
from .zarr import ZarrBuiltRegistry , add_zarr_dataset
@@ -52,10 +47,7 @@ def __init__(self, *, path, print=print, **kwargs):
52
47
53
48
statistics_tmp = kwargs .get ("statistics_tmp" ) or self .path + ".statistics"
54
49
55
- self .statistics_registry = StatisticsRegistry (
56
- statistics_tmp ,
57
- history_callback = self .registry .add_to_history ,
58
- )
50
+ self .statistics_registry = TempStatistics (statistics_tmp )
59
51
60
52
@classmethod
61
53
def from_config (cls , * , config , path , print = print , ** kwargs ):
@@ -94,16 +86,17 @@ def read_dataset_metadata(self):
94
86
ds = open_dataset (self .path )
95
87
self .dataset_shape = ds .shape
96
88
self .variables_names = ds .variables
89
+ assert len (self .variables_names ) == ds .shape [1 ], self .dataset_shape
90
+ self .dates = ds .dates
97
91
98
92
z = zarr .open (self .path , "r" )
99
- start = z .attrs .get ("statistics_start_date" )
100
- end = z .attrs .get ("statistics_end_date" )
101
- if start :
102
- start = to_datetime (start )
103
- if end :
104
- end = to_datetime (end )
105
- self ._statistics_start_date_from_dataset = start
106
- self ._statistics_end_date_from_dataset = end
93
+ self .missing_dates = z .attrs .get ("missing_dates" )
94
+ if self .missing_dates :
95
+ self .missing_dates = [np .datetime64 (d ) for d in self .missing_dates ]
96
+ assert type (self .missing_dates [0 ]) == type (self .dates [0 ]), (
97
+ self .missing_dates [0 ],
98
+ self .dates [0 ],
99
+ )
107
100
108
101
@cached_property
109
102
def registry (self ):
@@ -283,10 +276,29 @@ def initialise(self, check_name=True):
283
276
self .statistics_registry .create (exist_ok = False )
284
277
self .registry .add_to_history ("statistics_registry_initialised" , version = self .statistics_registry .version )
285
278
279
+ statistics_start , statistics_end = self ._build_statistics_dates (
280
+ self .main_config .output .get ("statistics_start" ),
281
+ self .main_config .output .get ("statistics_end" ),
282
+ )
283
+ self .update_metadata (
284
+ statistics_start_date = statistics_start ,
285
+ statistics_end_date = statistics_end ,
286
+ )
287
+ print (f"Will compute statistics from { statistics_start } to { statistics_end } " )
288
+
286
289
self .registry .add_to_history ("init finished" )
287
290
288
291
assert chunks == self .get_zarr_chunks (), (chunks , self .get_zarr_chunks ())
289
292
293
+ def _build_statistics_dates (self , start , end ):
294
+ ds = open_dataset (self .path )
295
+ subset = ds .dates_interval_to_indices (start , end )
296
+ start , end = ds .dates [subset [0 ]], ds .dates [subset [- 1 ]]
297
+ return (
298
+ start .astype (datetime .datetime ).isoformat (),
299
+ end .astype (datetime .datetime ).isoformat (),
300
+ )
301
+
290
302
291
303
class ContentLoader (Loader ):
292
304
def __init__ (self , config , ** kwargs ):
@@ -340,24 +352,20 @@ def __init__(
340
352
** kwargs ,
341
353
):
342
354
super ().__init__ (** kwargs )
355
+ assert statistics_start is None , statistics_start
356
+ assert statistics_end is None , statistics_end
357
+
343
358
self .recompute = recompute
344
359
345
360
self ._write_to_dataset = True
346
361
347
362
self .statistics_output = statistics_output
348
- if self .statistics_output :
349
- self ._write_to_dataset = False
350
363
351
364
if config :
352
365
self .main_config = loader_config (config )
353
366
354
- self ._statistics_start = statistics_start
355
- self ._statistics_end = statistics_end
356
-
357
367
self .check_complete (force = force )
358
-
359
368
self .read_dataset_metadata ()
360
- self .read_dataset_dates_metadata ()
361
369
362
370
def run (self ):
363
371
# if requested, recompute statistics from data
@@ -366,19 +374,33 @@ def run(self):
366
374
if self .recompute :
367
375
self .recompute_temporary_statistics ()
368
376
369
- # compute the detailed statistics from temporary statistics directory
370
- detailed = self .get_detailed_stats ()
377
+ dates = [d for d in self .dates if d not in self .missing_dates ]
371
378
372
- if self ._write_to_dataset :
373
- self .write_detailed_statistics ( detailed )
379
+ if self .missing_dates :
380
+ assert type ( self .missing_dates [ 0 ]) == type ( dates [ 0 ]), ( type ( self . missing_dates [ 0 ]), type ( dates [ 0 ]) )
374
381
375
- # compute the aggregated statistics from the detailed statistics
376
- # for the selected dates
377
- selected = {k : v [self .i_start : self .i_end + 1 ] for k , v in detailed .items ()}
378
- stats = compute_aggregated_statistics (selected , self .variables_names )
382
+ dates_computed = self .statistics_registry .dates_computed
383
+ for d in dates :
384
+ if d in self .missing_dates :
385
+ assert d not in dates_computed , (d , date_computed )
386
+ else :
387
+ assert d in dates_computed , (d , dates_computed )
379
388
380
- if self ._write_to_dataset :
381
- self .write_aggregated_statistics (stats )
389
+ z = zarr .open (self .path , mode = "r" )
390
+ start = z .attrs .get ("statistics_start_date" )
391
+ end = z .attrs .get ("statistics_end_date" )
392
+ start = np .datetime64 (start )
393
+ end = np .datetime64 (end )
394
+ dates = [d for d in dates if d >= start and d <= end ]
395
+ assert type (start ) == type (dates [0 ]), (type (start ), type (dates [0 ]))
396
+
397
+ stats = self .statistics_registry .get_aggregated (dates , self .variables_names )
398
+
399
+ writer = {
400
+ None : self .write_stats_to_dataset ,
401
+ "-" : self .write_stats_to_stdout ,
402
+ }.get (self .statistics_output , self .write_stats_to_file )
403
+ writer (stats )
382
404
383
405
def check_complete (self , force ):
384
406
if self ._complete :
@@ -389,57 +411,12 @@ def check_complete(self, force):
389
411
print (f"❗Zarr { self .path } is not fully built, not writting statistics into dataset." )
390
412
self ._write_to_dataset = False
391
413
392
- @property
393
- def statistics_start (self ):
394
- user = self ._statistics_start
395
- config = self .main_config .get ("output" , {}).get ("statistics_start" )
396
- dataset = self ._statistics_start_date_from_dataset
397
- return user or config or dataset
398
-
399
- @property
400
- def statistics_end (self ):
401
- user = self ._statistics_end
402
- config = self .main_config .get ("output" , {}).get ("statistics_end" )
403
- dataset = self ._statistics_end_date_from_dataset
404
- return user or config or dataset
405
-
406
414
@property
407
415
def _complete (self ):
408
416
return all (self .registry .get_flags (sync = False ))
409
417
410
- def read_dataset_dates_metadata (self ):
411
- ds = open_dataset (self .path )
412
- subset = ds .dates_interval_to_indices (self .statistics_start , self .statistics_end )
413
- self .i_start = subset [0 ]
414
- self .i_end = subset [- 1 ]
415
- self .date_start = ds .dates [subset [0 ]]
416
- self .date_end = ds .dates [subset [- 1 ]]
417
-
418
- # do not write statistics to dataset if dates do not match the ones in the dataset metadata
419
- start = self ._statistics_start_date_from_dataset
420
- end = self ._statistics_end_date_from_dataset
421
-
422
- start_ok = start is None or to_datetime (self .date_start ) == start
423
- end_ok = end is None or to_datetime (self .date_end ) == end
424
- if not (start_ok and end_ok ):
425
- print (
426
- f"Statistics start/end dates { self .date_start } /{ self .date_end } "
427
- f"do not match dates in the dataset metadata { start } /{ end } . "
428
- f"Will not write statistics to dataset."
429
- )
430
- self ._write_to_dataset = False
431
-
432
- def check ():
433
- i_len = self .i_end + 1 - self .i_start
434
- self .print (f"Statistics computed on { i_len } /{ len (ds .dates )} samples " )
435
- print (f"Requested ({ i_len } ): from { self .date_start } to { self .date_end } ." )
436
- print (f"Available ({ len (ds .dates )} ): from { ds .dates [0 ]} to { ds .dates [- 1 ]} ." )
437
- if i_len < 1 :
438
- raise ValueError ("Cannot compute statistics on an empty interval." )
439
-
440
- check ()
441
-
442
418
def recompute_temporary_statistics (self ):
419
+ raise NotImplementedError ("Untested code" )
443
420
self .statistics_registry .create (exist_ok = True )
444
421
445
422
self .print (
@@ -471,67 +448,21 @@ def recompute_temporary_statistics(self):
471
448
self .statistics_registry [key ] = detailed_stats
472
449
self .statistics_registry .add_provenance (name = "provenance_recompute_statistics" , config = self .main_config )
473
450
474
- def get_detailed_stats (self ):
475
- expected_shape = (self .dataset_shape [0 ], self .dataset_shape [1 ])
476
- try :
477
- return self .statistics_registry .as_detailed_stats (expected_shape )
478
- except self .statistics_registry .MissingDataException as e :
479
- missing_index = e .args [1 ]
480
- dates = open_dataset (self .path ).dates
481
- missing_dates = dates [missing_index [0 ]]
482
- print (
483
- f"Missing dates: "
484
- f"{ missing_dates [0 ]} ... { missing_dates [len (missing_dates )- 1 ]} "
485
- f"({ missing_dates .shape [0 ]} missing)"
486
- )
487
- raise
488
-
489
- def write_detailed_statistics (self , detailed_stats ):
490
- z = zarr .open (self .path )["_build" ]
491
- for k , v in detailed_stats .items ():
492
- if k == "variables_names" :
493
- continue
494
- add_zarr_dataset (zarr_root = z , name = k , array = v )
495
- print ("Wrote detailed statistics to zarr." )
496
-
497
- def write_aggregated_statistics (self , stats ):
498
- if self .statistics_output == "-" :
499
- print (stats )
500
- return
501
-
502
- if self .statistics_output :
503
- stats .save (self .statistics_output , provenance = dict (config = self .main_config ))
504
- print (f"✅ Statistics written in { self .statistics_output } " )
505
- return
506
-
507
- if not self ._write_to_dataset :
508
- return
451
+ def write_stats_to_file (self , stats ):
452
+ stats .save (self .statistics_output , provenance = dict (config = self .main_config ))
453
+ print (f"✅ Statistics written in { self .statistics_output } " )
454
+ return
509
455
510
- for k in [
511
- "mean" ,
512
- "stdev" ,
513
- "minimum" ,
514
- "maximum" ,
515
- "sums" ,
516
- "squares" ,
517
- "count" ,
518
- ]:
456
+ def write_stats_to_dataset (self , stats ):
457
+ for k in ["mean" , "stdev" , "minimum" , "maximum" , "sums" , "squares" , "count" ]:
519
458
self ._add_dataset (name = k , array = stats [k ])
520
459
521
- self .update_metadata (
522
- statistics_start_date = str (self .date_start ),
523
- statistics_end_date = str (self .date_end ),
524
- )
525
-
526
- self .registry .add_to_history (
527
- "compute_statistics_end" ,
528
- start = str (self .date_start ),
529
- end = str (self .date_end ),
530
- i_start = self .i_start ,
531
- i_end = self .i_end ,
532
- )
460
+ self .registry .add_to_history ("compute_statistics_end" )
533
461
print (f"Wrote statistics in { self .path } " )
534
462
463
+ def write_stats_to_stdout (self , stats ):
464
+ print (stats )
465
+
535
466
536
467
class SizeLoader (Loader ):
537
468
def __init__ (self , path , print ):
0 commit comments