Skip to content
262 changes: 178 additions & 84 deletions geoarches/dataloaders/dcpp.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,18 @@
import importlib.resources

import numpy as np
import pandas as pd
import torch
from tensordict.tensordict import TensorDict

from .. import stats as geoarches_stats
from .netcdf import XarrayDataset

filename_filters = dict(
all=(lambda _: True),
train=lambda x: any(
substring in x for substring in [f"_{str(x)}_tos_included.nc" for x in range(1, 9)]
),
val=lambda x: any(
substring in x for substring in [f"_{str(x)}_tos_included.nc" for x in range(9, 10)]
),
test=lambda x: any(
substring in x for substring in [f"_{str(x)}_tos_included.nc" for x in range(10, 11)]
),
empty=lambda x: False,
from geoarches.utils.tensordict_utils import (
apply_nan_to_num,
get_non_nan_mask,
replace_inf_and_large_values,
)

pressure_levels = [85000, 70000, 50000]
surface_variables = ["tas", "npp", "nbp", "gpp", "cVeg", "evspsbl", "mrfso", "mrro", "ps", "tos"]
level_variables = ["hur", "hus", "o3", "ta", "ua", "va", "wap", "zg"]


def replace_nans(tensordict, value=0):
return tensordict.apply(
lambda x: torch.where(torch.isnan(x), torch.tensor(value, dtype=x.dtype), x)
)


default_dimension_indexers = {
"level": ("plev", pressure_levels),
"latitude": ("lat", slice(None)),
"longitude": ("lon", slice(None)),
"time": ("time", slice(None)),
}
from .. import stats as geoarches_stats
from .netcdf import XarrayDataset


class DCPPForecast(XarrayDataset):
Expand All @@ -50,19 +25,25 @@ class DCPPForecast(XarrayDataset):

def __init__(
self,
path="data/batch_with_tos/",
path="/path/to/data/",
forcings_path="data/",
domain="train",
filename_filter=None,
lead_time_months=1,
multistep=1,
load_prev=True,
load_prev=1, # how many previous states to load
load_clim=False,
norm_scheme="spatial_norm",
limit_examples: int = 0,
mask_value=0,
variables=None,
dimension_indexers: dict = default_dimension_indexers,
surface_variables=None,
level_variables=None,
surface_variable_indices=[],
level_variable_indices=[],
pressure_levels=[85000, 70000, 50000, 25000],
filename_filter_type="dcpp", # choose train/test split with filename filter: "dcpp" and "dcpp_alt"
indexed_year_forcings=False, # toggle index for forcings between zero-indexed and and year-indexed
):
"""
Args:
Expand All @@ -81,27 +62,65 @@ def __init__(

"""
self.__dict__.update(locals()) # concise way to update self with input arguments

self.timedelta = 1
self.current_multistep = 1
if self.filename_filter_type == "dcpp_alt":
train_filter = range(1960, 2000)
val_filter = range(2000, 2010)
test_filter = range(2010, 2016)
filename_filters = dict(
all=(lambda _: True),
train=lambda x: any(
substring in x for substring in [f"{str(x)}_" for x in train_filter]
),
test=lambda x: any(
substring in x for substring in [f"{str(x)}_" for x in test_filter]
),
val=lambda x: any(
substring in x for substring in [f"{str(x)}_" for x in val_filter]
),
empty=lambda x: False,
)
elif self.filename_filter_type == "dcpp":
train_filter = [x for i, x in enumerate(range(1960, 2010)) if (i + 1) % 10 != 0]
val_filter = range(2010, 2016)
test_filter = [1969, 1979, 1989, 1999, 2009]
filename_filters = dict(
all=(lambda _: True),
train=lambda x: any(
substring in x for substring in [f"{str(x)}_" for x in train_filter]
),
val=lambda x: any(
substring in x for substring in [f"{str(x)}_" for x in val_filter]
),
test=lambda x: any(
substring in x for substring in [f"{str(x)}_" for x in test_filter]
),
empty=lambda x: False,
)
if filename_filter is None:
filename_filter = filename_filters[domain]
if variables is None:
variables = dict(surface=surface_variables, level=level_variables)
stats_file_path = "dcpp_stats.pt"
dimension_indexers = {"plev": ("plev", pressure_levels)}

super().__init__(
path,
filename_filter=filename_filter,
variables=variables,
limit_examples=limit_examples,
dimension_indexers=dimension_indexers,
timestamp_key=lambda x: (x[0], x[1]),
)

geoarches_stats_path = importlib.resources.files(geoarches_stats)
norm_file_path = geoarches_stats_path / "dcpp_spatial_norm_stats.pt"
norm_file_path = geoarches_stats_path / stats_file_path

spatial_norm_stats = torch.load(norm_file_path)

# normalization,
clim_removed_file_path = geoarches_stats_path / "dcpp_clim_removed_norm_stats.pt"
clim_removed_norm_stats = torch.load(clim_removed_file_path)

if self.norm_scheme is None:
self.data_mean = TensorDict(
surface=torch.tensor(0),
Expand All @@ -111,36 +130,69 @@ def __init__(
surface=torch.tensor(1),
level=torch.tensor(1),
)

# both mean and std_dev have spatial dimensions
elif self.norm_scheme == "spatial_norm":
self.data_mean = TensorDict(
surface=spatial_norm_stats["surface_mean"],
level=spatial_norm_stats["level_mean"],
surface=torch.stack(
[spatial_norm_stats["surface_mean"][i] for i in surface_variable_indices]
),
level=torch.stack(
[spatial_norm_stats["level_mean"][i] for i in level_variable_indices]
),
)
self.data_std = TensorDict(
surface=spatial_norm_stats["surface_std"],
level=spatial_norm_stats["level_std"],
surface=torch.stack(
[spatial_norm_stats["surface_std"][i] for i in surface_variable_indices]
),
level=torch.stack(
[spatial_norm_stats["level_std"][i] for i in level_variable_indices]
),
)
# only mean is spatial, std_dev is averaged over space
elif self.norm_scheme == "mean_only_spatial_norm":
self.data_mean = TensorDict(
surface=torch.stack(
[spatial_norm_stats["surface_mean"][i] for i in surface_variable_indices]
),
level=torch.stack(
[spatial_norm_stats["level_mean"][i] for i in level_variable_indices]
),
)

self.surface_variables = [
"tas",
"npp",
"nbp",
"gpp",
"cVeg",
"evspsbl",
"mrfso",
"mrro",
"ps",
"tos",
]
self.data_std = TensorDict(
surface=torch.stack(
[spatial_norm_stats["surface_std"][i] for i in surface_variable_indices]
).nanmean(axis=(-1, -2), keepdim=True),
level=torch.stack(
[spatial_norm_stats["level_std"][i] for i in level_variable_indices]
).nanmean(axis=(-1, -2), keepdim=True),
)
# statistics for predicting the anomaly, both mean and std_dev have spatial dimensions
elif self.norm_scheme == "clim_removed":
self.data_mean = TensorDict(
surface=clim_removed_norm_stats["surface_mean"],
level=clim_removed_norm_stats["level_mean"],
)
self.data_mean.batch_size = [
12
] # this is so the tensordict can be indexed by one value for both surface/level
self.data_std = TensorDict(
surface=clim_removed_norm_stats["surface_std"],
level=clim_removed_norm_stats["level_std"],
)
self.data_std.batch_size = [12]

self.surface_variables = surface_variables
self.level_variables = [
a + str(p)
for a in ["hur_", "hus_", "o3_", "ta_", "ua_", "va_", "wap_", "zg_"]
for p in pressure_levels
a + " " + str(p // 100) for a in level_variables for p in pressure_levels
]
self.atmos_forcings = torch.load(f"{forcings_path}/full_atmos_normal.pt")
self.solar_forcings = torch.load(f"{forcings_path}/full_solar_normal.pt")
self.atmos_forcings = torch.load(f"{forcings_path}/cmip_ghg_forcings_ssp245.pt")
self.solar_forcings = torch.tensor(np.load(f"{forcings_path}/solar_forcings_normed.npy"))
times_seconds = [v[2].item() // 10**9 for k, v in self.id2pt.items()]
self.next_timestamp_map = {k: v for k, v in list(zip(times_seconds, times_seconds[1:]))}

# override netcdf functionality
self.timestamps = sorted(self.timestamps, key=lambda x: (x[0], x[1])) # sort by timestamp

def convert_to_tensordict(self, xr_dataset):
"""
Expand All @@ -163,28 +215,27 @@ def __getitem__(self, i, normalize=True):
out = dict()
# load current state
out["state"] = super().__getitem__(i)

out["timestamp"] = torch.tensor(
self.id2pt[i][2].item() // 10**9,
dtype=torch.int32,
dtype=torch.int64,
) # time in seconds
times = pd.to_datetime(out["timestamp"].cpu().numpy() * 10**9).tz_localize(None)
current_year = (
torch.tensor(times.month) + 1970 - 1961
) # plus 1970 for the timestep, -1961 to zero index
current_month = torch.tensor(times.year) % 12

times = pd.to_datetime(out["timestamp"].cpu().numpy(), unit="s").tz_localize(None)
current_month = torch.tensor(times.month) - 1 % 12
if self.indexed_year_forcings:
current_year = torch.tensor(times.year)
current_solar_year = torch.tensor(times.year) - 1961
else:
current_year = torch.tensor(times.year) - 1961
current_solar_year = torch.tensor(times.year) - 1961 # -1961 to zero index
out["forcings"] = torch.concatenate(
[
self.atmos_forcings[:, (current_year * 12) + current_month],
self.solar_forcings[current_year, current_month, :],
self.atmos_forcings[current_year, :],
self.solar_forcings[(current_solar_year * 12) + current_month, :],
]
)
# next obsi. has function of
t = self.lead_time_months # multistep

t = self.lead_time_months # multistep
out["next_state"] = super().__getitem__(i + t // self.timedelta)

# Load multiple future timestamps if specified.
if self.multistep > 1:
future_states = []
Expand All @@ -193,24 +244,67 @@ def __getitem__(self, i, normalize=True):
out["future_states"] = torch.stack(future_states, dim=0)

if self.load_prev:
out["prev_state"] = super().__getitem__(i - self.lead_time_months // self.timedelta)
if self.load_prev > 1:
prev_states = []
for k in range(0, self.load_prev):
prev_states.append(
super().__getitem__(i - (self.lead_time_months * k + 1) // self.timedelta)
)
out["prev_state"] = torch.stack(prev_states, dim=0)
else:
out["prev_state"] = super().__getitem__(
i - self.lead_time_months // self.timedelta
)
prev_timestamp = (
self.id2pt[i - self.lead_time_months // self.timedelta][2].item() // 10**9
)
times = pd.to_datetime(prev_timestamp, unit="s").tz_localize(None)

if normalize:
out = self.normalize(out)
out = self.normalize(out, month=current_month)

# need to replace nans with mask_value
out = {k: replace_nans(v, self.mask_value) if "state" in k else v for k, v in out.items()}
mask = {k: (get_non_nan_mask(v) if "state" in k else v) for k, v in out.items()}

out = {k: apply_nan_to_num(v) if "state" in k else v for k, v in out.items()}
out = {k: (v * mask[k] if "state" in k else v) for k, v in out.items()}
return out

def normalize(self, batch):
def normalize(self, batch, stateless=False, month=None):
device = list(batch.values())[0].device
means = self.data_mean.to(device)
stds = self.data_std.to(device)
out = {k: ((v - means) / stds if "state" in k else v) for k, v in batch.items()}
return out

def denormalize(self, batch):
device = list(batch.values())[0].device
if self.norm_scheme == "clim_removed":
return {
k: ((v - means[month]) / stds[month] if "state" in k else v)
for k, v in batch.items()
}
elif stateless:
return (batch - means) / stds
else:
dict_out = {k: ((v - means) / stds if "state" in k else v) for k, v in batch.items()}
dict_out = {
k: replace_inf_and_large_values(v, 1e35) if "state" in k else v
for k, v in dict_out.items()
}
return dict_out

def denormalize(self, batch, stateless=False, month=None):
if stateless:
device = batch.device
else:
device = list(batch.values())[0].device
means = self.data_mean.to(device)
stds = self.data_std.to(device)
out = {k: (v * stds + means if "state" in k else v) for k, v in batch.items()}
return out
if self.norm_scheme == "clim_removed" and stateless:
return batch * stds[month] + means[month]
elif self.norm_scheme == "clim_removed":
return {
k: (v * stds[month] + means[month] if "state" in k else v)
for k, v in batch.items()
}
elif stateless:
return (batch * stds) + means
else:
return {k: ((v * stds) + means if "state" in k else v) for k, v in batch.items()}
Loading