@@ -151,11 +151,17 @@ def main():
151151 required = True ,
152152 help = "Directory or file path to read groundtruth." ,
153153 )
154+ parser .add_argument (
155+ "--groundtruth_dataset_domain" ,
156+ type = str ,
157+ default = "test_z0012" ,
158+ help = "Domain (all, train, val, test) for groundtruth dataset. Should be a key in filename_filters. Determines filename_filter used." ,
159+ )
154160 parser .add_argument (
155161 "--multistep" ,
156162 default = 10 ,
157163 type = int ,
158- help = "Number of future timesteps model is rolled out for evaluation. In days "
164+ help = "Number of future timesteps model is rolled out for evaluation. Set to 1 if just one step. "
159165 "(This script assumes lead time is 24 hours)." ,
160166 )
161167 parser .add_argument (
@@ -198,6 +204,16 @@ def main():
198204 action = "store_true" ,
199205 help = "Whether to evaluate climatology." ,
200206 )
207+ parser .add_argument (
208+ "--verbose" ,
209+ action = "store_true" ,
210+ help = "Whether to print more verbose debug logs." ,
211+ )
212+ parser .add_argument (
213+ "--breakpoint" ,
214+ action = "store_true" ,
215+ help = "Whether to add breakpoint for debug." ,
216+ )
201217
202218 args = parser .parse_args ()
203219
@@ -231,7 +247,7 @@ def main():
231247 surface_variables = args .surface_vars ,
232248 level_variables = args .level_vars ,
233249 pressure_levels = [500 , 700 , 850 ],
234- lead_time_hours = 24 if args .multistep else None ,
250+ lead_time_hours = 24 if args .multistep and args . multistep > 1 else None ,
235251 rollout_iterations = args .multistep ,
236252 ).to (device )
237253 print (f"Computing: { metrics .keys ()} " )
@@ -240,7 +256,7 @@ def main():
240256 ds_test = era5 .Era5Forecast (
241257 path = args .groundtruth_path ,
242258 # filename_filter=lambda x: ("2020" in x) and ("0h" in x or "12h" in x),
243- domain = "test_z0012" ,
259+ domain = args . groundtruth_dataset_domain ,
244260 lead_time_hours = 24 ,
245261 multistep = args .multistep ,
246262 load_prev = False ,
@@ -251,30 +267,36 @@ def main():
251267 )
252268
253269 print (f"Reading { len (ds_test .files )} files from groundtruth path: { args .groundtruth_path } ." )
270+ if args .verbose :
271+ print (ds_test .files )
254272
255273 # Predictions.
256274 def _pred_filename_filter (filename ):
257275 if "metric" in filename :
258276 return False
259277 if args .pred_filename_filter is None :
260278 return True
261- for substring in args .pred_filename_filter :
262- if substring not in filename :
263- return False
264- return True
279+ return any ([str (y ) in filename for y in args .pred_filename_filter ])
265280
266281 if not args .eval_clim :
282+ dimension_indexers = dict (level = [500 , 700 , 850 ])
283+ if args .multistep > 1 :
284+ dimension_indexers ["prediction_timedelta" ] = [
285+ timedelta (days = i ) for i in range (1 , args .multistep + 1 )
286+ ]
287+
267288 ds_pred = era5 .Era5Dataset (
268289 path = args .pred_path ,
269290 filename_filter = _pred_filename_filter , # Update filename_filter to filter within pred_path.
270291 variables = variables ,
271292 return_timestamp = True ,
272- dimension_indexers = dict (
273- prediction_timedelta = [timedelta (days = i ) for i in range (1 , args .multistep + 1 )],
274- level = [500 , 700 , 850 ],
275- ),
293+ dimension_indexers = dimension_indexers ,
276294 )
277295 print (f"Reading { len (ds_pred .files )} files from pred_path: { args .pred_path } ." )
296+ if args .verbose :
297+ print (ds_pred .files )
298+ print ("# prediction examples:" , len (ds_pred ))
299+ print ("# test examples:" , len (ds_test ))
278300
279301 if reloaded_timestamp is not None :
280302 # Don't include the reloaded timestamp.
@@ -315,8 +337,13 @@ def __getitem__(self, idx):
315337 collate_fn = _custom_collate_fn ,
316338 )
317339
340+ if args .breakpoint :
341+ breakpoint ()
342+
318343 # iterable = tqdm(dl_test) if args.eval_clim else tqdm(zip(dl_test, dl_pred))
319344 for next_batch in tqdm (dl_test ) if args .eval_clim else tqdm (zip (dl_test , dl_pred )):
345+ if args .verbose :
346+ print (f"{ nbatches } batch" )
320347 nbatches += 1
321348
322349 if args .eval_clim :
@@ -333,7 +360,7 @@ def __getitem__(self, idx):
333360 pred = pred .apply (
334361 lambda tensor : rearrange (
335362 tensor ,
336- "batch var mem ... lev lat lon -> batch mem ... var lev lat lon" ,
363+ "batch var ... lev lat lon -> batch ... var lev lat lon" ,
337364 )
338365 )
339366 timestamps = target ["timestamp" ]
@@ -344,9 +371,14 @@ def __getitem__(self, idx):
344371 else :
345372 target = target ["future_states" ]
346373
374+ if args .breakpoint :
375+ breakpoint ()
376+
347377 # Update metrics.
348378 for metric in metrics .values ():
349379 metric .update (target .to (device ), pred .to (device ))
380+ if args .breakpoint :
381+ breakpoint ()
350382
351383 if args .cache_metrics_every_nbatches and nbatches % args .cache_metrics_every_nbatches == 0 :
352384 print (f"Processed { nbatches } batches." )
@@ -370,26 +402,35 @@ def __getitem__(self, idx):
370402 else :
371403 output_filename = f"test-multistep={ args .multistep } -{ metric_name } "
372404
373- # Get xr dataset.
374405 if isinstance (labelled_metric_output , dict ):
375406 labelled_dict = {
376407 k : (v .cpu () if hasattr (v , "cpu" ) else v ) for k , v in labelled_metric_output .items ()
377408 }
378- extra_dimensions = ["prediction_timedelta" ]
379- if "brier" in metric_name :
380- extra_dimensions = ["quantile" , "prediction_timedelta" ]
381- if "rankhist" in metric_name or "rank_hist" in metric_name :
382- extra_dimensions = ["bins" , "prediction_timedelta" ]
383- ds = convert_metric_dict_to_xarray (labelled_dict , extra_dimensions )
384-
385409 # Write labeled dict.
386410 labelled_dict ["metadata" ] = dict (
387411 groundtruth_path = args .groundtruth_path , predictions_path = args .pred_path
388412 )
389413 torch .save (labelled_dict , Path (output_dir ).joinpath (f"{ output_filename } .pt" ))
414+
415+ # Convert to xr dataset.
416+ extra_dimensions = []
417+ if args .multistep > 1 :
418+ extra_dimensions = ["prediction_timedelta" ]
419+ if "brier" in metric_name :
420+ extra_dimensions .insert (0 , "quantile" ) # ["quantile", "prediction_timedelta"]
421+ if "rankhist" in metric_name or "rank_hist" in metric_name :
422+ extra_dimensions .insert (0 , "bins" ) # ["bins", "prediction_timedelta"]
423+ if "spatial" in metric_name :
424+ # Does not yet handle extra lat/lon dims.
425+ continue
426+
427+ ds = convert_metric_dict_to_xarray (labelled_dict , extra_dimensions )
390428 else :
391429 ds = labelled_metric_output
392430 # Write xr dataset.
431+ ds .attrs ["groundtruth_path" ] = args .groundtruth_path
432+ ds .attrs ["predictions_path" ] = args .args .pred_path
433+ ds .attrs ["groundtruth_dataset_domain" ] = args .groundtruth_dataset_domain
393434 ds .to_netcdf (Path (output_dir ).joinpath (f"{ output_filename } .nc" ))
394435
395436
0 commit comments