diff --git a/parcels/_compat.py b/parcels/_compat.py index 6efab15a76..416ed2905b 100644 --- a/parcels/_compat.py +++ b/parcels/_compat.py @@ -17,3 +17,17 @@ from sklearn.cluster import KMeans # type: ignore[no-redef] except ModuleNotFoundError: pass + + +def add_note(e: Exception, note: str, *, before=False) -> Exception: # TODO: Remove once py3.10 support is dropped + """Implements something similar to PEP 678 but for python <3.11. + + https://stackoverflow.com/a/75549200/15545258 + """ + args = e.args + if not args: + arg0 = note + else: + arg0 = f"{note}\n{args[0]}" if before else f"{args[0]}\n{note}" + e.args = (arg0,) + args[1:] + return e diff --git a/parcels/_index_search.py b/parcels/_index_search.py new file mode 100644 index 0000000000..a6f733fea1 --- /dev/null +++ b/parcels/_index_search.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from parcels._typing import ( + GridIndexingType, + InterpMethodOption, +) +from parcels.tools.statuscodes import ( + FieldOutOfBoundError, + FieldOutOfBoundSurfaceError, + _raise_field_out_of_bound_error, + _raise_field_out_of_bound_surface_error, + _raise_field_sampling_error, +) + +from .grid import GridType + +if TYPE_CHECKING: + from .field import Field + from .grid import Grid + + +def search_indices_vertical_z(grid: Grid, gridindexingtype: GridIndexingType, z: float): + if grid.depth[-1] > grid.depth[0]: + if z < grid.depth[0]: + # Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0]) + if gridindexingtype == "mom5" and z > 2 * grid.depth[0] - grid.depth[1]: + return (-1, z / grid.depth[0]) + else: + _raise_field_out_of_bound_surface_error(z, None, None) + elif z > grid.depth[-1]: + # In case of CROCO, allow particles in last (uppermost) layer using depth[-1] + if gridindexingtype in ["croco"] and z < 0: + return (-2, 1) + _raise_field_out_of_bound_error(z, None, None) + depth_indices = grid.depth < z + if z >= grid.depth[-1]: + zi = len(grid.depth) - 2 + else: + zi = depth_indices.argmin() - 1 if z > grid.depth[0] else 0 + else: + if z > grid.depth[0]: + _raise_field_out_of_bound_surface_error(z, None, None) + elif z < grid.depth[-1]: + _raise_field_out_of_bound_error(z, None, None) + depth_indices = grid.depth > z + if z <= grid.depth[-1]: + zi = len(grid.depth) - 2 + else: + zi = depth_indices.argmin() - 1 if z < grid.depth[0] else 0 + zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) + while zeta > 1: + zi += 1 + zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) + while zeta < 0: + zi -= 1 + zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) + return (zi, zeta) + + +def search_indices_vertical_s( + grid: Grid, + interp_method: InterpMethodOption, + time: float, + z: float, + y: float, + x: float, + ti: int, + yi: int, + xi: int, + eta: float, + xsi: float, +): + if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: + xsi = 1 + eta = 1 + if time < grid.time[ti]: + ti -= 1 + if grid._z4d: + if ti == len(grid.time) - 1: + depth_vector = ( + (1 - xsi) * (1 - eta) * grid.depth[-1, :, yi, xi] + + xsi * (1 - eta) * grid.depth[-1, :, yi, xi + 1] + + xsi * eta * grid.depth[-1, :, yi + 1, xi + 1] + + (1 - xsi) * eta * grid.depth[-1, :, yi + 1, xi] + ) + else: + dv2 = ( + (1 - xsi) * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi] + + xsi * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi + 1] + + xsi * eta * grid.depth[ti : ti + 2, :, yi + 1, xi + 1] + + (1 - xsi) * eta * grid.depth[ti : ti + 2, :, yi + 1, xi] + ) + tt = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) + assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time" + depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt + else: + depth_vector = ( + (1 - xsi) * (1 - eta) * grid.depth[:, yi, xi] + + xsi * (1 - eta) * grid.depth[:, yi, xi + 1] + + xsi * eta * grid.depth[:, yi + 1, xi + 1] + + (1 - xsi) * eta * grid.depth[:, yi + 1, xi] + ) + z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 + + if depth_vector[-1] > depth_vector[0]: + if z < depth_vector[0]: + _raise_field_out_of_bound_error(z, None, None) + elif z > depth_vector[-1]: + _raise_field_out_of_bound_error(z, None, None) + depth_indices = depth_vector < z + if z >= depth_vector[-1]: + zi = len(depth_vector) - 2 + else: + zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0 + else: + if z > depth_vector[0]: + _raise_field_out_of_bound_error(z, None, None) + elif z < depth_vector[-1]: + _raise_field_out_of_bound_error(z, None, None) + depth_indices = depth_vector > z + if z <= depth_vector[-1]: + zi = len(depth_vector) - 2 + else: + zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0 + zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) + while zeta > 1: + zi += 1 + zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) + while zeta < 0: + zi -= 1 + zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) + return (zi, zeta) + + +def _search_indices_rectilinear( + field: Field, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False +): + grid = field.grid + + if grid.xdim > 1 and (not grid.zonal_periodic): + if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: + _raise_field_out_of_bound_error(z, y, x) + if grid.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]): + _raise_field_out_of_bound_error(z, y, x) + + if grid.xdim > 1: + if grid.mesh != "spherical": + lon_index = grid.lon < x + if lon_index.all(): + xi = len(grid.lon) - 2 + else: + xi = lon_index.argmin() - 1 if lon_index.any() else 0 + xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) + if xsi < 0: + xi -= 1 + xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) + elif xsi > 1: + xi += 1 + xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) + else: + lon_fixed = grid.lon.copy() + indices = lon_fixed >= lon_fixed[0] + if not indices.all(): + lon_fixed[indices.argmin() :] += 360 + if x < lon_fixed[0]: + lon_fixed -= 360 + + lon_index = lon_fixed < x + if lon_index.all(): + xi = len(lon_fixed) - 2 + else: + xi = lon_index.argmin() - 1 if lon_index.any() else 0 + xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) + if xsi < 0: + xi -= 1 + xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) + elif xsi > 1: + xi += 1 + xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) + else: + xi, xsi = -1, 0 + + if grid.ydim > 1: + lat_index = grid.lat < y + if lat_index.all(): + yi = len(grid.lat) - 2 + else: + yi = lat_index.argmin() - 1 if lat_index.any() else 0 + + eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) + if eta < 0: + yi -= 1 + eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) + elif eta > 1: + yi += 1 + eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) + else: + yi, eta = -1, 0 + + if grid.zdim > 1 and not search2D: + if grid._gtype == GridType.RectilinearZGrid: + try: + (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) + except FieldOutOfBoundError: + _raise_field_out_of_bound_error(z, y, x) + except FieldOutOfBoundSurfaceError: + _raise_field_out_of_bound_surface_error(z, y, x) + elif grid._gtype == GridType.RectilinearSGrid: + (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) + else: + zi, zeta = -1, 0 + + if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): + _raise_field_sampling_error(z, y, x) + + if particle: + particle.xi[field.igrid] = xi + particle.yi[field.igrid] = yi + particle.zi[field.igrid] = zi + + return (zeta, eta, xsi, zi, yi, xi) + + +def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False): + if particle: + xi = particle.xi[field.igrid] + yi = particle.yi[field.igrid] + else: + xi = int(field.grid.xdim / 2) - 1 + yi = int(field.grid.ydim / 2) - 1 + xsi = eta = -1 + grid = field.grid + invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) + maxIterSearch = 1e6 + it = 0 + tol = 1.0e-10 + if not grid.zonal_periodic: + if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: + if grid.lon[0, 0] < grid.lon[0, -1]: + _raise_field_out_of_bound_error(z, y, x) + elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] + _raise_field_out_of_bound_error(z, y, x) + if y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]: + _raise_field_out_of_bound_error(z, y, x) + + while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + if grid.mesh == "spherical": + px[0] = px[0] + 360 if px[0] < x - 225 else px[0] + px[0] = px[0] - 360 if px[0] > x + 225 else px[0] + px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) + px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) + a = np.dot(invA, px) + b = np.dot(invA, py) + + aa = a[3] * b[2] - a[2] * b[3] + bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] + cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] + if abs(aa) < 1e-12: # Rectilinear cell, or quasi + eta = -cc / bb + else: + det2 = bb * bb - 4 * aa * cc + if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter + det = np.sqrt(det2) + eta = (-bb + det) / (2 * aa) + if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg + xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5 + else: + xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta) + if xsi < 0 and eta < 0 and xi == 0 and yi == 0: + _raise_field_out_of_bound_error(0, y, x) + if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1: + _raise_field_out_of_bound_error(0, y, x) + if xsi < -tol: + xi -= 1 + elif xsi > 1 + tol: + xi += 1 + if eta < -tol: + yi -= 1 + elif eta > 1 + tol: + yi += 1 + (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) + it += 1 + if it > maxIterSearch: + print(f"Correct cell not found after {maxIterSearch} iterations") + _raise_field_out_of_bound_error(0, y, x) + xsi = max(0.0, xsi) + eta = max(0.0, eta) + xsi = min(1.0, xsi) + eta = min(1.0, eta) + + if grid.zdim > 1 and not search2D: + if grid._gtype == GridType.CurvilinearZGrid: + try: + (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) + except FieldOutOfBoundError: + _raise_field_out_of_bound_error(z, y, x) + elif grid._gtype == GridType.CurvilinearSGrid: + (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) + else: + zi = -1 + zeta = 0 + + if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): + _raise_field_sampling_error(z, y, x) + + if particle: + particle.xi[field.igrid] = xi + particle.yi[field.igrid] = yi + particle.zi[field.igrid] = zi + + return (zeta, eta, xsi, zi, yi, xi) + + +def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool): + if xi < 0: + if sphere_mesh: + xi = xdim - 2 + else: + xi = 0 + if xi > xdim - 2: + if sphere_mesh: + xi = 0 + else: + xi = xdim - 2 + if yi < 0: + yi = 0 + if yi > ydim - 2: + yi = ydim - 2 + if sphere_mesh: + xi = xdim - xi + return yi, xi diff --git a/parcels/_interpolation.py b/parcels/_interpolation.py new file mode 100644 index 0000000000..a6ebf65950 --- /dev/null +++ b/parcels/_interpolation.py @@ -0,0 +1,280 @@ +from collections.abc import Callable, Mapping +from dataclasses import dataclass + +import numpy as np + +from parcels._typing import GridIndexingType + + +@dataclass +class InterpolationContext2D: + """Information provided by Parcels during 2D spatial interpolation. See Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019 for more info. + + Attributes + ---------- + data: np.ndarray + field data of shape (time, y, x) + eta: float + y-direction interpolation coordinate in unit cube (between 0 and 1) + xsi: float + x-direction interpolation coordinate in unit cube (between 0 and 1) + ti: int + time index + yi: int + y index of cell containing particle + xi: int + x index of cell containing particle + + """ + + data: np.ndarray + eta: float + xsi: float + ti: int + yi: int + xi: int + + +@dataclass +class InterpolationContext3D: + """Information provided by Parcels during 3D spatial interpolation. See Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019 for more info. + + Attributes + ---------- + data: np.ndarray + field data of shape (time, z, y, x). This needs to be complete in the vertical + direction as some interpolation methods need to know whether they are at the + surface or bottom. + zeta: float + vertical interpolation coordinate in unit cube + eta: float + y-direction interpolation coordinate in unit cube + xsi: float + x-direction interpolation coordinate in unit cube + zi: int + z index of cell containing particle + ti: int + time index + yi: int + y index of cell containing particle + xi: int + x index of cell containing particle + gridindexingtype: GridIndexingType + grid indexing type + + """ + + data: np.ndarray + zeta: float + eta: float + xsi: float + ti: int + zi: int + yi: int + xi: int + gridindexingtype: GridIndexingType # included in 3D as z-face is indexed differently with MOM5 and POP + + +_interpolator_registry_2d: dict[str, Callable[[InterpolationContext2D], float]] = {} +_interpolator_registry_3d: dict[str, Callable[[InterpolationContext3D], float]] = {} + + +def get_2d_interpolator_registry() -> Mapping[str, Callable[[InterpolationContext2D], float]]: + # See Discussion on Python Discord for more context (function prevents re-alias of global variable) + # _interpolator_registry_2d etc shouldn't be imported directly + # https://discord.com/channels/267624335836053506/1329136004459794483 + return _interpolator_registry_2d + + +def get_3d_interpolator_registry() -> Mapping[str, Callable[[InterpolationContext3D], float]]: + return _interpolator_registry_3d + + +def register_2d_interpolator(name: str): + def decorator(interpolator: Callable[[InterpolationContext2D], float]): + _interpolator_registry_2d[name] = interpolator + return interpolator + + return decorator + + +def register_3d_interpolator(name: str): + def decorator(interpolator: Callable[[InterpolationContext3D], float]): + _interpolator_registry_3d[name] = interpolator + return interpolator + + return decorator + + +@register_2d_interpolator("nearest") +def _nearest_2d(ctx: InterpolationContext2D) -> float: + xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1 + yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1 + return ctx.data[ctx.ti, yii, xii] + + +def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int, xi: int) -> float: + """Interpolation on a unit square. See Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.""" + return ( + (1 - xsi) * (1 - eta) * data[yi, xi] + + xsi * (1 - eta) * data[yi, xi + 1] + + xsi * eta * data[yi + 1, xi + 1] + + (1 - xsi) * eta * data[yi + 1, xi] + ) + + +@register_2d_interpolator("linear") +@register_2d_interpolator("bgrid_velocity") +@register_2d_interpolator("partialslip") +@register_2d_interpolator("freeslip") +def _linear_2d(ctx: InterpolationContext2D) -> float: + return _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi) + + +@register_2d_interpolator("linear_invdist_land_tracer") +def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float: + xsi = ctx.xsi + eta = ctx.eta + data = ctx.data + yi = ctx.yi + xi = ctx.xi + ti = ctx.ti + land = np.isclose(data[ti, yi : yi + 2, xi : xi + 2], 0.0) + nb_land = np.sum(land) + + if nb_land == 4: + return 0 + elif nb_land > 0: + val = 0 + w_sum = 0 + for j in range(2): + for i in range(2): + distance = pow((eta - j), 2) + pow((xsi - i), 2) + if np.isclose(distance, 0): + if land[j][i] == 1: # index search led us directly onto land + return 0 + else: + return data[ti, yi + j, xi + i] + elif land[j][i] == 0: + val += data[ti, yi + j, xi + i] / distance + w_sum += 1 / distance + return val / w_sum + else: + return _interp_on_unit_square(eta=eta, xsi=xsi, data=data[ti, :, :], yi=yi, xi=xi) + + +@register_2d_interpolator("cgrid_tracer") +@register_2d_interpolator("bgrid_tracer") +def _tracer_2d(ctx: InterpolationContext2D) -> float: + return ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1] + + +@register_3d_interpolator("nearest") +def _nearest_3d(ctx: InterpolationContext3D) -> float: + xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1 + yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1 + zii = ctx.zi if ctx.zeta <= 0.5 else ctx.zi + 1 + return ctx.data[ctx.ti, zii, yii, xii] + + +@register_3d_interpolator("cgrid_velocity") +def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float: + # evaluating W velocity in c_grid + if ctx.gridindexingtype == "nemo": + f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1] + f1 = ctx.data[ctx.ti, ctx.zi + 1, ctx.yi + 1, ctx.xi + 1] + elif ctx.gridindexingtype in ["mitgcm", "croco"]: + f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi, ctx.xi] + f1 = ctx.data[ctx.ti, ctx.zi + 1, ctx.yi, ctx.xi] + return (1 - ctx.zeta) * f0 + ctx.zeta * f1 + + +@register_3d_interpolator("linear_invdist_land_tracer") +def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float: + land = np.isclose(ctx.data[ctx.ti, ctx.zi : ctx.zi + 2, ctx.yi : ctx.yi + 2, ctx.xi : ctx.xi + 2], 0.0) + nb_land = np.sum(land) + if nb_land == 8: + return 0 + elif nb_land > 0: + val = 0 + w_sum = 0 + for k in range(2): + for j in range(2): + for i in range(2): + distance = pow((ctx.zeta - k), 2) + pow((ctx.eta - j), 2) + pow((ctx.xsi - i), 2) + if np.isclose(distance, 0): + if land[k][j][i] == 1: # index search led us directly onto land + return 0 + else: + return ctx.data[ctx.ti, ctx.zi + k, ctx.yi + j, ctx.xi + i] + elif land[k][j][i] == 0: + val += ctx.data[ctx.ti, ctx.zi + k, ctx.yi + j, ctx.xi + i] / distance + w_sum += 1 / distance + return val / w_sum + else: + data = ctx.data[ctx.ti, ctx.zi, :, :] + f0 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=data, yi=ctx.yi, xi=ctx.xi) + + data = ctx.data[ctx.ti, ctx.zi + 1, :, :] + f1 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=data, yi=ctx.yi, xi=ctx.xi) + + return (1 - ctx.zeta) * f0 + ctx.zeta * f1 + + +def _get_3d_f0_f1(*, eta: float, xsi: float, data: np.ndarray, zi: int, yi: int, xi: int) -> tuple[float, float | None]: + data_2d = data[zi, :, :] + f0 = _interp_on_unit_square(eta=eta, xsi=xsi, data=data_2d, yi=yi, xi=xi) + try: + data_2d = data[zi + 1, :, :] + except IndexError: + f1 = None # POP indexing at edge of domain + else: + f1 = _interp_on_unit_square(eta=eta, xsi=xsi, data=data_2d, yi=yi, xi=xi) + + return f0, f1 + + +def _z_layer_interp( + *, zeta: float, f0: float, f1: float | None, zi: int, zdim: int, gridindexingtype: GridIndexingType +): + if gridindexingtype == "pop" and zi >= zdim - 2: + # Since POP is indexed at cell top, allow linear interpolation of W to zero in lowest cell + return (1 - zeta) * f0 + assert f1 is not None, "f1 should not be None for gridindexingtype != 'pop'" + if gridindexingtype == "mom5" and zi == -1: + # Since MOM5 is indexed at cell bottom, allow linear interpolation of W to zero in uppermost cell + return zeta * f1 + return (1 - zeta) * f0 + zeta * f1 + + +@register_3d_interpolator("linear") +@register_3d_interpolator("partialslip") +@register_3d_interpolator("freeslip") +def _linear_3d(ctx: InterpolationContext3D) -> float: + zdim = ctx.data.shape[1] + data_3d = ctx.data[ctx.ti, :, :, :] + f0, f1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi) + + return _z_layer_interp(zeta=ctx.zeta, f0=f0, f1=f1, zi=ctx.zi, zdim=zdim, gridindexingtype=ctx.gridindexingtype) + + +@register_3d_interpolator("bgrid_velocity") +def _linear_3d_bgrid_velocity(ctx: InterpolationContext3D) -> float: + if ctx.gridindexingtype == "mom5": + ctx.zeta = 1.0 + else: + ctx.zeta = 0.0 + return _linear_3d(ctx) + + +@register_3d_interpolator("bgrid_w_velocity") +def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float: + ctx.eta = 1.0 + ctx.xsi = 1.0 + return _linear_3d(ctx) + + +@register_3d_interpolator("bgrid_tracer") +@register_3d_interpolator("cgrid_tracer") +def _tracer_3d(ctx: InterpolationContext3D) -> float: + return ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1] diff --git a/parcels/field.py b/parcels/field.py index a02b53a49e..9cde71efe7 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -11,6 +11,13 @@ import xarray as xr import parcels.tools.interpolation_utils as i_u +from parcels._compat import add_note +from parcels._interpolation import ( + InterpolationContext2D, + InterpolationContext3D, + get_2d_interpolator_registry, + get_3d_interpolator_registry, +) from parcels._typing import ( GridIndexingType, InterpMethod, @@ -22,8 +29,6 @@ ) from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr, timedelta_to_float from parcels.tools.converters import ( - Geographic, - GeographicPolar, TimeConverter, UnitConverter, unitconverters_map, @@ -34,16 +39,18 @@ FieldOutOfBoundSurfaceError, FieldSamplingError, TimeExtrapolationError, + _raise_field_out_of_bound_error, ) from parcels.tools.warnings import FieldSetWarning, _deprecated_param_netcdf_decodewarning +from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear from .fieldfilebuffer import ( DaskFileBuffer, DeferredDaskFileBuffer, DeferredNetcdfFileBuffer, NetcdfFileBuffer, ) -from .grid import CGrid, Grid, GridType +from .grid import CGrid, Grid, GridType, _calc_cell_areas, _calc_cell_edge_sizes if TYPE_CHECKING: from ctypes import _Pointer as PointerType @@ -953,520 +960,79 @@ def set_depth_from_field(self, field): @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def calc_cell_edge_sizes(self): - return self._calc_cell_edge_sizes() - - def _calc_cell_edge_sizes(self): - """Method to calculate cell sizes based on numpy.gradient method. - - Currently only works for Rectilinear Grids - """ - if not self.grid.cell_edge_sizes: - if self.grid._gtype in (GridType.RectilinearZGrid, GridType.RectilinearSGrid): - self.grid.cell_edge_sizes["x"] = np.zeros((self.grid.ydim, self.grid.xdim), dtype=np.float32) - self.grid.cell_edge_sizes["y"] = np.zeros((self.grid.ydim, self.grid.xdim), dtype=np.float32) - - x_conv = GeographicPolar() if self.grid.mesh == "spherical" else UnitConverter() - y_conv = Geographic() if self.grid.mesh == "spherical" else UnitConverter() - for y, (lat, dy) in enumerate(zip(self.grid.lat, np.gradient(self.grid.lat), strict=False)): - for x, (lon, dx) in enumerate(zip(self.grid.lon, np.gradient(self.grid.lon), strict=False)): - self.grid.cell_edge_sizes["x"][y, x] = x_conv.to_source(dx, self.grid.depth[0], lat, lon) - self.grid.cell_edge_sizes["y"][y, x] = y_conv.to_source(dy, self.grid.depth[0], lat, lon) - else: - raise ValueError( - f"Field.cell_edge_sizes() not implemented for {self.grid._gtype} grids. " - "You can provide Field.grid.cell_edge_sizes yourself by in, e.g., " - "NEMO using the e1u fields etc from the mesh_mask.nc file." - ) + _calc_cell_edge_sizes(self.grid) def cell_areas(self): """Method to calculate cell sizes based on cell_edge_sizes. - Currently only works for Rectilinear Grids + Only works for Rectilinear Grids """ - if not self.grid.cell_edge_sizes: - self._calc_cell_edge_sizes() - return self.grid.cell_edge_sizes["x"] * self.grid.cell_edge_sizes["y"] + return _calc_cell_areas(self.grid) @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def search_indices_vertical_z(self, z): - return self._search_indices_vertical_z(z) - - def _search_indices_vertical_z(self, z): - grid = self.grid - z = np.float32(z) - if grid.depth[-1] > grid.depth[0]: - if z < grid.depth[0]: - # Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0]) - if self.gridindexingtype == "mom5" and z > 2 * grid.depth[0] - grid.depth[1]: - return (-1, z / grid.depth[0]) - else: - raise FieldOutOfBoundSurfaceError(z, 0, 0, field=self) - elif z > grid.depth[-1]: - # In case of CROCO, allow particles in last (uppermost) layer using depth[-1] - if self.gridindexingtype in ["croco"] and z < 0: - return (-2, 1) - raise FieldOutOfBoundError(z, 0, 0, field=self) - depth_indices = grid.depth < z - if z >= grid.depth[-1]: - zi = len(grid.depth) - 2 - else: - zi = depth_indices.argmin() - 1 if z > grid.depth[0] else 0 - else: - if z > grid.depth[0]: - raise FieldOutOfBoundSurfaceError(z, 0, 0, field=self) - elif z < grid.depth[-1]: - raise FieldOutOfBoundError(z, 0, 0, field=self) - depth_indices = grid.depth > z - if z <= grid.depth[-1]: - zi = len(grid.depth) - 2 - else: - zi = depth_indices.argmin() - 1 if z < grid.depth[0] else 0 - zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) - while zeta > 1: - zi += 1 - zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) - while zeta < 0: - zi -= 1 - zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) - return (zi, zeta) + def search_indices_vertical_z(self, *_): + raise NotImplementedError @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def search_indices_vertical_s(self, *args, **kwargs): - return self._search_indices_vertical_s(*args, **kwargs) - - def _search_indices_vertical_s( - self, time: float, z: float, y: float, x: float, ti: int, yi: int, xi: int, eta: float, xsi: float - ): - grid = self.grid - if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: - xsi = 1 - eta = 1 - if time < grid.time[ti]: - ti -= 1 - if grid._z4d: - if ti == len(grid.time) - 1: - depth_vector = ( - (1 - xsi) * (1 - eta) * grid.depth[-1, :, yi, xi] - + xsi * (1 - eta) * grid.depth[-1, :, yi, xi + 1] - + xsi * eta * grid.depth[-1, :, yi + 1, xi + 1] - + (1 - xsi) * eta * grid.depth[-1, :, yi + 1, xi] - ) - else: - dv2 = ( - (1 - xsi) * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi] - + xsi * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi + 1] - + xsi * eta * grid.depth[ti : ti + 2, :, yi + 1, xi + 1] - + (1 - xsi) * eta * grid.depth[ti : ti + 2, :, yi + 1, xi] - ) - tt = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) - assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time" - depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt - else: - depth_vector = ( - (1 - xsi) * (1 - eta) * grid.depth[:, yi, xi] - + xsi * (1 - eta) * grid.depth[:, yi, xi + 1] - + xsi * eta * grid.depth[:, yi + 1, xi + 1] - + (1 - xsi) * eta * grid.depth[:, yi + 1, xi] - ) - z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 - - if depth_vector[-1] > depth_vector[0]: - if z < depth_vector[0]: - raise FieldOutOfBoundSurfaceError(z, 0, 0, field=self) - elif z > depth_vector[-1]: - raise FieldOutOfBoundError(z, y, x, field=self) - depth_indices = depth_vector < z - if z >= depth_vector[-1]: - zi = len(depth_vector) - 2 - else: - zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0 - else: - if z > depth_vector[0]: - raise FieldOutOfBoundSurfaceError(z, 0, 0, field=self) - elif z < depth_vector[-1]: - raise FieldOutOfBoundError(z, y, x, field=self) - depth_indices = depth_vector > z - if z <= depth_vector[-1]: - zi = len(depth_vector) - 2 - else: - zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0 - zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) - while zeta > 1: - zi += 1 - zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) - while zeta < 0: - zi -= 1 - zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) - return (zi, zeta) + raise NotImplementedError @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def reconnect_bnd_indices(self, *args, **kwargs): - return self._reconnect_bnd_indices(*args, **kwargs) - - def _reconnect_bnd_indices(self, yi, xi, ydim, xdim, sphere_mesh): - if xi < 0: - if sphere_mesh: - xi = xdim - 2 - else: - xi = 0 - if xi > xdim - 2: - if sphere_mesh: - xi = 0 - else: - xi = xdim - 2 - if yi < 0: - yi = 0 - if yi > ydim - 2: - yi = ydim - 2 - if sphere_mesh: - xi = xdim - xi - return yi, xi + raise NotImplementedError @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def search_indices_rectilinear(self, *args, **kwargs): - return self._search_indices_rectilinear(*args, **kwargs) - - def _search_indices_rectilinear( - self, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False - ): - grid = self.grid - - if grid.xdim > 1 and (not grid.zonal_periodic): - if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: - raise FieldOutOfBoundError(z, y, x, field=self) - if grid.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]): - raise FieldOutOfBoundError(z, y, x, field=self) - - if grid.xdim > 1: - if grid.mesh != "spherical": - lon_index = grid.lon < x - if lon_index.all(): - xi = len(grid.lon) - 2 - else: - xi = lon_index.argmin() - 1 if lon_index.any() else 0 - xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) - if xsi < 0: - xi -= 1 - xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) - elif xsi > 1: - xi += 1 - xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) - else: - lon_fixed = grid.lon.copy() - indices = lon_fixed >= lon_fixed[0] - if not indices.all(): - lon_fixed[indices.argmin() :] += 360 - if x < lon_fixed[0]: - lon_fixed -= 360 - - lon_index = lon_fixed < x - if lon_index.all(): - xi = len(lon_fixed) - 2 - else: - xi = lon_index.argmin() - 1 if lon_index.any() else 0 - xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) - if xsi < 0: - xi -= 1 - xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) - elif xsi > 1: - xi += 1 - xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) - else: - xi, xsi = -1, 0 - - if grid.ydim > 1: - lat_index = grid.lat < y - if lat_index.all(): - yi = len(grid.lat) - 2 - else: - yi = lat_index.argmin() - 1 if lat_index.any() else 0 - - eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) - if eta < 0: - yi -= 1 - eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) - elif eta > 1: - yi += 1 - eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) - else: - yi, eta = -1, 0 - - if grid.zdim > 1 and not search2D: - if grid._gtype == GridType.RectilinearZGrid: - # Never passes here, because in this case, we work with scipy - try: - (zi, zeta) = self._search_indices_vertical_z(z) - except FieldOutOfBoundError: - raise FieldOutOfBoundError(z, y, x, field=self) - except FieldOutOfBoundSurfaceError: - raise FieldOutOfBoundSurfaceError(z, y, x, field=self) - elif grid._gtype == GridType.RectilinearSGrid: - (zi, zeta) = self._search_indices_vertical_s(time, z, y, x, ti, yi, xi, eta, xsi) - else: - zi, zeta = -1, 0 - - if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): - raise FieldSamplingError(z, y, x, field=self) - - if particle: - particle.xi[self.igrid] = xi - particle.yi[self.igrid] = yi - particle.zi[self.igrid] = zi - - return (zeta, eta, xsi, zi, yi, xi) + def search_indices_rectilinear(self, *_): + raise NotImplementedError @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def search_indices_curvilinear(self, *args, **kwargs): - return self._search_indices_curvilinear(*args, **kwargs) - - def _search_indices_curvilinear(self, time, z, y, x, ti=-1, particle=None, search2D=False): - if particle: - xi = particle.xi[self.igrid] - yi = particle.yi[self.igrid] - else: - xi = int(self.grid.xdim / 2) - 1 - yi = int(self.grid.ydim / 2) - 1 - xsi = eta = -1 - grid = self.grid - invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) - maxIterSearch = 1e6 - it = 0 - tol = 1.0e-10 - if not grid.zonal_periodic: - if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: - if grid.lon[0, 0] < grid.lon[0, -1]: - raise FieldOutOfBoundError(z, y, x, field=self) - elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] - raise FieldOutOfBoundError(z, y, x, field=self) - if y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]: - raise FieldOutOfBoundError(z, y, x, field=self) - - while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - if grid.mesh == "spherical": - px[0] = px[0] + 360 if px[0] < x - 225 else px[0] - px[0] = px[0] - 360 if px[0] > x + 225 else px[0] - px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) - px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - a = np.dot(invA, px) - b = np.dot(invA, py) - - aa = a[3] * b[2] - a[2] * b[3] - bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] - cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] - if abs(aa) < 1e-12: # Rectilinear cell, or quasi - eta = -cc / bb - else: - det2 = bb * bb - 4 * aa * cc - if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter - det = np.sqrt(det2) - eta = (-bb + det) / (2 * aa) - if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg - xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5 - else: - xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta) - if xsi < 0 and eta < 0 and xi == 0 and yi == 0: - raise FieldOutOfBoundError(0, y, x, field=self) - if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1: - raise FieldOutOfBoundError(0, y, x, field=self) - if xsi < -tol: - xi -= 1 - elif xsi > 1 + tol: - xi += 1 - if eta < -tol: - yi -= 1 - elif eta > 1 + tol: - yi += 1 - (yi, xi) = self._reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) - it += 1 - if it > maxIterSearch: - print(f"Correct cell not found after {maxIterSearch} iterations") - raise FieldOutOfBoundError(0, y, x, field=self) - xsi = max(0.0, xsi) - eta = max(0.0, eta) - xsi = min(1.0, xsi) - eta = min(1.0, eta) - - if grid.zdim > 1 and not search2D: - if grid._gtype == GridType.CurvilinearZGrid: - try: - (zi, zeta) = self._search_indices_vertical_z(z) - except FieldOutOfBoundError: - raise FieldOutOfBoundError(z, y, x, field=self) - elif grid._gtype == GridType.CurvilinearSGrid: - (zi, zeta) = self._search_indices_vertical_s(time, z, y, x, ti, yi, xi, eta, xsi) - else: - zi = -1 - zeta = 0 - - if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): - raise FieldSamplingError(z, y, x, field=self) - - if particle: - particle.xi[self.igrid] = xi - particle.yi[self.igrid] = yi - particle.zi[self.igrid] = zi - - return (zeta, eta, xsi, zi, yi, xi) + def search_indices_curvilinear(self, *_): + raise NotImplementedError @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def search_indices(self, *args, **kwargs): - return self._search_indices(*args, **kwargs) + def search_indices(self, *_): + raise NotImplementedError def _search_indices(self, time, z, y, x, ti=-1, particle=None, search2D=False): if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: - return self._search_indices_rectilinear(time, z, y, x, ti, particle=particle, search2D=search2D) + return _search_indices_rectilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D) else: - return self._search_indices_curvilinear(time, z, y, x, ti, particle=particle, search2D=search2D) + return _search_indices_curvilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D) @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def interpolator2D(self, *args, **kwargs): - return self._interpolator2D(*args, **kwargs) + def interpolator2D(self, *_): + raise NotImplementedError def _interpolator2D(self, ti, z, y, x, particle=None): + """Impelement 2D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019..""" (_, eta, xsi, _, yi, xi) = self._search_indices(-1, z, y, x, particle=particle) - if self.interp_method == "nearest": - xii = xi if xsi <= 0.5 else xi + 1 - yii = yi if eta <= 0.5 else yi + 1 - return self.data[ti, yii, xii] - elif self.interp_method in ["linear", "bgrid_velocity", "partialslip", "freeslip"]: - val = ( - (1 - xsi) * (1 - eta) * self.data[ti, yi, xi] - + xsi * (1 - eta) * self.data[ti, yi, xi + 1] - + xsi * eta * self.data[ti, yi + 1, xi + 1] - + (1 - xsi) * eta * self.data[ti, yi + 1, xi] - ) - return val - elif self.interp_method == "linear_invdist_land_tracer": - land = np.isclose(self.data[ti, yi : yi + 2, xi : xi + 2], 0.0) - nb_land = np.sum(land) - if nb_land == 4: - return 0 - elif nb_land > 0: - val = 0 - w_sum = 0 - for j in range(2): - for i in range(2): - distance = pow((eta - j), 2) + pow((xsi - i), 2) - if np.isclose(distance, 0): - if land[j][i] == 1: # index search led us directly onto land - return 0 - else: - return self.data[ti, yi + j, xi + i] - elif land[j][i] == 0: - val += self.data[ti, yi + j, xi + i] / distance - w_sum += 1 / distance - return val / w_sum - else: - val = ( - (1 - xsi) * (1 - eta) * self.data[ti, yi, xi] - + xsi * (1 - eta) * self.data[ti, yi, xi + 1] - + xsi * eta * self.data[ti, yi + 1, xi + 1] - + (1 - xsi) * eta * self.data[ti, yi + 1, xi] + ctx = InterpolationContext2D(self.data, eta, xsi, ti, yi, xi) + + try: + f = get_2d_interpolator_registry()[self.interp_method] + except KeyError: + if self.interp_method == "cgrid_velocity": + raise RuntimeError( + f"{self.name} is a scalar field. cgrid_velocity interpolation method should be used for vector fields (e.g. FieldSet.UV)" ) - return val - elif self.interp_method in ["cgrid_tracer", "bgrid_tracer"]: - return self.data[ti, yi + 1, xi + 1] - elif self.interp_method == "cgrid_velocity": - raise RuntimeError( - f"{self.name} is a scalar field. cgrid_velocity interpolation method should be used for vector fields (e.g. FieldSet.UV)" - ) - else: - raise RuntimeError(self.interp_method + " is not implemented for 2D grids") + else: + raise RuntimeError(self.interp_method + " is not implemented for 2D grids") + return f(ctx) @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def interpolator3D(self, *args, **kwargs): - return self._interpolator3D(*args, **kwargs) + def interpolator3D(self, *_): + raise NotImplementedError def _interpolator3D(self, ti, z, y, x, time, particle=None): + """Impelement 3D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019..""" (zeta, eta, xsi, zi, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle) - if self.interp_method == "nearest": - xii = xi if xsi <= 0.5 else xi + 1 - yii = yi if eta <= 0.5 else yi + 1 - zii = zi if zeta <= 0.5 else zi + 1 - return self.data[ti, zii, yii, xii] - elif self.interp_method == "cgrid_velocity": - # evaluating W velocity in c_grid - if self.gridindexingtype == "nemo": - f0 = self.data[ti, zi, yi + 1, xi + 1] - f1 = self.data[ti, zi + 1, yi + 1, xi + 1] - elif self.gridindexingtype in ["mitgcm", "croco"]: - f0 = self.data[ti, zi, yi, xi] - f1 = self.data[ti, zi + 1, yi, xi] - return (1 - zeta) * f0 + zeta * f1 - elif self.interp_method == "linear_invdist_land_tracer": - land = np.isclose(self.data[ti, zi : zi + 2, yi : yi + 2, xi : xi + 2], 0.0) - nb_land = np.sum(land) - if nb_land == 8: - return 0 - elif nb_land > 0: - val = 0 - w_sum = 0 - for k in range(2): - for j in range(2): - for i in range(2): - distance = pow((zeta - k), 2) + pow((eta - j), 2) + pow((xsi - i), 2) - if np.isclose(distance, 0): - if land[k][j][i] == 1: # index search led us directly onto land - return 0 - else: - return self.data[ti, zi + k, yi + j, xi + i] - elif land[k][j][i] == 0: - val += self.data[ti, zi + k, yi + j, xi + i] / distance - w_sum += 1 / distance - return val / w_sum - else: - data = self.data[ti, zi, :, :] - f0 = ( - (1 - xsi) * (1 - eta) * data[yi, xi] - + xsi * (1 - eta) * data[yi, xi + 1] - + xsi * eta * data[yi + 1, xi + 1] - + (1 - xsi) * eta * data[yi + 1, xi] - ) - data = self.data[ti, zi + 1, :, :] - f1 = ( - (1 - xsi) * (1 - eta) * data[yi, xi] - + xsi * (1 - eta) * data[yi, xi + 1] - + xsi * eta * data[yi + 1, xi + 1] - + (1 - xsi) * eta * data[yi + 1, xi] - ) - return (1 - zeta) * f0 + zeta * f1 - elif self.interp_method in ["linear", "bgrid_velocity", "bgrid_w_velocity", "partialslip", "freeslip"]: - if self.interp_method == "bgrid_velocity": - if self.gridindexingtype == "mom5": - zeta = 1.0 - else: - zeta = 0.0 - elif self.interp_method == "bgrid_w_velocity": - eta = 1.0 - xsi = 1.0 - data = self.data[ti, zi, :, :] - f0 = ( - (1 - xsi) * (1 - eta) * data[yi, xi] - + xsi * (1 - eta) * data[yi, xi + 1] - + xsi * eta * data[yi + 1, xi + 1] - + (1 - xsi) * eta * data[yi + 1, xi] - ) - if self.gridindexingtype == "pop" and zi >= self.grid.zdim - 2: - # Since POP is indexed at cell top, allow linear interpolation of W to zero in lowest cell - return (1 - zeta) * f0 - data = self.data[ti, zi + 1, :, :] - f1 = ( - (1 - xsi) * (1 - eta) * data[yi, xi] - + xsi * (1 - eta) * data[yi, xi + 1] - + xsi * eta * data[yi + 1, xi + 1] - + (1 - xsi) * eta * data[yi + 1, xi] - ) - if self.interp_method == "bgrid_w_velocity" and self.gridindexingtype == "mom5" and zi == -1: - # Since MOM5 is indexed at cell bottom, allow linear interpolation of W to zero in uppermost cell - return zeta * f1 - else: - return (1 - zeta) * f0 + zeta * f1 - elif self.interp_method in ["cgrid_tracer", "bgrid_tracer"]: - return self.data[ti, zi, yi + 1, xi + 1] - else: + ctx = InterpolationContext3D(self.data, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype) + + try: + f = get_3d_interpolator_registry()[self.interp_method] + except KeyError: raise RuntimeError(self.interp_method + " is not implemented for 3D grids") + return f(ctx) def temporal_interpolate_fullfield(self, ti, time): """Calculate the data of a field between two snapshots using linear interpolation. @@ -1482,7 +1048,7 @@ def temporal_interpolate_fullfield(self, ti, time): if time == t0: return self.data[ti, :] elif ti + 1 >= len(self.grid.time): - raise TimeExtrapolationError(time, field=self, msg="show_time") + raise TimeExtrapolationError(time, field=self) else: t1 = self.grid.time[ti + 1] f0 = self.data[ti, :] @@ -1495,21 +1061,27 @@ def spatial_interpolation(self, *args, **kwargs): def _spatial_interpolation(self, ti, z, y, x, time, particle=None): """Interpolate horizontal field values using a SciPy interpolator.""" - if self.grid.zdim == 1: - val = self._interpolator2D(ti, z, y, x, particle=particle) - else: - val = self._interpolator3D(ti, z, y, x, time, particle=particle) - if np.isnan(val): - # Detect Out-of-bounds sampling and raise exception - raise FieldOutOfBoundError(z, y, x, field=self) - else: - if isinstance(val, da.core.Array): - val = val.compute() - return val + try: + if self.grid.zdim == 1: + val = self._interpolator2D(ti, z, y, x, particle=particle) + else: + val = self._interpolator3D(ti, z, y, x, time, particle=particle) + + if np.isnan(val): + # Detect Out-of-bounds sampling and raise exception + _raise_field_out_of_bound_error(z, y, x) + else: + if isinstance(val, da.core.Array): + val = val.compute() + return val + + except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: + e = add_note(e, f"Error interpolating field '{self.name}'.", before=True) + raise e @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def time_index(self, *args, **kwargs): - return self._time_index(*args, **kwargs) + def time_index(self, *_): + raise NotImplementedError def _time_index(self, time): """Find the index in the time array associated with a given time. diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index 341790157e..8c3a55495f 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -784,8 +784,7 @@ def _get_initial_chunk_dictionary(self): if len(init_chunk_dict) == 0 and self.chunksize not in [False, None, "auto"]: self.autochunkingfailed = True raise DaskChunkingError( - self.__class__.__name__, - "No correct mapping found between Parcels- and NetCDF dimensions! Please correct the 'FieldSet(..., chunksize={...})' parameter and try again.", + f"[{self.__class__.__name__}]: No correct mapping found between Parcels- and NetCDF dimensions! Please correct the 'FieldSet(..., chunksize=...)' parameter and try again.", ) else: self.autochunkingfailed = False diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 5fa6b13c53..8a2e8ffbe6 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1552,7 +1552,7 @@ def computeTimeChunk(self, time=0.0, dt=1): if f.grid._update_status == "not_updated": nextTime_loc = f.grid._computeTimeChunk(f, time, signdt) if time == nextTime_loc and signdt != 0: - raise TimeExtrapolationError(time, field=f, msg="In fset.computeTimeChunk") + raise TimeExtrapolationError(time, field=f) nextTime = min(nextTime, nextTime_loc) if signdt >= 0 else max(nextTime, nextTime_loc) for f in self.get_fields(): diff --git a/parcels/grid.py b/parcels/grid.py index 38ffdb0a0e..4f211382a9 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -8,7 +8,7 @@ from parcels._typing import Mesh, UpdateStatus, assert_valid_mesh from parcels.tools._helpers import deprecated_made_private -from parcels.tools.converters import TimeConverter +from parcels.tools.converters import Geographic, GeographicPolar, TimeConverter, UnitConverter from parcels.tools.warnings import FieldSetWarning __all__ = [ @@ -821,3 +821,34 @@ def __init__( @property def zdim(self): return self.depth.shape[-3] + + +def _calc_cell_edge_sizes(grid: RectilinearGrid) -> None: + """Method to calculate cell sizes based on numpy.gradient method. + + Currently only works for Rectilinear Grids. Operates in place adding a `cell_edge_sizes` + attribute to the grid. + """ + if not grid.cell_edge_sizes: + if grid._gtype in (GridType.RectilinearZGrid, GridType.RectilinearSGrid): + grid.cell_edge_sizes["x"] = np.zeros((grid.ydim, grid.xdim), dtype=np.float32) + grid.cell_edge_sizes["y"] = np.zeros((grid.ydim, grid.xdim), dtype=np.float32) + + x_conv = GeographicPolar() if grid.mesh == "spherical" else UnitConverter() + y_conv = Geographic() if grid.mesh == "spherical" else UnitConverter() + for y, (lat, dy) in enumerate(zip(grid.lat, np.gradient(grid.lat), strict=False)): + for x, (lon, dx) in enumerate(zip(grid.lon, np.gradient(grid.lon), strict=False)): + grid.cell_edge_sizes["x"][y, x] = x_conv.to_source(dx, grid.depth[0], lat, lon) + grid.cell_edge_sizes["y"][y, x] = y_conv.to_source(dy, grid.depth[0], lat, lon) + else: + raise ValueError( + f"_cell_edge_sizes() not implemented for {grid._gtype} grids. " + "You can provide Field.grid.cell_edge_sizes yourself by in, e.g., " + "NEMO using the e1u fields etc from the mesh_mask.nc file." + ) + + +def _calc_cell_areas(grid: RectilinearGrid) -> np.ndarray: + if not grid.cell_edge_sizes: + _calc_cell_edge_sizes(grid) + return grid.cell_edge_sizes["x"] * grid.cell_edge_sizes["y"] diff --git a/parcels/kernel.py b/parcels/kernel.py index a8f6d7ff90..9b1c5035c4 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -34,11 +34,11 @@ from parcels.tools.global_statics import get_cache_dir from parcels.tools.loggers import logger from parcels.tools.statuscodes import ( - FieldOutOfBoundError, - FieldOutOfBoundSurfaceError, - FieldSamplingError, StatusCode, TimeExtrapolationError, + _raise_field_out_of_bound_error, + _raise_field_out_of_bound_surface_error, + _raise_field_sampling_error, ) from parcels.tools.warnings import KernelWarning @@ -648,11 +648,11 @@ def execute(self, pset, endtime, dt): elif p.state == StatusCode.ErrorTimeExtrapolation: raise TimeExtrapolationError(p.time) elif p.state == StatusCode.ErrorOutOfBounds: - raise FieldOutOfBoundError(p.depth, p.lat, p.lon) + _raise_field_out_of_bound_error(p.depth, p.lat, p.lon) elif p.state == StatusCode.ErrorThroughSurface: - raise FieldOutOfBoundSurfaceError(p.depth, p.lat, p.lon) + _raise_field_out_of_bound_surface_error(p.depth, p.lat, p.lon) elif p.state == StatusCode.Error: - raise FieldSamplingError(p.depth, p.lat, p.lon) + _raise_field_sampling_error(p.depth, p.lat, p.lon) elif p.state == StatusCode.Delete: pass else: diff --git a/parcels/tools/statuscodes.py b/parcels/tools/statuscodes.py index aea22fadf8..437c83e0d6 100644 --- a/parcels/tools/statuscodes.py +++ b/parcels/tools/statuscodes.py @@ -7,6 +7,9 @@ "KernelError", "StatusCode", "TimeExtrapolationError", + "_raise_field_out_of_bound_error", + "_raise_field_out_of_bound_surface_error", + "_raise_field_sampling_error", ] @@ -29,62 +32,54 @@ class StatusCode: class DaskChunkingError(RuntimeError): """Error indicating to the user that something with setting up Dask and chunked fieldsets went wrong.""" - def __init__(self, src_class_type, message): - msg = f"[{src_class_type}]: {message}" - super().__init__(msg) + pass class FieldSamplingError(RuntimeError): """Utility error class to propagate erroneous field sampling.""" - def __init__(self, z, y, x, field=None): - self.field = field - self.x = x - self.y = y - self.z = z - message = f"{field.name if field else 'Field'} sampled at (depth={self.z}, lat={self.y}, lon={self.x})" - super().__init__(message) + pass class FieldOutOfBoundError(RuntimeError): """Utility error class to propagate out-of-bound field sampling.""" - def __init__(self, z, y, x, field=None): - self.field = field - self.x = x - self.y = y - self.z = z - message = ( - f"{field.name if field else 'Field'} sampled out-of-bound, at (depth={self.z}, lat={self.y}, lon={self.x})" - ) - super().__init__(message) + pass class FieldOutOfBoundSurfaceError(RuntimeError): """Utility error class to propagate out-of-bound field sampling at the surface.""" - def __init__(self, z, y, x, field=None): - self.field = field - self.x = x - self.y = y - self.z = z - message = f"{field.name if field else 'Field'} sampled out-of-bound at the surface, at (depth={self.z}, lat={self.y}, lon={self.x})" - super().__init__(message) + pass + + +def _raise_field_sampling_error(z, y, x): + raise FieldSamplingError(f"Field sampled at (depth={z}, lat={y}, lon={x})") + + +def _raise_field_out_of_bound_error(z, y, x): + raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})") + + +def _raise_field_out_of_bound_surface_error(z: float | None, y: float | None, x: float | None) -> None: + def format_out(val): + return "unknown" if val is None else val + + raise FieldOutOfBoundSurfaceError( + f"Field sampled out-of-bound at the surface, at (depth={format_out(z)}, lat={format_out(y)}, lon={format_out(x)})" + ) class TimeExtrapolationError(RuntimeError): """Utility error class to propagate erroneous time extrapolation sampling.""" - def __init__(self, time, field=None, msg="allow_time_extrapoltion"): + def __init__(self, time, field=None): if field is not None and field.grid.time_origin and time is not None: time = field.grid.time_origin.fulltime(time) - message = f"{field.name if field else 'Field'} sampled outside time domain at time {time}." - if msg == "allow_time_extrapoltion": - message += " Try setting allow_time_extrapolation to True" - elif msg == "show_time": - message += " Try explicitly providing a 'show_time'" - else: - message += msg + " Try setting allow_time_extrapolation to True" + message = ( + f"{field.name if field else 'Field'} sampled outside time domain at time {time}." + " Try setting allow_time_extrapolation to True." + ) super().__init__(message) @@ -95,7 +90,7 @@ def __init__(self, particle, fieldset=None, msg=None): message = ( f"{particle.state}\n" f"Particle {particle}\n" - f"Time: {parse_particletime(particle.time, fieldset)}\n" + f"Time: {_parse_particletime(particle.time, fieldset)}\n" f"timestep dt: {particle.dt}\n" ) if msg: @@ -103,7 +98,7 @@ def __init__(self, particle, fieldset=None, msg=None): super().__init__(message) -def parse_particletime(time, fieldset): +def _parse_particletime(time, fieldset): if fieldset is not None and fieldset.time_origin: time = fieldset.time_origin.fulltime(time) return time diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 0000000000..4ad9c9e070 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,14 @@ +import pytest + +from parcels._compat import add_note + + +def test_add_note_and_raise_value_error(): + with pytest.raises(ValueError) as excinfo: + try: + raise ValueError("original message") + except ValueError as e: + e = add_note(e, "additional note") + raise e + assert "additional note" in str(excinfo.value) + assert "original message" in str(excinfo.value) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index f75b85b574..b364582ec2 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -125,17 +125,17 @@ def test_testing_action_class(): Action("Field", "get_dim_filenames()", "make_private" ), Action("Field", "collect_timeslices()", "make_private" ), Action("Field", "reshape()", "make_private" ), - Action("Field", "calc_cell_edge_sizes()", "make_private" ), - Action("Field", "search_indices_vertical_z()", "make_private" ), - Action("Field", "search_indices_vertical_s()", "make_private" ), - Action("Field", "reconnect_bnd_indices()", "make_private" ), - Action("Field", "search_indices_rectilinear()", "make_private" ), - Action("Field", "search_indices_curvilinear()", "make_private" ), - Action("Field", "search_indices()", "make_private" ), - Action("Field", "interpolator2D()", "make_private" ), - Action("Field", "interpolator3D()", "make_private" ), + Action("Field", "calc_cell_edge_sizes()", "make_private" , skip_reason = "Moved to Grid"), + Action("Field", "search_indices_vertical_z()", "make_private" , skip_reason = "Removed implementation"), + Action("Field", "search_indices_vertical_s()", "make_private" , skip_reason = "Removed implementation"), + Action("Field", "reconnect_bnd_indices()", "make_private" , skip_reason = "Moved to grid"), + Action("Field", "search_indices_rectilinear()", "make_private" , skip_reason = "Removed implementation"), + Action("Field", "search_indices_curvilinear()", "make_private" , skip_reason = "Removed implementation"), + Action("Field", "search_indices()", "make_private" , skip_reason = "Removed implementation"), + Action("Field", "interpolator2D()", "make_private" , skip_reason = "Removed implementation"), + Action("Field", "interpolator3D()", "make_private" , skip_reason = "Removed implementation"), Action("Field", "spatial_interpolation()", "make_private" ), - Action("Field", "time_index()", "make_private" ), + Action("Field", "time_index()", "make_private" , skip_reason = "Removed implementation"), Action("Field", "ccode_eval()", "make_private" ), Action("Field", "ccode_convert()", "make_private" ), Action("Field", "get_block_id()", "make_private" ), diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index c48d822e96..adbcb757b8 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -441,24 +441,6 @@ def test_fieldset_samegrids_from_data(): assert fieldset1.U.grid == fieldset1.B.grid -@pytest.mark.parametrize("mesh", ["flat", "spherical"]) -def test_fieldset_celledgesizes(mesh): - data, dimensions = generate_fieldset_data(10, 7) - fieldset = FieldSet.from_data(data, dimensions, mesh=mesh) - - fieldset.U._calc_cell_edge_sizes() - D_meridional = fieldset.U.cell_edge_sizes["y"] - D_zonal = fieldset.U.cell_edge_sizes["x"] - - assert np.allclose( - D_meridional.flatten(), D_meridional[0, 0] - ) # all meridional distances should be the same in either mesh - if mesh == "flat": - assert np.allclose(D_zonal.flatten(), D_zonal[0, 0]) # all zonal distances should be the same in flat mesh - else: - assert all((np.gradient(D_zonal, axis=0) < 0).flatten()) # zonal distances should decrease in spherical mesh - - @pytest.mark.parametrize("dx, dy", [("e1u", "e2u"), ("e1v", "e2v")]) def test_fieldset_celledgesizes_curvilinear(dx, dy): fname = str(TEST_DATA / "mask_nemo_cross_180lon.nc") diff --git a/tests/test_grids.py b/tests/test_grids.py index dfb31aa83f..75caa7ebf1 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -20,6 +20,8 @@ UnitConverter, Variable, ) +from parcels.grid import Grid, _calc_cell_edge_sizes +from parcels.tools.converters import TimeConverter from tests.utils import TEST_DATA ptype = {"scipy": ScipyParticle, "jit": JITParticle} @@ -989,3 +991,27 @@ def VelocityInterpolator(particle, fieldset, time): # pragma: no cover assert np.allclose(pset.Wvel[0], 0, atol=1e-9) else: assert np.allclose(pset.Wvel[0], -w * convfactor) + + +@pytest.mark.parametrize( + "lon, lat", + [ + (np.arange(0.0, 20.0, 1.0), np.arange(0.0, 10.0, 1.0)), + ], +) +@pytest.mark.parametrize("mesh", ["flat", "spherical"]) +def test_grid_celledgesizes(lon, lat, mesh): + grid = Grid.create_grid( + lon=lon, lat=lat, depth=np.array([0]), time=np.array([0]), time_origin=TimeConverter(0), mesh=mesh + ) + + _calc_cell_edge_sizes(grid) + D_meridional = grid.cell_edge_sizes["y"] + D_zonal = grid.cell_edge_sizes["x"] + assert np.allclose( + D_meridional.flatten(), D_meridional[0, 0] + ) # all meridional distances should be the same in either mesh + if mesh == "flat": + assert np.allclose(D_zonal.flatten(), D_zonal[0, 0]) # all zonal distances should be the same in flat mesh + else: + assert all((np.gradient(D_zonal, axis=0) < 0).flatten()) # zonal distances should decrease in spherical mesh diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py new file mode 100644 index 0000000000..93ae14bb31 --- /dev/null +++ b/tests/test_interpolation.py @@ -0,0 +1,180 @@ +import numpy as np +import pytest +import xarray as xr + +import parcels._interpolation as interpolation +from parcels import AdvectionRK4_3D, FieldSet, JITParticle, ParticleSet, ScipyParticle +from tests.utils import create_fieldset_zeros_3d + + +@pytest.fixture +def tmp_interpolator_registry(): + """Resets the interpolator registry after the test. Vital when testing manipulating the registry.""" + old_2d = interpolation._interpolator_registry_2d.copy() + old_3d = interpolation._interpolator_registry_3d.copy() + yield + interpolation._interpolator_registry_2d = old_2d + interpolation._interpolator_registry_3d = old_3d + + +@pytest.mark.usefixtures("tmp_interpolator_registry") +def test_interpolation_registry(): + @interpolation.register_3d_interpolator("test") + @interpolation.register_2d_interpolator("test") + def some_function(): + return "test" + + assert "test" in interpolation.get_2d_interpolator_registry() + assert "test" in interpolation.get_3d_interpolator_registry() + + f = interpolation.get_2d_interpolator_registry()["test"] + g = interpolation.get_3d_interpolator_registry()["test"] + assert f() == g() == "test" + + +def create_interpolation_data(): + """Reference data used for testing interpolation. + + Most interpolation will be focussed around index + (depth, lat, lon) = (zi, yi, xi) = (1, 1, 1) with ti=0. + """ + z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous + [ + [0.0, 1.0, 2.0, 3.0], + [2.0, 3.0, 4.0, 5.0], + [4.0, 5.0, 6.0, 7.0], + [6.0, 7.0, 8.0, 9.0], + ] + ) + spatial_data = [z0, z0 + 3, z0 + 6, z0 + 9] # each z is +3 from the previous + return xr.DataArray([spatial_data, spatial_data, spatial_data], dims=("time", "depth", "lat", "lon")) + + +def create_interpolation_data_random(*, with_land_point: bool) -> xr.Dataset: + tdim, zdim, ydim, xdim = 20, 5, 10, 10 + ds = xr.Dataset( + { + "U": (("time", "depth", "lat", "lon"), np.random.random((tdim, zdim, ydim, xdim)) / 1e3), + "V": (("time", "depth", "lat", "lon"), np.random.random((tdim, zdim, ydim, xdim)) / 1e3), + "W": (("time", "depth", "lat", "lon"), np.random.random((tdim, zdim, ydim, xdim)) / 1e3), + }, + coords={ + "time": np.linspace(0, tdim - 1, tdim), + "depth": np.linspace(0, 1, zdim), + "lat": np.linspace(0, 1, ydim), + "lon": np.linspace(0, 1, xdim), + }, + ) + # Set a land point (for testing freeslip) + if with_land_point: + ds["U"][:, :, 2, 5] = 0.0 + ds["V"][:, :, 2, 5] = 0.0 + ds["W"][:, :, 2, 5] = 0.0 + + return ds + + +@pytest.fixture +def data_2d(): + """2D slice of the reference data at depth=0.""" + return create_interpolation_data().isel(depth=0).values + + +@pytest.mark.parametrize( + "func, eta, xsi, expected", + [ + pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"), + pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"), + pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"), + pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"), + pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"), + ], +) +def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected): + """Test the 2D interpolation functions on the raw arrays. + + Interpolation via the other interpolation methods are tested in `test_scipy_vs_jit`. + """ + ti = 0 + yi, xi = 1, 1 + ctx = interpolation.InterpolationContext2D(data_2d, eta, xsi, ti, yi, xi) + assert func(ctx) == expected + + +@pytest.mark.usefixtures("tmp_interpolator_registry") +def test_interpolator_override(): + fieldset = create_fieldset_zeros_3d() + + @interpolation.register_3d_interpolator("linear") + def test_interpolator(ctx: interpolation.InterpolationContext3D): + raise NotImplementedError + + with pytest.raises(NotImplementedError): + fieldset.U[0, 0.5, 0.5, 0.5] + + +@pytest.mark.usefixtures("tmp_interpolator_registry") +def test_full_depth_provided_to_interpolators(): + """The full depth needs to be provided to the interpolation schemes as some interpolators + need to know whether they are at the surface or bottom of the water column. + + https://github.com/OceanParcels/Parcels/pull/1816#discussion_r1908840408 + """ + xdim, ydim, zdim = 10, 11, 12 + fieldset = create_fieldset_zeros_3d(xdim=xdim, ydim=ydim, zdim=zdim) + + @interpolation.register_3d_interpolator("linear") + def test_interpolator2(ctx: interpolation.InterpolationContext3D): + assert ctx.data.shape[1] == zdim + # The array z dimension is the same as the fieldset z dimension + return 0 + + fieldset.U[0.5, 0.5, 0.5, 0.5] + + +@pytest.mark.parametrize( + "interp_method", + [ + "linear", + "freeslip", + "nearest", + "cgrid_velocity", + ], +) +def test_scipy_vs_jit(interp_method): + """Test that the scipy and JIT versions of the interpolation are the same.""" + variables = {"U": "U", "V": "V", "W": "W"} + dimensions = {"time": "time", "lon": "lon", "lat": "lat", "depth": "depth"} + fieldset = FieldSet.from_xarray_dataset( + create_interpolation_data_random(with_land_point=interp_method == "freeslip"), + variables, + dimensions, + mesh="flat", + ) + + for field in [fieldset.U, fieldset.V, fieldset.W]: # Set a land point (for testing freeslip) + field.interp_method = interp_method + + x, y, z = np.meshgrid(np.linspace(0, 1, 7), np.linspace(0, 1, 13), np.linspace(0, 1, 5)) + + TestP = ScipyParticle.add_variable("pid", dtype=np.int32, initial=0) + pset_scipy = ParticleSet(fieldset, pclass=TestP, lon=x, lat=y, depth=z, pid=np.arange(x.size)) + pset_jit = ParticleSet(fieldset, pclass=JITParticle, lon=x, lat=y, depth=z) + + def DeleteParticle(particle, fieldset, time): + if particle.state >= 50: + particle.delete() + + for pset in [pset_scipy, pset_jit]: + pset.execute([AdvectionRK4_3D, DeleteParticle], runtime=4, dt=1) + + tol = 1e-6 + for i in range(len(pset_scipy)): + # Check that the Scipy and JIT particles are at the same location + assert np.isclose(pset_scipy[i].lon, pset_jit[i].lon, atol=tol) + assert np.isclose(pset_scipy[i].lat, pset_jit[i].lat, atol=tol) + assert np.isclose(pset_scipy[i].depth, pset_jit[i].depth, atol=tol) + # Check that the Scipy and JIT particles have moved + assert not np.isclose(pset_scipy[i].lon, x.flatten()[pset_scipy.pid[i]], atol=tol) + assert not np.isclose(pset_scipy[i].lat, y.flatten()[pset_scipy.pid[i]], atol=tol) + assert not np.isclose(pset_scipy[i].depth, z.flatten()[pset_scipy.pid[i]], atol=tol) diff --git a/tests/utils.py b/tests/utils.py index c20d95ccbe..35e134b21e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ from pathlib import Path import numpy as np +import xarray as xr import parcels from parcels import FieldSet @@ -22,6 +23,27 @@ def create_fieldset_unit_mesh(xdim=20, ydim=20, mesh="flat", transpose=False) -> return FieldSet.from_data(data, dimensions, mesh=mesh, transpose=transpose) +def create_fieldset_zeros_3d(zdim=5, ydim=10, xdim=10): + """3d fieldset with U, V, and W equivalent to longitude, latitude, and depth.""" + tdim = 20 + ds = xr.Dataset( + { + "U": (("time", "depth", "lat", "lon"), np.zeros((tdim, zdim, ydim, xdim))), + "V": (("time", "depth", "lat", "lon"), np.zeros((tdim, zdim, ydim, xdim))), + "W": (("time", "depth", "lat", "lon"), np.zeros((tdim, zdim, ydim, xdim))), + }, + coords={ + "time": np.linspace(0, tdim - 1, tdim), + "depth": np.linspace(0, 1, zdim), + "lat": np.linspace(0, 1, ydim), + "lon": np.linspace(0, 1, xdim), + }, + ) + variables = {"U": "U", "V": "V", "W": "W"} + dimensions = {"time": "time", "lon": "lon", "lat": "lat", "depth": "depth"} + return FieldSet.from_xarray_dataset(ds, variables, dimensions, mesh="flat") + + def create_fieldset_zeros_unit_mesh(xdim=100, ydim=100): """Standard unit mesh fieldset with flat mesh, and zero velocity.""" data = {"U": np.zeros((ydim, xdim), dtype=np.float32), "V": np.zeros((ydim, xdim), dtype=np.float32)}