From c4eb7ecb8b218ea07f0612f0ce4c0808852a2498 Mon Sep 17 00:00:00 2001 From: Aymeric Delefosse Date: Wed, 22 Oct 2025 11:49:18 +0200 Subject: [PATCH] fix: make caching metrics more robust to lambda preprocess serialization errors --- geoarches/evaluation/eval_multistep.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/geoarches/evaluation/eval_multistep.py b/geoarches/evaluation/eval_multistep.py index 1a6f6c2..ff4e70c 100644 --- a/geoarches/evaluation/eval_multistep.py +++ b/geoarches/evaluation/eval_multistep.py @@ -30,6 +30,7 @@ from geoarches.dataloaders import era5 from geoarches.metrics.label_wrapper import convert_metric_dict_to_xarray +from geoarches.metrics.spherical_power_spectrum import _remove_south_pole_lat from . import metric_registry @@ -66,7 +67,7 @@ def _custom_collate_fn(batch): def cache_metrics(output_dir, timestamp, nbatches, metrics): """ Saves the training state to disk. - :param filepath: Path to save the checkpoint file. + :param output_dir: Path to save the checkpoint file. :param timestamp: The timestamp of the current training iteration. :param nbatches: Number of batches already processed. :param metrics: A dictionary of metrics to save. @@ -75,9 +76,12 @@ def cache_metrics(output_dir, timestamp, nbatches, metrics): output_dir = Path(output_dir).joinpath("tmp").joinpath("_".join(metrics.keys())) output_dir.mkdir(parents=True, exist_ok=True) - if "era5_rank_histogram_50_members" in metrics: + for metric_name in metrics: # Hack: Can't pickle lambda functions. - metrics["era5_rank_histogram_50_members"].metrics["surface"].metric.preprocess = None + if "rank_histogram" in metric_name: + metrics[metric_name].metrics["surface"].metric.preprocess = None + if "power_spectrum" in metric_name: + metrics[metric_name].metrics["surface"].metric.preprocess = None # Need to save seed for rank_hist reproducibility. torch.save( @@ -92,7 +96,8 @@ def cache_metrics(output_dir, timestamp, nbatches, metrics): def load_metrics(output_dir, metric_names): """ Loads the training state from disk. - :param dir: Directory to load the checkpoint files from. + :param output_dir: Directory to load the checkpoint files from. + :param metric_names: List of metric names to load. :return: A dictionary of metrics loaded from the checkpoint files. """ output_dir = Path(output_dir).joinpath("tmp").joinpath("_".join(metric_names)) @@ -108,9 +113,13 @@ def load_metrics(output_dir, metric_names): np.random.set_state(cached_dict["np_random_state"]) for metric_name in metric_names: + # Hack: Add back lambda function. if "rank_histogram" in metric_name: - # Hack: Add back lambda function. metrics[metric_name].metrics["surface"].metric.preprocess = lambda x: x.squeeze(-3) + if "power_spectrum" in metric_name: + metrics[metric_name].metrics["surface"].metric.preprocess = ( + lambda x: _remove_south_pole_lat(x.squeeze(-3)) + ) timestamp = np.datetime64(int(file.stem), "s") print(