|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import torch |
| 7 | + |
| 8 | + |
| 9 | +def plot_rmse_by_year( |
| 10 | + base_dir, |
| 11 | + metric_filename, |
| 12 | + save_path, |
| 13 | + year_range=(1979, 2018), |
| 14 | + metric_keys_left=["rmse_Z500", "rmse_Z850"], |
| 15 | + metric_keys_right=["rmse_U850", "rmse_V850"], |
| 16 | + ylabel_left=r"RMSE $[m^2/s^2]$", |
| 17 | + ylabel_right=r"RMSE $[m/s]$", |
| 18 | + ref_year=2020, |
| 19 | + force=False, |
| 20 | +): |
| 21 | + """ |
| 22 | + Plots the RMSE for a given range of years with a reference year's RMSE as a dashed line. |
| 23 | +
|
| 24 | + Args: |
| 25 | + base_dir (str): The base directory containing the yearly data subdirectories. |
| 26 | + save_path (str or Path): The path to save the plot. If None, the plot will be shown but not saved. |
| 27 | + year_range (tuple): A tuple (start_year, end_year) for the plot's x-axis. |
| 28 | + metric_keys_left (list): A list of metric keys (e.g., ['rmse_Z500', 'rmse_Z850']) for the left subplot. |
| 29 | + metric_keys_right (list): A list of metric keys (e.g., ['rmse_U850', 'rmse_V850']) for the right subplot. |
| 30 | + ref_year (int): The year to use for the dashed reference line. |
| 31 | + force (bool): If True, forces saving the plot even if the file already exists. |
| 32 | + """ |
| 33 | + if Path(save_path).exists() and not force: |
| 34 | + print(f"Plot file {save_path} already exists. Use --force to overwrite.") |
| 35 | + return |
| 36 | + |
| 37 | + # Generate the list of years to plot |
| 38 | + years = list(range(year_range[0], year_range[1] + 1)) |
| 39 | + |
| 40 | + # Dictionaries to store the loaded data |
| 41 | + data_left = {metric: [] for metric in metric_keys_left} |
| 42 | + data_right = {metric: [] for metric in metric_keys_right} |
| 43 | + |
| 44 | + # Load data for the specified year range |
| 45 | + for year in years: |
| 46 | + file_path = os.path.join(base_dir, str(year), metric_filename) |
| 47 | + if not os.path.exists(file_path): |
| 48 | + print(f"Warning: File not found for year {year} at {file_path}. Skipping.") |
| 49 | + continue |
| 50 | + |
| 51 | + year_data = torch.load(file_path, map_location=torch.device("cpu"), weights_only=False) |
| 52 | + |
| 53 | + # Store the data for the left subplot |
| 54 | + for metric in metric_keys_left: |
| 55 | + data_left[metric].append(year_data[metric].item()) |
| 56 | + |
| 57 | + # Store the data for the right subplot |
| 58 | + for metric in metric_keys_right: |
| 59 | + data_right[metric].append(year_data[metric].item()) |
| 60 | + |
| 61 | + # Load the reference year data separately |
| 62 | + ref_file_path = os.path.join(base_dir, str(ref_year), metric_filename) |
| 63 | + if not os.path.exists(ref_file_path): |
| 64 | + print(f"Error: Reference year data not found at {ref_file_path}.") |
| 65 | + return |
| 66 | + |
| 67 | + ref_data = torch.load(ref_file_path, map_location=torch.device("cpu"), weights_only=False) |
| 68 | + |
| 69 | + # Set font family and use LaTeX for consistent plotting style |
| 70 | + plt.rc("font", family="serif") |
| 71 | + # plt.rc("text", usetex=True) |
| 72 | + |
| 73 | + # Create the plot |
| 74 | + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 2.5)) |
| 75 | + |
| 76 | + # Define the desired x-axis ticks |
| 77 | + desired_ticks = [1980, 1990, 2000, 2010] |
| 78 | + ax1.set_xticks(desired_ticks) |
| 79 | + ax2.set_xticks(desired_ticks) |
| 80 | + |
| 81 | + # --- Left Subplot (Geopotential) --- |
| 82 | + # ax1.set_title("Geopotential RMSE") |
| 83 | + ax1.set_xlabel("Year") |
| 84 | + ax1.set_ylabel(ylabel_left) |
| 85 | + ax1.grid(True, linestyle="-", color="lightgray") |
| 86 | + |
| 87 | + # Plot the RMSE lines |
| 88 | + colors = plt.cm.tab10.colors |
| 89 | + for i, metric in enumerate(metric_keys_left): |
| 90 | + if data_left[metric]: |
| 91 | + label = metric.split("_")[-1] # e.g., 'rmse_Z500' -> 'Z500' |
| 92 | + ax1.plot(years, data_left[metric], label=label, color=colors[i]) |
| 93 | + |
| 94 | + # Plot the dashed reference line |
| 95 | + ref_value = ref_data[metric].item() |
| 96 | + ax1.axhline(y=ref_value, color=colors[i], linestyle=":", linewidth=1.5) |
| 97 | + |
| 98 | + ax1.legend() |
| 99 | + |
| 100 | + # --- Right Subplot (Wind Speed) --- |
| 101 | + # ax2.set_title("Wind Speed RMSE") |
| 102 | + ax2.set_xlabel("Year") |
| 103 | + ax2.set_ylabel(ylabel_right) |
| 104 | + ax2.grid(True, linestyle="-", color="lightgray") |
| 105 | + |
| 106 | + # Plot the RMSE lines |
| 107 | + for i, metric in enumerate(metric_keys_right): |
| 108 | + if data_right[metric]: |
| 109 | + label = metric.split("_")[-1] # e.g., 'rmse_Z500' -> 'Z500' |
| 110 | + ax2.plot(years, data_right[metric], label=label, color=colors[i]) |
| 111 | + |
| 112 | + # Plot the dashed reference line |
| 113 | + ref_value = ref_data[metric].item() |
| 114 | + ax2.axhline(y=ref_value, color=colors[i], linestyle=":", linewidth=1.5) |
| 115 | + |
| 116 | + ax2.legend() |
| 117 | + |
| 118 | + plt.tight_layout() |
| 119 | + |
| 120 | + plt.savefig(save_path) |
| 121 | + print(f"Plot saved to {save_path}") |
| 122 | + |
| 123 | + |
| 124 | +# Example usage: |
| 125 | +# This part assumes a directory structure and some dummy data for demonstration. |
| 126 | +# You would replace this with your actual data path. |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + parser = argparse.ArgumentParser(description="Plot RMSE per year.") |
| 130 | + parser.add_argument( |
| 131 | + "--base_dir", |
| 132 | + type=str, |
| 133 | + default="/scratch/resingh/weather/evaluation/era5_pred_archesweather-S/", |
| 134 | + help="Base directory containing yearly data subdirectories.", |
| 135 | + ) |
| 136 | + parser.add_argument( |
| 137 | + "--metric_filename", |
| 138 | + type=str, |
| 139 | + default="test-multistep=1-era5_deterministic_metrics_with_spatial_and_hemisphere.pt", |
| 140 | + help="Filename of the metric data to load.", |
| 141 | + ) |
| 142 | + parser.add_argument( |
| 143 | + "--save_dir", |
| 144 | + type=str, |
| 145 | + default="plots", |
| 146 | + help="Path to save the plot. If None, the plot will be shown but not saved.", |
| 147 | + ) |
| 148 | + parser.add_argument( |
| 149 | + "--force", |
| 150 | + action="store_true", |
| 151 | + help="Force saving plot even if the plot file already exists.", |
| 152 | + ) |
| 153 | + args = parser.parse_args() |
| 154 | + |
| 155 | + # Replace 'dummy_data' with your actual path, e.g., '' |
| 156 | + plot_rmse_by_year( |
| 157 | + base_dir=args.base_dir, |
| 158 | + metric_filename=args.metric_filename, |
| 159 | + save_path=Path(args.save_dir) / "rmse_per_year.png", |
| 160 | + year_range=(1979, 2018), |
| 161 | + metric_keys_left=["rmse_Z500", "rmse_Z850"], |
| 162 | + metric_keys_right=["rmse_U850", "rmse_V850"], |
| 163 | + ref_year=2020, |
| 164 | + force=args.force, |
| 165 | + ) |
| 166 | + plot_rmse_by_year( |
| 167 | + base_dir=args.base_dir, |
| 168 | + metric_filename=args.metric_filename, |
| 169 | + save_path=Path(args.save_dir) / "north_rmse_per_year.png", |
| 170 | + year_range=(1979, 2018), |
| 171 | + metric_keys_left=["rmse-north_Z500", "rmse-north_Z850"], |
| 172 | + metric_keys_right=["rmse-north_U850", "rmse-north_V850"], |
| 173 | + ref_year=2020, |
| 174 | + force=args.force, |
| 175 | + ) |
| 176 | + plot_rmse_by_year( |
| 177 | + base_dir=args.base_dir, |
| 178 | + metric_filename=args.metric_filename, |
| 179 | + save_path=Path(args.save_dir) / "south_rmse_per_year.png", |
| 180 | + year_range=(1979, 2018), |
| 181 | + metric_keys_left=["rmse-south_Z500", "rmse-south_Z850"], |
| 182 | + metric_keys_right=["rmse-south_U850", "rmse-south_V850"], |
| 183 | + ref_year=2020, |
| 184 | + force=args.force, |
| 185 | + ) |
| 186 | + # plot_rmse_by_year( |
| 187 | + # base_dir=args.base_dir, |
| 188 | + # metric_filename=args.metric_filename, |
| 189 | + # save_path=Path(args.save_dir) / "all_rmse_per_year.png", |
| 190 | + # year_range=(1979, 2018), |
| 191 | + # metric_keys_left=["rmse-all_Z500", "rmse-all_Z850"], |
| 192 | + # metric_keys_right=["rmse-all_U850", "rmse-all_V850"], |
| 193 | + # ref_year=2020, |
| 194 | + # force=args.force, |
| 195 | + # ) |
0 commit comments