-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpost_process_global_zarr.py
More file actions
235 lines (178 loc) · 12.8 KB
/
post_process_global_zarr.py
File metadata and controls
235 lines (178 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Script to post-process global simulation data and compute histograms for ARC in a more storage and memory efficient manner
import numpy as np
import os
import xarray as xr
import pandas as pd
import os
import gc
from tqdm import tqdm
from argparse import ArgumentParser
# Parse the arguments
p = ArgumentParser(description="""Global post-process of parcels simulations into zarr format""")
p.add_argument('-startdate', '--startdate', default='2000-01-01', help='Start date for processing (YYYY-MM-DD)')
p.add_argument('-startposition', '--startposition', default='0', help='Start position for processing (number of weeks)')
p.add_argument('-simulationtype', '--simulationtype', default='n', help='Type of simulation run (normal=n, restart =r)')
parsed_args = p.parse_args()
startdate = parsed_args.startdate
startposition = int(parsed_args.startposition)
simtype = parsed_args.simulationtype
# Locations of trajectory data and output folder
data_path = "/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/initial_simulations/"
output_dir = '/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/histograms_global_zarr/'
scaling_factor_file = '/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/plastic_scaling_factors.npz'
# Load plastic scaling factors
scaling_factors = np.load(scaling_factor_file, allow_pickle=True)['scaling_factors'].item()
# Construct the NWS grid for the histograms by shifting the bathymetry grid
bathy = xr.open_dataset("/storage/shared/oceanparcels/output_data/data_Michael/NECCTONsimulations/data/copernicus_data/cmems_mod_glo_phy_my_static_fulldomain.nc")
bathy_NWS = bathy.sel(longitude=slice(-15.1,10.1), latitude=slice(44.9,61.1))
# Construct bin edges and cell centers
lat_bins = (bathy_NWS.latitude.values[:-1] + bathy_NWS.latitude.values[1:]) / 2
lat_centers = bathy_NWS.latitude.values[1:-1]
lon_bins = (bathy_NWS.longitude.values[:-1] + bathy_NWS.longitude.values[1:]) / 2
lon_centers = bathy_NWS.longitude.values[1:-1]
#bathy_NWS.latitude.size, len(lat_centers), len(lat_bins) # Center cell coordinates are bathy_NWS.latitude[1:-1] so lat_bins is 1 index longer
#bathy_NWS.longitude.size, len(lon_centers), len(lon_bins) # Center cell coordinates are bathy_NWS.longitude[1:-1] so lon_bins is 1 index longer
# Create a grid to compute the densitites over
nws_x_edges = lon_bins
nws_y_edges = lat_bins
surface_threshhold = 5 # meters
# Save to file
if not os.path.isfile(output_dir + "nws_grid.npz"):
np.savez(output_dir + "nws_grid.npz", nws_x_edges=nws_x_edges, nws_y_edges=nws_y_edges, lon_centers=lon_centers, lat_centers=lat_centers)
# Load the at dataset we are going to analyse
if simtype == 'n':
sim_ds = xr.open_zarr(os.path.join(data_path, f'particles_{startdate}.zarr'))
elif simtype == 'r':
sim_ds = xr.open_zarr(os.path.join(data_path, f'restart_particles_{startdate}.zarr'))
else:
raise NotImplementedError
starting_day_str = str(sim_ds.isel(obs=0, trajectory=0).time.values.astype('datetime64[D]').astype(str))
# If this is a restart file, the starting_day_str will be when the restart happened, and not when the particles were released, so let's re-align back to the original simulation start date
aligned_start_date = pd.date_range(start=startdate, end='2022-12-31', freq='W-TUE')[0].strftime('%Y-%m-%d')
starting_day_year = int(aligned_start_date[:4])
# NOTE: This will be the same thing as starting_day_str if simtype=='n'
# The first simulation that everything needs to align back to
origin_start_date = pd.to_datetime('2008-01-08')
# Construct a list of "starting times" for each trajectory
global_start_times = (sim_ds.isel(obs=0).time.values - sim_ds.isel(obs=0, trajectory=0).time.values).astype('timedelta64[D]').astype(int)
# Construct a list of trajectory IDS
global_trajectory_id = sim_ds.trajectory.values
# List of plastic sizes we model
plastic_sizes = np.unique(sim_ds.plastic_diameter.values)
# List of release types (0=river, 1=coastal, 2=fisheries)
release_classes = np.unique(sim_ds.release_class.values).astype(int)
# List of trajectories for each plastic size
plastic_class_traj = {}
for p_size in plastic_sizes:
mask = ((sim_ds.plastic_diameter == p_size)).rename("plastic_mask")
plastic_class_traj[p_size] = sim_ds.sel(trajectory=mask).trajectory.values
# List of trajectories for each release type
release_class_traj = {}
for r_class in release_classes:
mask = ((sim_ds.release_class == r_class)).rename("release_mask")
release_class_traj[r_class] = sim_ds.sel(trajectory=mask).trajectory.values
# Compute what the "restart day" should be to align with the original files
restart_date = origin_start_date + pd.Timedelta(days=np.ceil((pd.to_datetime(starting_day_str) - origin_start_date).days / 7)*7)
nday_offset = (restart_date - pd.to_datetime(starting_day_str)).days
# Map the number of days to process into weekly units
ndays_to_process = int(np.floor((sim_ds.obs.size-nday_offset)/7)*7) # NOTE: because we save in chunks of 7 obs, sometimes the last few days are not complete weeks, this will be handled in the restart processing step
# Flag if there are additional days at the end that we have to deal with, especially if we need to handle restart files...
additional_days_f = True if ndays_to_process < sim_ds.obs.size else False
#print(f"Processing: start_day={starting_day_str}. Do we need to handle additional days at the end? {additional_days_f}")
# Compute number of weeks to process for this file
loop_ndays_to_process = range(nday_offset + startposition*7, ndays_to_process, 7)
for obs in tqdm(loop_ndays_to_process):
sim_day = pd.to_datetime(starting_day_str) + pd.Timedelta(days=obs)
sim_day_str = sim_day.strftime("%Y-%m-%d")
if os.path.exists(os.path.join(output_dir, f"nws_{aligned_start_date}_week_{sim_day_str}.zarr")):
continue # Skip already processed weeks
# If we start on a new week, we create a new accumulator
accumulator_histograms = np.zeros((2, 6, lon_centers.size, lat_centers.size), dtype=np.float32) # Surface/depth, plastic size, lon, lat
# Get a list of trajectories that were valid in this obs
valid_trajectory_mask = global_start_times <= obs
valid_trajectory_ids = global_trajectory_id[valid_trajectory_mask]
for release_i, release_class in enumerate(release_classes):
# Get only the trajectories for this release type and plastic class
release_type_trajectory_ids = release_class_traj[release_class]
valid_release_trajectory_ids = np.intersect1d(valid_trajectory_ids, release_type_trajectory_ids, assume_unique=True)
# Get the scaling factor for this year and release type
rc_scaling_factor = scaling_factors[(starting_day_year, release_class)]/6 # Divide by 6 because of the number of plastic classes
# Construct 2 dataarrays to select over - see example here: https://docs.xarray.dev/en/latest/user-guide/interpolation.html#advanced-interpolation
if simtype == 'n':
starttimes_to_use = global_start_times[valid_release_trajectory_ids] # the positional id matches the trajectory id
elif simtype == 'r':
# The positional id no longer matches the trajectory id, and we need to re-align
starttimes_to_use = global_start_times[np.argwhere(np.isin(valid_release_trajectory_ids, global_trajectory_id, assume_unique=True)).flatten()]
else:
raise NotImplementedError
x = xr.DataArray(valid_release_trajectory_ids, dims="traj")
#y = xr.DataArray([obs + i - global_start_times[valid_release_trajectory_ids] for i in range(7)], dims=["localobs","traj"])
y = xr.DataArray([obs + i - starttimes_to_use for i in range(7)], dims=["localobs","traj"])
# The object we want to compute the histogram for!
release_object_ds = sim_ds.sel(trajectory=x, obs=y)
release_object_ds = release_object_ds.set_xindex('trajectory')
release_object_ds.load()
for plastic_i, plastic_size in enumerate(plastic_sizes):
# Get only the trajectories for this plastic size
plastic_size_trajectory_ids = plastic_class_traj[plastic_size]
valid_plastic_trajectory_ids = np.intersect1d(valid_release_trajectory_ids, plastic_size_trajectory_ids, assume_unique=True)
# Construct 2 dataarrays to select over - see example here: https://docs.xarray.dev/en/latest/user-guide/interpolation.html#advanced-interpolation
x = xr.DataArray(valid_plastic_trajectory_ids, dims="traj")
#y = xr.DataArray([obs + i - global_start_times[valid_plastic_trajectory_ids] for i in range(7)], dims=["localobs","traj"])
# The object we want to compute the histogram for!
object_ds = release_object_ds.sel(trajectory=x)#, obs=y)
object_ds.load() # Load into memory for faster processing
# Create list of lons, lats, and weights
nonnan_indices_nf = ~np.isnan(object_ds.lon.values)
nonnan_indices = nonnan_indices_nf.flatten()
lon_values = object_ds.lon.values.flatten()[nonnan_indices]
lon_values = ((lon_values + 180) % 360) - 180
lat_values = object_ds.lat.values.flatten()[nonnan_indices]
plastic_amount_values = np.repeat(object_ds.plastic_amount.values.flatten(), np.sum(nonnan_indices_nf, axis=1))
# Compute the histogram for the watercolumn
H_nws, _, _ = np.histogram2d(lon_values, lat_values,
weights=plastic_amount_values*rc_scaling_factor,
bins=(nws_x_edges, nws_y_edges), density=False)
# Divide by 7 to get an average value over the week
H_nws = H_nws / 7
# Now compute surface only histograms
surface_object_ds_mask = (object_ds.z <= surface_threshhold).rename("surface_mask").compute()
surface_object_ds = object_ds.where(surface_object_ds_mask)
# Because some particles will move into/out of the surface layer during the week, we have to handle the weights accordingly
nonnan_indices = ~np.isnan(surface_object_ds.lon.values.flatten())
lon_values = surface_object_ds.lon.values.flatten()[nonnan_indices]
lon_values = ((lon_values + 180) % 360) - 180
lat_values = surface_object_ds.lat.values.flatten()[nonnan_indices]
plastic_amount_values = surface_object_ds.plastic_amount.values.flatten()[nonnan_indices]
# Compute the histogram for the watercolumn
H_nws_surf, _, _ = np.histogram2d(lon_values, lat_values,
weights=plastic_amount_values*rc_scaling_factor,
bins=(nws_x_edges, nws_y_edges), density=False)
# Divide by 7 to get an average value over the week
H_nws_surf = H_nws_surf / 7
accumulator_histograms[0, plastic_i, :, :] += H_nws
accumulator_histograms[1, plastic_i, :, :] += H_nws_surf
# Now that the accumulators are ready, construct an xarray dataset, and save to file
histogram_ds = xr.Dataset(
{
"plastic_amount": (("watercolumn_surface_flag", "plastic_size", "lon", "lat"), accumulator_histograms, {"units": "kilograms", 'description': 'Average plastic mass per grid cell over the week.'})
},
coords={
"watercolumn_surface_flag": ("watercolumn_surface_flag", ["watercolumn", "surface"], {'description': 'Flag indicating whether the histogram is for the entire water column or just the surface layer (top 5 meters).'}),
"plastic_size": ("plastic_size", plastic_sizes, {'units': 'm', 'description': 'Diameter of the plastic particles.'}),
"lon": ("lon", lon_centers, {'units': 'degrees_east', 'description': 'Longitude bin centers for the histogram.'}),
"lat": ("lat", lat_centers, {'units': 'degrees_north', 'description': 'Latitude bin centers for the histogram.'}),
"time": ("time", [pd.to_datetime(sim_day_str)], {'description': 'Starting day of the week for which the histogram is computed.'})
},
attrs={
"description": "Weekly averaged plastic mass histograms over the NWS region for different plastic sizes.",
"starting_day": aligned_start_date
}
)
# Save the output to zarr
histogram_ds.to_zarr(os.path.join(output_dir, f"nws_{aligned_start_date}_week_{sim_day_str}.zarr"), mode='w')
# Cleanup memory
del histogram_ds
gc.collect()
print(f"Processing for: start_day={aligned_start_date}, has been completed.")
# print(f"Additional days at the end: {additional_days_f} (if True, these have not been handled here.)") #NOTE: See above on how additinal_days_f might not cover the entire last week.