Skip to content

Commit 2d51778

Browse files
14renusRenu Singh
authored andcommitted
Add deterministic metrics to eval script, fix 1 multistep support (no prediction_timedelta dim)
1 parent a4dc345 commit 2d51778

File tree

4 files changed

+80
-24
lines changed

4 files changed

+80
-24
lines changed

geoarches/dataloaders/era5.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
last_train=lambda x: ("2018" in x),
1818
last_train_z0012=lambda x: ("2018" in x and ("0h" in x or "12h" in x)),
1919
train=lambda x: not ("2019" in x or "2020" in x or "2021" in x),
20+
# Before and after 2000. Need to load timestamp after to account for offset..
21+
train_before_2000=lambda x: any([str(y) in x for y in range(1979, 2001)]), # 1979-1999
22+
train_after_2000=lambda x: any([str(y) in x for y in range(2000, 2020)]), # 2000-2018
2023
# Splits val and test are from 2019 and 2020 respectively, but
2124
# we read the years before and after to account for offsets when
2225
# loading previous and future timestamps for an example.
23-
val=lambda x: ("2018" in x or "2019" in x or "2020" in x),
24-
test=lambda x: ("2019" in x or "2020" in x or "2021" in x),
26+
val=lambda x: ("2018" in x or "2019" in x or "2020" in x), # 2019
27+
test=lambda x: ("2019" in x or "2020" in x or "2021" in x), # 2020
2528
test_z0012=lambda x: ("2019" in x or "2020" in x or "2021" in x) and ("0h" in x or "12h" in x),
2629
test2022_z0012=lambda x: ("2022" in x) and ("0h" in x or "12h" in x), # check if that works ?
2730
recent2=lambda x: any([str(y) in x for y in range(2007, 2019)]),
@@ -274,7 +277,6 @@ def __init__(
274277
)
275278

276279
# depending on domain, re-set timestamp bounds
277-
278280
if domain in ("val", "test", "test_z0012"):
279281
# re-select timestamps
280282
year = 2019 if domain.startswith("val") else 2020

geoarches/evaluation/eval_multistep.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,17 @@ def main():
151151
required=True,
152152
help="Directory or file path to read groundtruth.",
153153
)
154+
parser.add_argument(
155+
"--groundtruth_dataset_domain",
156+
type=str,
157+
default="test_z0012",
158+
help="Domain (all, train, val, test) for groundtruth dataset. Should be a key in filename_filters. Determines filename_filter used.",
159+
)
154160
parser.add_argument(
155161
"--multistep",
156162
default=10,
157163
type=int,
158-
help="Number of future timesteps model is rolled out for evaluation. In days "
164+
help="Number of future timesteps model is rolled out for evaluation. Set to 1 if just one step."
159165
"(This script assumes lead time is 24 hours).",
160166
)
161167
parser.add_argument(
@@ -198,6 +204,16 @@ def main():
198204
action="store_true",
199205
help="Whether to evaluate climatology.",
200206
)
207+
parser.add_argument(
208+
"--verbose",
209+
action="store_true",
210+
help="Whether to print more verbose debug logs.",
211+
)
212+
parser.add_argument(
213+
"--breakpoint",
214+
action="store_true",
215+
help="Whether to add breakpoint for debug.",
216+
)
201217

202218
args = parser.parse_args()
203219

@@ -231,7 +247,7 @@ def main():
231247
surface_variables=args.surface_vars,
232248
level_variables=args.level_vars,
233249
pressure_levels=[500, 700, 850],
234-
lead_time_hours=24 if args.multistep else None,
250+
lead_time_hours=24 if args.multistep and args.multistep > 1 else None,
235251
rollout_iterations=args.multistep,
236252
).to(device)
237253
print(f"Computing: {metrics.keys()}")
@@ -240,7 +256,7 @@ def main():
240256
ds_test = era5.Era5Forecast(
241257
path=args.groundtruth_path,
242258
# filename_filter=lambda x: ("2020" in x) and ("0h" in x or "12h" in x),
243-
domain="test_z0012",
259+
domain=args.groundtruth_dataset_domain,
244260
lead_time_hours=24,
245261
multistep=args.multistep,
246262
load_prev=False,
@@ -251,30 +267,36 @@ def main():
251267
)
252268

