diff --git a/geoarches/dataloaders/dcpp.py b/geoarches/dataloaders/dcpp.py index eabae1f..e9d78b2 100644 --- a/geoarches/dataloaders/dcpp.py +++ b/geoarches/dataloaders/dcpp.py @@ -1,43 +1,18 @@ import importlib.resources +import numpy as np import pandas as pd import torch from tensordict.tensordict import TensorDict -from .. import stats as geoarches_stats -from .netcdf import XarrayDataset - -filename_filters = dict( - all=(lambda _: True), - train=lambda x: any( - substring in x for substring in [f"_{str(x)}_tos_included.nc" for x in range(1, 9)] - ), - val=lambda x: any( - substring in x for substring in [f"_{str(x)}_tos_included.nc" for x in range(9, 10)] - ), - test=lambda x: any( - substring in x for substring in [f"_{str(x)}_tos_included.nc" for x in range(10, 11)] - ), - empty=lambda x: False, +from geoarches.utils.tensordict_utils import ( + apply_nan_to_num, + get_non_nan_mask, + replace_inf_and_large_values, ) -pressure_levels = [85000, 70000, 50000] -surface_variables = ["tas", "npp", "nbp", "gpp", "cVeg", "evspsbl", "mrfso", "mrro", "ps", "tos"] -level_variables = ["hur", "hus", "o3", "ta", "ua", "va", "wap", "zg"] - - -def replace_nans(tensordict, value=0): - return tensordict.apply( - lambda x: torch.where(torch.isnan(x), torch.tensor(value, dtype=x.dtype), x) - ) - - -default_dimension_indexers = { - "level": ("plev", pressure_levels), - "latitude": ("lat", slice(None)), - "longitude": ("lon", slice(None)), - "time": ("time", slice(None)), -} +from .. import stats as geoarches_stats +from .netcdf import XarrayDataset class DCPPForecast(XarrayDataset): @@ -50,19 +25,25 @@ class DCPPForecast(XarrayDataset): def __init__( self, - path="data/batch_with_tos/", + path="/path/to/data/", forcings_path="data/", domain="train", filename_filter=None, lead_time_months=1, multistep=1, - load_prev=True, + load_prev=1, # how many previous states to load load_clim=False, norm_scheme="spatial_norm", limit_examples: int = 0, mask_value=0, variables=None, - dimension_indexers: dict = default_dimension_indexers, + surface_variables=None, + level_variables=None, + surface_variable_indices=[], + level_variable_indices=[], + pressure_levels=[85000, 70000, 50000, 25000], + filename_filter_type="dcpp", # choose train/test split with filename filter: "dcpp" and "dcpp_alt" + indexed_year_forcings=False, # toggle index for forcings between zero-indexed and and year-indexed ): """ Args: @@ -81,13 +62,47 @@ def __init__( """ self.__dict__.update(locals()) # concise way to update self with input arguments - self.timedelta = 1 - self.current_multistep = 1 + if self.filename_filter_type == "dcpp_alt": + train_filter = range(1960, 2000) + val_filter = range(2000, 2010) + test_filter = range(2010, 2016) + filename_filters = dict( + all=(lambda _: True), + train=lambda x: any( + substring in x for substring in [f"{str(x)}_" for x in train_filter] + ), + test=lambda x: any( + substring in x for substring in [f"{str(x)}_" for x in test_filter] + ), + val=lambda x: any( + substring in x for substring in [f"{str(x)}_" for x in val_filter] + ), + empty=lambda x: False, + ) + elif self.filename_filter_type == "dcpp": + train_filter = [x for i, x in enumerate(range(1960, 2010)) if (i + 1) % 10 != 0] + val_filter = range(2010, 2016) + test_filter = [1969, 1979, 1989, 1999, 2009] + filename_filters = dict( + all=(lambda _: True), + train=lambda x: any( + substring in x for substring in [f"{str(x)}_" for x in train_filter] + ), + val=lambda x: any( + substring in x for substring in [f"{str(x)}_" for x in val_filter] + ), + test=lambda x: any( + substring in x for substring in [f"{str(x)}_" for x in test_filter] + ), + empty=lambda x: False, + ) if filename_filter is None: filename_filter = filename_filters[domain] if variables is None: variables = dict(surface=surface_variables, level=level_variables) + stats_file_path = "dcpp_stats.pt" + dimension_indexers = {"plev": ("plev", pressure_levels)} super().__init__( path, @@ -95,13 +110,17 @@ def __init__( variables=variables, limit_examples=limit_examples, dimension_indexers=dimension_indexers, + timestamp_key=lambda x: (x[0], x[1]), ) geoarches_stats_path = importlib.resources.files(geoarches_stats) - norm_file_path = geoarches_stats_path / "dcpp_spatial_norm_stats.pt" + norm_file_path = geoarches_stats_path / stats_file_path + spatial_norm_stats = torch.load(norm_file_path) - # normalization, + clim_removed_file_path = geoarches_stats_path / "dcpp_clim_removed_norm_stats.pt" + clim_removed_norm_stats = torch.load(clim_removed_file_path) + if self.norm_scheme is None: self.data_mean = TensorDict( surface=torch.tensor(0), @@ -111,36 +130,69 @@ def __init__( surface=torch.tensor(1), level=torch.tensor(1), ) - + # both mean and std_dev have spatial dimensions elif self.norm_scheme == "spatial_norm": self.data_mean = TensorDict( - surface=spatial_norm_stats["surface_mean"], - level=spatial_norm_stats["level_mean"], + surface=torch.stack( + [spatial_norm_stats["surface_mean"][i] for i in surface_variable_indices] + ), + level=torch.stack( + [spatial_norm_stats["level_mean"][i] for i in level_variable_indices] + ), ) self.data_std = TensorDict( - surface=spatial_norm_stats["surface_std"], - level=spatial_norm_stats["level_std"], + surface=torch.stack( + [spatial_norm_stats["surface_std"][i] for i in surface_variable_indices] + ), + level=torch.stack( + [spatial_norm_stats["level_std"][i] for i in level_variable_indices] + ), + ) + # only mean is spatial, std_dev is averaged over space + elif self.norm_scheme == "mean_only_spatial_norm": + self.data_mean = TensorDict( + surface=torch.stack( + [spatial_norm_stats["surface_mean"][i] for i in surface_variable_indices] + ), + level=torch.stack( + [spatial_norm_stats["level_mean"][i] for i in level_variable_indices] + ), ) - self.surface_variables = [ - "tas", - "npp", - "nbp", - "gpp", - "cVeg", - "evspsbl", - "mrfso", - "mrro", - "ps", - "tos", - ] + self.data_std = TensorDict( + surface=torch.stack( + [spatial_norm_stats["surface_std"][i] for i in surface_variable_indices] + ).nanmean(axis=(-1, -2), keepdim=True), + level=torch.stack( + [spatial_norm_stats["level_std"][i] for i in level_variable_indices] + ).nanmean(axis=(-1, -2), keepdim=True), + ) + # statistics for predicting the anomaly, both mean and std_dev have spatial dimensions + elif self.norm_scheme == "clim_removed": + self.data_mean = TensorDict( + surface=clim_removed_norm_stats["surface_mean"], + level=clim_removed_norm_stats["level_mean"], + ) + self.data_mean.batch_size = [ + 12 + ] # this is so the tensordict can be indexed by one value for both surface/level + self.data_std = TensorDict( + surface=clim_removed_norm_stats["surface_std"], + level=clim_removed_norm_stats["level_std"], + ) + self.data_std.batch_size = [12] + + self.surface_variables = surface_variables self.level_variables = [ - a + str(p) - for a in ["hur_", "hus_", "o3_", "ta_", "ua_", "va_", "wap_", "zg_"] - for p in pressure_levels + a + " " + str(p // 100) for a in level_variables for p in pressure_levels ] - self.atmos_forcings = torch.load(f"{forcings_path}/full_atmos_normal.pt") - self.solar_forcings = torch.load(f"{forcings_path}/full_solar_normal.pt") + self.atmos_forcings = torch.load(f"{forcings_path}/cmip_ghg_forcings_ssp245.pt") + self.solar_forcings = torch.tensor(np.load(f"{forcings_path}/solar_forcings_normed.npy")) + times_seconds = [v[2].item() // 10**9 for k, v in self.id2pt.items()] + self.next_timestamp_map = {k: v for k, v in list(zip(times_seconds, times_seconds[1:]))} + + # override netcdf functionality + self.timestamps = sorted(self.timestamps, key=lambda x: (x[0], x[1])) # sort by timestamp def convert_to_tensordict(self, xr_dataset): """ @@ -163,28 +215,27 @@ def __getitem__(self, i, normalize=True): out = dict() # load current state out["state"] = super().__getitem__(i) - out["timestamp"] = torch.tensor( self.id2pt[i][2].item() // 10**9, - dtype=torch.int32, + dtype=torch.int64, ) # time in seconds - times = pd.to_datetime(out["timestamp"].cpu().numpy() * 10**9).tz_localize(None) - current_year = ( - torch.tensor(times.month) + 1970 - 1961 - ) # plus 1970 for the timestep, -1961 to zero index - current_month = torch.tensor(times.year) % 12 - + times = pd.to_datetime(out["timestamp"].cpu().numpy(), unit="s").tz_localize(None) + current_month = torch.tensor(times.month) - 1 % 12 + if self.indexed_year_forcings: + current_year = torch.tensor(times.year) + current_solar_year = torch.tensor(times.year) - 1961 + else: + current_year = torch.tensor(times.year) - 1961 + current_solar_year = torch.tensor(times.year) - 1961 # -1961 to zero index out["forcings"] = torch.concatenate( [ - self.atmos_forcings[:, (current_year * 12) + current_month], - self.solar_forcings[current_year, current_month, :], + self.atmos_forcings[current_year, :], + self.solar_forcings[(current_solar_year * 12) + current_month, :], ] ) - # next obsi. has function of - t = self.lead_time_months # multistep + t = self.lead_time_months # multistep out["next_state"] = super().__getitem__(i + t // self.timedelta) - # Load multiple future timestamps if specified. if self.multistep > 1: future_states = [] @@ -193,24 +244,67 @@ def __getitem__(self, i, normalize=True): out["future_states"] = torch.stack(future_states, dim=0) if self.load_prev: - out["prev_state"] = super().__getitem__(i - self.lead_time_months // self.timedelta) + if self.load_prev > 1: + prev_states = [] + for k in range(0, self.load_prev): + prev_states.append( + super().__getitem__(i - (self.lead_time_months * k + 1) // self.timedelta) + ) + out["prev_state"] = torch.stack(prev_states, dim=0) + else: + out["prev_state"] = super().__getitem__( + i - self.lead_time_months // self.timedelta + ) + prev_timestamp = ( + self.id2pt[i - self.lead_time_months // self.timedelta][2].item() // 10**9 + ) + times = pd.to_datetime(prev_timestamp, unit="s").tz_localize(None) + if normalize: - out = self.normalize(out) + out = self.normalize(out, month=current_month) # need to replace nans with mask_value - out = {k: replace_nans(v, self.mask_value) if "state" in k else v for k, v in out.items()} + mask = {k: (get_non_nan_mask(v) if "state" in k else v) for k, v in out.items()} + + out = {k: apply_nan_to_num(v) if "state" in k else v for k, v in out.items()} + out = {k: (v * mask[k] if "state" in k else v) for k, v in out.items()} return out - def normalize(self, batch): + def normalize(self, batch, stateless=False, month=None): device = list(batch.values())[0].device means = self.data_mean.to(device) stds = self.data_std.to(device) - out = {k: ((v - means) / stds if "state" in k else v) for k, v in batch.items()} - return out - def denormalize(self, batch): - device = list(batch.values())[0].device + if self.norm_scheme == "clim_removed": + return { + k: ((v - means[month]) / stds[month] if "state" in k else v) + for k, v in batch.items() + } + elif stateless: + return (batch - means) / stds + else: + dict_out = {k: ((v - means) / stds if "state" in k else v) for k, v in batch.items()} + dict_out = { + k: replace_inf_and_large_values(v, 1e35) if "state" in k else v + for k, v in dict_out.items() + } + return dict_out + + def denormalize(self, batch, stateless=False, month=None): + if stateless: + device = batch.device + else: + device = list(batch.values())[0].device means = self.data_mean.to(device) stds = self.data_std.to(device) - out = {k: (v * stds + means if "state" in k else v) for k, v in batch.items()} - return out + if self.norm_scheme == "clim_removed" and stateless: + return batch * stds[month] + means[month] + elif self.norm_scheme == "clim_removed": + return { + k: (v * stds[month] + means[month] if "state" in k else v) + for k, v in batch.items() + } + elif stateless: + return (batch * stds) + means + else: + return {k: ((v * stds) + means if "state" in k else v) for k, v in batch.items()} diff --git a/geoarches/dataloaders/netcdf.py b/geoarches/dataloaders/netcdf.py index b56ddbb..9d287e4 100644 --- a/geoarches/dataloaders/netcdf.py +++ b/geoarches/dataloaders/netcdf.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Callable, Dict, List +import numpy as np import torch import xarray as xr from tensordict.tensordict import TensorDict @@ -44,7 +45,8 @@ def __init__( return_timestamp: bool = False, warning_on_nan: bool = True, limit_examples: int | None = None, - interpolate_nans: bool = True, + timestamp_key: Callable = lambda x: (x[-1]), + interpolate_nans: bool = False, ): """ Args: @@ -118,7 +120,7 @@ def __init__( self.timestamps = self.timestamps[:limit_examples] break - self.timestamps = sorted(self.timestamps, key=lambda x: x[-1]) # sort by timestamp + self.timestamps = sorted(self.timestamps, key=timestamp_key) # sort by timestamp self.id2pt = dict(enumerate(self.timestamps)) self.cached_xrdataset = None @@ -164,17 +166,17 @@ def convert_to_tensordict(self, xr_dataset): if self.dimension_indexers and not self.already_ran_index_selection: indexers = {v[0]: v[1] for k, v in self.dimension_indexers.items() if k != "time"} - print(xr_dataset) - print(indexers) - xr_dataset = xr_dataset.sel(**indexers) self.already_ran_index_selection = False # Reset for next call. - np_arrays = { - key: xr_dataset[list(variables)].to_array().to_numpy() - for key, variables in self.variables.items() - } + # Make np arrays for each key and make an empty array if no variables for this list. Needed for running experiments with different variable sets + np_arrays = {} + for key, variables in self.variables.items(): + if variables: # non-empty + np_arrays[key] = xr_dataset[list(variables)].to_array().to_numpy() + else: # empty list -> create an empty array + np_arrays[key] = np.empty((0,)) tdict = TensorDict( {key: torch.from_numpy(np_array).float() for key, np_array in np_arrays.items()} diff --git a/geoarches/lightning_modules/base_module.py b/geoarches/lightning_modules/base_module.py index 8701b0b..ff4a3ad 100644 --- a/geoarches/lightning_modules/base_module.py +++ b/geoarches/lightning_modules/base_module.py @@ -16,6 +16,7 @@ def load_module( dotlist: list = [], return_config: bool = True, ckpt_fname: str | None = None, + cfg=None, **kwargs, ): """ @@ -29,9 +30,10 @@ def load_module( path = Path("modelstore").joinpath(path) else: path = Path(path) - cfg = OmegaConf.load(path / "config.yaml") - cfg.merge_with_dotlist(dotlist) - module = instantiate(cfg.module.module, cfg.module, cfg.stats, **kwargs) + if cfg is None: + cfg = OmegaConf.load(path / "config.yaml") + cfg.merge_with_dotlist(dotlist) + module = instantiate(cfg.module.module, cfg.module, **kwargs) module.init_from_ckpt(path, ckpt_fname=ckpt_fname, missing_warning=False) if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/geoarches/metrics/deterministic_metrics.py b/geoarches/metrics/deterministic_metrics.py index 82ffa4d..23760a0 100644 --- a/geoarches/metrics/deterministic_metrics.py +++ b/geoarches/metrics/deterministic_metrics.py @@ -14,7 +14,7 @@ lat_coeffs_equi = (lat_coeffs_equi / lat_coeffs_equi.mean())[None, None, :, None] -def wrmse(pred, gt, weights=None): +def wrmse(pred, gt, weights=None, ignore_nans=False): """Weighted root mean square error. Expects inputs of shape: [..., lat, lon] @@ -26,8 +26,10 @@ def wrmse(pred, gt, weights=None): """ if weights is None: weights = lat_coeffs_equi.to(pred.device) - - err = (pred - gt).pow(2).mul(weights).mean((-2, -1)).sqrt() + if ignore_nans: + err = (pred - gt).pow(2).mul(weights).mean((-2, -1)).sqrt() + else: + err = (pred - gt).pow(2).mul(weights).nanmean((-2, -1)).sqrt() return err diff --git a/geoarches/metrics/deterministic_metrics_legacy.py b/geoarches/metrics/deterministic_metrics_legacy.py index 9b4ea4c..334dd53 100644 --- a/geoarches/metrics/deterministic_metrics_legacy.py +++ b/geoarches/metrics/deterministic_metrics_legacy.py @@ -1,9 +1,12 @@ import torch -lat_coeffs_equi = torch.tensor( - [torch.cos(x) for x in torch.arange(-torch.pi / 2, torch.pi / 2, torch.pi / 120)] -) -lat_coeffs_equi = (lat_coeffs_equi / lat_coeffs_equi.mean())[None, None, :, None] + +def compute_lat_coeffs(lat_size): + lat_coeffs_equi = torch.tensor( + [torch.cos(x) for x in torch.arange(-torch.pi / 2, torch.pi / 2, torch.pi / lat_size)] + ) + lat_coeffs_equi = (lat_coeffs_equi / lat_coeffs_equi.mean())[None, None, :, None] + return lat_coeffs_equi def acc(x, y, z=0): @@ -16,13 +19,13 @@ def acc(x, y, z=0): y: targets z: climatology """ - assert x.shape[-2] == 120, "Wrong shape for ACC computation" + lat_coeffs_equi = compute_lat_coeffs(x.shape[-2]) coeffs = lat_coeffs_equi.to(x.device)[None] x = x - z y = y - z - norm1 = (x * x).mul(coeffs).mean((-2, -1)) ** 0.5 - norm2 = (y * y).mul(coeffs).mean((-2, -1)) ** 0.5 - mean_acc = (x * y).mul(coeffs).mean((-2, -1)) / norm1 / norm2 + norm1 = (x * x).mul(coeffs).nanmean((-2, -1)) ** 0.5 + norm2 = (y * y).mul(coeffs).nanmean((-2, -1)) ** 0.5 + mean_acc = (x * y).mul(coeffs).nanmean((-2, -1)) / norm1 / norm2 return mean_acc @@ -35,9 +38,9 @@ def wrmse(x, y): x: predictions y: targets """ - assert x.shape[-2] == 120, "Wrong shape for WRMSE computation" + lat_coeffs_equi = compute_lat_coeffs(x.shape[-2]) coeffs = lat_coeffs_equi.to(x.device) - err = (x - y).pow(2).mul(coeffs).mean((-2, -1)).sqrt() + err = (x - y).pow(2).mul(coeffs).nanmean((-2, -1)).sqrt() return err diff --git a/geoarches/metrics/metric_base.py b/geoarches/metrics/metric_base.py index ab84314..ea81f19 100644 --- a/geoarches/metrics/metric_base.py +++ b/geoarches/metrics/metric_base.py @@ -103,7 +103,7 @@ def weighted_mean(self, x: torch.Tensor): class TensorDictMetricBase(Metric): - """Wrapper around metric to enable handling of targets and preds that are TensorDicts. + """Wrapper around metric to enable handling of and preds that are TensorDicts. Assumes metric should accept tensor target and pred. Keeps track of a metric instantiation per item in the TensorDict. diff --git a/geoarches/stats/dcpp_clim_removed_norm_stats.pt b/geoarches/stats/dcpp_clim_removed_norm_stats.pt new file mode 100644 index 0000000..5fe2b26 Binary files /dev/null and b/geoarches/stats/dcpp_clim_removed_norm_stats.pt differ diff --git a/geoarches/stats/dcpp_plev_delta_std b/geoarches/stats/dcpp_plev_delta_std new file mode 100644 index 0000000..4cfc430 Binary files /dev/null and b/geoarches/stats/dcpp_plev_delta_std differ diff --git a/geoarches/stats/dcpp_stats.pt b/geoarches/stats/dcpp_stats.pt new file mode 100644 index 0000000..9bfcee6 Binary files /dev/null and b/geoarches/stats/dcpp_stats.pt differ diff --git a/geoarches/stats/dcpp_surface_delta_std b/geoarches/stats/dcpp_surface_delta_std new file mode 100644 index 0000000..97ac77e Binary files /dev/null and b/geoarches/stats/dcpp_surface_delta_std differ diff --git a/geoarches/utils/tensordict_utils.py b/geoarches/utils/tensordict_utils.py index 1761212..e998142 100644 --- a/geoarches/utils/tensordict_utils.py +++ b/geoarches/utils/tensordict_utils.py @@ -103,3 +103,32 @@ def tensordict_cat(tdict_list, dim=0, **kwargs): ), device=tdict_list[0].device, ).auto_batch_size_() + + +def apply_nan_to_num(td: TensorDict): + return TensorDict( + {key: torch.nan_to_num(value) for key, value in td.items()}, batch_size=td.batch_size + ) + + +def get_non_nan_mask(td: TensorDict): + return TensorDict( + {key: ~torch.isnan(value) for key, value in td.items()}, batch_size=td.batch_size + ) + + +def replace_inf_and_large_values(td: TensorDict, threshold): + """ + Replaces `inf` values and values larger than threshold with 0. + """ + return TensorDict( + { + key: value.masked_fill(torch.isinf(value) | (value > threshold), 0) + for key, value in td.items() + }, + batch_size=td.batch_size, + ) + + +def replace_nans(td: TensorDict, value=0): + return td.apply(lambda x: torch.where(torch.isnan(x), torch.tensor(value, dtype=x.dtype), x)) diff --git a/tests/dataloaders/test_dcpp.py b/tests/dataloaders/test_dcpp.py index a0d5509..be6d5b9 100644 --- a/tests/dataloaders/test_dcpp.py +++ b/tests/dataloaders/test_dcpp.py @@ -8,7 +8,7 @@ # Dimension sizes. LAT, LON = 143, 144 -PLEV = 3 +PLEV = 4 class TestDCPPForecast: @@ -19,37 +19,39 @@ def setup_class(self, tmp_path_factory): self.test_dir = tmp_path_factory.mktemp("data") times = pd.date_range("2024-01-01", periods=6, freq="1ME") # datetime64[ns] for i in range(2): - file_path = self.test_dir / f"fake_dcpp_{i}_tos_included.nc" + file_path = self.test_dir / f"1961_{i}.nc" time = times[i * 2 : i * 2 + 2] # Create some dummy data level_var_data = np.zeros((len(time), PLEV, LAT, LON)) surface_var_data = np.zeros((len(time), LAT, LON)) + level_variables = ["va", "ua", "zg", "wap"] + surface_variables = ["psl", "tos"] ds = xr.Dataset( data_vars=dict( **{ var_name: (["time", "plev", "lat", "lon"], level_var_data) - for var_name in dcpp.level_variables + for var_name in level_variables }, **{ var_name: (["time", "lat", "lon"], surface_var_data) - for var_name in dcpp.surface_variables + for var_name in surface_variables }, ), coords={ "time": time, "lat": np.arange(0, LAT), "lon": np.arange(0, LON), - "plev": [85000, 70000, 50000], + "plev": [85000, 70000, 50000, 25000], }, ) ds.to_netcdf(file_path) # make fake atmos forcings - full_atmos_normal = torch.rand((4, 540)) - torch.save(full_atmos_normal, f"{self.test_dir}/full_atmos_normal.pt") - full_solar_normal = torch.rand((340, 12, 6)) - torch.save(full_solar_normal, f"{self.test_dir}/full_solar_normal.pt") + full_atmos_normal = torch.rand((540, 4)) + torch.save(full_atmos_normal, f"{self.test_dir}/cmip_ghg_forcings_ssp245.pt") + full_solar_normal = torch.rand((804, 6)) + np.save(f"{self.test_dir}/solar_forcings_normed.npy", full_solar_normal.numpy()) def test_load_current_state(self): dcpp_model = dcpp.DCPPForecast( @@ -60,19 +62,23 @@ def test_load_current_state(self): load_prev=False, multistep=0, load_clim=False, + surface_variable_indices=[0, 1], + level_variable_indices=[0, 1, 2], + surface_variables=["psl", "tos"], + level_variables=["va", "ua", "zg"], ) example = next(iter(dcpp_model)) - assert len(dcpp_model) == 2 + assert len(dcpp_model) == 4 # Current state - assert example["timestamp"] == 1711843200 # 2024-01-01-00-00 - assert example["state"]["surface"].shape == (10, 1, LAT, LON) # (var, lat, lon) - assert example["state"]["level"].shape == (8, 3, LAT, LON) # (var, lev, lat, lon) + assert example["timestamp"] == 1706659200 # 2024-01-01-00-00 + assert example["state"]["surface"].shape == (2, 1, LAT, LON) # (var, lat, lon) + assert example["state"]["level"].shape == (3, 4, LAT, LON) # (var, lev, lat, lon) assert example["forcings"].shape == torch.Size([10]) # (var) @pytest.mark.parametrize( "lead_time_months, expected_len, expected_next_timestamp", - [(1, 1, 1704088800), (1, 1, 1704110400)], + [(1, 3, 1704088800), (1, 3, 1704110400)], ) def test_load_current_and_next_state( self, lead_time_months, expected_len, expected_next_timestamp @@ -84,17 +90,21 @@ def test_load_current_and_next_state( lead_time_months=lead_time_months, load_prev=False, load_clim=False, + surface_variable_indices=[0, 1], + level_variable_indices=[0, 1, 2], + surface_variables=["psl", "tos"], + level_variables=["va", "ua", "zg"], ) example = ds[0] assert len(ds) == expected_len # Current state - assert example["timestamp"] == 1711843200 # 2024-01-01-00-00 - assert example["state"]["surface"].shape == (10, 1, LAT, LON) # (var, lat, lon) - assert example["state"]["level"].shape == (8, 3, LAT, LON) # (var, lev, lat, lon) + assert example["timestamp"] == 1706659200 # 2024-01-01-00-00 + assert example["state"]["surface"].shape == (2, 1, LAT, LON) # (var, lat, lon) + assert example["state"]["level"].shape == (3, 4, LAT, LON) # (var, lev, lat, lon) # Next state - assert example["next_state"]["surface"].shape == (10, 1, LAT, LON) # (var, lat, lon) - assert example["next_state"]["level"].shape == (8, 3, LAT, LON) # (var, lev, lat, lon) + assert example["next_state"]["surface"].shape == (2, 1, LAT, LON) # (var, lat, lon) + assert example["next_state"]["level"].shape == (3, 4, LAT, LON) # (var, lev, lat, lon) # No multistep assert "future_states" not in example # No prev state