diff --git a/.gitignore b/.gitignore index f94e82f..72bccf3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,15 @@ .vscode outputs lightning_logs +logs .DS_Store +# softlinks +evalstore +modelstore +data +wandblogs + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -12,6 +19,9 @@ __pycache__/ # C extensions *.so +# era5 quantiles +era5-quantiles-*.nc + # Distribution / packaging .Python build/ diff --git a/geoarches/configs/config.yaml b/geoarches/configs/config.yaml index fc80aa3..07eefda 100644 --- a/geoarches/configs/config.yaml +++ b/geoarches/configs/config.yaml @@ -3,6 +3,7 @@ defaults: - cluster: local # Tells hydra to use cluster/local.yaml when composing the cfg object. - dataloader: era5 - module: archesweather + - stats: pangu # Normalization scheme (pangu or graphcast or None) - override hydra/job_logging: none - override hydra/hydra_logging: none - _self_ diff --git a/geoarches/configs/dataloader/era5.yaml b/geoarches/configs/dataloader/era5.yaml index 2b7091f..47c62fe 100644 --- a/geoarches/configs/dataloader/era5.yaml +++ b/geoarches/configs/dataloader/era5.yaml @@ -3,8 +3,16 @@ dataset: path: data/era5_240/full/ lead_time_hours: 24 multistep: ${oc.select:module.train.rollout_iterations,1} - norm_scheme: pangu load_prev: True + variables: + surface: ${stats.module.variables.surface} + level: ${stats.module.variables.level} + dimension_indexers: + level: + - 'level' + - ${stats.module.levels} + warning_on_nan: True + interpolate_nans: True. # Remove nans in the model input (ie. SST) validation_args: multistep: ${oc.select:module.val.rollout_iterations,1} diff --git a/geoarches/configs/dataloader/era5pred.yaml b/geoarches/configs/dataloader/era5pred.yaml index 406062d..b7947b0 100644 --- a/geoarches/configs/dataloader/era5pred.yaml +++ b/geoarches/configs/dataloader/era5pred.yaml @@ -3,7 +3,6 @@ dataset: path: data/era5_240/full/ pred_path: data/outputs/deterministic/jzh-geoaw-m-seed0 lead_time_hours: 24 # mixed - norm_scheme: pangu load_prev: True load_hard_neg: False diff --git a/geoarches/configs/module/metrics/era5_brier.yaml b/geoarches/configs/module/metrics/era5_brier.yaml index bea1fef..aafd352 100644 --- a/geoarches/configs/module/metrics/era5_brier.yaml +++ b/geoarches/configs/module/metrics/era5_brier.yaml @@ -1,3 +1,6 @@ era5_brier_metric: _target_: geoarches.metrics.brier_skill_score.Era5BrierSkillScore + surface_variables: ${stats.module.variables.surface} + level_variables: ${stats.module.variables.level} + pressure_levels: ${stats.module.levels} lead_time_hours: ${dataloader.dataset.lead_time_hours} \ No newline at end of file diff --git a/geoarches/configs/module/metrics/era5_deterministic.yaml b/geoarches/configs/module/metrics/era5_deterministic.yaml index f5f8191..28f2aa5 100644 --- a/geoarches/configs/module/metrics/era5_deterministic.yaml +++ b/geoarches/configs/module/metrics/era5_deterministic.yaml @@ -1,3 +1,6 @@ era5_deterministic_metrics: _target_: geoarches.metrics.deterministic_metrics.Era5DeterministicMetrics + surface_variables: ${stats.module.variables.surface} + level_variables: ${stats.module.variables.level} + pressure_levels: ${stats.module.levels} lead_time_hours: ${dataloader.dataset.lead_time_hours} \ No newline at end of file diff --git a/geoarches/configs/module/metrics/era5_ensemble.yaml b/geoarches/configs/module/metrics/era5_ensemble.yaml index 395404d..0605ddf 100644 --- a/geoarches/configs/module/metrics/era5_ensemble.yaml +++ b/geoarches/configs/module/metrics/era5_ensemble.yaml @@ -1,3 +1,6 @@ era5_ensemble_metrics: _target_: geoarches.metrics.ensemble_metrics.Era5EnsembleMetrics + surface_variables: ${stats.module.variables.surface} + level_variables: ${stats.module.variables.level} + pressure_levels: ${stats.module.levels} lead_time_hours: ${dataloader.dataset.lead_time_hours} \ No newline at end of file diff --git a/geoarches/configs/stats/graphcast.yaml b/geoarches/configs/stats/graphcast.yaml new file mode 100644 index 0000000..c6b2e9b --- /dev/null +++ b/geoarches/configs/stats/graphcast.yaml @@ -0,0 +1,27 @@ +module: + _target_: geoarches.utils.normalization.NormalizationStatistics + variables: + surface: + - 10m_u_component_of_wind + - 10m_v_component_of_wind + - 2m_temperature + - mean_sea_level_pressure + level: + - geopotential + - u_component_of_wind + - v_component_of_wind + - temperature + - specific_humidity + - vertical_velocity + loss_weight_per_variable: + surface: [0.1, 0.1, 1.0, 0.1] + level: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + + levels: [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] + norm_scheme: graphcast + +compute_loss_coeffs_args: + latitude: 121 + pow: 2 + use_weatherbench_lat_coeffs: true + loss_delta_normalization: true diff --git a/geoarches/configs/stats/pangu.yaml b/geoarches/configs/stats/pangu.yaml new file mode 100644 index 0000000..473565e --- /dev/null +++ b/geoarches/configs/stats/pangu.yaml @@ -0,0 +1,26 @@ +module: + _target_: geoarches.utils.normalization.NormalizationStatistics + variables: + surface: + - 10m_u_component_of_wind + - 10m_v_component_of_wind + - 2m_temperature + - mean_sea_level_pressure + level: + - geopotential + - u_component_of_wind + - v_component_of_wind + - temperature + - specific_humidity + - vertical_velocity + loss_weight_per_variable: + surface: [0.1, 0.1, 1.0, 0.1] + level: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + levels: [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] + norm_scheme: pangu + +compute_loss_coeffs_args: + latitude: 121 + pow: 2 + use_weatherbench_lat_coeffs: true + loss_delta_normalization: true diff --git a/geoarches/dataloaders/dcpp.py b/geoarches/dataloaders/dcpp.py index 5c04f4f..eabae1f 100644 --- a/geoarches/dataloaders/dcpp.py +++ b/geoarches/dataloaders/dcpp.py @@ -32,6 +32,14 @@ def replace_nans(tensordict, value=0): ) +default_dimension_indexers = { + "level": ("plev", pressure_levels), + "latitude": ("lat", slice(None)), + "longitude": ("lon", slice(None)), + "time": ("time", slice(None)), +} + + class DCPPForecast(XarrayDataset): """ Load DCPP data for the forecast task. @@ -54,6 +62,7 @@ def __init__( limit_examples: int = 0, mask_value=0, variables=None, + dimension_indexers: dict = default_dimension_indexers, ): """ Args: @@ -67,6 +76,9 @@ def __init__( load_clim: Whether to load climatology. limit_examples: Return set number of examples in dataset mask_value: what value to use as mask for nan values in dataset + dimension_indexers: dict, dimension indexers for the dataset. + Default is set to pressure levels, latitude, longitude, and time. + """ self.__dict__.update(locals()) # concise way to update self with input arguments @@ -76,7 +88,7 @@ def __init__( filename_filter = filename_filters[domain] if variables is None: variables = dict(surface=surface_variables, level=level_variables) - dimension_indexers = {"plev": pressure_levels} + super().__init__( path, filename_filter=filename_filter, diff --git a/geoarches/dataloaders/era5.py b/geoarches/dataloaders/era5.py index 2fb393b..786cf23 100644 --- a/geoarches/dataloaders/era5.py +++ b/geoarches/dataloaders/era5.py @@ -1,4 +1,3 @@ -import importlib.resources from datetime import timedelta from pathlib import Path from typing import Callable, Dict, List @@ -7,9 +6,13 @@ import pandas as pd import torch import xarray as xr -from tensordict.tensordict import TensorDict +from hydra.utils import instantiate -from .. import stats as geoarches_stats +from .era5_constants import ( + arches_default_level_variables, + arches_default_pressure_levels, + arches_default_surface_variables, +) from .netcdf import XarrayDataset filename_filters = dict( @@ -17,41 +20,50 @@ last_train=lambda x: ("2018" in x), last_train_z0012=lambda x: ("2018" in x and ("0h" in x or "12h" in x)), train=lambda x: not ("2019" in x or "2020" in x or "2021" in x), + # Before and after 2000. Need to load timestamp after to account for offset.. + train_before_2000=lambda x: any([str(y) in x for y in range(1979, 2001)]), # 1979-1999 + train_after_2000=lambda x: any([str(y) in x for y in range(2000, 2020)]), # 2000-2018 # Splits val and test are from 2019 and 2020 respectively, but # we read the years before and after to account for offsets when # loading previous and future timestamps for an example. - val=lambda x: ("2018" in x or "2019" in x or "2020" in x), - test=lambda x: ("2019" in x or "2020" in x or "2021" in x), + val=lambda x: ("2018" in x or "2019" in x or "2020" in x), # 2019 + test=lambda x: ("2019" in x or "2020" in x or "2021" in x), # 2020 test_z0012=lambda x: ("2019" in x or "2020" in x or "2021" in x) and ("0h" in x or "12h" in x), test2022_z0012=lambda x: ("2022" in x) and ("0h" in x or "12h" in x), # check if that works ? recent2=lambda x: any([str(y) in x for y in range(2007, 2019)]), empty=lambda x: False, ) - -pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] - -level_variables = [ - "geopotential", - "u_component_of_wind", - "v_component_of_wind", - "temperature", - "specific_humidity", - "vertical_velocity", -] - -surface_variables = [ - "10m_u_component_of_wind", - "10m_v_component_of_wind", - "2m_temperature", - "mean_sea_level_pressure", -] - +# Short names for variables used in tensordicts and metrics surface_variables_short = { "10m_u_component_of_wind": "U10m", "10m_v_component_of_wind": "V10m", "2m_temperature": "T2m", - "mean_sea_level_pressure": "SP", + "mean_sea_level_pressure": "MSLP", + "low_vegetation_cover": "CVL", + "high_vegetation_cover": "CVH", + "tympe_of_low_vegetation_cover": "TVL", + "type_of_high_vegetation_cover": "TVH", + "soil_type": "SLT", + "standard_deviation_of_filtred_subgrid_orography": "SDFSOR", + "angle_of_sub_gridscale_orography": "ANOR", + "anisotropy_of_subgridscale_orography": "ASOR", + "geopotential_at_surface": "Z0", + "lake_cover": "LC", + "lake_depth": "LD", + "sea_ice_cover": "SIC", + "sea_surface_temperature": "SST", + "slope_of_subgridscale_orography": "SSOR", + "standard_deviation_of_orography": "SDFO", + "surface_pressure": "SP", + "toa_incident_solar_radiation": "SIS", + "toa_incident_solar_radiation_12hr": "SIS12", + "toa_incident_solar_radiation_24hr": "SIS24", + "total_cloud_cover": "TCC", + "total_precipitation_12hr": "TP", + "total_precipitation_24hr": "TP24", + "total_column_water_vapour": "TCWV", + "wind_speed": "WS", } level_variables_short = { @@ -63,13 +75,21 @@ "vertical_velocity": "W", } +default_dimension_indexers = { + "latitude": ("latitude", np.arange(90, -90 - 1e-6, -180 / 120)), # decreasing lats + "longitude": ("longitude", np.arange(0, 360, 360 / 240)), + "level": ("level", arches_default_pressure_levels), +} + -def get_surface_variable_indices(variables=surface_variables): +def get_surface_variable_indices(variables=arches_default_surface_variables): """Mapping from surface variable name to (var, lev) index in ERA5 dataset.""" return {surface_variables_short[var]: (i, 0) for i, var in enumerate(variables)} -def get_level_variable_indices(pressure_levels=pressure_levels, variables=level_variables): +def get_level_variable_indices( + pressure_levels=arches_default_pressure_levels, variables=arches_default_level_variables +): """Mapping from level variable name to (var, lev) index in ERA5 dataset.""" out = {} for var_idx, var in enumerate(variables): @@ -80,11 +100,13 @@ def get_level_variable_indices(pressure_levels=pressure_levels, variables=level_ def get_headline_level_variable_indices( - pressure_levels=pressure_levels, level_variables=level_variables + pressure_levels=arches_default_pressure_levels, + level_variables=arches_default_level_variables, + headline_variables=("Z500", "T850", "Q700", "U850", "V850"), ): """Mapping for main level variables.""" out = get_level_variable_indices(pressure_levels, level_variables) - return {k: v for k, v in out.items() if k in ("Z500", "T850", "Q700", "U850", "V850")} + return {k: v for k, v in out.items() if k in headline_variables} class Era5Dataset(XarrayDataset): @@ -99,8 +121,10 @@ def __init__( domain: str = "train", filename_filter: Callable | None = None, variables: Dict[str, List[str]] | None = None, - dimension_indexers: Dict[str, list] | None = None, + dimension_indexers: Dict[str, list] = default_dimension_indexers, return_timestamp: bool = False, + warning_on_nan: bool = True, + interpolate_nans: bool = False, ): """ Args: @@ -112,32 +136,47 @@ def __init__( variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict. e.g. {surface:[...], level:[...]}. By default uses standard 6 level and 4 surface vars. dimension_indexers: Dict of dimensions to select using Dataset.sel(dimension_indexers). + Used to select levels and lat/lon resolution. return_timestamp: Whether to return tuple of (example, timestamp) from __getitem__(). """ if filename_filter is None: filename_filter = filename_filters[domain] if variables is None: - variables = dict(surface=surface_variables, level=level_variables) + variables = dict( + surface=arches_default_surface_variables, level=arches_default_level_variables + ) + + all_indexers = default_dimension_indexers.copy() + all_indexers.update(dimension_indexers or {}) super().__init__( path, filename_filter=filename_filter, variables=variables, - dimension_indexers=dimension_indexers, + dimension_indexers=all_indexers, return_timestamp=return_timestamp, - warning_on_nan=True, + warning_on_nan=warning_on_nan, + interpolate_nans=interpolate_nans, ) def convert_to_tensordict(self, xr_dataset): """ input xarr should be a single time slice """ - if self.dimension_indexers: - xr_dataset = xr_dataset.sel(self.dimension_indexers) - # Workaround to avoid calling sel() after transponse() to avoid OOM. - self.already_ran_index_selection = True - xr_dataset = xr_dataset.transpose(..., "level", "latitude", "longitude") + if self.slice_indexers: + xr_dataset = xr_dataset.sel(**self.slice_indexers) + if self.other_indexers: + xr_dataset = xr_dataset.sel(**self.other_indexers, method="nearest", tolerance=1e-6) + # Workaround to avoid calling sel() after transponse() to avoid OOM. + self.already_ran_index_selection = True + xr_dataset = xr_dataset.transpose( + ..., + self.level_dim_name, + self.latitude_dim_name, + self.longitude_dim_name, + ) + tdict = super().convert_to_tensordict(xr_dataset) # we don't do operations on xr datasets since it takes more time than on tensors @@ -175,25 +214,46 @@ def convert_to_xarray(self, tdict, timestamp, levels=None): surface = tdict["surface"].squeeze(-3) level = tdict["level"] + # Xarray coordinates. times = pd.to_datetime(timestamp.cpu().numpy(), unit="s").tz_localize(None) + coords = {self.time_dim_name: times} + + if self.latitude_dim_name in self.other_indexers: + coords[self.latitude_dim_name] = self.dimension_indexers["latitude"][1] + if self.longitude_dim_name in self.other_indexers: + coords[self.longitude_dim_name] = self.dimension_indexers["longitude"][1] + if self.level_dim_name in self.other_indexers: + coords[self.level_dim_name] = self.dimension_indexers["level"][1] + xr_dataset = xr.Dataset( data_vars=dict( **{ - v: (["time", "level", "latitude", "longitude"], level[:, i]) + v: ( + [ + self.time_dim_name, + self.level_dim_name, + self.latitude_dim_name, + self.longitude_dim_name, + ], + level[:, i], + ) for (i, v) in enumerate(self.variables["level"]) }, **{ - v: (["time", "latitude", "longitude"], surface[:, i]) + v: ( + [ + self.time_dim_name, + self.latitude_dim_name, + self.longitude_dim_name, + ], + surface[:, i], + ) for (i, v) in enumerate(self.variables["surface"]) }, ), - coords=dict( - time=times, - latitude=np.arange(90, -90 - 1e-6, -180 / 120), # decreasing lats - longitude=np.arange(0, 360, 360 / 240), - level=pressure_levels, - ), + coords=coords, ) + if levels is not None: xr_dataset = xr_dataset.sel(level=levels) @@ -233,21 +293,24 @@ class Era5Forecast(Era5Dataset): def __init__( self, + stats_cfg, path: str = "data/era5_240/full/", domain: str = "train", filename_filter: Callable | None = None, timedelta_hours: int = None, variables: Dict[str, List[str]] | None = None, - dimension_indexers: Dict[str, list] | None = None, - norm_scheme: str | None = "pangu", + dimension_indexers: Dict[str, list] = default_dimension_indexers, lead_time_hours: int = 24, multistep: int = 1, load_prev: bool = True, load_clim: bool = False, switch_recent_data_after_steps: int = 250000, + warning_on_nan: bool = True, + interpolate_nans: bool = False, ): """ Args: + stats_cfg: Configuration for normalization statistics. None if no normalization is needed. path: Single filepath or directory holding files. domain: Specify data split for the default filename filters (eg. train, val, test, testz0012..) filename_filter: To filter files within `path` based on filename. If set, does not use `domain` param. @@ -256,25 +319,31 @@ def __init__( multistep: Number of future states to load. By default, loads next state only (current time + lead_time_hours). load_prev: Whether to load state at previous timestamp (current time - lead_time_hours). load_clim: Whether to load climatology. - norm_scheme: Normalization scheme to use. Can be None to perform no normalization. timedelta_hours: Time difference (hours) between 2 consecutive timestamps. If not expecified, default is 6 or 12, depending on domain. variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict. e.g. {surface:[...], level:[...] By default uses standard 6 level and 4 surface vars. dimension_indexers: Dict of dimensions to select using Dataset.sel(dimension_indexers). + Used to select levels and lat/lon resolution. + warning_on_nan: Whether to raise a warning if NaN values are encountered in model input (prev and current state). + interpolate_nans: Whether to interpolate NaN values for model input (prev and current state). """ self.__dict__.update(locals()) + all_indexers = default_dimension_indexers.copy() + all_indexers.update(dimension_indexers or {}) + super().__init__( path, filename_filter=filename_filter, domain=domain, variables=variables, - dimension_indexers=dimension_indexers, + dimension_indexers=all_indexers, + warning_on_nan=warning_on_nan, + interpolate_nans=interpolate_nans, ) # depending on domain, re-set timestamp bounds - if domain in ("val", "test", "test_z0012"): # re-select timestamps year = 2019 if domain.startswith("val") else 2020 @@ -293,28 +362,17 @@ def __init__( self.timedelta = 6 if "z0012" not in domain else 12 self.current_multistep = 1 - # include vertical component by default - geoarches_stats_path = importlib.resources.files(geoarches_stats) - norm_file_path = geoarches_stats_path / "pangu_norm_stats2_with_w.pt" - pangu_stats = torch.load(norm_file_path, weights_only=True) + # Load normalization statistics. + self.norm_scheme = None + if stats_cfg: + stats = instantiate(stats_cfg.module) + self.data_mean, self.data_std = stats.load_normalization_stats() + self.norm_scheme = stats.norm_scheme - # normalization, - if self.norm_scheme == "pangu": - self.data_mean = TensorDict( - surface=pangu_stats["surface_mean"], - level=pangu_stats["level_mean"], + # Check levels. + assert np.equal(self.dimension_indexers["level"][1], stats.levels).all(), ( + "Levels passed to NormalizationStatistics do not match the levels passed to the dataset dimension_indexers." ) - self.data_std = TensorDict( - surface=pangu_stats["surface_std"], - level=pangu_stats["level_std"], - ) - - # variable names - # TODO: fix this below - self.surface_variables = ["U10m", "V10m", "T2m", "SP"] - self.level_variables = [ - a + str(p) for a in ["Z", "U", "V", "T", "Q"] for p in pressure_levels - ] # Load climatology. self.clim_path = Path(path).parent.joinpath("era5_240_clim.nc") @@ -333,10 +391,12 @@ def __getitem__(self, i, normalize=True): # load current state out["timestamp"] = torch.tensor( self.id2pt[i][2].item() // 10**9, # how to convert to tensor ? - dtype=torch.int32, + dtype=torch.int64, ) # time in seconds - out["state"] = super().__getitem__(i) + out["state"] = super().__getitem__( + i, interpolate_nans=self.interpolate_nans, warning_on_nan=self.warning_on_nan + ) out["lead_time_hours"] = torch.tensor(self.lead_time_hours * int(self.multistep)).int() @@ -344,17 +404,27 @@ def __getitem__(self, i, normalize=True): T = self.lead_time_hours # multistep if self.multistep > 0: - out["next_state"] = super().__getitem__(i + T // self.timedelta) + out["next_state"] = super().__getitem__( + i + T // self.timedelta, interpolate_nans=False, warning_on_nan=False + ) # Load multiple future timestamps if specified. if self.multistep > 1: future_states = [] for k in range(1, self.multistep + 1): - future_states.append(super().__getitem__(i + k * T // self.timedelta)) + future_states.append( + super().__getitem__( + i + k * T // self.timedelta, interpolate_nans=False, warning_on_nan=False + ) + ) out["future_states"] = torch.stack(future_states, dim=0) if self.load_prev: - out["prev_state"] = super().__getitem__(i - self.lead_time_hours // self.timedelta) + out["prev_state"] = super().__getitem__( + i - self.lead_time_hours // self.timedelta, + interpolate_nans=self.interpolate_nans, + warning_on_nan=self.warning_on_nan, + ) if self.load_clim: clim_xr = xr.open_dataset(self.clim_path) diff --git a/geoarches/dataloaders/era5_constants.py b/geoarches/dataloaders/era5_constants.py new file mode 100644 index 0000000..ea5a1a6 --- /dev/null +++ b/geoarches/dataloaders/era5_constants.py @@ -0,0 +1,62 @@ +# Constants for ERA5 dataset + +# 37 pressure levels available from graphcast stats +pressure_levels = [ + 1, + 2, + 3, + 5, + 7, + 10, + 20, + 30, + 50, + 70, + 100, + 125, + 150, + 175, + 200, + 225, + 250, + 300, + 350, + 400, + 450, + 500, + 550, + 600, + 650, + 700, + 750, + 775, + 800, + 825, + 850, + 875, + 900, + 925, + 950, + 975, + 1000, +] + + +# ArchesWeather default settings for ERA5 dataset. +arches_default_pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] + +arches_default_level_variables = [ + "geopotential", + "u_component_of_wind", + "v_component_of_wind", + "temperature", + "specific_humidity", + "vertical_velocity", +] + +arches_default_surface_variables = [ + "10m_u_component_of_wind", + "10m_v_component_of_wind", + "2m_temperature", + "mean_sea_level_pressure", +] diff --git a/geoarches/dataloaders/era5pred.py b/geoarches/dataloaders/era5pred.py index 857f201..bd536c8 100644 --- a/geoarches/dataloaders/era5pred.py +++ b/geoarches/dataloaders/era5pred.py @@ -11,13 +11,13 @@ class Era5ForecastWithPrediction(era5.Era5Forecast): def __init__( self, + stats_cfg, path="data/era5_240/full/", domain="train", filename_filter=None, lead_time_hours=24, pred_path="data/era5_pred_archesweather-S/", load_prev=False, - norm_scheme="pangu", load_hard_neg=False, variables=None, **kwargs, @@ -31,17 +31,16 @@ def __init__( lead_time_hours: Time difference between current state and previous and future states. pred_path: Single filepath or directory holding model prediction files to also load. load_prev: Whether to load state at previous timestamp (current time - lead_time_hours). - norm_scheme: Normalization scheme to use. Can be None to perform no normalization. load_hard_neg: Whether to additionallty load hard negative example for contrastive learning. variables: Variables to load from dataset. Dict holding variable lists mapped by their keys to be processed into tensordict. e.g. {surface:[...], level:[...] By default uses standard 6 level and 4 surface vars. """ super().__init__( + stats_cfg=stats_cfg, path=path, domain=domain, lead_time_hours=lead_time_hours, filename_filter=filename_filter, - norm_scheme=norm_scheme, load_prev=load_prev, **kwargs, ) diff --git a/geoarches/dataloaders/netcdf.py b/geoarches/dataloaders/netcdf.py index 52aa17d..9a987f9 100644 --- a/geoarches/dataloaders/netcdf.py +++ b/geoarches/dataloaders/netcdf.py @@ -18,6 +18,14 @@ } +default_dimension_indexers = { + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + "level": ("level", slice(None)), + "time": ("time", slice(None)), +} + + class XarrayDataset(torch.utils.data.Dataset): """ dataset to read a list of xarray files and iterate through it by timestamp. @@ -34,26 +42,53 @@ def __init__( dimension_indexers: Dict[str, list] | None = None, filename_filter: Callable = lambda _: True, # condition to keep file in dataset return_timestamp: bool = False, - warning_on_nan: bool = False, + warning_on_nan: bool = True, limit_examples: int | None = None, + interpolate_nans: bool = False, ): """ Args: - path: Single filepath or directory holding xarray files. - variables: Dict holding xarray data variable lists mapped by their keys to be processed into tensordict. - e.g. {surface: [data_var1, datavar2, ...], level: [...]} - Used in convert_to_tensordict() to read data arrays in the xarray dataset and convert to tensordict. - dimension_indexers: Dict of dimensions to select in xarray using Dataset.sel(dimension_indexers). - filename_filter: To filter files within `path` based on filename. - return_timestamp: Whether to return timestamp in __getitem__() along with tensordict. - warning_on_nan: Whether to log warning if nan data found. - limit_examples: Return set number of examples in dataset + path: Single filepath or directory holding xarray files. + variables: Dict holding xarray data variable lists mapped by their keys to be processed into tensordict. + e.g. {surface: [data_var1, datavar2, ...], level: [...]} + Used in convert_to_tensordict() to read data arrays in the xarray dataset and convert to tensordict. + dimension_indexers: Dict of dimensions to select in xarray using Dataset.sel(dimension_indexers). Also provides + the dimension names to treat the xarray dataset as tensordict. + Defaults to default_dimension_indexers. + If not provided, defaults to selecting all data in all dimensions. + First value is the dimension name in xarray, second value is the indexer + If None is used as the indexer, all coordinates in that dimension are used. + filename_filter: To filter files within `path` based on filename. + return_timestamp: Whether to return timestamp in __getitem__() along with tensordict. + warning_on_nan: Whether to log warning if nan data found. + limit_examples: Return set number of examples in dataset + interpolate_nans: Whether to fill NaN values in the data with the mean of the + data across latitude and longitude dimensions. Defaults to True. """ self.filename_filter = filename_filter self.variables = variables - self.dimension_indexers = dimension_indexers + + self.dimension_indexers = default_dimension_indexers.copy() + self.dimension_indexers.update(dimension_indexers or {}) + + self.time_dim_name = self.dimension_indexers["time"][0] + self.latitude_dim_name = self.dimension_indexers["latitude"][0] + self.longitude_dim_name = self.dimension_indexers["longitude"][0] + self.level_dim_name = self.dimension_indexers["level"][0] + + # Separate indexers with slice and those without. + indexers = {v[0]: v[1] for k, v in self.dimension_indexers.items() if k != "time"} + self.slice_indexers = {k: v for k, v in indexers.items() if isinstance(v, slice)} + self.other_indexers = {k: v for k, v in indexers.items() if not isinstance(v, slice)} + if not self.slice_indexers: + self.slice_indexers = None + if not self.other_indexers: + self.other_indexers = None + self.return_timestamp = return_timestamp self.warning_on_nan = warning_on_nan + self.interpolate_nans = interpolate_nans + # Workaround to avoid calling ds.sel() after ds.transponse() to avoid OOM. self.already_ran_index_selection = False @@ -70,7 +105,7 @@ def __init__( self.files = sorted( [str(x) for x in files if filename_filter(x.name)], - key=lambda x: x.replace("6h", "06h").replace("0h", "00h"), + key=lambda x: x.replace("_6h", "_06h").replace("_0h", "_00h"), ) if len(self.files) == 0: raise ValueError("filename_filter filtered all files under path:", path) @@ -127,7 +162,7 @@ def set_timestamp_bounds(self, low, high, debug=False): def __len__(self): return len(self.id2pt) - def convert_to_tensordict(self, xr_dataset): + def convert_to_tensordict(self, xr_dataset, debug=False): """ Convert xarray dataset to tensordict. @@ -135,20 +170,38 @@ def convert_to_tensordict(self, xr_dataset): e.g. {surface:[data_var1, data_var2, ...], level:[...]} """ # Optionally select dimensions. - if self.dimension_indexers and not self.already_ran_index_selection: - xr_dataset = xr_dataset.sel(self.dimension_indexers) + if not self.already_ran_index_selection: + if debug: + print(xr_dataset) + print(self.slice_indexers) + print(self.other_indexers) + + # Apply sel for non-slice indexers with method and tolerance + if self.other_indexers: + xr_dataset = xr_dataset.sel( + **self.other_indexers, method="nearest", tolerance=1e-6 + ) + # Apply sel for slice indexers without method and tolerance + if self.slice_indexers: + xr_dataset = xr_dataset.sel(**self.slice_indexers) + self.already_ran_index_selection = False # Reset for next call. np_arrays = { key: xr_dataset[list(variables)].to_array().to_numpy() for key, variables in self.variables.items() } + tdict = TensorDict( {key: torch.from_numpy(np_array).float() for key, np_array in np_arrays.items()} ) + return tdict - def __getitem__(self, i, return_timestamp=False): + def __getitem__(self, i, return_timestamp=False, interpolate_nans=None, warning_on_nan=None): + interpolate_nans = interpolate_nans or self.interpolate_nans + warning_on_nan = warning_on_nan or self.warning_on_nan + file_id, line_id, timestamp = self.id2pt[i] if self.cached_fileid != file_id: @@ -158,14 +211,26 @@ def __getitem__(self, i, return_timestamp=False): self.cached_fileid = file_id obsi = self.cached_xrdataset.isel(time=line_id) + if interpolate_nans: + obsi = obsi.fillna( + value=obsi.mean( + dim=[ + self.dimension_indexers["latitude"][0], + self.dimension_indexers["longitude"][0], + ], + skipna=True, + ) + ) + tdict = self.convert_to_tensordict(obsi) - if self.warning_on_nan: + if warning_on_nan: if any([x.isnan().any().item() for x in tdict.values()]): warnings.warn(f"NaN values detected in {file_id} {line_id} {self.files[file_id]}") if return_timestamp or self.return_timestamp: timestamp = self.cached_xrdataset.time[line_id].values.item() - timestamp = torch.tensor(timestamp // 10**9, dtype=torch.int32) + timestamp = torch.tensor(timestamp // 10**9, dtype=torch.int64) return tdict, timestamp + return tdict diff --git a/geoarches/evaluation/eval_multistep.py b/geoarches/evaluation/eval_multistep.py index 1b64fe6..06fa7fc 100644 --- a/geoarches/evaluation/eval_multistep.py +++ b/geoarches/evaluation/eval_multistep.py @@ -29,6 +29,7 @@ from tqdm import tqdm from geoarches.dataloaders import era5 +from geoarches.dataloaders.netcdf import default_dimension_indexers from geoarches.metrics.label_wrapper import convert_metric_dict_to_xarray from . import metric_registry @@ -151,11 +152,17 @@ def main(): required=True, help="Directory or file path to read groundtruth.", ) + parser.add_argument( + "--groundtruth_dataset_domain", + type=str, + default="test_z0012", + help="Domain (all, train, val, test) for groundtruth dataset. Should be a key in filename_filters. Determines filename_filter used.", + ) parser.add_argument( "--multistep", default=10, type=int, - help="Number of future timesteps model is rolled out for evaluation. In days " + help="Number of future timesteps model is rolled out for evaluation. Set to 1 if just one step." "(This script assumes lead time is 24 hours).", ) parser.add_argument( @@ -179,13 +186,13 @@ def main(): parser.add_argument( "--level_vars", nargs="*", # Accepts 0 or more arguments as a list. - default=era5.level_variables, + default=era5.arches_default_level_variables, help="Level vars to load from preds. Order is respected when read into tensors. Can be empty.", ) parser.add_argument( "--surface_vars", nargs="*", # Accepts 0 or more arguments as a list. - default=era5.surface_variables, + default=era5.arches_default_surface_variables, help="Surface vars to load from preds. Order is respected when read into tensors. Can be empty.", ) parser.add_argument( @@ -198,6 +205,16 @@ def main(): action="store_true", help="Whether to evaluate climatology.", ) + parser.add_argument( + "--verbose", + action="store_true", + help="Whether to print more verbose debug logs.", + ) + parser.add_argument( + "--breakpoint", + action="store_true", + help="Whether to add breakpoint for debug.", + ) args = parser.parse_args() @@ -231,26 +248,30 @@ def main(): surface_variables=args.surface_vars, level_variables=args.level_vars, pressure_levels=[500, 700, 850], - lead_time_hours=24 if args.multistep else None, + lead_time_hours=24 if args.multistep and args.multistep > 1 else None, rollout_iterations=args.multistep, ).to(device) print(f"Computing: {metrics.keys()}") # Groundtruth. + dimension_indexers = default_dimension_indexers.copy() + dimension_indexers["level"] = ("level", [500, 700, 850]) # Use only these pressure levels. ds_test = era5.Era5Forecast( + stats_cfg=None, # No normalization. path=args.groundtruth_path, # filename_filter=lambda x: ("2020" in x) and ("0h" in x or "12h" in x), - domain="test_z0012", + domain=args.groundtruth_dataset_domain, lead_time_hours=24, multistep=args.multistep, load_prev=False, - norm_scheme=None, variables=variables, - dimension_indexers=dict(level=[500, 700, 850]), + dimension_indexers=dimension_indexers, load_clim=True if args.eval_clim else False, # Set if evaluating climatology. ) print(f"Reading {len(ds_test.files)} files from groundtruth path: {args.groundtruth_path}.") + if args.verbose: + print(ds_test.files) # Predictions. def _pred_filename_filter(filename): @@ -258,23 +279,28 @@ def _pred_filename_filter(filename): return False if args.pred_filename_filter is None: return True - for substring in args.pred_filename_filter: - if substring not in filename: - return False - return True + return any([str(y) in filename for y in args.pred_filename_filter]) if not args.eval_clim: + if args.multistep > 1: + dimension_indexers["prediction_timedelta"] = ( + "prediction_timedelta", + [timedelta(days=i) for i in range(1, args.multistep + 1)], + ) + + # Load predictions. ds_pred = era5.Era5Dataset( path=args.pred_path, filename_filter=_pred_filename_filter, # Update filename_filter to filter within pred_path. variables=variables, return_timestamp=True, - dimension_indexers=dict( - prediction_timedelta=[timedelta(days=i) for i in range(1, args.multistep + 1)], - level=[500, 700, 850], - ), + dimension_indexers=dimension_indexers, ) print(f"Reading {len(ds_pred.files)} files from pred_path: {args.pred_path}.") + if args.verbose: + print(ds_pred.files) + print("# prediction examples:", len(ds_pred)) + print("# test examples:", len(ds_test)) if reloaded_timestamp is not None: # Don't include the reloaded timestamp. @@ -302,7 +328,7 @@ def __getitem__(self, idx): dl_test = torch.utils.data.DataLoader( ds_test, batch_size=args.eval_batch_size, - num_workers=args.num_workers, + num_workers=args.num_workers if not args.breakpoint else 0, shuffle=False, collate_fn=_custom_collate_fn, ) @@ -310,13 +336,18 @@ def __getitem__(self, idx): dl_pred = torch.utils.data.DataLoader( ds_pred, batch_size=args.eval_batch_size, - num_workers=args.num_workers, + num_workers=args.num_workers if not args.breakpoint else 0, shuffle=False, collate_fn=_custom_collate_fn, ) + if args.breakpoint: + breakpoint() + # iterable = tqdm(dl_test) if args.eval_clim else tqdm(zip(dl_test, dl_pred)) for next_batch in tqdm(dl_test) if args.eval_clim else tqdm(zip(dl_test, dl_pred)): + if args.verbose: + print(f"{nbatches} batch") nbatches += 1 if args.eval_clim: @@ -333,7 +364,7 @@ def __getitem__(self, idx): pred = pred.apply( lambda tensor: rearrange( tensor, - "batch var mem ... lev lat lon -> batch mem ... var lev lat lon", + "batch var ... lev lat lon -> batch ... var lev lat lon", ) ) timestamps = target["timestamp"] @@ -344,9 +375,14 @@ def __getitem__(self, idx): else: target = target["future_states"] + if args.breakpoint: + breakpoint() + # Update metrics. for metric in metrics.values(): metric.update(target.to(device), pred.to(device)) + if args.breakpoint: + breakpoint() if args.cache_metrics_every_nbatches and nbatches % args.cache_metrics_every_nbatches == 0: print(f"Processed {nbatches} batches.") @@ -361,6 +397,7 @@ def __getitem__(self, idx): print( f"Finished computation. Computed until {np.datetime64(timestamp, 's')} ({timestamps[-1]})" ) + print(f"Total of {nbatches} batches.") for metric_name, metric in metrics.items(): labelled_metric_output = metric.compute() @@ -370,26 +407,35 @@ def __getitem__(self, idx): else: output_filename = f"test-multistep={args.multistep}-{metric_name}" - # Get xr dataset. if isinstance(labelled_metric_output, dict): labelled_dict = { k: (v.cpu() if hasattr(v, "cpu") else v) for k, v in labelled_metric_output.items() } - extra_dimensions = ["prediction_timedelta"] - if "brier" in metric_name: - extra_dimensions = ["quantile", "prediction_timedelta"] - if "rankhist" in metric_name or "rank_hist" in metric_name: - extra_dimensions = ["bins", "prediction_timedelta"] - ds = convert_metric_dict_to_xarray(labelled_dict, extra_dimensions) - # Write labeled dict. labelled_dict["metadata"] = dict( groundtruth_path=args.groundtruth_path, predictions_path=args.pred_path ) torch.save(labelled_dict, Path(output_dir).joinpath(f"{output_filename}.pt")) + + # Convert to xr dataset. + extra_dimensions = [] + if args.multistep > 1: + extra_dimensions = ["prediction_timedelta"] + if "brier" in metric_name: + extra_dimensions.insert(0, "quantile") # ["quantile", "prediction_timedelta"] + if "rankhist" in metric_name or "rank_hist" in metric_name: + extra_dimensions.insert(0, "bins") # ["bins", "prediction_timedelta"] + if "spatial" in metric_name: + # Does not yet handle extra lat/lon dims. + continue + + ds = convert_metric_dict_to_xarray(labelled_dict, extra_dimensions) else: ds = labelled_metric_output # Write xr dataset. + ds.attrs["groundtruth_path"] = args.groundtruth_path + ds.attrs["predictions_path"] = args.pred_path + ds.attrs["groundtruth_dataset_domain"] = args.groundtruth_dataset_domain ds.to_netcdf(Path(output_dir).joinpath(f"{output_filename}.nc")) diff --git a/geoarches/evaluation/metric_registry.py b/geoarches/evaluation/metric_registry.py index 0f923ae..10fb436 100644 --- a/geoarches/evaluation/metric_registry.py +++ b/geoarches/evaluation/metric_registry.py @@ -5,6 +5,7 @@ import torchmetrics from geoarches.metrics.brier_skill_score import Era5BrierSkillScore +from geoarches.metrics.deterministic_metrics import Era5DeterministicMetrics from geoarches.metrics.ensemble_metrics import Era5EnsembleMetrics from geoarches.metrics.rank_histogram import Era5RankHistogram from geoarches.metrics.spherical_power_spectrum import Era5PowerSpectrum @@ -38,6 +39,20 @@ def instantiate_metric(metric_name: str, **extra_kwargs): ####################################################### ###### Registering classes with their arguments. ###### ####################################################### +register_metric( + "era5_deterministic_metrics", + Era5DeterministicMetrics, +) +register_metric( + "era5_deterministic_metrics_with_spatial", Era5DeterministicMetrics, compute_per_gridpoint=True +) +register_metric( + "era5_deterministic_metrics_with_spatial_and_hemisphere", + Era5DeterministicMetrics, + compute_per_gridpoint=True, + compute_per_hemisphere=True, + headline_variables=("Z500", "Z850", "T850", "Q700", "U850", "V850"), +) register_metric( "era5_ensemble_metrics", Era5EnsembleMetrics, diff --git a/geoarches/evaluation/plot_rmse_per_year.py b/geoarches/evaluation/plot_rmse_per_year.py new file mode 100644 index 0000000..8ac6366 --- /dev/null +++ b/geoarches/evaluation/plot_rmse_per_year.py @@ -0,0 +1,185 @@ +import argparse +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import torch + + +def plot_rmse_by_year( + base_dir, + metric_filename, + save_path, + year_range=(1979, 2018), + metric_keys_left=["rmse_Z500", "rmse_Z850"], + metric_keys_right=["rmse_U850", "rmse_V850"], + ylabel_left=r"RMSE $[m^2/s^2]$", + ylabel_right=r"RMSE $[m/s]$", + ref_year=2020, + force=False, +): + """ + Plots the RMSE for a given range of years with a reference year's RMSE as a dashed line. + + Args: + base_dir (str): The base directory containing the yearly data subdirectories. + save_path (str or Path): The path to save the plot. If None, the plot will be shown but not saved. + year_range (tuple): A tuple (start_year, end_year) for the plot's x-axis. + metric_keys_left (list): A list of metric keys (e.g., ['rmse_Z500', 'rmse_Z850']) for the left subplot. + metric_keys_right (list): A list of metric keys (e.g., ['rmse_U850', 'rmse_V850']) for the right subplot. + ref_year (int): The year to use for the dashed reference line. + force (bool): If True, forces saving the plot even if the file already exists. + """ + if Path(save_path).exists() and not force: + print(f"Plot file {save_path} already exists. Use --force to overwrite.") + return + + # Generate the list of years to plot + years = list(range(year_range[0], year_range[1] + 1)) + + # Dictionaries to store the loaded data + data_left = {metric: [] for metric in metric_keys_left} + data_right = {metric: [] for metric in metric_keys_right} + + # Load data for the specified year range + for year in years: + file_path = os.path.join(base_dir, str(year), metric_filename) + if not os.path.exists(file_path): + print(f"Warning: File not found for year {year} at {file_path}. Skipping.") + continue + + year_data = torch.load(file_path, map_location=torch.device("cpu"), weights_only=False) + + # Store the data for the left subplot + for metric in metric_keys_left: + data_left[metric].append(year_data[metric].item()) + + # Store the data for the right subplot + for metric in metric_keys_right: + data_right[metric].append(year_data[metric].item()) + + # Load the reference year data separately + ref_file_path = os.path.join(base_dir, str(ref_year), metric_filename) + if not os.path.exists(ref_file_path): + print(f"Error: Reference year data not found at {ref_file_path}.") + return + + ref_data = torch.load(ref_file_path, map_location=torch.device("cpu"), weights_only=False) + + # Set font family and use LaTeX for consistent plotting style + plt.rc("font", family="serif") + # plt.rc("text", usetex=True) + + # Create the plot + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 2.5)) + + # Define the desired x-axis ticks + desired_ticks = [1980, 1990, 2000, 2010] + ax1.set_xticks(desired_ticks) + ax2.set_xticks(desired_ticks) + + # --- Left Subplot (Geopotential) --- + # ax1.set_title("Geopotential RMSE") + ax1.set_xlabel("Year") + ax1.set_ylabel(ylabel_left) + ax1.grid(True, linestyle="-", color="lightgray") + + # Plot the RMSE lines + colors = plt.cm.tab10.colors + for i, metric in enumerate(metric_keys_left): + if data_left[metric]: + label = metric.split("_")[-1] # e.g., 'rmse_Z500' -> 'Z500' + ax1.plot(years, data_left[metric], label=label, color=colors[i]) + + # Plot the dashed reference line + ref_value = ref_data[metric].item() + ax1.axhline(y=ref_value, color=colors[i], linestyle=":", linewidth=1.5) + + ax1.legend() + + # --- Right Subplot (Wind Speed) --- + # ax2.set_title("Wind Speed RMSE") + ax2.set_xlabel("Year") + ax2.set_ylabel(ylabel_right) + ax2.grid(True, linestyle="-", color="lightgray") + + # Plot the RMSE lines + for i, metric in enumerate(metric_keys_right): + if data_right[metric]: + label = metric.split("_")[-1] # e.g., 'rmse_Z500' -> 'Z500' + ax2.plot(years, data_right[metric], label=label, color=colors[i]) + + # Plot the dashed reference line + ref_value = ref_data[metric].item() + ax2.axhline(y=ref_value, color=colors[i], linestyle=":", linewidth=1.5) + + ax2.legend() + + plt.tight_layout() + + plt.savefig(save_path) + print(f"Plot saved to {save_path}") + + +# Example usage: +# This part assumes a directory structure and some dummy data for demonstration. +# You would replace this with your actual data path. + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Plot RMSE per year.") + parser.add_argument( + "--base_dir", + type=str, + default="/scratch/resingh/weather/evaluation/era5_pred_archesweather-S/", + help="Base directory containing yearly data subdirectories (named `${base_dir}/{year}`)", + ) + parser.add_argument( + "--metric_filename", + type=str, + default="test-multistep=1-era5_deterministic_metrics_with_spatial_and_hemisphere.pt", + help="Filename of the metric data to load.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="plots/", + help="Path to save the plot.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force saving plot even if the plot file already exists.", + ) + args = parser.parse_args() + + # Replace 'dummy_data' with your actual path, e.g., '' + plot_rmse_by_year( + base_dir=args.base_dir, + metric_filename=args.metric_filename, + save_path=Path(args.save_dir) / "rmse_per_year.png", + year_range=(1979, 2018), + metric_keys_left=["rmse_Z500", "rmse_Z850"], + metric_keys_right=["rmse_U850", "rmse_V850"], + ref_year=2020, + force=args.force, + ) + plot_rmse_by_year( + base_dir=args.base_dir, + metric_filename=args.metric_filename, + save_path=Path(args.save_dir) / "north_rmse_per_year.png", + year_range=(1979, 2018), + metric_keys_left=["rmse-north_Z500", "rmse-north_Z850"], + metric_keys_right=["rmse-north_U850", "rmse-north_V850"], + ref_year=2020, + force=args.force, + ) + plot_rmse_by_year( + base_dir=args.base_dir, + metric_filename=args.metric_filename, + save_path=Path(args.save_dir) / "south_rmse_per_year.png", + year_range=(1979, 2018), + metric_keys_left=["rmse-south_Z500", "rmse-south_Z850"], + metric_keys_right=["rmse-south_U850", "rmse-south_V850"], + ref_year=2020, + force=args.force, + ) diff --git a/geoarches/evaluation/plot_spatial_rmse.py b/geoarches/evaluation/plot_spatial_rmse.py new file mode 100644 index 0000000..7b878fc --- /dev/null +++ b/geoarches/evaluation/plot_spatial_rmse.py @@ -0,0 +1,159 @@ +import argparse +from pathlib import Path + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def plot_spatial_rmse( + base_dirs: list[str] | list[Path], + metric_filename: str | Path, + save_path: str | Path, + metric_key: str = "rmse_per_gridpoint_V850", + titles: list[str] | None = None, + force: bool = False, + cbar_label: str | None = None, +): + """ + Plots a 2D spatial RMSE array as a world map grid using Cartopy. + + Args: + spatial_data (torch.Tensor or np.ndarray): A 2D array of spatial RMSE values + with shape (lat, lon). + """ + if Path(save_path).exists() and not force: + print(f"Plot file {save_path} already exists. Use --force to overwrite.") + return + + spatial_datas = [] + for base_dir in base_dirs: + spatial_data = torch.load( + Path(base_dir) / metric_filename, + map_location=torch.device("cpu"), + weights_only=False, + ) + spatial_datas.append(spatial_data[metric_key]) + + # Determine global min and max for a consistent color scale across all plots + all_data = np.concatenate([data.flatten() for data in spatial_datas]) + vmin, vmax = np.min(all_data), np.max(all_data) + + # Set font family and use LaTeX for consistent plotting style + plt.rc("font", family="serif") + + # Set up the plot with a PlateCarree projection, suitable for global data. + fig = plt.figure(figsize=(4, 2)) + num_plots = len(spatial_datas) + fig, axes = plt.subplots( + 1, num_plots, figsize=(6 * num_plots, 5), subplot_kw={"projection": ccrs.PlateCarree()} + ) + # Ensure axes is an array even for a single subplot + if num_plots == 1: + axes = [axes] + + axes[0].set_ylabel("Latitude") + + for i, data in enumerate(spatial_datas): + ax = axes[i] + + # Add map features + ax.add_feature(cfeature.COASTLINE) + ax.add_feature(cfeature.BORDERS, linestyle=":") + ax.add_feature(cfeature.LAND, edgecolor="black") + ax.add_feature(cfeature.OCEAN) + ax.set_global() + + # Convert to numpy if needed + if isinstance(data, torch.Tensor): + data = data.numpy() + + # Convert to numpy if needed + if isinstance(data, torch.Tensor): + data = data.numpy() + + # num_lat, num_lon = data.shape + # lons = np.arange(0, 360, 360 / num_lon) + # lats = np.linspace(90, -90, num_lat) + + # Use imshow to plot the data on top of the map with a normalized color scale + im = ax.imshow( + data, + cmap="plasma", + origin="upper", + extent=[-180, 180, -90, 90], + transform=ccrs.PlateCarree(), + vmin=vmin, + vmax=vmax, + ) + + # Set plot title and labels + if titles: + ax.set_title(titles[i]) + + ax.set_xlabel("Longitude") + + # Set ticks and gridlines for clarity + ax.set_xticks(np.arange(-180, 181, 60), crs=ccrs.PlateCarree()) + ax.set_yticks(np.arange(-90, 91, 30), crs=ccrs.PlateCarree()) + ax.grid(True, linestyle="-", color="gray") + + # Add a single color bar for the entire figure + # Adjust subplots and color bar for better proportions + plt.tight_layout(rect=[0, 0, 0.9, 1]) + # fig.subplots_adjust(right=0.85) + # cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7]) + + cbar_ax = fig.add_axes([0.91, axes[0].get_position().y0, 0.02, axes[0].get_position().height]) + cbar = fig.colorbar(im, cax=cbar_ax, orientation="vertical") + if cbar_label: + cbar.set_label(cbar_label) + + # Add an overall title to the figure + fig.suptitle(metric_key.split("_")[-1].upper(), y=0.9, fontsize=16) + + plt.savefig(save_path) + print(f"Plot saved to {save_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Plot RMSE per year.") + parser.add_argument( + "--base_dir", + type=str, + default="/scratch/resingh/weather/evaluation/era5_pred_archesweather-S/", + help="Base directory containing yearly data subdirectories.", + ) + parser.add_argument( + "--metric_filename", + type=str, + default="test-multistep=1-era5_deterministic_metrics_with_spatial.pt", + help="Filename of the metric data to load.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="plots", + help="Path to save the plot. If None, the plot will be shown but not saved.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force saving plot even if the plot file already exists.", + ) + args = parser.parse_args() + + for var, cbar_label in zip( + ["V850", "U850", "Z500"], ["RMSE $[m/s]$", "RMSE $[m/s]$", "RMSE $[m^2/s^2]$"] + ): + plot_spatial_rmse( + base_dirs=[Path(args.base_dir) / "1979_1999", Path(args.base_dir) / "2000_2018"], + metric_filename=args.metric_filename, + metric_key=f"rmse_per_gridpoint_{var}", + save_path=Path(args.save_dir) / f"spatial_rmse_{var}.png", + force=args.force, + titles=["1979-1999", "2000-2018"], + cbar_label=cbar_label, + ) diff --git a/geoarches/lightning_modules/base_module.py b/geoarches/lightning_modules/base_module.py index 923da7e..8701b0b 100644 --- a/geoarches/lightning_modules/base_module.py +++ b/geoarches/lightning_modules/base_module.py @@ -31,7 +31,7 @@ def load_module( path = Path(path) cfg = OmegaConf.load(path / "config.yaml") cfg.merge_with_dotlist(dotlist) - module = instantiate(cfg.module.module, cfg.module, **kwargs) + module = instantiate(cfg.module.module, cfg.module, cfg.stats, **kwargs) module.init_from_ckpt(path, ckpt_fname=ckpt_fname, missing_warning=False) if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/geoarches/lightning_modules/forecast.py b/geoarches/lightning_modules/forecast.py index c1b98fe..4914d9c 100644 --- a/geoarches/lightning_modules/forecast.py +++ b/geoarches/lightning_modules/forecast.py @@ -8,10 +8,13 @@ import torch.nn as nn import torch.utils.checkpoint as gradient_checkpoint from hydra.utils import instantiate -from tensordict.tensordict import TensorDict -from geoarches.dataloaders import era5, zarr -from geoarches.metrics.metric_base import compute_lat_weights, compute_lat_weights_weatherbench +from geoarches.dataloaders import zarr +from geoarches.utils.tensordict_utils import ( + apply_mask_from_gt_nans, + check_pred_has_no_nans, + tensordict_apply, +) from .. import stats as geoarches_stats from .base_module import BaseLightningModule @@ -22,7 +25,8 @@ class ForecastModule(BaseLightningModule): def __init__( self, - cfg, # instead of backbone + cfg, # module config, instead of backbone + stats_cfg, name="forecast", dataset=None, pow=2, # 2 is standard mse @@ -41,6 +45,7 @@ def __init__( lead_time_hours=24, rollout_iterations=1, test_filename_suffix="", + replace_nan_by=torch.nan, **kwargs, ): """should create self.encoder and self.decoder in subclasses""" @@ -50,53 +55,14 @@ def __init__( self.backbone = instantiate(cfg.backbone) # necessary to put it on device self.embedder = instantiate(cfg.embedder) - # define coeffs for loss + # Instantiate stats module for loss coeffs + stats = instantiate(stats_cfg.module) + self.variables = stats.variables + self.levels = stats.levels - compute_weights_fn = ( - compute_lat_weights_weatherbench - if use_weatherbench_lat_coeffs - else compute_lat_weights - ) - area_weights = compute_weights_fn(121) - - pressure_levels = torch.tensor(era5.pressure_levels).float() - vertical_coeffs = (pressure_levels / pressure_levels.mean()).reshape(-1, 1, 1) - - # define relative surface and level weights - total_coeff = 6 + 1.3 - surface_coeffs = 4 * torch.tensor([0.1, 0.1, 1.0, 0.1]).reshape( - -1, 1, 1, 1 - ) # graphcast, mul 4 because we do a mean - level_coeffs = 6 * torch.tensor(1).reshape(-1, 1, 1, 1) - - self.loss_coeffs = TensorDict( - surface=area_weights * surface_coeffs / total_coeff, - level=area_weights * level_coeffs * vertical_coeffs / total_coeff, - ) - - if loss_delta_normalization: - # assumes include vertical wind component - - pangu_stats = torch.load( - geoarches_stats_path / "pangu_norm_stats2_with_w.pt", weights_only=True - ) - - # mul by first to remove norm, div by second to apply fake delta normalization - self.loss_delta_scaler = TensorDict( - level=pangu_stats["level_std"] - / torch.tensor( - [5.9786e02, 7.4878e00, 8.9492e00, 2.7132e00, 9.5222e-04, 0.3] - ).reshape(-1, 1, 1, 1), - surface=pangu_stats["surface_std"] - / torch.tensor([3.8920, 4.5422, 2.0727, 584.0980]).reshape(-1, 1, 1, 1), - ) - self.loss_coeffs = self.loss_coeffs * self.loss_delta_scaler.pow(self.pow) + self.loss_coeffs = stats.compute_loss_coeffs(**stats_cfg.compute_loss_coeffs_args) - compute_lat_weights_fn = ( - compute_lat_weights_weatherbench - if use_weatherbench_lat_coeffs - else compute_lat_weights - ) + # Instantiate metric modules self.train_metrics = nn.ModuleList( [instantiate(metric, **cfg.train.metrics_kwargs) for metric in cfg.train.metrics] ) @@ -116,6 +82,8 @@ def forward(self, batch, *args, **kwargs): x = self.backbone(x, *args, **kwargs) out = self.embedder.decode(x) # we get tdict + _ = tensordict_apply(check_pred_has_no_nans, pred=out, target=batch["next_state"]) + if self.add_input_state: out += batch["state"] @@ -143,12 +111,15 @@ def forward_multistep(self, batch, iters=None, return_format="tensordict", use_a loop_batch = dict( prev_state=loop_batch["state"], state=pred, + # Used only to obtain NaN mask (not true next state) + next_state=loop_batch["next_state"], timestamp=loop_batch["timestamp"] + batch["lead_time_hours"] * 3600, ) if return_format == "list": return preds_future preds_future = torch.stack(preds_future, dim=1) + return preds_future def loss(self, pred, gt, multistep=False, **kwargs): @@ -165,9 +136,12 @@ def loss(self, pred, gt, multistep=False, **kwargs): loss_coeffs.apply(lambda x: x * future_coeffs) + # Set pred to value where gt is NaN and set gt to value where itself is NaN + pred, gt = apply_mask_from_gt_nans(pred, gt, value=self.replace_nan_by) + weighted_error = (pred - gt).abs().pow(self.pow).mul(loss_coeffs) - loss = sum(weighted_error.mean().values()) + loss = sum(weighted_error.nanmean().values()) return loss diff --git a/geoarches/metrics/brier_skill_score.py b/geoarches/metrics/brier_skill_score.py index d2218bb..1f7c0a8 100644 --- a/geoarches/metrics/brier_skill_score.py +++ b/geoarches/metrics/brier_skill_score.py @@ -175,12 +175,12 @@ class Era5BrierSkillScore(TensorDictMetricBase): def __init__( self, + surface_variables=era5.arches_default_surface_variables, + level_variables=era5.arches_default_level_variables, + pressure_levels=era5.arches_default_pressure_levels, quantiles_filepath="era5-quantiles-2016_2022.nc", high_quantile_levels=[0.99, 0.999, 0.9999], low_quantiles_levels=[0.01, 0.001, 0.0001], - surface_variables=era5.surface_variables, - level_variables=era5.level_variables, - pressure_levels=era5.pressure_levels, lead_time_hours: None | int = None, rollout_iterations: None | int = None, save_memory: bool = False, @@ -206,13 +206,13 @@ def __init__( with resources.as_file(resources.files(geoarches_stats).joinpath(quantiles_filepath)) as f: q = xr.open_dataset(f).transpose(..., "latitude", "longitude") self.surface_high_quantiles = torch.from_numpy( - q[era5.surface_variables] + q[surface_variables] .sel({"quantile": high_quantile_levels}, method="nearest") .to_array() .to_numpy() ).unsqueeze(-3) # Add level dimension. self.surface_low_quantiles = torch.from_numpy( - q[era5.surface_variables] + q[surface_variables] .sel({"quantile": low_quantiles_levels}, method="nearest") .to_array() .to_numpy() diff --git a/geoarches/metrics/deterministic_metrics.py b/geoarches/metrics/deterministic_metrics.py index 9e653cc..fbdb9b4 100644 --- a/geoarches/metrics/deterministic_metrics.py +++ b/geoarches/metrics/deterministic_metrics.py @@ -73,6 +73,8 @@ def __init__( self, data_shape: tuple, compute_lat_weights_fn: Callable[[int], torch.tensor] = compute_lat_weights_weatherbench, + compute_per_gridpoint: bool = False, + compute_per_hemisphere: bool = False, ): """ Args: @@ -84,6 +86,10 @@ def __init__( Used for error and variance calculations. Expected shape of weights: [..., lat, 1]. See function example in metric_base.MetricBase. Default function assumes latitudes are ordered -90 to 90. + compute_per_gridpoint: Whether to also compute mse and rmse per gridpoint (along with aggregated globally mse and rmse). + Default: only compute globally aggregated mse and rmse. + compute_per_hemisphere: Whether to also compute mse and rmse per north and sount hemisphere + (along with aggregated globally mse and rmse). """ Metric.__init__(self) MetricBase.__init__( @@ -93,15 +99,30 @@ def __init__( # lead_time_hours=lead_time_hours, # rollout_iterations=rollout_iterations, ) + self.compute_per_gridpoint = compute_per_gridpoint + self.compute_per_hemisphere = compute_per_hemisphere # Call `self.add_state`for every internal state that is needed for the metrics computations. # `dist_reduce_fx` indicates the function that should be used to reduce. self.add_state("nsamples", default=torch.tensor(0), dist_reduce_fx="sum") + + # Aggregated over gridpoints (lat weighted). self.add_state("mse", default=torch.zeros(data_shape), dist_reduce_fx="sum") self.add_state( "rmse_before_time_avg", default=torch.zeros(data_shape), dist_reduce_fx="sum" ) + if self.compute_per_hemisphere: + # Aggregated over north and south hemispheres. + self.add_state("mse_north", default=torch.zeros(data_shape), dist_reduce_fx="sum") + self.add_state("mse_south", default=torch.zeros(data_shape), dist_reduce_fx="sum") + + if self.compute_per_gridpoint: + # Per gridpoint. + self.add_state( + "mse_per_gridpt", default=torch.zeros((*data_shape, 1, 1)), dist_reduce_fx="sum" + ) + def update(self, targets: torch.Tensor, preds: torch.Tensor) -> None: """Update internal state with a batch of targets and predictions. @@ -122,6 +143,19 @@ def update(self, targets: torch.Tensor, preds: torch.Tensor) -> None: targets, preds ).sqrt().sum(0) + if self.compute_per_hemisphere: + num_lats = targets.shape[-2] + equator_index = num_lats // 2 + self.mse_north = self.mse_north + self.wmse( + targets, preds, lat_range=(0, equator_index) + ).sum(0) + self.mse_south = self.mse_south + self.wmse( + targets, preds, lat_range=(equator_index, num_lats) + ).sum(0) + + if self.compute_per_gridpoint: + self.mse_per_gridpt = self.mse_per_gridpt + self.spatial_mse(targets, preds).sum(0) + def compute(self) -> Dict[str, torch.Tensor]: """Compute final metrics utilizing internal states. Returns: @@ -134,6 +168,24 @@ def compute(self) -> Dict[str, torch.Tensor]: rmse=(self.mse / self.nsamples).sqrt(), ) + if self.compute_per_hemisphere: + all_metrics.update( + dict( + mse_north=self.mse_north / self.nsamples, + rmse_north=(self.mse_north / self.nsamples).sqrt(), + mse_south=self.mse_south / self.nsamples, + rmse_south=(self.mse_south / self.nsamples).sqrt(), + ) + ) + + if self.compute_per_gridpoint: + all_metrics.update( + dict( + mse_per_gridpoint=self.mse_per_gridpt / self.nsamples, + rmse_per_gridpoint=(self.mse_per_gridpt / self.nsamples).sqrt(), + ) + ) + # out = dict() # for var, index in self.indices.items(): # for metric_name, metric in all_metrics.items(): @@ -143,55 +195,76 @@ def compute(self) -> Dict[str, torch.Tensor]: class Era5DeterministicMetrics(TensorDictMetricBase): - """Wrapper class around EnsembleMetrics for computing over surface and level variables. + """Wrapper class around DeterministicRMSE for computing over surface and level variables. Handles batches coming from Era5 Dataloader. Accepted tensor shapes: targets: (batch, ..., timedelta, var, level, lat, lon) - preds: (batch, nmembers, ..., timedelta, var, level, lat, lon) + preds: (batch, ..., timedelta, var, level, lat, lon) Return dictionary of metrics reduced over batch, lat, lon. """ def __init__( self, + surface_variables=era5.arches_default_surface_variables, + level_variables=era5.arches_default_level_variables, + pressure_levels=era5.arches_default_pressure_levels, + headline_variables=("Z500", "T850", "Q700", "U850", "V850"), compute_lat_weights_fn: Callable[[int], torch.tensor] = compute_lat_weights_weatherbench, - pressure_levels=era5.pressure_levels, - num_level_variables=len(era5.level_variables), + compute_per_gridpoint: bool = False, + compute_per_hemisphere: bool = False, lead_time_hours: int = 24, rollout_iterations: int = 1, ): """ Args: + surface_variables: Names of surface variables (to select quantiles). + level_variables: Names of level variables (used to get `variable_indices`). pressure_levels: pressure levels in data (used to get `variable_indices`). + headline_variables: Short names of level variables to output (used to get 'variable_indices'). level_data_shape: (var, lev) shape for level variables. num_level_variables: Number of level variables (used to compute data_shape). - rollout_iterations: Number of rollout iterations in multistep predictions. - this option labels each timestep separately in output metric dict. - Assumes that data shape of predictions/targets are [batch, ..., multistep, var, lev, lat, lon] + compute_per_gridpoint: Whether to also compute mse and rmse per gridpoint (along with aggregated globally mse and rmse). + Default: only compute globally aggregated mse and rmse. + lead_time_hours: Timedelta between timestamps in multistep rollout. + Set to explicitly handle predictions from multistep rollout. + This option labels each timestep separately in output metric dict. + Assumes that data shape of predictions/targets are [batch, ..., multistep, var, lev, lat, lon]. + FYI when set to None, Era5EnsembleMetrics still handles natively any extra dimensions in targets/preds. + Set to None if no multistep dimension. + rollout_iterations: Size of timedelta dimension (number of rollout iterations in multistep predictions). + Set to explicitly handle metrics computed on predictions from multistep rollout. + See param `lead_time_hours`. """ super().__init__( surface=LabelDictWrapper( DeterministicRMSE( - data_shape=(len(era5.surface_variables), 1), + data_shape=(len(surface_variables), 1), compute_lat_weights_fn=compute_lat_weights_fn, + compute_per_gridpoint=compute_per_gridpoint, + compute_per_hemisphere=compute_per_hemisphere, ), variable_indices=add_timedelta_index( - era5.get_surface_variable_indices(), + era5.get_surface_variable_indices(surface_variables), lead_time_hours=lead_time_hours, rollout_iterations=rollout_iterations, ), ), level=LabelDictWrapper( DeterministicRMSE( - data_shape=(num_level_variables, len(pressure_levels)), + data_shape=(len(level_variables), len(pressure_levels)), compute_lat_weights_fn=compute_lat_weights_fn, + compute_per_gridpoint=compute_per_gridpoint, + compute_per_hemisphere=compute_per_hemisphere, ), variable_indices=add_timedelta_index( - era5.get_headline_level_variable_indices(pressure_levels), + era5.get_headline_level_variable_indices( + pressure_levels, level_variables, headline_variables + ), lead_time_hours=lead_time_hours, rollout_iterations=rollout_iterations, ), diff --git a/geoarches/metrics/ensemble_metrics.py b/geoarches/metrics/ensemble_metrics.py index 9782264..2e022f9 100644 --- a/geoarches/metrics/ensemble_metrics.py +++ b/geoarches/metrics/ensemble_metrics.py @@ -172,9 +172,9 @@ class Era5EnsembleMetrics(TensorDictMetricBase): def __init__( self, - surface_variables=era5.surface_variables, - level_variables=era5.level_variables, - pressure_levels=era5.pressure_levels, + surface_variables=era5.arches_default_surface_variables, + level_variables=era5.arches_default_level_variables, + pressure_levels=era5.arches_default_pressure_levels, save_memory: bool = False, lead_time_hours: None | int = None, rollout_iterations: None | int = None, diff --git a/geoarches/metrics/label_wrapper.py b/geoarches/metrics/label_wrapper.py index 1d4c31e..0b7c464 100644 --- a/geoarches/metrics/label_wrapper.py +++ b/geoarches/metrics/label_wrapper.py @@ -1,4 +1,5 @@ import itertools +import warnings from collections import defaultdict from typing import Any, Dict, List, Sequence @@ -59,7 +60,16 @@ def _convert(self, raw_metric_dict: Dict[str, Tensor]): labeled_dict = dict() for var, index in self.variable_indices.items(): for metric_name, metric in raw_metric_dict.items(): - labeled_dict[f"{metric_name}_{var}"] = metric.__getitem__((..., *index)) + # Remove dashes for compatibility with convert_metric_dict_to_xarray(). + metric_name = metric_name.replace("_", "-") + if any(s in metric_name for s in ["spatial", "per-gridpt", "per-gridpoint"]): + # Account for lat, lon dims + warnings.warn( + "Assuming that metric {metric_name} has lat, lon dimensions. Not supported for WandDB." + ) + labeled_dict[f"{metric_name}_{var}"] = metric[..., *index, :, :] + else: + labeled_dict[f"{metric_name}_{var}"] = metric[..., *index] return labeled_dict def update(self, *args: Any, **kwargs: Any) -> None: @@ -134,6 +144,7 @@ def _convert_coord(name, value): labels = label.split("_") if len(labels) - 2 != len(extra_dimensions): raise ValueError( + f"Assumes metric name {label} is in format __...." f"Expected length of extra_dimensions for key {label} to be: {len(labels) - 2}. Got extra_dimensions={extra_dimensions}." ) metrics.add(labels[0]) diff --git a/geoarches/metrics/metric_base.py b/geoarches/metrics/metric_base.py index cbcf591..761fa3f 100644 --- a/geoarches/metrics/metric_base.py +++ b/geoarches/metrics/metric_base.py @@ -62,15 +62,39 @@ def __init__( super().__init__() self.compute_lat_weights_fn = compute_lat_weights_fn - def wmse(self, x: torch.Tensor, y: torch.Tensor | int = 0): + def wmse( + self, x: torch.Tensor, y: torch.Tensor | int = 0, lat_range: tuple[int, int] | None = None + ): """Latitude weighted mse error. Args: x: preds with shape (..., lat, lon) y: targets with shape (..., lat, lon) + lat_range: Optional tuple of (min_lat, max_lat) to restrict the latitude range for the computation. + If None, uses the full latitude range. """ lat_coeffs = self.compute_lat_weights_fn(latitude_resolution=x.shape[-2]).to(x.device) - return (x - y).pow(2).mul(lat_coeffs).mean((-2, -1)) + + if lat_range is not None: + start_lat, end_lat = lat_range + x = x[..., start_lat:end_lat, :] + lat_coeffs = lat_coeffs[start_lat:end_lat, :] + # Renormalize the weights for the specified latitude range + lat_coeffs = lat_coeffs / lat_coeffs.mean() + + if not isinstance(y, int): + y = y[..., start_lat:end_lat, :] + + return (x - y).pow(2).mul(lat_coeffs).nanmean((-2, -1)) + + def spatial_mse(self, x: torch.Tensor, y: torch.Tensor | int = 0): + """Per gridpoint mse error. + + Args: + x: preds with shape (..., lat, lon) + y: targets with shape (..., lat, lon) + """ + return (x - y).pow(2) def wmae(self, x: torch.Tensor, y: torch.Tensor | int = 0): """Latitude weighted mae error. @@ -80,7 +104,7 @@ def wmae(self, x: torch.Tensor, y: torch.Tensor | int = 0): y: targets with shape (..., lat, lon) """ lat_coeffs = self.compute_lat_weights_fn(latitude_resolution=x.shape[-2]).to(x.device) - return (x - y).abs().mul(lat_coeffs).mean((-2, -1)) + return (x - y).abs().mul(lat_coeffs).nanmean((-2, -1)) def wvar(self, x: torch.Tensor, dim: int = 1): """Latitude weighted variance along axis. @@ -90,7 +114,7 @@ def wvar(self, x: torch.Tensor, dim: int = 1): dim: over which dimension to compute variance. """ lat_coeffs = self.compute_lat_weights_fn(latitude_resolution=x.shape[-2]).to(x.device) - return x.var(dim).mul(lat_coeffs).mean((-2, -1)) + return x.var(dim).mul(lat_coeffs).nanmean((-2, -1)) def weighted_mean(self, x: torch.Tensor): """Latitude weighted mean over grid. @@ -99,7 +123,7 @@ def weighted_mean(self, x: torch.Tensor): x: preds with shape (..., lat, lon) """ lat_coeffs = self.compute_lat_weights_fn(latitude_resolution=x.shape[-2]).to(x.device) - return x.mul(lat_coeffs).mean((-2, -1)) + return x.mul(lat_coeffs).nanmean((-2, -1)) class TensorDictMetricBase(Metric): diff --git a/geoarches/metrics/rank_histogram.py b/geoarches/metrics/rank_histogram.py index fb25d76..5697b89 100644 --- a/geoarches/metrics/rank_histogram.py +++ b/geoarches/metrics/rank_histogram.py @@ -142,9 +142,9 @@ class Era5RankHistogram(TensorDictMetricBase): def __init__( self, n_members, - surface_variables=era5.surface_variables, - level_variables=era5.level_variables, - pressure_levels=era5.pressure_levels, + surface_variables=era5.arches_default_surface_variables, + level_variables=era5.arches_default_level_variables, + pressure_levels=era5.arches_default_pressure_levels, lead_time_hours: None | int = None, rollout_iterations: None | int = None, ): diff --git a/geoarches/metrics/spherical_power_spectrum.py b/geoarches/metrics/spherical_power_spectrum.py index c2ef102..2edf19a 100644 --- a/geoarches/metrics/spherical_power_spectrum.py +++ b/geoarches/metrics/spherical_power_spectrum.py @@ -112,10 +112,10 @@ class Era5PowerSpectrum(TensorDictMetricBase): def __init__( self, + surface_variables=era5.arches_default_surface_variables, + level_variables=era5.arches_default_level_variables, + pressure_levels=era5.arches_default_pressure_levels, compute_target_spectrum: bool = False, - surface_variables: str = era5.surface_variables, - level_variables: str = era5.level_variables, - pressure_levels: str = era5.pressure_levels, lead_time_hours: None | int = None, rollout_iterations: None | int = None, ): diff --git a/geoarches/stats/gc_stats_diffs_stddev_by_level.json b/geoarches/stats/gc_stats_diffs_stddev_by_level.json new file mode 100644 index 0000000..4b5f578 --- /dev/null +++ b/geoarches/stats/gc_stats_diffs_stddev_by_level.json @@ -0,0 +1,425 @@ +{ + "coords": { + "level": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 1, + 2, + 3, + 5, + 7, + 10, + 20, + 30, + 50, + 70, + 100, + 125, + 150, + 175, + 200, + 225, + 250, + 300, + 350, + 400, + 450, + 500, + 550, + 600, + 650, + 700, + 750, + 775, + 800, + 825, + 850, + 875, + 900, + 925, + 950, + 975, + 1000 + ] + } + }, + "attrs": { + "date_start": "1979-01-02", + "date_end": "2015", + "timestep_hours": 6, + "window_stride": 1 + }, + "dims": { + "level": 37 + }, + "data_vars": { + "geopotential": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 1153.3236615667074, + 795.034664163933, + 622.2635763211008, + 478.373765021565, + 412.1088415325711, + 354.5595833736702, + 275.78971681197083, + 242.008008778255, + 208.6322187978729, + 193.91983480697874, + 191.01766200299144, + 199.44811918240407, + 213.40448931176704, + 233.16085997614195, + 258.7060994441645, + 285.0746020385599, + 306.04455387420177, + 322.2987078377359, + 311.2782390217717, + 288.08836922513115, + 262.7422047313381, + 239.71967148713998, + 220.75604230741158, + 206.1707200120217, + 195.78388562164778, + 189.29747149956302, + 186.2235844639424, + 185.83245720044118, + 186.1962264276176, + 187.278693666876, + 189.10903841711763, + 191.6343772396133, + 194.72962905103913, + 198.26483322933004, + 202.14035543665807, + 206.24134707573242, + 210.3374969299931 + ] + }, + "specific_humidity": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 6.640175966281834e-08, + 3.909006251718587e-08, + 3.746183855796142e-08, + 3.919451822231693e-08, + 4.2966211315339e-08, + 4.4283665996826096e-08, + 4.657199605110932e-08, + 4.551214689788184e-08, + 4.658684637305686e-08, + 6.327991852757859e-08, + 1.6348547152715437e-07, + 4.811058349756144e-07, + 1.468294037839677e-06, + 4.086185389502715e-06, + 9.728637696402969e-06, + 1.9833229365576873e-05, + 3.401754512971197e-05, + 7.333047947780854e-05, + 0.0001332051964476865, + 0.00021465164726080453, + 0.0003153737888009881, + 0.0004252841762649919, + 0.0005466285846394803, + 0.0006451832142818177, + 0.0007632778913723108, + 0.00088359784474792, + 0.001014776541884741, + 0.0010677954238679668, + 0.001111234976759102, + 0.0011407654863650864, + 0.0011410079751381372, + 0.0010991055821180482, + 0.0010283050589989723, + 0.0009334517609154307, + 0.0008229854515383136, + 0.0007657228740416388, + 0.0007414803632021611 + ] + }, + "temperature": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 3.680318851123672, + 3.8318878636112537, + 3.439696466621194, + 2.6775682108289356, + 2.3325500928338134, + 1.9628536615340624, + 1.544838747767884, + 1.3789373106295892, + 1.2641932026228335, + 1.2511386380433243, + 1.0794865862462029, + 1.026442103783078, + 1.1286996633737063, + 1.3486876467289826, + 1.5433826760468057, + 1.600130619234847, + 1.5106945778627725, + 1.239533708236731, + 1.2344538947320591, + 1.3358459591532381, + 1.3962095373924943, + 1.4025345615324933, + 1.3897261313626144, + 1.3727767252702563, + 1.3699250308872553, + 1.3738020283273578, + 1.4090072632670816, + 1.4383353706567221, + 1.4811674144272151, + 1.5473749898531692, + 1.6135235252296107, + 1.6688769608436818, + 1.7103621433004004, + 1.7519710596016853, + 1.8131106133282429, + 1.864745761493253, + 1.8598407151530196 + ] + }, + "u_component_of_wind": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 8.712805693229797, + 7.032393744920325, + 6.002627010378335, + 4.915242124904145, + 4.479708439819066, + 3.8863529878507115, + 3.1666837711565052, + 2.8976018516965945, + 2.689749850403594, + 2.6225576045109404, + 2.729538119017451, + 2.9333758472760474, + 3.203373046243629, + 3.6125061968160552, + 4.187357872051589, + 4.818180856614245, + 5.31504874860131, + 5.742670144020393, + 5.539882908851254, + 5.037096291508256, + 4.494289689428559, + 4.01707120000842, + 3.6504411103869145, + 3.3727291331654956, + 3.1797119020393625, + 3.051410220642815, + 2.99553106072978, + 2.9906446228081496, + 3.0025600666665637, + 3.030525096377766, + 3.067396956012487, + 3.112904447446578, + 3.1645209534084033, + 3.188282362973559, + 3.103898367018071, + 2.8306912036515204, + 2.4377055965863574 + ] + }, + "v_component_of_wind": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 9.787924018305782, + 7.719212981558986, + 6.545831988501556, + 5.4273948271568715, + 4.975125902008625, + 4.311111607324932, + 3.5443400048373426, + 3.2352403837083483, + 2.9714437622009395, + 2.8965168950589684, + 3.0147651212288107, + 3.2682052334583602, + 3.61406643264312, + 4.153454626072052, + 4.913166146256208, + 5.731074527029968, + 6.369596367263378, + 6.898778176680501, + 6.647530062691671, + 6.039969536482651, + 5.378332599125726, + 4.7827964396242, + 4.311030174373271, + 3.9427348452274877, + 3.673836331864481, + 3.4836884930764236, + 3.3801714526560542, + 3.3578618131375833, + 3.3587029392109855, + 3.384515382655561, + 3.431874882963668, + 3.4993830543159126, + 3.5729775883550365, + 3.605402560936616, + 3.511518297213606, + 3.204441290129302, + 2.7203565981704765 + ] + }, + "vertical_velocity": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 0.0013592488325373294, + 0.002457127538198585, + 0.0031517314809615784, + 0.004090714767134668, + 0.005163979691796069, + 0.006525739370190636, + 0.009776381993628246, + 0.012260173301951522, + 0.01659557615407214, + 0.02112807324669416, + 0.03277275065027288, + 0.047538646440606076, + 0.06487119484161766, + 0.08260116832833243, + 0.09959469217092194, + 0.11618038836622262, + 0.13317716383919723, + 0.16992147669193855, + 0.20396427984330565, + 0.22966619134241795, + 0.24645575938385053, + 0.2555158673980838, + 0.2614232242099868, + 0.2665193235464868, + 0.27110805357040507, + 0.2730423896580946, + 0.27228178362329436, + 0.2697365154717864, + 0.2657063763299476, + 0.25948414386809376, + 0.2504589940087339, + 0.23740944015127696, + 0.21798347077148827, + 0.1897924335401282, + 0.15288209647098705, + 0.11325332308569894, + 0.08681145052560271 + ] + }, + "seconds_since_epoch": { + "dims": [], + "attrs": {}, + "data": 0.0 + }, + "10m_u_component_of_wind": { + "dims": [], + "attrs": {}, + "data": 2.2480854089011153 + }, + "10m_v_component_of_wind": { + "dims": [], + "attrs": {}, + "data": 2.4782441572496183 + }, + "2m_temperature": { + "dims": [], + "attrs": {}, + "data": 2.609049099295341 + }, + "mean_sea_level_pressure": { + "dims": [], + "attrs": {}, + "data": 266.85836144413696 + }, + "sea_ice_cover": { + "dims": [], + "attrs": {}, + "data": 0.010026612577502175 + }, + "sea_surface_temperature": { + "dims": [], + "attrs": {}, + "data": 0.08066415910025554 + }, + "surface_pressure": { + "dims": [], + "attrs": {}, + "data": 249.3593433435223 + }, + "toa_incident_solar_radiation": { + "dims": [], + "attrs": {}, + "data": 1958440.1202389547 + }, + "toa_incident_solar_radiation_6hr": { + "dims": [], + "attrs": {}, + "data": 9991075.912200263 + }, + "total_cloud_cover": { + "dims": [], + "attrs": {}, + "data": 0.2845560385935181 + }, + "total_column_water_vapour": { + "dims": [], + "attrs": {}, + "data": 2.7629454848907633 + }, + "total_precipitation_6hr": { + "dims": [], + "attrs": {}, + "data": 0.0019004865687748437 + }, + "year_progress": { + "dims": [], + "attrs": {}, + "data": 0.024697753562180874 + }, + "year_progress_sin": { + "dims": [], + "attrs": {}, + "data": 0.0030342521761048467 + }, + "year_progress_cos": { + "dims": [], + "attrs": {}, + "data": 0.0030474038590028816 + }, + "day_progress": { + "dims": [], + "attrs": {}, + "data": 0.4330127018922193 + }, + "day_progress_sin": { + "dims": [], + "attrs": {}, + "data": 0.9999999974440369 + }, + "day_progress_cos": { + "dims": [], + "attrs": {}, + "data": 1.0 + } + } +} \ No newline at end of file diff --git a/geoarches/stats/gc_stats_mean_by_level.json b/geoarches/stats/gc_stats_mean_by_level.json new file mode 100644 index 0000000..80d7a33 --- /dev/null +++ b/geoarches/stats/gc_stats_mean_by_level.json @@ -0,0 +1,495 @@ +{ + "coords": { + "level": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 1, + 2, + 3, + 5, + 7, + 10, + 20, + 30, + 50, + 70, + 100, + 125, + 150, + 175, + 200, + 225, + 250, + 300, + 350, + 400, + 450, + 500, + 550, + 600, + 650, + 700, + 750, + 775, + 800, + 825, + 850, + 875, + 900, + 925, + 950, + 975, + 1000 + ] + } + }, + "attrs": { + "date_start": "1979-01-02", + "date_end": "2015", + "timestep_hours": 6, + "window_stride": 1 + }, + "dims": { + "level": 37 + }, + "data_vars": { + "geopotential": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 464158.64456721913, + 412344.9988950276, + 382754.07047268265, + 346888.1404542664, + 324173.7900552486, + 300691.9739717618, + 256205.5386126458, + 230782.08275015347, + 199314.46531614487, + 178930.67354205032, + 157578.78060159608, + 144167.7750767342, + 133076.5632903622, + 123581.44628606507, + 115266.74229588704, + 107859.16464088397, + 101163.90030693678, + 89360.0687538367, + 79090.53996316758, + 69937.41632903622, + 61655.36451810927, + 54081.917249846534, + 47100.39729895641, + 40622.95690607735, + 34579.49465930019, + 28911.316390423573, + 23570.866052793124, + 21012.628115408224, + 18523.833517495397, + 16101.10349907919, + 13741.475751995089, + 11441.921178637202, + 9199.416574585635, + 7010.71430325353, + 4872.758256599141, + 2782.7472068753837, + 737.6906077348066 + ] + }, + "specific_humidity": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 3.877200730524596e-06, + 3.711241448928126e-06, + 3.5958521255618446e-06, + 3.4336574682068282e-06, + 3.323139379257216e-06, + 3.2100865378417874e-06, + 2.99326137869983e-06, + 2.8548591831405154e-06, + 2.6766878556736853e-06, + 2.587460406796928e-06, + 2.6269185261319502e-06, + 3.2536406470193973e-06, + 5.218535753171674e-06, + 9.94390984705579e-06, + 1.922827728545468e-05, + 3.4665040548335124e-05, + 5.7214977845156245e-05, + 0.0001262438524127958, + 0.00023270570990349926, + 0.00038189147424668606, + 0.000581101042277393, + 0.0008452774338227536, + 0.001175620498797731, + 0.0015292930485940259, + 0.0019281458752224093, + 0.0024123642033342207, + 0.0029917979840488797, + 0.0033213395879194, + 0.0036804206243975165, + 0.004084971728245857, + 0.0045401004723349256, + 0.005019580520245837, + 0.005500134490910595, + 0.005990029200957259, + 0.006479751090464818, + 0.006812593613467091, + 0.006980203777240638 + ] + }, + "temperature": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 261.2464702271332, + 257.53311847759363, + 250.4352977286679, + 238.69211172498464, + 231.9219459791283, + 227.11554634745244, + 220.25475751995089, + 216.60999079189688, + 212.50360650705954, + 209.52572130141192, + 208.42917434008595, + 210.60148864333948, + 213.31437998772253, + 215.82381829343154, + 218.0219766728054, + 220.21935236341315, + 222.68821362799264, + 228.7509361571516, + 235.53321055862492, + 241.99353898096993, + 247.75807243707797, + 252.79929404542665, + 257.1949508901166, + 260.97596685082874, + 264.253591160221, + 267.2038674033149, + 269.9390730509515, + 271.1902394106814, + 272.35432780847145, + 273.4095150399018, + 274.36709637814613, + 275.2785758133824, + 276.1978207489257, + 277.1601289134438, + 278.19060773480663, + 279.391006752609, + 280.80564763658685 + ] + }, + "u_component_of_wind": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 10.965264349294046, + 8.705731085021485, + 7.513738777624309, + 6.173290745856353, + 5.480376668968693, + 4.70942775475752, + 3.560477814226519, + 3.840542510742787, + 5.59145996393493, + 7.31334311694291, + 10.224550145794966, + 12.264089356967464, + 13.47607236034377, + 14.052869858809085, + 14.14782362645795, + 13.858489679251074, + 13.294243017188458, + 11.75768780693677, + 10.175375038367097, + 8.783528046347453, + 7.578708179864948, + 6.534649324739104, + 5.616337668815224, + 4.79007299340086, + 4.035468941835482, + 3.3248875364487414, + 2.6552785930785756, + 2.331686185926949, + 2.0158978859729895, + 1.7065739621700429, + 1.4066210817602824, + 1.1233328297651934, + 0.8577126256522406, + 0.6098378031000614, + 0.38745932488298035, + 0.1904675000719383, + -0.036453252690492634 + ] + }, + "v_component_of_wind": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + -0.06572874103169123, + -0.011399316758447955, + -0.026981349544630526, + -0.024366201045982964, + -0.017284208312950813, + -0.01092347203038674, + 0.008022661015912752, + 0.005625177878701226, + 0.002710293230650394, + 0.0076611850801752415, + 0.01671418341630218, + -0.00047295434552515264, + -0.03201648391339587, + -0.051579529145804555, + -0.041990580128481815, + -0.03524014804903315, + -0.026850069630285068, + -0.021695000617470456, + -0.017722657299100292, + -0.018336570182411565, + -0.02146871717814802, + -0.0237674701426057, + -0.030470080109298265, + -0.02788166493415247, + -0.012904150900187998, + 0.019581333298563153, + 0.06047136225972606, + 0.08149375095917741, + 0.10072268022943523, + 0.11818829671673572, + 0.13653230989103746, + 0.16059385970591622, + 0.18577430495606967, + 0.19757737264521946, + 0.19209840141094997, + 0.18057166673860497, + 0.1803672570403622 + ] + }, + "vertical_velocity": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + -1.7325796914144554e-06, + -8.170435661329969e-06, + -1.4213529776180354e-05, + -2.3953679736367466e-05, + -3.202668796513548e-05, + -4.1203955479967876e-05, + -5.737479074078882e-05, + -6.106777598039733e-05, + -5.841953685636386e-05, + -4.851190571114649e-05, + -1.5877243051359154e-05, + 6.465635338467714e-06, + 2.4713397465138583e-05, + 3.049958147396699e-05, + 2.4945242702046224e-05, + 2.643944223355925e-05, + 4.424439609380058e-05, + 0.00014191689558597047, + 0.0002506863492774846, + 0.00031892484942413384, + 0.0003315187949502007, + 0.00032194783308522114, + 0.0003344866013219563, + 0.0003959916280485356, + 0.0007813948952690672, + 0.002984888909271409, + 0.00543766431533317, + 0.0069271314795431435, + 0.00835428161808433, + 0.009874159108648422, + 0.011665140038403066, + 0.0136386269510268, + 0.015577679019972472, + 0.017481574052092923, + 0.019408168055075968, + 0.02080165612470745, + 0.02085827194597913 + ] + }, + "seconds_since_epoch": { + "dims": [], + "attrs": {}, + "data": 870987547.265625 + }, + "angle_of_sub_gridscale_orography": { + "dims": [], + "attrs": {}, + "data": 0.5139553718730816 + }, + "anisotropy_of_sub_gridscale_orography": { + "dims": [], + "attrs": {}, + "data": 0.16101400938555094 + }, + "geopotential_at_surface": { + "dims": [], + "attrs": {}, + "data": 3764.3027624309393 + }, + "high_vegetation_cover": { + "dims": [], + "attrs": {}, + "data": 0.08262674929980049 + }, + "lake_cover": { + "dims": [], + "attrs": {}, + "data": 0.006582444426909482 + }, + "lake_depth": { + "dims": [], + "attrs": {}, + "data": 2272.2055248618785 + }, + "land_sea_mask": { + "dims": [], + "attrs": {}, + "data": 0.33665826187461634 + }, + "low_vegetation_cover": { + "dims": [], + "attrs": {}, + "data": 0.11036909746201658 + }, + "slope_of_sub_gridscale_orography": { + "dims": [], + "attrs": {}, + "data": 0.0034532487282656395 + }, + "soil_type": { + "dims": [], + "attrs": {}, + "data": 0.6712915803406998 + }, + "standard_deviation_of_filtered_subgrid_orography": { + "dims": [], + "attrs": {}, + "data": 14.084696324432167 + }, + "standard_deviation_of_orography": { + "dims": [], + "attrs": {}, + "data": 20.350742403314918 + }, + "type_of_high_vegetation": { + "dims": [], + "attrs": {}, + "data": 1.8231297239487416 + }, + "type_of_low_vegetation": { + "dims": [], + "attrs": {}, + "data": 1.4019073242786986 + }, + "10m_u_component_of_wind": { + "dims": [], + "attrs": {}, + "data": -0.051680525823366505 + }, + "10m_v_component_of_wind": { + "dims": [], + "attrs": {}, + "data": 0.18364298303787094 + }, + "2m_temperature": { + "dims": [], + "attrs": {}, + "data": 278.2418983529725 + }, + "mean_sea_level_pressure": { + "dims": [], + "attrs": {}, + "data": 100960.6314111418 + }, + "sea_ice_cover": { + "dims": [], + "attrs": {}, + "data": 0.1716348353283821 + }, + "sea_surface_temperature": { + "dims": [], + "attrs": {}, + "data": 286.77197881520186 + }, + "surface_pressure": { + "dims": [], + "attrs": {}, + "data": 96604.64498733886 + }, + "toa_incident_solar_radiation": { + "dims": [], + "attrs": {}, + "data": 1072008.31131062 + }, + "toa_incident_solar_radiation_6hr": { + "dims": [], + "attrs": {}, + "data": 6432047.826764886 + }, + "total_cloud_cover": { + "dims": [], + "attrs": {}, + "data": 0.6748096574516367 + }, + "total_column_water_vapour": { + "dims": [], + "attrs": {}, + "data": 18.17574520918985 + }, + "total_precipitation_6hr": { + "dims": [], + "attrs": {}, + "data": 0.0005949484786685348 + }, + "year_progress": { + "dims": [], + "attrs": {}, + "data": 0.49975101137533784 + }, + "year_progress_sin": { + "dims": [], + "attrs": {}, + "data": -0.0019232822626236157 + }, + "year_progress_cos": { + "dims": [], + "attrs": {}, + "data": 0.01172127404282719 + }, + "day_progress": { + "dims": [], + "attrs": {}, + "data": 0.49861110098039113 + }, + "day_progress_sin": { + "dims": [], + "attrs": {}, + "data": -1.0231613285011715e-08 + }, + "day_progress_cos": { + "dims": [], + "attrs": {}, + "data": 2.679492657383283e-08 + } + } +} \ No newline at end of file diff --git a/geoarches/stats/gc_stats_stddev_by_level.json b/geoarches/stats/gc_stats_stddev_by_level.json new file mode 100644 index 0000000..21e71aa --- /dev/null +++ b/geoarches/stats/gc_stats_stddev_by_level.json @@ -0,0 +1,495 @@ +{ + "coords": { + "level": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 1, + 2, + 3, + 5, + 7, + 10, + 20, + 30, + 50, + 70, + 100, + 125, + 150, + 175, + 200, + 225, + 250, + 300, + 350, + 400, + 450, + 500, + 550, + 600, + 650, + 700, + 750, + 775, + 800, + 825, + 850, + 875, + 900, + 925, + 950, + 975, + 1000 + ] + } + }, + "attrs": { + "date_start": "1979-01-02", + "date_end": "2015", + "timestep_hours": 6, + "window_stride": 1 + }, + "dims": { + "level": 37 + }, + "data_vars": { + "geopotential": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 17112.606642371204, + 14927.383052899238, + 13569.137888056317, + 12021.936671181445, + 11014.847827394611, + 9987.536359651194, + 8028.360716699242, + 6968.466916837494, + 5891.323074395919, + 5493.193936494461, + 5516.180068127431, + 5709.441486145994, + 5831.1099295406375, + 5873.22321319918, + 5833.658588186441, + 5717.928894409594, + 5546.009475860004, + 5104.0355432686865, + 4619.871467831329, + 4158.276237951773, + 3739.084388677641, + 3357.5037623703215, + 3013.5978477577146, + 2698.288651604121, + 2405.9854640637536, + 2136.090565397523, + 1890.4576872450148, + 1775.9228074078367, + 1667.2767065598096, + 1564.4008184932661, + 1468.2921274640337, + 1379.5975591962606, + 1298.6494821698027, + 1226.6568351847818, + 1163.919502837812, + 1111.5790141902128, + 1070.6871348533682 + ] + }, + "specific_humidity": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 1.4884940898294625e-07, + 2.1982487237760237e-07, + 2.708466995774636e-07, + 3.2041869036098746e-07, + 3.4175658419972106e-07, + 3.4726404357349843e-07, + 3.0677010931776993e-07, + 3.0864184800485564e-07, + 3.603612255954032e-07, + 4.061716586143391e-07, + 5.68540894911217e-07, + 1.1833470976675956e-06, + 3.761409657170274e-06, + 1.0044122391083973e-05, + 2.2482128519836652e-05, + 4.34226182936993e-05, + 7.389704879977555e-05, + 0.00016727781432090424, + 0.00031071731485568093, + 0.0005042878342727835, + 0.0007538476507841843, + 0.0010727691169100044, + 0.0014474621834160808, + 0.0017626674134673527, + 0.002104651525190109, + 0.002541037639035409, + 0.0030431828192282614, + 0.0033042261888893147, + 0.0035668913237675765, + 0.0038318789675385436, + 0.0041065214504702845, + 0.004395623208311555, + 0.004706688533403947, + 0.005069831709104547, + 0.005502426689700594, + 0.005797175523138732, + 0.005913643139614111 + ] + }, + "temperature": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 12.694518236220443, + 15.232124812693131, + 14.893106549723116, + 13.744933295760056, + 13.259734148222678, + 12.736646404214579, + 11.475200142540812, + 10.650490364149563, + 10.32188678115599, + 11.575040681380434, + 12.555584297313873, + 10.883391821976442, + 8.95102108605434, + 7.645066691718888, + 7.215256210446501, + 7.607862396061363, + 8.536004836972927, + 10.725001586837827, + 12.091980542365603, + 12.724797115535194, + 12.994872113909757, + 13.116685543572038, + 13.199557689053842, + 13.47399011118827, + 14.207634221102976, + 14.884047343980265, + 15.206750968420238, + 15.341975401237484, + 15.472462756066035, + 15.575436984960715, + 15.688536836489513, + 15.802840314282813, + 15.971918502192493, + 16.188619834256205, + 16.46959019048637, + 16.855691408906022, + 17.249412015525827 + ] + }, + "u_component_of_wind": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 38.26052570667428, + 34.270842763336, + 31.556841010179124, + 28.8371658227183, + 27.236206776930036, + 25.51038958019535, + 21.849073702916307, + 19.073217298659614, + 15.225951208534651, + 13.282016354190807, + 13.474027890736593, + 14.750113167590195, + 16.003829090626503, + 16.968605622896536, + 17.636894579718373, + 17.955508203885937, + 17.922898511655788, + 17.070389975003465, + 15.701449946687415, + 14.295561510119098, + 13.033535604034563, + 11.942143637838862, + 11.034022881543997, + 10.294104291346015, + 9.672860922415467, + 9.132481633312917, + 8.692430656553018, + 8.512381917809037, + 8.36059115389644, + 8.241769861117577, + 8.155295236725324, + 8.092590831220376, + 8.030901300986454, + 7.906749246228666, + 7.605810416900687, + 6.9879960260758995, + 6.112984842398069 + ] + }, + "v_component_of_wind": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 14.608977711095875, + 13.428561027565353, + 12.747354089352685, + 12.066761236149306, + 11.576466977862362, + 10.827751590060249, + 9.124197035822519, + 8.08250070531511, + 6.993648209307545, + 6.72062010415169, + 7.4283190032699595, + 8.435765928416192, + 9.52049395823155, + 10.670286452174034, + 11.82078103045397, + 12.76242162634319, + 13.315289370023564, + 13.275832117046342, + 12.341097758618329, + 11.17402563327411, + 10.078543576770732, + 9.135353088563589, + 8.370656752710701, + 7.762638554467759, + 7.276696313223785, + 6.837367162469687, + 6.508181767896796, + 6.387550546590763, + 6.296602131621005, + 6.242267241239323, + 6.232468779935759, + 6.275756624856855, + 6.363589876889152, + 6.439280975877655, + 6.379126186651189, + 6.000091613245538, + 5.279551671970607 + ] + }, + "vertical_velocity": { + "dims": [ + "level" + ], + "attrs": {}, + "data": [ + 0.0009727713106113812, + 0.0017794115860712873, + 0.00229529832568355, + 0.003001654190599906, + 0.0037905665204143388, + 0.004802021707268748, + 0.0072524042958936556, + 0.009156352505735892, + 0.01256159727708323, + 0.016308131027882197, + 0.025933847020698887, + 0.03800127075583799, + 0.0522796279026779, + 0.0674782586889616, + 0.08294118053757693, + 0.09850344825524196, + 0.11413099044961492, + 0.14551438939055283, + 0.17330064037863915, + 0.19442601010697674, + 0.20876988018542028, + 0.21716122226434, + 0.22270439414970705, + 0.22782697210236988, + 0.2335986832922111, + 0.2380936300575962, + 0.24230701485231876, + 0.24368538827080208, + 0.2433665092550341, + 0.24154118053832827, + 0.23816984602645516, + 0.23224732861204292, + 0.22121198853427576, + 0.2036532931473489, + 0.1808025833492825, + 0.15757211186080092, + 0.14350719996485117 + ] + }, + "seconds_since_epoch": { + "dims": [], + "attrs": {}, + "data": NaN + }, + "angle_of_sub_gridscale_orography": { + "dims": [], + "attrs": {}, + "data": 0.5657162704458326 + }, + "anisotropy_of_sub_gridscale_orography": { + "dims": [], + "attrs": {}, + "data": 0.251740002489918 + }, + "geopotential_at_surface": { + "dims": [], + "attrs": {}, + "data": 8403.270176549393 + }, + "high_vegetation_cover": { + "dims": [], + "attrs": {}, + "data": 0.24549546218862894 + }, + "lake_cover": { + "dims": [], + "attrs": {}, + "data": 0.04530940497505089 + }, + "lake_depth": { + "dims": [], + "attrs": {}, + "data": 2117.047295251493 + }, + "land_sea_mask": { + "dims": [], + "attrs": {}, + "data": 0.4609399796233273 + }, + "low_vegetation_cover": { + "dims": [], + "attrs": {}, + "data": 0.28147005509312106 + }, + "slope_of_sub_gridscale_orography": { + "dims": [], + "attrs": {}, + "data": 0.009944139529737528 + }, + "soil_type": { + "dims": [], + "attrs": {}, + "data": 1.1670392783453456 + }, + "standard_deviation_of_filtered_subgrid_orography": { + "dims": [], + "attrs": {}, + "data": 42.55937783744212 + }, + "standard_deviation_of_orography": { + "dims": [], + "attrs": {}, + "data": 59.27252261204485 + }, + "type_of_high_vegetation": { + "dims": [], + "attrs": {}, + "data": 5.1262587537852164 + }, + "type_of_low_vegetation": { + "dims": [], + "attrs": {}, + "data": 3.5754746482577477 + }, + "10m_u_component_of_wind": { + "dims": [], + "attrs": {}, + "data": 5.523581165393907 + }, + "10m_v_component_of_wind": { + "dims": [], + "attrs": {}, + "data": 4.7417550189604984 + }, + "2m_temperature": { + "dims": [], + "attrs": {}, + "data": 21.40771678509187 + }, + "mean_sea_level_pressure": { + "dims": [], + "attrs": {}, + "data": 1330.9428153554159 + }, + "sea_ice_cover": { + "dims": [], + "attrs": {}, + "data": 0.3563553361538059 + }, + "sea_surface_temperature": { + "dims": [], + "attrs": {}, + "data": 11.649041841838898 + }, + "surface_pressure": { + "dims": [], + "attrs": {}, + "data": 9654.282969821455 + }, + "toa_incident_solar_radiation": { + "dims": [], + "attrs": {}, + "data": 1437888.7082678971 + }, + "toa_incident_solar_radiation_6hr": { + "dims": [], + "attrs": {}, + "data": 7718870.639704755 + }, + "total_cloud_cover": { + "dims": [], + "attrs": {}, + "data": 0.3658826881372015 + }, + "total_column_water_vapour": { + "dims": [], + "attrs": {}, + "data": 16.35455894238812 + }, + "total_precipitation_6hr": { + "dims": [], + "attrs": {}, + "data": 0.0019469866832673617 + }, + "year_progress": { + "dims": [], + "attrs": {}, + "data": 0.29067483157079654 + }, + "year_progress_sin": { + "dims": [], + "attrs": {}, + "data": 0.7085840482846367 + }, + "year_progress_cos": { + "dims": [], + "attrs": {}, + "data": 0.7055264413169846 + }, + "day_progress": { + "dims": [], + "attrs": {}, + "data": 0.28867401335991755 + }, + "day_progress_sin": { + "dims": [], + "attrs": {}, + "data": 0.7071067811865475 + }, + "day_progress_cos": { + "dims": [], + "attrs": {}, + "data": 0.7071067888988349 + } + } +} \ No newline at end of file diff --git a/geoarches/stats/graphcast_diffs_stats.nc b/geoarches/stats/graphcast_diffs_stats.nc new file mode 100644 index 0000000..91dcdc2 Binary files /dev/null and b/geoarches/stats/graphcast_diffs_stats.nc differ diff --git a/geoarches/stats/graphcast_norm_stats.nc b/geoarches/stats/graphcast_norm_stats.nc new file mode 100644 index 0000000..e26be41 Binary files /dev/null and b/geoarches/stats/graphcast_norm_stats.nc differ diff --git a/geoarches/utils/normalization.py b/geoarches/utils/normalization.py new file mode 100644 index 0000000..c005953 --- /dev/null +++ b/geoarches/utils/normalization.py @@ -0,0 +1,299 @@ +import importlib +from typing import Dict, List + +import torch +import xarray as xr +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from omegaconf.listconfig import ListConfig +from tensordict import TensorDict + +import geoarches.stats as geoarches_stats +from geoarches.dataloaders.era5_constants import ( + arches_default_level_variables, + arches_default_pressure_levels, + arches_default_surface_variables, +) +from geoarches.metrics.metric_base import compute_lat_weights, compute_lat_weights_weatherbench + +# Stats path +geoarches_stats_path = importlib.resources.files(geoarches_stats) + +# Default loss weights used for ArchesWeather +default_var_weights = { + "surface": [0.1, 0.1, 1.0, 0.1], + "level": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], +} + + +class NormalizationStatistics: + def __init__( + self, + variables: Dict[str, List[str]] = None, + levels: List[int] = arches_default_pressure_levels, + norm_scheme: str = "pangu", + loss_weight_per_variable: Dict[str, Dict[str, float]] = default_var_weights, + ): + """ + Initializes the normalization module with the specified normalization scheme, variables, + pressure levels, and loss weights per variable. + + The module supports two normalization schemes: 'graphcast' and 'pangu'. + Pangu normalization is the scheme used for the models presented in the papers. + Graphcast normalization scheme allows the use of more variables and pressure levels, + but it is not the default scheme used in the ArchesWeather models. + + The normalization module will load the mean and standard deviation statistics + for the specified variables and pressure levels from the precomputed stats files. + Further, the modules computes the loss coefficients based on the provided variables, + pressure levels, and loss weights per variable. + The loss coefficients are computed based on the area weights, surface and level variables, + and the vertical coefficients derived from the pressure levels. + The normalization module also supports delta normalization, which is used to normalize + the loss coefficients based on the standard deviation of the difference between successive states. + + Parameters + ---------- + variables : dict, optional + A dictionary containing the variables to be normalized. The keys should be 'surface' and 'level', + and the values should be lists of variable names. If None, the default surface and level variables + of archesweather will be used. + levels : list, optional + A list of pressure levels to be used for normalization. If None, the default pressure levels + of archesweather will be used. + norm_scheme : str, optional + The normalization scheme to be used. It can be either 'graphcast' or 'pangu'. + If None, no normalization will be applied. + loss_weight_per_variable : dict, optional + A dictionary containing the loss weights for each variable. The keys should be 'surface' and 'level', + and the values should be dictionaries with variable names as keys and their corresponding weights as values. + If None, the default weights defined in `default_var_weights` will be used. + Raises + ------ + ValueError + If the provided normalization scheme is not supported. Supported schemes are 'graphcast' and ' + 'pangu'. + AssertionError + If the normalization scheme is 'pangu' and the provided variables or levels do not match + the default values required for this scheme. + Notes + ----- + The default surface variables for the 'pangu' normalization scheme are: + - 10m_u_component_of_wind + - 10m_v_component_of_wind + - 2m_temperature + - mean_sea_level_pressure + The default level variables for the 'pangu' normalization scheme are: + - geopotential + - u_component_of_wind + - v_component_of_wind + - temperature + - specific_humidity + - vertical_velocity + The default pressure levels for the 'pangu' normalization scheme are: + - 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000 + """ + + print("##### NORM SCHEME: ", norm_scheme, " #####") + + if variables is None: + variables = { + "surface": arches_default_surface_variables, + "level": arches_default_level_variables, + } + print("##### VARIABLES: ", variables, " #####") + + if norm_scheme and norm_scheme not in ["graphcast", "pangu"]: + raise ValueError( + f"Normalization scheme {norm_scheme} not supported. Choose from ['graphcast', 'pangu']" + ) + self.norm_scheme = norm_scheme + + self.variables = ( + OmegaConf.to_object(variables) if isinstance(variables, DictConfig) else variables + ) + self.levels = OmegaConf.to_object(levels) if isinstance(levels, ListConfig) else levels + + if norm_scheme == "pangu": + assert self.variables["surface"] == arches_default_surface_variables, ( + "Pangu normalization scheme requires the default surface variables./n" + "Surf. Vars: 10m_u_component_of_wind, 10m_v_component_of_wind, 2m_temperature, " + "mean_sea_level_pressure" + ) + assert self.variables["level"] == arches_default_level_variables, ( + "Pangu normalization scheme requires the default level variables./n" + "Level Vars: geopotential, u_component_of_wind, v_component_of_wind, " + "temperature, specific_humidity, vertical_velocity" + ) + assert self.levels == arches_default_pressure_levels, ( + "Pangu normalization scheme requires the default pressure levels./n" + "Pressure Levels: 50, 100, 150, 200, 250, 300, 400, " + "500, 600, 700, 850, 925, 1000" + ) + + self.loss_weight_per_variable = loss_weight_per_variable + + self.mean = None + self.std = None + self.loss_coeffs = None + + def _graphcast_normalization_stats(self): + norm_file_path = geoarches_stats_path / "graphcast_norm_stats.nc" + stats_ds = xr.open_dataset(norm_file_path) + + stats = { + "surface_mean": torch.from_numpy( + stats_ds[self.variables["surface"]].sel(statistic="mean").to_array().to_numpy() + )[..., None, None, None], + "surface_std": torch.from_numpy( + stats_ds[self.variables["surface"]].sel(statistic="std").to_array().to_numpy() + )[..., None, None, None], + "level_mean": torch.from_numpy( + stats_ds[self.variables["level"]] + .sel(statistic="mean") + .sel(level=self.levels) + .to_array() + .to_numpy() + )[..., None, None], + "level_std": torch.from_numpy( + stats_ds[self.variables["level"]] + .sel(statistic="std") + .sel(level=self.levels) + .to_array() + .to_numpy() + )[..., None, None], + } + + data_mean = TensorDict( + surface=stats["surface_mean"], + level=stats["level_mean"], + ) + + data_std = TensorDict( + surface=stats["surface_std"], + level=stats["level_std"], + ) + + return data_mean, data_std + + def _pangu_normalization_stats(self): + # include vertical component by default + norm_file_path = geoarches_stats_path / "pangu_norm_stats2_with_w.pt" + pangu_stats = torch.load(norm_file_path, weights_only=True) + + data_mean = TensorDict( + surface=pangu_stats["surface_mean"], + level=pangu_stats["level_mean"], + ) + + data_std = TensorDict( + surface=pangu_stats["surface_std"], + level=pangu_stats["level_std"], + ) + + return data_mean, data_std + + def load_normalization_stats(self): + if self.norm_scheme is None: + return None, None + + if self.norm_scheme == "pangu": + mean, std = self._pangu_normalization_stats() + elif self.norm_scheme == "graphcast": + mean, std = self._graphcast_normalization_stats() + + self.mean = mean + self.std = std + + return mean, std + + def load_graphcast_timedelta_stats(self): + """Loads the standard deviation of the difference between successive states.""" + file = geoarches_stats_path / "graphcast_diffs_stats.nc" + stats_ds = xr.open_dataset(file) + + surface_stds = torch.from_numpy( + stats_ds[self.variables["surface"]].sel(statistic="std").to_array().to_numpy() + )[..., None, None, None] + level_stds = torch.from_numpy( + stats_ds[self.variables["level"]] + .sel(statistic="std") + .sel(level=self.levels) + .to_array() + .to_numpy() + )[..., None, None] + + return surface_stds, level_stds + + def compute_loss_coeffs( + self, latitude=121, pow=2, loss_delta_normalization=True, use_weatherbench_lat_coeffs=False + ): + compute_weights_fn = ( + compute_lat_weights_weatherbench + if use_weatherbench_lat_coeffs + else compute_lat_weights + ) + + area_weights = compute_weights_fn(latitude) + pressure_levels = torch.tensor(self.levels).float() + vertical_coeffs = (pressure_levels / pressure_levels.mean()).reshape(-1, 1, 1) + + n_surface_vars = len(self.variables["surface"]) + n_level_vars = len(self.variables["level"]) + + surf_weights = torch.tensor([self.loss_weight_per_variable["surface"]]).reshape( + -1, 1, 1, 1 + ) + level_weights = torch.tensor([self.loss_weight_per_variable["level"]]).reshape(-1, 1, 1, 1) + + total_coeff = sum(surf_weights) + sum(level_weights) + + surface_coeffs = n_surface_vars * surf_weights + level_coeffs = n_level_vars * level_weights + + loss_coeffs = TensorDict( + surface=area_weights * surface_coeffs / total_coeff, + level=area_weights * level_coeffs * vertical_coeffs / total_coeff, + ) + + # Get standard deviation for normalization + if self.std is not None: + data_std = self.std + else: + _, data_std = self.load_normalization_stats() + + if loss_delta_normalization: + if self.norm_scheme == "graphcast": + delta_surface_stds, delta_level_stds = self.load_graphcast_timedelta_stats() + else: + # For Pangu, we use the precomputed stats + delta_surface_stds = torch.tensor([3.8920, 4.5422, 2.0727, 584.0980]).reshape( + -1, 1, 1, 1 + ) + delta_level_stds = torch.tensor( + [5.9786e02, 7.4878e00, 8.9492e00, 2.7132e00, 9.5222e-04, 0.3] + ).reshape(-1, 1, 1, 1) + + assert data_std["surface"].shape[0] == delta_surface_stds.shape[0], ( + "Surface stds shape mismatch" + ) + assert data_std["level"].shape[0] == delta_level_stds.shape[0], ( + "Level stds shape mismatch" + ) + + loss_delta_scaler = TensorDict( + surface=data_std["surface"] / delta_surface_stds, + level=data_std["level"] / delta_level_stds, + ) + + loss_coeffs = loss_coeffs * loss_delta_scaler.pow(pow) + + print( + f"Loss coefficients computed with normalization scheme:\ + {self.norm_scheme}, pow: {pow}, delta normalization: {loss_delta_normalization},\ + use_weatherbench_lat_coeffs: {use_weatherbench_lat_coeffs}" + ) + + self.loss_coeffs = loss_coeffs + + return loss_coeffs diff --git a/geoarches/utils/tensordict_utils.py b/geoarches/utils/tensordict_utils.py index 46ceed0..1761212 100644 --- a/geoarches/utils/tensordict_utils.py +++ b/geoarches/utils/tensordict_utils.py @@ -1,7 +1,75 @@ +import warnings + import torch from tensordict.tensordict import TensorDict +def apply_mask_from_gt_nans(pred: TensorDict, ground_truth: TensorDict, value) -> TensorDict: + """ + Applies a mask retrieved from the ground truth to the predictions. + The mask is created by checking where the ground truth has NaNs. + The predictions are then multiplied by the binary mask accordingly. + + This function is useful to ensure that variables like sea_ice_cover are treated correctly, + where the ground truth may have NaNs in certain areas, and we want to apply a mask to the predictions + to avoid using those NaNs in the loss calculation. + """ + + # Mask predictions with binary mask from ground truth + # where ground truth is not NaN, the mask is 1, otherwise 0 + + # pred = TensorDict( + # {k: (~torch.isnan(ground_truth[k])).float() * v for k, v in pred.items()}, batch_size=pred.batch_size + # ) + + for k, v in ground_truth.items(): + pred[k][torch.isnan(v)] = value + + for k, v in ground_truth.items(): + ground_truth[k][torch.isnan(v)] = value + + # pred = TensorDict( + # if value is not None: + # pred = TensorDict( + # {k: torch.where(v == 0, ) for k, v in pred.items()}, batch_size=pred.batch_size + # ) + + # Remove NaNs form the ground truth + """ground_truth = TensorDict( + {k: torch.nan_to_num(v) * (~torch.isnan(v)).float() for k, v in ground_truth.items()}, batch_size=ground_truth.batch_size + )""" + + return pred, ground_truth + + +def check_pred_has_no_nans(pred: torch.Tensor, target: torch.Tensor): + """ + Pred is a tensor with predictions. + Target is a tensor with targets. + The function checks if pred has no NaNs where target has no NaNs. + """ + + target_no_nans = ~target.isnan() + target_nans = target.isnan() + + # index pred with target_nans to check if pred has no NaNs where target has no NaNs + pred_no_target_nans = pred[target_no_nans] + pred_target_nans = pred[target_nans] + pred_no_target_nans = pred_no_target_nans.isnan() + pred_target_nans = pred_target_nans.isnan() + + # check if pred_nans is all False where target_nans is True + # i.e., pred has no NaNs where target has no NaNs + + if pred_no_target_nans.any(): + warnings.warn("Prediction has NaNs where target data has no NaNs") + + if pred_target_nans.any(): + warnings.warn("Prediction has NaNs where target data has NaNs") + + return pred + + def tensordict_apply(f, *args, **kwargs): tdicts = [a for a in args if isinstance(a, TensorDict)] tdicts += [v for v in kwargs.values() if isinstance(v, TensorDict)] diff --git a/poetry.lock b/poetry.lock index 93da84c..a1ea60f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -252,8 +252,8 @@ files = [ ] [package.extras] -astroid = ["astroid (>=2,<4)"] -test = ["astroid (>=2,<4)", "pytest", "pytest-cov", "pytest-xdist"] +asteroid = ["asteroid (>=2,<4)"] +test = ["asteroid (>=2,<4)", "pytest", "pytest-cov", "pytest-xdist"] [[package]] name = "attrs" diff --git a/tests/dataloaders/test_era5.py b/tests/dataloaders/test_era5.py index 16b44a7..04e91a3 100644 --- a/tests/dataloaders/test_era5.py +++ b/tests/dataloaders/test_era5.py @@ -2,15 +2,23 @@ import pandas as pd import pytest import xarray as xr +from hydra import compose, initialize +from omegaconf import OmegaConf -from geoarches.dataloaders import era5 +from geoarches.dataloaders import era5, era5_constants + +with initialize(version_base=None, config_path="../../geoarches/configs", job_name="test"): + cfg = compose(config_name="config") + OmegaConf.resolve(cfg) # Dimension sizes. LAT, LON = 2, 4 -LEVEL = len(era5.pressure_levels) +# Need real levels to load correct normalization stats. +all_levels = OmegaConf.to_object(cfg.stats.module.levels) +LEVEL = len(all_levels) -class TestEra5Forecast: +class TestBase: @classmethod @pytest.fixture(autouse=True) def setup_class(self, tmp_path_factory): @@ -31,42 +39,120 @@ def setup_class(self, tmp_path_factory): data_vars=dict( **{ var_name: (["time", "level", "longitude", "latitude"], level_var_data) - for var_name in era5.level_variables + for var_name in era5_constants.arches_default_level_variables }, **{ var_name: (["time", "latitude", "longitude"], surface_var_data) - for var_name in era5.surface_variables + for var_name in era5_constants.arches_default_surface_variables }, ), - coords={"time": time, "latitude": np.arange(0, LAT), "level": np.arange(0, LEVEL)}, + coords={ + "time": time, + "latitude": np.arange(0, LAT), + "longitude": np.arange(0, LON), + "level": all_levels, + }, ) ds.to_netcdf(file_path) + +class TestEra5Dataset(TestBase): def test_load_current_state(self): ds = era5.Era5Dataset( path=str(self.test_dir), domain="all", + # Select all values in each dimension. + dimension_indexers={ + "level": ("level", all_levels), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, ) example = ds[0] assert len(ds) == 6 assert example["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) - assert example["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + assert example["level"].shape == (6, LEVEL, LAT, LON) # (var, lev, lat, lon) def test_load_current_state_with_timestamp(self): ds = era5.Era5Dataset( path=str(self.test_dir), domain="all", return_timestamp=True, + # Select all values in each dimension. + dimension_indexers={ + "level": ("level", all_levels), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, ) example, timestamp = ds[0] assert len(ds) == 6 # Current state assert example["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) - assert example["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + assert example["level"].shape == (6, LEVEL, LAT, LON) # (var, lev, lat, lon) assert timestamp == 1704067200 # 2024-01-01-00-00 + @pytest.mark.parametrize( + "indexers, expected_lat, expected_lon", + [ + # Filter by level only. + ( + { + "level": ("level", all_levels[3:]), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, + LAT, + LON, + ), + # Filter by level and latitude. + ( + { + "level": ("level", all_levels[3:]), + "latitude": ("latitude", np.arange(0, LAT - 1)), + "longitude": ("longitude", slice(None)), + }, + LAT - 1, + LON, + ), + # Filter by level and longitude. + ( + { + "level": ("level", all_levels[3:]), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", np.arange(0, LON - 1)), + }, + LAT, + LON - 1, + ), + ], + ) + def test_dimension_indexers(self, indexers, expected_lat, expected_lon): + ds = era5.Era5Dataset( + path=str(self.test_dir), + domain="all", + dimension_indexers=indexers, + ) + example = ds[0] + + assert len(ds) == 6 + assert example["surface"].shape == ( + 4, + 1, + expected_lat, + expected_lon, + ) # (var, 1, lat, lon) + assert example["level"].shape == ( + 6, + len(indexers["level"][1]), + expected_lat, + expected_lon, + ) # (var, lev, lat, lon) + + +class TestEra5Forecast(TestBase): @pytest.mark.parametrize( "lead_time_hours, expected_len, expected_next_timestamp", [(6, 5, 1704088800), (12, 4, 1704110400), (24, 2, 1704153600)], @@ -75,11 +161,18 @@ def test_load_current_and_next_state( self, lead_time_hours, expected_len, expected_next_timestamp ): ds = era5.Era5Forecast( + stats_cfg=None, path=str(self.test_dir), domain="all", lead_time_hours=lead_time_hours, load_prev=False, load_clim=False, + # Select all values in each dimension. + dimension_indexers={ + "level": ("level", all_levels), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, ) example = ds[0] @@ -99,12 +192,19 @@ def test_load_current_and_next_state( @pytest.mark.parametrize("multistep, expected_len", [(2, 4), (3, 3), (4, 2)]) def test_multistep(self, multistep, expected_len): ds = era5.Era5Forecast( + stats_cfg=None, path=str(self.test_dir), domain="all", lead_time_hours=6, multistep=multistep, load_prev=False, load_clim=False, + # Select all values in each dimension. + dimension_indexers={ + "level": ("level", all_levels), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, ) example = ds[0] @@ -125,12 +225,19 @@ def test_multistep(self, multistep, expected_len): @pytest.mark.parametrize("multistep, expected_len", [(2, 3), (3, 2), (4, 1)]) def test_multistep_and_load_prev(self, multistep, expected_len): ds = era5.Era5Forecast( + stats_cfg=None, path=str(self.test_dir), domain="all", lead_time_hours=6, multistep=multistep, load_prev=True, load_clim=False, + # Select all values in each dimension. + dimension_indexers={ + "level": ("level", all_levels), + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, ) example = ds[0] @@ -149,16 +256,146 @@ def test_multistep_and_load_prev(self, multistep, expected_len): assert example["prev_state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) assert example["prev_state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) - @pytest.mark.parametrize("indexers", [{"level": [2, 4, 8]}]) - def test_dimension_indexers(self, indexers): - ds = era5.Era5Dataset(path=str(self.test_dir), domain="all", dimension_indexers=indexers) + +class TestEra5ForecastWithGraphcastNormalization(TestBase): + @pytest.mark.parametrize( + "lead_time_hours, expected_len, expected_next_timestamp", + [(6, 5, 1704088800), (12, 4, 1704110400), (24, 2, 1704153600)], + ) + def test_load_current_and_next_state( + self, lead_time_hours, expected_len, expected_next_timestamp + ): + cfg.stats.module.norm_scheme = "graphcast" + ds = era5.Era5Forecast( + stats_cfg=cfg.stats, + path=str(self.test_dir), + domain="all", + lead_time_hours=lead_time_hours, + load_prev=False, + load_clim=False, + # Select all lat/lon. + dimension_indexers={ + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, + ) example = ds[0] - assert len(ds) == 6 - assert example["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) - assert example["level"].shape == ( - 6, - len(indexers["level"]), - LAT, - LON, - ) # (var, lev, lat, lon) + assert len(ds) == expected_len + # Current state + assert example["timestamp"] == 1704067200 # 2024-01-01-00-00 + assert example["state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # Next state + assert example["next_state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["next_state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # No multistep + assert "future_states" not in example + # No prev state + assert "prev_state" not in example + + @pytest.mark.parametrize("multistep, expected_len", [(2, 4), (3, 3), (4, 2)]) + def test_multistep(self, multistep, expected_len): + cfg.stats.module.norm_scheme = "graphcast" + ds = era5.Era5Forecast( + stats_cfg=cfg.stats, + path=str(self.test_dir), + domain="all", + lead_time_hours=6, + multistep=multistep, + load_prev=False, + load_clim=False, + # Select all lat/lon. + dimension_indexers={ + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, + ) + example = ds[0] + + assert len(ds) == expected_len + # Current state + assert example["timestamp"] == 1704067200 # 2024-01-01-00-00 + assert example["state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # Next state + assert example["next_state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["next_state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # Future states + assert example["future_states"]["surface"].shape[0] == multistep + assert example["future_states"]["level"].shape[0] == multistep + # No prev state + assert "prev_state" not in example + + @pytest.mark.parametrize("multistep, expected_len", [(2, 3), (3, 2), (4, 1)]) + def test_multistep_and_load_prev(self, multistep, expected_len): + cfg.stats.module.norm_scheme = "graphcast" + ds = era5.Era5Forecast( + stats_cfg=cfg.stats, + path=str(self.test_dir), + domain="all", + lead_time_hours=6, + multistep=multistep, + load_prev=True, + load_clim=False, + # Select all lat/lon. + dimension_indexers={ + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, + ) + example = ds[0] + + assert len(ds) == expected_len + # Current state + assert example["timestamp"] == 1704088800 # 2024-01-01-06-00 + assert example["state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # Next state + assert example["next_state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["next_state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # Future states + assert example["future_states"]["surface"].shape[0] == multistep + assert example["future_states"]["level"].shape[0] == multistep + # Prev state + assert example["prev_state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["prev_state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + + +class TestEra5ForecastWithPanguNormalization(TestBase): + @pytest.mark.parametrize( + "lead_time_hours, expected_len, expected_next_timestamp", + [(6, 5, 1704088800), (12, 4, 1704110400), (24, 2, 1704153600)], + ) + def test_load_current_and_next_state( + self, lead_time_hours, expected_len, expected_next_timestamp + ): + cfg.stats.module.norm_scheme = "pangu" + ds = era5.Era5Forecast( + stats_cfg=cfg.stats, + path=str(self.test_dir), + domain="all", + lead_time_hours=lead_time_hours, + load_prev=False, + load_clim=False, + # Select all lat/lon. + dimension_indexers={ + "latitude": ("latitude", slice(None)), + "longitude": ("longitude", slice(None)), + }, + ) + + example = ds[0] + + assert len(ds) == expected_len + # Current state + assert example["timestamp"] == 1704067200 # 2024-01-01-00-00 + assert example["state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # Next state + assert example["next_state"]["surface"].shape == (4, 1, LAT, LON) # (var, 1, lat, lon) + assert example["next_state"]["level"].shape == (6, 13, LAT, LON) # (var, lev, lat, lon) + # No multistep + assert "future_states" not in example + # No prev state + assert "prev_state" not in example diff --git a/tests/metrics/test_ensemble_metrics.py b/tests/metrics/test_ensemble_metrics.py index c02203c..cb31a2f 100644 --- a/tests/metrics/test_ensemble_metrics.py +++ b/tests/metrics/test_ensemble_metrics.py @@ -221,7 +221,7 @@ def test_output_keys(self): "mse_U10m", "mse_V10m", "mse_T2m", - "mse_SP", + "mse_MSLP", "mse_Z500", "mse_Q700", ] @@ -261,8 +261,8 @@ def test_output_keys_with_timedelta_dimension(self): "mse_V10m_48h", "mse_T2m_24h", "mse_T2m_48h", - "mse_SP_24h", - "mse_SP_48h", + "mse_MSLP_24h", + "mse_MSLP_48h", "mse_Z500_24h", "mse_Z500_48h", "mse_Q700_24h", diff --git a/tests/utils/test_normalization.py b/tests/utils/test_normalization.py new file mode 100644 index 0000000..6bdeca1 --- /dev/null +++ b/tests/utils/test_normalization.py @@ -0,0 +1,128 @@ +import pytest +from tensordict import TensorDict + +from geoarches.dataloaders.era5 import ( + arches_default_level_variables, + arches_default_pressure_levels, + arches_default_surface_variables, +) +from geoarches.utils import normalization + + +def test_init_defaults(): + norm_stats = normalization.NormalizationStatistics() + assert norm_stats.norm_scheme == "pangu" + assert norm_stats.variables["surface"] == arches_default_surface_variables + assert norm_stats.variables["level"] == arches_default_level_variables + assert norm_stats.levels == arches_default_pressure_levels + assert norm_stats.loss_weight_per_variable == normalization.default_var_weights + + +def test_init_graphcast(): + variables = { + "surface": ["T2m", "U10m"], + "level": ["Z", "U"], + } + levels = [500, 850] + norm_stats = normalization.NormalizationStatistics( + variables=variables, levels=levels, norm_scheme="graphcast" + ) + assert norm_stats.norm_scheme == "graphcast" + assert norm_stats.variables == variables + assert norm_stats.levels == levels + + +def test_init_invalid_scheme(): + with pytest.raises(ValueError): + normalization.NormalizationStatistics(norm_scheme="invalid") + + +def test_init_pangu_invalid_vars(): + variables = { + "surface": ["T2m", "U10m"], + "level": ["Z", "U"], + } + with pytest.raises(AssertionError): + normalization.NormalizationStatistics(variables=variables, norm_scheme="pangu") + + +def test_load_normalization_stats_pangu(): + norm_stats = normalization.NormalizationStatistics(norm_scheme="pangu") + mean, std = norm_stats.load_normalization_stats() + assert isinstance(mean, TensorDict) + assert isinstance(std, TensorDict) + assert "surface" in mean + assert "level" in mean + assert "surface" in std + assert "level" in std + assert mean["surface"].shape == (4, 1, 1, 1) + assert mean["level"].shape == (6, 13, 1, 1) + assert std["surface"].shape == (4, 1, 1, 1) + assert std["level"].shape == (6, 13, 1, 1) + + +def test_load_normalization_stats_graphcast(): + variables = { + "surface": ["2m_temperature", "10m_u_component_of_wind"], + "level": ["geopotential", "u_component_of_wind"], + } + levels = [500, 850] + norm_stats = normalization.NormalizationStatistics( + variables=variables, levels=levels, norm_scheme="graphcast" + ) + mean, std = norm_stats.load_normalization_stats() + assert isinstance(mean, TensorDict) + assert isinstance(std, TensorDict) + assert "surface" in mean + assert "level" in mean + assert "surface" in std + assert "level" in std + assert mean["surface"].shape == (2, 1, 1, 1) + assert mean["level"].shape == (2, 2, 1, 1) + assert std["surface"].shape == (2, 1, 1, 1) + assert std["level"].shape == (2, 2, 1, 1) + + +def test_load_graphcast_timedelta_stats(): + variables = { + "surface": ["2m_temperature", "10m_u_component_of_wind"], + "level": ["geopotential", "u_component_of_wind"], + } + levels = [500, 850] + norm_stats = normalization.NormalizationStatistics( + variables=variables, levels=levels, norm_scheme="graphcast" + ) + surface_stds, level_stds = norm_stats.load_graphcast_timedelta_stats() + assert surface_stds.shape == (2, 1, 1, 1) + assert level_stds.shape == (2, 2, 1, 1) + + +def test_compute_loss_coeffs_pangu(): + norm_stats = normalization.NormalizationStatistics(norm_scheme="pangu") + loss_coeffs = norm_stats.compute_loss_coeffs(latitude=121) + assert isinstance(loss_coeffs, TensorDict) + assert "surface" in loss_coeffs + assert "level" in loss_coeffs + assert loss_coeffs["surface"].shape == (4, 1, 121, 1) + assert loss_coeffs["level"].shape == (6, 13, 121, 1) + + +def test_compute_loss_coeffs_graphcast(): + levels = [500, 850] + norm_stats = normalization.NormalizationStatistics(levels=levels, norm_scheme="graphcast") + loss_coeffs = norm_stats.compute_loss_coeffs(latitude=121) + assert isinstance(loss_coeffs, TensorDict) + assert "surface" in loss_coeffs + assert "level" in loss_coeffs + assert loss_coeffs["surface"].shape == (4, 1, 121, 1) + assert loss_coeffs["level"].shape == (6, 2, 121, 1) + + +def test_compute_loss_coeffs_pangu_no_delta(): + norm_stats = normalization.NormalizationStatistics(norm_scheme="pangu") + loss_coeffs = norm_stats.compute_loss_coeffs(latitude=121, loss_delta_normalization=False) + assert isinstance(loss_coeffs, TensorDict) + assert "surface" in loss_coeffs + assert "level" in loss_coeffs + assert loss_coeffs["surface"].shape == (4, 1, 121, 1) + assert loss_coeffs["level"].shape == (6, 13, 121, 1)