253269
print(f"Reading {len(ds_test.files)} files from groundtruth path: {args.groundtruth_path}.")
270+
if args.verbose:
271+
print(ds_test.files)
254272

255273
# Predictions.
256274
def _pred_filename_filter(filename):
257275
if "metric" in filename:
258276
return False
259277
if args.pred_filename_filter is None:
260278
return True
261-
for substring in args.pred_filename_filter:
262-
if substring not in filename:
263-
return False
264-
return True
279+
return any([str(y) in filename for y in args.pred_filename_filter])
265280

266281
if not args.eval_clim:
282+
dimension_indexers = dict(level=[500, 700, 850])
283+
if args.multistep > 1:
284+
dimension_indexers["prediction_timedelta"] = [
285+
timedelta(days=i) for i in range(1, args.multistep + 1)
286+
]
287+
267288
ds_pred = era5.Era5Dataset(
268289
path=args.pred_path,
269290
filename_filter=_pred_filename_filter, # Update filename_filter to filter within pred_path.
270291
variables=variables,
271292
return_timestamp=True,
272-
dimension_indexers=dict(
273-
prediction_timedelta=[timedelta(days=i) for i in range(1, args.multistep + 1)],
274-
level=[500, 700, 850],
275-
),
293+
dimension_indexers=dimension_indexers,
276294
)
277295
print(f"Reading {len(ds_pred.files)} files from pred_path: {args.pred_path}.")
296+
if args.verbose:
297+
print(ds_pred.files)
298+
print("# prediction examples:", len(ds_pred))
299+
print("# test examples:", len(ds_test))
278300

279301
if reloaded_timestamp is not None:
280302
# Don't include the reloaded timestamp.
@@ -315,8 +337,13 @@ def __getitem__(self, idx):
315337
collate_fn=_custom_collate_fn,
316338
)
317339

340+
if args.breakpoint:
341+
breakpoint()
342+
318343
# iterable = tqdm(dl_test) if args.eval_clim else tqdm(zip(dl_test, dl_pred))
319344
for next_batch in tqdm(dl_test) if args.eval_clim else tqdm(zip(dl_test, dl_pred)):
345+
if args.verbose:
346+
print(f"{nbatches} batch")
320347
nbatches += 1
321348

322349
if args.eval_clim:
@@ -333,7 +360,7 @@ def __getitem__(self, idx):
333360
pred = pred.apply(
334361
lambda tensor: rearrange(
335362
tensor,
336-
"batch var mem ... lev lat lon -> batch mem ... var lev lat lon",
363+
"batch var ... lev lat lon -> batch ... var lev lat lon",
337364
)
338365
)
339366
timestamps = target["timestamp"]
@@ -344,9 +371,14 @@ def __getitem__(self, idx):
344371
else:
345372
target = target["future_states"]
346373

374+
if args.breakpoint:
375+
breakpoint()
376+
347377
# Update metrics.
348378
for metric in metrics.values():
349379
metric.update(target.to(device), pred.to(device))
380+
if args.breakpoint:
381+
breakpoint()
350382

351383
if args.cache_metrics_every_nbatches and nbatches % args.cache_metrics_every_nbatches == 0:
352384
print(f"Processed {nbatches} batches.")
@@ -370,26 +402,35 @@ def __getitem__(self, idx):
370402
else:
371403
output_filename = f"test-multistep={args.multistep}-{metric_name}"
372404

