Skip to content

Commit 8fd6795

Browse files
committed
Option to compute mse over north/south hemisphere
1 parent de0e90e commit 8fd6795

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

geoarches/dataloaders/era5.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,13 @@ def get_level_variable_indices(
100100

101101

102102
def get_headline_level_variable_indices(
103-
pressure_levels=arches_default_pressure_levels, level_variables=arches_default_level_variables
103+
pressure_levels=arches_default_pressure_levels,
104+
level_variables=arches_default_level_variables,
105+
headline_variables=("Z500", "T850", "Q700", "U850", "V850"),
104106
):
105107
"""Mapping for main level variables."""
106108
out = get_level_variable_indices(pressure_levels, level_variables)
107-
return {k: v for k, v in out.items() if k in ("Z500", "T850", "Q700", "U850", "V850")}
109+
return {k: v for k, v in out.items() if k in headline_variables}
108110

109111

110112
class Era5Dataset(XarrayDataset):

geoarches/evaluation/metric_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def instantiate_metric(metric_name: str, **extra_kwargs):
4646
register_metric(
4747
"era5_deterministic_metrics_with_spatial", Era5DeterministicMetrics, compute_per_gridpoint=True
4848
)
49+
register_metric(
50+
"era5_deterministic_metrics_with_spatial_and_hemisphere",
51+
Era5DeterministicMetrics,
52+
compute_per_gridpoint=True,
53+
compute_per_hemisphere=True,
54+
headline_variables=("Z500", "Z850", "T850", "Q700", "U850", "V850"),
55+
)
4956
register_metric(
5057
"era5_ensemble_metrics",
5158
Era5EnsembleMetrics,

geoarches/metrics/deterministic_metrics.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
data_shape: tuple,
7575
compute_lat_weights_fn: Callable[[int], torch.tensor] = compute_lat_weights_weatherbench,
7676
compute_per_gridpoint: bool = False,
77+
compute_per_hemisphere: bool = False,
7778
):
7879
"""
7980
Args:
@@ -87,6 +88,8 @@ def __init__(
8788
Default function assumes latitudes are ordered -90 to 90.
8889
compute_per_gridpoint: Whether to also compute mse and rmse per gridpoint (along with aggregated globally mse and rmse).
8990
Default: only compute globally aggregated mse and rmse.
91+
compute_per_hemisphere: Whether to also compute mse and rmse per north and sount hemisphere
92+
(along with aggregated globally mse and rmse).
9093
"""
9194
Metric.__init__(self)
9295
MetricBase.__init__(
@@ -97,6 +100,7 @@ def __init__(
97100
# rollout_iterations=rollout_iterations,
98101
)
99102
self.compute_per_gridpoint = compute_per_gridpoint
103+
self.compute_per_hemisphere = compute_per_hemisphere
100104

101105
# Call `self.add_state`for every internal state that is needed for the metrics computations.
102106
# `dist_reduce_fx` indicates the function that should be used to reduce.
@@ -108,6 +112,11 @@ def __init__(
108112
"rmse_before_time_avg", default=torch.zeros(data_shape), dist_reduce_fx="sum"
109113
)
110114

115+
if self.compute_per_gridpoint:
116+
# Aggregated over north and south hemispheres.
117+
self.add_state("mse_north", default=torch.zeros(data_shape), dist_reduce_fx="sum")
118+
self.add_state("mse_south", default=torch.zeros(data_shape), dist_reduce_fx="sum")
119+
111120
if self.compute_per_gridpoint:
112121
# Per gridpoint.
113122
self.add_state(
@@ -134,6 +143,12 @@ def update(self, targets: torch.Tensor, preds: torch.Tensor) -> None:
134143
targets, preds
135144
).sqrt().sum(0)
136145

146+
if self.compute_per_hemisphere:
147+
num_lats = targets.shape[-2]
148+
equator_index = num_lats // 2
149+
self.mse_north = self.wmse(targets, preds, lat_range=(0, equator_index)).sum(0)
150+
self.mse_south = self.wmse(targets, preds, lat_range=(equator_index, num_lats)).sum(0)
151+
137152
if self.compute_per_gridpoint:
138153
self.mse_per_gridpt = self.mse_per_gridpt + self.spatial_mse(targets, preds).sum(0)
139154

@@ -149,6 +164,16 @@ def compute(self) -> Dict[str, torch.Tensor]:
149164
rmse=(self.mse / self.nsamples).sqrt(),
150165
)
151166

167+
if self.compute_per_hemisphere:
168+
all_metrics.update(
169+
dict(
170+
mse_north=self.mse_north / self.nsamples,
171+
rmse_north=(self.mse_north / self.nsamples).sqrt(),
172+
mse_south=self.mse_south / self.nsamples,
173+
rmse_south=(self.mse_south / self.nsamples).sqrt(),
174+
)
175+
)
176+
152177
if self.compute_per_gridpoint:
153178
all_metrics.update(
154179
dict(
@@ -182,6 +207,7 @@ def __init__(
182207
surface_variables=era5.arches_default_surface_variables,
183208
level_variables=era5.arches_default_level_variables,
184209
pressure_levels=era5.arches_default_pressure_levels,
210+
headline_variables=("Z500", "T850", "Q700", "U850", "V850"),
185211
compute_lat_weights_fn: Callable[[int], torch.tensor] = compute_lat_weights_weatherbench,
186212
compute_per_gridpoint: bool = False,
187213
lead_time_hours: int = 24,
@@ -192,6 +218,7 @@ def __init__(
192218
surface_variables: Names of surface variables (to select quantiles).
193219
level_variables: Names of level variables (used to get `variable_indices`).
194220
pressure_levels: pressure levels in data (used to get `variable_indices`).
221+
headline_variables: Short names of level variables to output (used to get 'variable_indices').
195222
level_data_shape: (var, lev) shape for level variables.
196223
num_level_variables: Number of level variables (used to compute data_shape).
197224
compute_per_gridpoint: Whether to also compute mse and rmse per gridpoint (along with aggregated globally mse and rmse).
@@ -228,7 +255,9 @@ def __init__(
228255
compute_per_gridpoint=compute_per_gridpoint,
229256
),
230257
variable_indices=add_timedelta_index(
231-
era5.get_headline_level_variable_indices(pressure_levels, level_variables),
258+
era5.get_headline_level_variable_indices(
259+
pressure_levels, level_variables, headline_variables
260+
),
232261
lead_time_hours=lead_time_hours,
233262
rollout_iterations=rollout_iterations,
234263
),

geoarches/metrics/metric_base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,27 @@ def __init__(
6262
super().__init__()
6363
self.compute_lat_weights_fn = compute_lat_weights_fn
6464

65-
def wmse(self, x: torch.Tensor, y: torch.Tensor | int = 0):
65+
def wmse(
66+
self, x: torch.Tensor, y: torch.Tensor | int = 0, lat_range: tuple[int, int] | None = None
67+
):
6668
"""Latitude weighted mse error.
6769
6870
Args:
6971
x: preds with shape (..., lat, lon)
7072
y: targets with shape (..., lat, lon)
73+
lat_range: Optional tuple of (min_lat, max_lat) to restrict the latitude range for the computation.
74+
If None, uses the full latitude range.
7175
"""
7276
lat_coeffs = self.compute_lat_weights_fn(latitude_resolution=x.shape[-2]).to(x.device)
77+
78+
if lat_range is not None:
79+
start_lat, end_lat = lat_range
80+
x = x[..., start_lat:end_lat, :]
81+
lat_coeffs = lat_coeffs[start_lat:end_lat, :]
82+
83+
if not isinstance(y, int):
84+
y = y[..., start_lat:end_lat, :]
85+
7386
return (x - y).pow(2).mul(lat_coeffs).nanmean((-2, -1))
7487

7588
def spatial_mse(self, x: torch.Tensor, y: torch.Tensor | int = 0):

0 commit comments

Comments
 (0)