Skip to content

Commit dcee04d

Browse files
committed
plot rmse per year, per gridpoint (spatial)
1 parent 52b6854 commit dcee04d

File tree

2 files changed

+354
-0
lines changed

2 files changed

+354
-0
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import argparse
2+
import os
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import torch
7+
8+
9+
def plot_rmse_by_year(
10+
base_dir,
11+
metric_filename,
12+
save_path,
13+
year_range=(1979, 2018),
14+
metric_keys_left=["rmse_Z500", "rmse_Z850"],
15+
metric_keys_right=["rmse_U850", "rmse_V850"],
16+
ylabel_left=r"RMSE $[m^2/s^2]$",
17+
ylabel_right=r"RMSE $[m/s]$",
18+
ref_year=2020,
19+
force=False,
20+
):
21+
"""
22+
Plots the RMSE for a given range of years with a reference year's RMSE as a dashed line.
23+
24+
Args:
25+
base_dir (str): The base directory containing the yearly data subdirectories.
26+
save_path (str or Path): The path to save the plot. If None, the plot will be shown but not saved.
27+
year_range (tuple): A tuple (start_year, end_year) for the plot's x-axis.
28+
metric_keys_left (list): A list of metric keys (e.g., ['rmse_Z500', 'rmse_Z850']) for the left subplot.
29+
metric_keys_right (list): A list of metric keys (e.g., ['rmse_U850', 'rmse_V850']) for the right subplot.
30+
ref_year (int): The year to use for the dashed reference line.
31+
force (bool): If True, forces saving the plot even if the file already exists.
32+
"""
33+
if Path(save_path).exists() and not force:
34+
print(f"Plot file {save_path} already exists. Use --force to overwrite.")
35+
return
36+
37+
# Generate the list of years to plot
38+
years = list(range(year_range[0], year_range[1] + 1))
39+
40+
# Dictionaries to store the loaded data
41+
data_left = {metric: [] for metric in metric_keys_left}
42+
data_right = {metric: [] for metric in metric_keys_right}
43+
44+
# Load data for the specified year range
45+
for year in years:
46+
file_path = os.path.join(base_dir, str(year), metric_filename)
47+
if not os.path.exists(file_path):
48+
print(f"Warning: File not found for year {year} at {file_path}. Skipping.")
49+
continue
50+
51+
year_data = torch.load(file_path, map_location=torch.device("cpu"), weights_only=False)
52+
53+
# Store the data for the left subplot
54+
for metric in metric_keys_left:
55+
data_left[metric].append(year_data[metric].item())
56+
57+
# Store the data for the right subplot
58+
for metric in metric_keys_right:
59+
data_right[metric].append(year_data[metric].item())
60+
61+
# Load the reference year data separately
62+
ref_file_path = os.path.join(base_dir, str(ref_year), metric_filename)
63+
if not os.path.exists(ref_file_path):
64+
print(f"Error: Reference year data not found at {ref_file_path}.")
65+
return
66+
67+
ref_data = torch.load(ref_file_path, map_location=torch.device("cpu"), weights_only=False)
68+
69+
# Set font family and use LaTeX for consistent plotting style
70+
plt.rc("font", family="serif")
71+
# plt.rc("text", usetex=True)
72+
73+
# Create the plot
74+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 2.5))
75+
76+
# Define the desired x-axis ticks
77+
desired_ticks = [1980, 1990, 2000, 2010]
78+
ax1.set_xticks(desired_ticks)
79+
ax2.set_xticks(desired_ticks)
80+
81+
# --- Left Subplot (Geopotential) ---
82+
# ax1.set_title("Geopotential RMSE")
83+
ax1.set_xlabel("Year")
84+
ax1.set_ylabel(ylabel_left)
85+
ax1.grid(True, linestyle="-", color="lightgray")
86+
87+
# Plot the RMSE lines
88+
colors = plt.cm.tab10.colors
89+
for i, metric in enumerate(metric_keys_left):
90+
if data_left[metric]:
91+
label = metric.split("_")[-1] # e.g., 'rmse_Z500' -> 'Z500'
92+
ax1.plot(years, data_left[metric], label=label, color=colors[i])
93+
94+
# Plot the dashed reference line
95+
ref_value = ref_data[metric].item()
96+
ax1.axhline(y=ref_value, color=colors[i], linestyle=":", linewidth=1.5)
97+
98+
ax1.legend()
99+
100+
# --- Right Subplot (Wind Speed) ---
101+
# ax2.set_title("Wind Speed RMSE")
102+
ax2.set_xlabel("Year")
103+
ax2.set_ylabel(ylabel_right)
104+
ax2.grid(True, linestyle="-", color="lightgray")
105+
106+
# Plot the RMSE lines
107+
for i, metric in enumerate(metric_keys_right):
108+
if data_right[metric]:
109+
label = metric.split("_")[-1] # e.g., 'rmse_Z500' -> 'Z500'
110+
ax2.plot(years, data_right[metric], label=label, color=colors[i])
111+
112+
# Plot the dashed reference line
113+
ref_value = ref_data[metric].item()
114+
ax2.axhline(y=ref_value, color=colors[i], linestyle=":", linewidth=1.5)
115+
116+
ax2.legend()
117+
118+
plt.tight_layout()
119+
120+
plt.savefig(save_path)
121+
print(f"Plot saved to {save_path}")
122+
123+
124+
# Example usage:
125+
# This part assumes a directory structure and some dummy data for demonstration.
126+
# You would replace this with your actual data path.
127+
128+
if __name__ == "__main__":
129+
parser = argparse.ArgumentParser(description="Plot RMSE per year.")
130+
parser.add_argument(
131+
"--base_dir",
132+
type=str,
133+
default="/scratch/resingh/weather/evaluation/era5_pred_archesweather-S/",
134+
help="Base directory containing yearly data subdirectories.",
135+
)
136+
parser.add_argument(
137+
"--metric_filename",
138+
type=str,
139+
default="test-multistep=1-era5_deterministic_metrics_with_spatial_and_hemisphere.pt",
140+
help="Filename of the metric data to load.",
141+
)
142+
parser.add_argument(
143+
"--save_dir",
144+
type=str,
145+
default="plots",
146+
help="Path to save the plot. If None, the plot will be shown but not saved.",
147+
)
148+
parser.add_argument(
149+
"--force",
150+
action="store_true",
151+
help="Force saving plot even if the plot file already exists.",
152+
)
153+
args = parser.parse_args()
154+
155+
# Replace 'dummy_data' with your actual path, e.g., ''
156+
plot_rmse_by_year(
157+
base_dir=args.base_dir,
158+
metric_filename=args.metric_filename,
159+
save_path=Path(args.save_dir) / "rmse_per_year.png",
160+
year_range=(1979, 2018),
161+
metric_keys_left=["rmse_Z500", "rmse_Z850"],
162+
metric_keys_right=["rmse_U850", "rmse_V850"],
163+
ref_year=2020,
164+
force=args.force,
165+
)
166+
plot_rmse_by_year(
167+
base_dir=args.base_dir,
168+
metric_filename=args.metric_filename,
169+
save_path=Path(args.save_dir) / "north_rmse_per_year.png",
170+
year_range=(1979, 2018),
171+
metric_keys_left=["rmse-north_Z500", "rmse-north_Z850"],
172+
metric_keys_right=["rmse-north_U850", "rmse-north_V850"],
173+
ref_year=2020,
174+
force=args.force,
175+
)
176+
plot_rmse_by_year(
177+
base_dir=args.base_dir,
178+
metric_filename=args.metric_filename,
179+
save_path=Path(args.save_dir) / "south_rmse_per_year.png",
180+
year_range=(1979, 2018),
181+
metric_keys_left=["rmse-south_Z500", "rmse-south_Z850"],
182+
metric_keys_right=["rmse-south_U850", "rmse-south_V850"],
183+
ref_year=2020,
184+
force=args.force,
185+
)
186+
# plot_rmse_by_year(
187+
# base_dir=args.base_dir,
188+
# metric_filename=args.metric_filename,
189+
# save_path=Path(args.save_dir) / "all_rmse_per_year.png",
190+
# year_range=(1979, 2018),
191+
# metric_keys_left=["rmse-all_Z500", "rmse-all_Z850"],
192+
# metric_keys_right=["rmse-all_U850", "rmse-all_V850"],
193+
# ref_year=2020,
194+
# force=args.force,
195+
# )
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import cartopy.crs as ccrs
5+
import cartopy.feature as cfeature
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import torch
9+
10+
11+
def plot_spatial_rmse(
12+
base_dirs: list[str] | list[Path],
13+
metric_filename: str | Path,
14+
save_path: str | Path,
15+
metric_key: str = "rmse_per_gridpoint_V850",
16+
titles: list[str] | None = None,
17+
force: bool = False,
18+
cbar_label: str | None = None,
19+
):
20+
"""
21+
Plots a 2D spatial RMSE array as a world map grid using Cartopy.
22+
23+
Args:
24+
spatial_data (torch.Tensor or np.ndarray): A 2D array of spatial RMSE values
25+
with shape (lat, lon).
26+
"""
27+
if Path(save_path).exists() and not force:
28+
print(f"Plot file {save_path} already exists. Use --force to overwrite.")
29+
return
30+
31+
spatial_datas = []
32+
for base_dir in base_dirs:
33+
spatial_data = torch.load(
34+
Path(base_dir) / metric_filename,
35+
map_location=torch.device("cpu"),
36+
weights_only=False,
37+
)
38+
spatial_datas.append(spatial_data[metric_key])
39+
40+
# Determine global min and max for a consistent color scale across all plots
41+
all_data = np.concatenate([data.flatten() for data in spatial_datas])
42+
vmin, vmax = np.min(all_data), np.max(all_data)
43+
44+
# Set font family and use LaTeX for consistent plotting style
45+
plt.rc("font", family="serif")
46+
47+
# Set up the plot with a PlateCarree projection, suitable for global data.
48+
fig = plt.figure(figsize=(4, 2))
49+
num_plots = len(spatial_datas)
50+
fig, axes = plt.subplots(
51+
1, num_plots, figsize=(6 * num_plots, 5), subplot_kw={"projection": ccrs.PlateCarree()}
52+
)
53+
# Ensure axes is an array even for a single subplot
54+
if num_plots == 1:
55+
axes = [axes]
56+
57+
axes[0].set_ylabel("Latitude")
58+
59+
for i, data in enumerate(spatial_datas):
60+
ax = axes[i]
61+
62+
# Add map features
63+
ax.add_feature(cfeature.COASTLINE)
64+
ax.add_feature(cfeature.BORDERS, linestyle=":")
65+
ax.add_feature(cfeature.LAND, edgecolor="black")
66+
ax.add_feature(cfeature.OCEAN)
67+
ax.set_global()
68+
69+
# Convert to numpy if needed
70+
if isinstance(data, torch.Tensor):
71+
data = data.numpy()
72+
73+
# Convert to numpy if needed
74+
if isinstance(data, torch.Tensor):
75+
data = data.numpy()
76+
77+
# num_lat, num_lon = data.shape
78+
# lons = np.arange(0, 360, 360 / num_lon)
79+
# lats = np.linspace(90, -90, num_lat)
80+
81+
# Use imshow to plot the data on top of the map with a normalized color scale
82+
im = ax.imshow(
83+
data,
84+
cmap="plasma",
85+
origin="upper",
86+
extent=[-180, 180, -90, 90],
87+
transform=ccrs.PlateCarree(),
88+
vmin=vmin,
89+
vmax=vmax,
90+
)
91+
92+
# Set plot title and labels
93+
if titles:
94+
ax.set_title(titles[i])
95+
96+
ax.set_xlabel("Longitude")
97+
98+
# Set ticks and gridlines for clarity
99+
ax.set_xticks(np.arange(-180, 181, 60), crs=ccrs.PlateCarree())
100+
ax.set_yticks(np.arange(-90, 91, 30), crs=ccrs.PlateCarree())
101+
ax.grid(True, linestyle="-", color="gray")
102+
103+
# Add a single color bar for the entire figure
104+
# Adjust subplots and color bar for better proportions
105+
plt.tight_layout(rect=[0, 0, 0.9, 1])
106+
# fig.subplots_adjust(right=0.85)
107+
# cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])
108+
109+
cbar_ax = fig.add_axes([0.91, axes[0].get_position().y0, 0.02, axes[0].get_position().height])
110+
cbar = fig.colorbar(im, cax=cbar_ax, orientation="vertical")
111+
if cbar_label:
112+
cbar.set_label(cbar_label)
113+
114+
# Add an overall title to the figure
115+
fig.suptitle(metric_key.split("_")[-1].upper(), y=0.9, fontsize=16)
116+
117+
plt.savefig(save_path)
118+
print(f"Plot saved to {save_path}")
119+
120+
121+
if __name__ == "__main__":
122+
parser = argparse.ArgumentParser(description="Plot RMSE per year.")
123+
parser.add_argument(
124+
"--base_dir",
125+
type=str,
126+
default="/scratch/resingh/weather/evaluation/era5_pred_archesweather-S/",
127+
help="Base directory containing yearly data subdirectories.",
128+
)
129+
parser.add_argument(
130+
"--metric_filename",
131+
type=str,
132+
default="test-multistep=1-era5_deterministic_metrics_with_spatial.pt",
133+
help="Filename of the metric data to load.",
134+
)
135+
parser.add_argument(
136+
"--save_dir",
137+
type=str,
138+
default="plots",
139+
help="Path to save the plot. If None, the plot will be shown but not saved.",
140+
)
141+
parser.add_argument(
142+
"--force",
143+
action="store_true",
144+
help="Force saving plot even if the plot file already exists.",
145+
)
146+
args = parser.parse_args()
147+
148+
for var, cbar_label in zip(
149+
["V850", "U850", "Z500"], ["RMSE $[m/s]$", "RMSE $[m/s]$", "RMSE $[m^2/s^2]$"]
150+
):
151+
plot_spatial_rmse(
152+
base_dirs=[Path(args.base_dir) / "1979_1999", Path(args.base_dir) / "2000_2018"],
153+
metric_filename=args.metric_filename,
154+
metric_key=f"rmse_per_gridpoint_{var}",
155+
save_path=Path(args.save_dir) / f"spatial_rmse_{var}.png",
156+
force=args.force,
157+
titles=["1979-1999", "2000-2018"],
158+
cbar_label=cbar_label,
159+
)

0 commit comments

Comments
 (0)