373-
# Get xr dataset.
374405
if isinstance(labelled_metric_output, dict):
375406
labelled_dict = {
376407
k: (v.cpu() if hasattr(v, "cpu") else v) for k, v in labelled_metric_output.items()
377408
}
378-
extra_dimensions = ["prediction_timedelta"]
379-
if "brier" in metric_name:
380-
extra_dimensions = ["quantile", "prediction_timedelta"]
381-
if "rankhist" in metric_name or "rank_hist" in metric_name:
382-
extra_dimensions = ["bins", "prediction_timedelta"]
383-
ds = convert_metric_dict_to_xarray(labelled_dict, extra_dimensions)
384-
385409
# Write labeled dict.
386410
labelled_dict["metadata"] = dict(
387411
groundtruth_path=args.groundtruth_path, predictions_path=args.pred_path
388412
)
389413
torch.save(labelled_dict, Path(output_dir).joinpath(f"{output_filename}.pt"))
414+
415+
# Convert to xr dataset.
416+
extra_dimensions = []
417+
if args.multistep > 1:
418+
extra_dimensions = ["prediction_timedelta"]
419+
if "brier" in metric_name:
420+
extra_dimensions.insert(0, "quantile") # ["quantile", "prediction_timedelta"]
421+
if "rankhist" in metric_name or "rank_hist" in metric_name:
422+
extra_dimensions.insert(0, "bins") # ["bins", "prediction_timedelta"]
423+
if "spatial" in metric_name:
424+
# Does not yet handle extra lat/lon dims.
425+
continue
426+
427+
ds = convert_metric_dict_to_xarray(labelled_dict, extra_dimensions)
390428
else:
391429
ds = labelled_metric_output
392430
# Write xr dataset.
431+
ds.attrs["groundtruth_path"] = args.groundtruth_path
432+
ds.attrs["predictions_path"] = args.args.pred_path
433+
ds.attrs["groundtruth_dataset_domain"] = args.groundtruth_dataset_domain
393434
ds.to_netcdf(Path(output_dir).joinpath(f"{output_filename}.nc"))
394435

395436

geoarches/evaluation/metric_registry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torchmetrics
66

77
from geoarches.metrics.brier_skill_score import Era5BrierSkillScore
8+
from geoarches.metrics.deterministic_metrics import Era5DeterministicMetrics
89
from geoarches.metrics.ensemble_metrics import Era5EnsembleMetrics
910
from geoarches.metrics.rank_histogram import Era5RankHistogram
1011
from geoarches.metrics.spherical_power_spectrum import Era5PowerSpectrum
@@ -38,6 +39,13 @@ def instantiate_metric(metric_name: str, **extra_kwargs):
3839
#######################################################
3940
###### Registering classes with their arguments. ######
4041
#######################################################
42+
register_metric(
43+
"era5_deterministic_metrics",
44+
Era5DeterministicMetrics,
45+
)
46+
register_metric(
47+
"era5_deterministic_metrics_with_spatial", Era5DeterministicMetrics, compute_per_gridpoint=True
48+
)
4149
register_metric(
4250
"era5_ensemble_metrics",
4351
Era5EnsembleMetrics,

geoarches/metrics/label_wrapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ def _convert(self, raw_metric_dict: Dict[str, Tensor]):
5959
labeled_dict = dict()
6060
for var, index in self.variable_indices.items():
6161
for metric_name, metric in raw_metric_dict.items():
62-
labeled_dict[f"{metric_name}_{var}"] = metric.__getitem__((..., *index))
62+
if any(s in metric_name for s in ["spatial", "per_gridpt", "per_gridpoint"]):
63+
# Account for lat, lon dims
64+
labeled_dict[f"{metric_name}_{var}"] = metric[..., *index, :, :]
65+
else:
66+
labeled_dict[f"{metric_name}_{var}"] = metric[..., *index]
6367
return labeled_dict
6468

6569
def update(self, *args: Any, **kwargs: Any) -> None:
@@ -134,6 +138,7 @@ def _convert_coord(name, value):
134138
labels = label.split("_")
135139
if len(labels) - 2 != len(extra_dimensions):
136140
raise ValueError(
141+
f"Assumes metric name {label} is in format <metric>_<var>_<dim1>...."
137142
f"Expected length of extra_dimensions for key {label} to be: {len(labels) - 2}. Got extra_dimensions={extra_dimensions}."
138143
)
139144
metrics.add(labels[0])

0 commit comments

Comments
 (0)