@@ -74,6 +74,7 @@ def __init__(
7474 data_shape : tuple ,
7575 compute_lat_weights_fn : Callable [[int ], torch .tensor ] = compute_lat_weights_weatherbench ,
7676 compute_per_gridpoint : bool = False ,
77+ compute_per_hemisphere : bool = False ,
7778 ):
7879 """
7980 Args:
@@ -87,6 +88,8 @@ def __init__(
8788 Default function assumes latitudes are ordered -90 to 90.
8889 compute_per_gridpoint: Whether to also compute mse and rmse per gridpoint (along with aggregated globally mse and rmse).
8990 Default: only compute globally aggregated mse and rmse.
91+ compute_per_hemisphere: Whether to also compute mse and rmse per north and sount hemisphere
92+ (along with aggregated globally mse and rmse).
9093 """
9194 Metric .__init__ (self )
9295 MetricBase .__init__ (
@@ -97,6 +100,7 @@ def __init__(
97100 # rollout_iterations=rollout_iterations,
98101 )
99102 self .compute_per_gridpoint = compute_per_gridpoint
103+ self .compute_per_hemisphere = compute_per_hemisphere
100104
101105 # Call `self.add_state`for every internal state that is needed for the metrics computations.
102106 # `dist_reduce_fx` indicates the function that should be used to reduce.
@@ -108,6 +112,11 @@ def __init__(
108112 "rmse_before_time_avg" , default = torch .zeros (data_shape ), dist_reduce_fx = "sum"
109113 )
110114
115+ if self .compute_per_gridpoint :
116+ # Aggregated over north and south hemispheres.
117+ self .add_state ("mse_north" , default = torch .zeros (data_shape ), dist_reduce_fx = "sum" )
118+ self .add_state ("mse_south" , default = torch .zeros (data_shape ), dist_reduce_fx = "sum" )
119+
111120 if self .compute_per_gridpoint :
112121 # Per gridpoint.
113122 self .add_state (
@@ -134,6 +143,12 @@ def update(self, targets: torch.Tensor, preds: torch.Tensor) -> None:
134143 targets , preds
135144 ).sqrt ().sum (0 )
136145
146+ if self .compute_per_hemisphere :
147+ num_lats = targets .shape [- 2 ]
148+ equator_index = num_lats // 2
149+ self .mse_north = self .wmse (targets , preds , lat_range = (0 , equator_index )).sum (0 )
150+ self .mse_south = self .wmse (targets , preds , lat_range = (equator_index , num_lats )).sum (0 )
151+
137152 if self .compute_per_gridpoint :
138153 self .mse_per_gridpt = self .mse_per_gridpt + self .spatial_mse (targets , preds ).sum (0 )
139154
@@ -149,6 +164,16 @@ def compute(self) -> Dict[str, torch.Tensor]:
149164 rmse = (self .mse / self .nsamples ).sqrt (),
150165 )
151166
167+ if self .compute_per_hemisphere :
168+ all_metrics .update (
169+ dict (
170+ mse_north = self .mse_north / self .nsamples ,
171+ rmse_north = (self .mse_north / self .nsamples ).sqrt (),
172+ mse_south = self .mse_south / self .nsamples ,
173+ rmse_south = (self .mse_south / self .nsamples ).sqrt (),
174+ )
175+ )
176+
152177 if self .compute_per_gridpoint :
153178 all_metrics .update (
154179 dict (
@@ -182,6 +207,7 @@ def __init__(
182207 surface_variables = era5 .arches_default_surface_variables ,
183208 level_variables = era5 .arches_default_level_variables ,
184209 pressure_levels = era5 .arches_default_pressure_levels ,
210+ headline_variables = ("Z500" , "T850" , "Q700" , "U850" , "V850" ),
185211 compute_lat_weights_fn : Callable [[int ], torch .tensor ] = compute_lat_weights_weatherbench ,
186212 compute_per_gridpoint : bool = False ,
187213 lead_time_hours : int = 24 ,
@@ -192,6 +218,7 @@ def __init__(
192218 surface_variables: Names of surface variables (to select quantiles).
193219 level_variables: Names of level variables (used to get `variable_indices`).
194220 pressure_levels: pressure levels in data (used to get `variable_indices`).
221+ headline_variables: Short names of level variables to output (used to get 'variable_indices').
195222 level_data_shape: (var, lev) shape for level variables.
196223 num_level_variables: Number of level variables (used to compute data_shape).
197224 compute_per_gridpoint: Whether to also compute mse and rmse per gridpoint (along with aggregated globally mse and rmse).
@@ -228,7 +255,9 @@ def __init__(
228255 compute_per_gridpoint = compute_per_gridpoint ,
229256 ),
230257 variable_indices = add_timedelta_index (
231- era5 .get_headline_level_variable_indices (pressure_levels , level_variables ),
258+ era5 .get_headline_level_variable_indices (
259+ pressure_levels , level_variables , headline_variables
260+ ),
232261 lead_time_hours = lead_time_hours ,
233262 rollout_iterations = rollout_iterations ,
234263 ),
0 commit comments