Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions geoarches/evaluation/eval_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from geoarches.dataloaders import era5
from geoarches.metrics.label_wrapper import convert_metric_dict_to_xarray
from geoarches.metrics.spherical_power_spectrum import _remove_south_pole_lat

from . import metric_registry

Expand Down Expand Up @@ -66,7 +67,7 @@ def _custom_collate_fn(batch):
def cache_metrics(output_dir, timestamp, nbatches, metrics):
"""
Saves the training state to disk.
:param filepath: Path to save the checkpoint file.
:param output_dir: Path to save the checkpoint file.
:param timestamp: The timestamp of the current training iteration.
:param nbatches: Number of batches already processed.
:param metrics: A dictionary of metrics to save.
Expand All @@ -75,9 +76,12 @@ def cache_metrics(output_dir, timestamp, nbatches, metrics):
output_dir = Path(output_dir).joinpath("tmp").joinpath("_".join(metrics.keys()))
output_dir.mkdir(parents=True, exist_ok=True)

if "era5_rank_histogram_50_members" in metrics:
for metric_name in metrics:
# Hack: Can't pickle lambda functions.
metrics["era5_rank_histogram_50_members"].metrics["surface"].metric.preprocess = None
if "rank_histogram" in metric_name:
metrics[metric_name].metrics["surface"].metric.preprocess = None
if "power_spectrum" in metric_name:
metrics[metric_name].metrics["surface"].metric.preprocess = None

# Need to save seed for rank_hist reproducibility.
torch.save(
Expand All @@ -92,7 +96,8 @@ def cache_metrics(output_dir, timestamp, nbatches, metrics):
def load_metrics(output_dir, metric_names):
"""
Loads the training state from disk.
:param dir: Directory to load the checkpoint files from.
:param output_dir: Directory to load the checkpoint files from.
:param metric_names: List of metric names to load.
:return: A dictionary of metrics loaded from the checkpoint files.
"""
output_dir = Path(output_dir).joinpath("tmp").joinpath("_".join(metric_names))
Expand All @@ -108,9 +113,13 @@ def load_metrics(output_dir, metric_names):
np.random.set_state(cached_dict["np_random_state"])

for metric_name in metric_names:
# Hack: Add back lambda function.
if "rank_histogram" in metric_name:
# Hack: Add back lambda function.
metrics[metric_name].metrics["surface"].metric.preprocess = lambda x: x.squeeze(-3)
if "power_spectrum" in metric_name:
metrics[metric_name].metrics["surface"].metric.preprocess = (
lambda x: _remove_south_pole_lat(x.squeeze(-3))
)

timestamp = np.datetime64(int(file.stem), "s")
print(
Expand Down