diff --git a/pyfv3/stencils/a2b_ord4.py b/pyfv3/stencils/a2b_ord4.py index a9551046..0c9ff1f5 100644 --- a/pyfv3/stencils/a2b_ord4.py +++ b/pyfv3/stencils/a2b_ord4.py @@ -5,7 +5,7 @@ from ndsl.dsl.gt4py import horizontal, interval, region, sin, sqrt from ndsl.dsl.typing import Float, FloatField, FloatFieldI, FloatFieldIJ from ndsl.grid import GridData -from ndsl.stencils.basic_operations import copy +from ndsl.stencils import copy # compact 4-pt cubic interpolation diff --git a/pyfv3/stencils/del2cubed.py b/pyfv3/stencils/del2cubed.py index f06476f1..e74ab5c1 100644 --- a/pyfv3/stencils/del2cubed.py +++ b/pyfv3/stencils/del2cubed.py @@ -6,7 +6,7 @@ from ndsl.dsl.stencil import get_stencils_with_varied_bounds from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, cast_to_index3d from ndsl.grid import DampingCoefficients -from ndsl.stencils.basic_operations import copy +from ndsl.stencils import copy from pyfv3.stencils.copy_corners import CopyCornersX, CopyCornersY diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index 720535f3..8badfa46 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -18,7 +18,7 @@ from ndsl.grid import DampingCoefficients, GridData from ndsl.logging import ndsl_log from ndsl.performance import Timer -from ndsl.stencils.basic_operations import copy +from ndsl.stencils import copy from ndsl.stencils.c2l_ord import CubedToLatLon from ndsl.typing import Checkpointer, Communicator from pyfv3._config import DynamicalCoreConfig diff --git a/pyfv3/stencils/fv_subgridz.py b/pyfv3/stencils/fv_subgridz.py index 5f1f8b36..c81953d2 100644 --- a/pyfv3/stencils/fv_subgridz.py +++ b/pyfv3/stencils/fv_subgridz.py @@ -21,7 +21,7 @@ from ndsl.dsl.gt4py import function as gtfunction from ndsl.dsl.gt4py import interval from ndsl.dsl.typing import Float, FloatField -from ndsl.stencils.basic_operations import dim +from ndsl.stencils.arithmetical_functions import dim from pyfv3.dycore_state import DycoreState @@ -42,18 +42,8 @@ def standard_cm(cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel): q_liq = q0_liquid + q0_rain q_sol = q0_ice + q0_snow + q0_graupel - cpm = ( - (1.0 - (q0_vapor + q_liq + q_sol)) * CP_AIR - + q0_vapor * CP_VAP - + q_liq * C_LIQ - + q_sol * C_ICE - ) - cvm = ( - (1.0 - (q0_vapor + q_liq + q_sol)) * CV_AIR - + q0_vapor * CV_VAP - + q_liq * C_LIQ - + q_sol * C_ICE - ) + cpm = (1.0 - (q0_vapor + q_liq + q_sol)) * CP_AIR + q0_vapor * CP_VAP + q_liq * C_LIQ + q_sol * C_ICE + cvm = (1.0 - (q0_vapor + q_liq + q_sol)) * CV_AIR + q0_vapor * CV_VAP + q_liq * C_LIQ + q_sol * C_ICE return cpm, cvm @@ -114,9 +104,7 @@ def init( gzh = 0.0 with computation(BACKWARD), interval(0, -1): # note only for nwat = 6 - cpm, cvm = standard_cm( - cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel - ) + cpm, cvm = standard_cm(cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel) gz = gzh[0, 0, 1] - G2 * delz tmp = tvol(gz, u0, v0, w0) static_energy = cpm * t0 + tmp @@ -150,9 +138,7 @@ def adjust_cvm( """ Non-hydrostatic under constant volume heating/cooling """ - cpm, cvm = standard_cm( - cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel - ) + cpm, cvm = standard_cm(cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel) tv = tvol(gz, u0, v0, w0) t0 = (total_energy - tv) / cvm static_energy = cpm * t0 + tv @@ -160,30 +146,17 @@ def adjust_cvm( @gtfunction -def compute_richardson_number( - t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min -): +def compute_richardson_number(t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min): tv1 = t0[0, 0, -1] * (1.0 + xvir * q0_vapor[0, 0, -1] - qcon[0, 0, -1]) tv2 = t0 * (1.0 + xvir * q0_vapor - qcon) pt1 = tv1 / pkz[0, 0, -1] pt2 = tv2 / pkz - ri = ( - (gz[0, 0, -1] - gz) - * (pt1 - pt2) - / ( - 0.5 - * (pt1 + pt2) - * ((u0[0, 0, -1] - u0) ** 2 + (v0[0, 0, -1] - v0) ** 2 + USTAR2) - ) - ) + ri = (gz[0, 0, -1] - gz) * (pt1 - pt2) / (0.5 * (pt1 + pt2) * ((u0[0, 0, -1] - u0) ** 2 + (v0[0, 0, -1] - v0) ** 2 + USTAR2)) if tv1 > t_max and tv1 > tv2: ri = 0 elif tv2 < t_min: ri = ri if ri < 0.1 else 0.1 - ri_ref = ( - RI_MIN - + (RI_MAX - RI_MIN) * dim(400.0e2, delp / (peln[0, 0, 1] - peln)) / 200.0e2 - ) + ri_ref = RI_MIN + (RI_MAX - RI_MIN) * dim(400.0e2, delp / (peln[0, 0, 1] - peln)) / 200.0e2 if RI_MAX < ri_ref: ri_ref = RI_MAX return ri, ri_ref @@ -196,13 +169,7 @@ def compute_mass_flux(ri, ri_ref, delp, ratio): if max_ri_ratio < 0.0: max_ri_ratio = 0.0 if ri < ri_ref: - mc = ( - ratio - * delp[0, 0, -1] - * delp - / (delp[0, 0, -1] + delp) - * (1.0 - max_ri_ratio) ** 2.0 - ) + mc = ratio * delp[0, 0, -1] * delp / (delp[0, 0, -1] + delp) * (1.0 - max_ri_ratio) ** 2.0 return mc @@ -269,9 +236,7 @@ def m_loop( ref = 0.0 with computation(BACKWARD): with interval(-1, None): - ri, ri_ref = compute_richardson_number( - t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min - ) + ri, ri_ref = compute_richardson_number(t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min) mc = compute_mass_flux(ri, ri_ref, delp, ratio) if ri < ri_ref: # TODO: loop over tracers not hardcoded @@ -282,20 +247,14 @@ def m_loop( q0_rain, h0_rain = kh_adjust_down(mc, delp, q0_rain, h0_rain) q0_ice, h0_ice = kh_adjust_down(mc, delp, q0_ice, h0_ice) q0_snow, h0_snow = kh_adjust_down(mc, delp, q0_snow, h0_snow) - q0_graupel, h0_graupel = kh_adjust_down( - mc, delp, q0_graupel, h0_graupel - ) + q0_graupel, h0_graupel = kh_adjust_down(mc, delp, q0_graupel, h0_graupel) q0_o3mr, h0_o3mr = kh_adjust_down(mc, delp, q0_o3mr, h0_o3mr) - q0_sgs_tke, h0_sgs_tke = kh_adjust_down( - mc, delp, q0_sgs_tke, h0_sgs_tke - ) + q0_sgs_tke, h0_sgs_tke = kh_adjust_down(mc, delp, q0_sgs_tke, h0_sgs_tke) q0_cld, h0_cld = kh_adjust_down(mc, delp, q0_cld, h0_cld) u0, h0_u = kh_adjust_down(mc, delp, u0, h0_u) v0, h0_v = kh_adjust_down(mc, delp, v0, h0_v) w0, h0_w = kh_adjust_down(mc, delp, w0, h0_w) - total_energy, h0_total_energy = kh_adjust_energy_down( - mc, delp, static_energy, total_energy, h0_total_energy - ) + total_energy, h0_total_energy = kh_adjust_energy_down(mc, delp, static_energy, total_energy, h0_total_energy) cpm, cvm, t0, static_energy = adjust_cvm( cpm, cvm, @@ -346,9 +305,7 @@ def m_loop( total_energy, static_energy, ) - ri, ri_ref = compute_richardson_number( - t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min - ) + ri, ri_ref = compute_richardson_number(t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min) mc = compute_mass_flux(ri, ri_ref, delp, ratio) if ri < ri_ref: q0_vapor, h0_vapor = kh_adjust_down(mc, delp, q0_vapor, h0_vapor) @@ -356,20 +313,14 @@ def m_loop( q0_rain, h0_rain = kh_adjust_down(mc, delp, q0_rain, h0_rain) q0_ice, h0_ice = kh_adjust_down(mc, delp, q0_ice, h0_ice) q0_snow, h0_snow = kh_adjust_down(mc, delp, q0_snow, h0_snow) - q0_graupel, h0_graupel = kh_adjust_down( - mc, delp, q0_graupel, h0_graupel - ) + q0_graupel, h0_graupel = kh_adjust_down(mc, delp, q0_graupel, h0_graupel) q0_o3mr, h0_o3mr = kh_adjust_down(mc, delp, q0_o3mr, h0_o3mr) - q0_sgs_tke, h0_sgs_tke = kh_adjust_down( - mc, delp, q0_sgs_tke, h0_sgs_tke - ) + q0_sgs_tke, h0_sgs_tke = kh_adjust_down(mc, delp, q0_sgs_tke, h0_sgs_tke) q0_cld, h0_cld = kh_adjust_down(mc, delp, q0_cld, h0_cld) u0, h0_u = kh_adjust_down(mc, delp, u0, h0_u) v0, h0_v = kh_adjust_down(mc, delp, v0, h0_v) w0, h0_w = kh_adjust_down(mc, delp, w0, h0_w) - total_energy, h0_total_energy = kh_adjust_energy_down( - mc, delp, static_energy, total_energy, h0_total_energy - ) + total_energy, h0_total_energy = kh_adjust_energy_down(mc, delp, static_energy, total_energy, h0_total_energy) cpm, cvm, t0, static_energy = adjust_cvm( cpm, cvm, @@ -424,9 +375,7 @@ def m_loop( total_energy, static_energy, ) - ri, ri_ref = compute_richardson_number( - t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min - ) + ri, ri_ref = compute_richardson_number(t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min) # TODO, can we just check if index(K) == 3? ri_ref = ri_ref * 1.5 mc = compute_mass_flux(ri, ri_ref, delp, ratio) @@ -436,20 +385,14 @@ def m_loop( q0_rain, h0_rain = kh_adjust_down(mc, delp, q0_rain, h0_rain) q0_ice, h0_ice = kh_adjust_down(mc, delp, q0_ice, h0_ice) q0_snow, h0_snow = kh_adjust_down(mc, delp, q0_snow, h0_snow) - q0_graupel, h0_graupel = kh_adjust_down( - mc, delp, q0_graupel, h0_graupel - ) + q0_graupel, h0_graupel = kh_adjust_down(mc, delp, q0_graupel, h0_graupel) q0_o3mr, h0_o3mr = kh_adjust_down(mc, delp, q0_o3mr, h0_o3mr) - q0_sgs_tke, h0_sgs_tke = kh_adjust_down( - mc, delp, q0_sgs_tke, h0_sgs_tke - ) + q0_sgs_tke, h0_sgs_tke = kh_adjust_down(mc, delp, q0_sgs_tke, h0_sgs_tke) q0_cld, h0_cld = kh_adjust_down(mc, delp, q0_cld, h0_cld) u0, h0_u = kh_adjust_down(mc, delp, u0, h0_u) v0, h0_v = kh_adjust_down(mc, delp, v0, h0_v) w0, h0_w = kh_adjust_down(mc, delp, w0, h0_w) - total_energy, h0_total_energy = kh_adjust_energy_down( - mc, delp, static_energy, total_energy, h0_total_energy - ) + total_energy, h0_total_energy = kh_adjust_energy_down(mc, delp, static_energy, total_energy, h0_total_energy) cpm, cvm, t0, static_energy = adjust_cvm( cpm, cvm, @@ -501,9 +444,7 @@ def m_loop( total_energy, static_energy, ) - ri, ri_ref = compute_richardson_number( - t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min - ) + ri, ri_ref = compute_richardson_number(t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min) ri_ref = ri_ref * 2.0 mc = compute_mass_flux(ri, ri_ref, delp, ratio) if ri < ri_ref: @@ -512,20 +453,14 @@ def m_loop( q0_rain, h0_rain = kh_adjust_down(mc, delp, q0_rain, h0_rain) q0_ice, h0_ice = kh_adjust_down(mc, delp, q0_ice, h0_ice) q0_snow, h0_snow = kh_adjust_down(mc, delp, q0_snow, h0_snow) - q0_graupel, h0_graupel = kh_adjust_down( - mc, delp, q0_graupel, h0_graupel - ) + q0_graupel, h0_graupel = kh_adjust_down(mc, delp, q0_graupel, h0_graupel) q0_o3mr, h0_o3mr = kh_adjust_down(mc, delp, q0_o3mr, h0_o3mr) - q0_sgs_tke, h0_sgs_tke = kh_adjust_down( - mc, delp, q0_sgs_tke, h0_sgs_tke - ) + q0_sgs_tke, h0_sgs_tke = kh_adjust_down(mc, delp, q0_sgs_tke, h0_sgs_tke) q0_cld, h0_cld = kh_adjust_down(mc, delp, q0_cld, h0_cld) u0, h0_u = kh_adjust_down(mc, delp, u0, h0_u) v0, h0_v = kh_adjust_down(mc, delp, v0, h0_v) w0, h0_w = kh_adjust_down(mc, delp, w0, h0_w) - total_energy, h0_total_energy = kh_adjust_energy_down( - mc, delp, static_energy, total_energy, h0_total_energy - ) + total_energy, h0_total_energy = kh_adjust_energy_down(mc, delp, static_energy, total_energy, h0_total_energy) cpm, cvm, t0, static_energy = adjust_cvm( cpm, cvm, @@ -577,9 +512,7 @@ def m_loop( total_energy, static_energy, ) - ri, ri_ref = compute_richardson_number( - t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min - ) + ri, ri_ref = compute_richardson_number(t0, q0_vapor, qcon, pkz, delp, peln, gz, u0, v0, xvir, t_max, t_min) ri_ref = ri_ref * 4.0 mc = compute_mass_flux(ri, ri_ref, delp, ratio) if ri < ri_ref: @@ -588,20 +521,14 @@ def m_loop( q0_rain, h0_rain = kh_adjust_down(mc, delp, q0_rain, h0_rain) q0_ice, h0_ice = kh_adjust_down(mc, delp, q0_ice, h0_ice) q0_snow, h0_snow = kh_adjust_down(mc, delp, q0_snow, h0_snow) - q0_graupel, h0_graupel = kh_adjust_down( - mc, delp, q0_graupel, h0_graupel - ) + q0_graupel, h0_graupel = kh_adjust_down(mc, delp, q0_graupel, h0_graupel) q0_o3mr, h0_o3mr = kh_adjust_down(mc, delp, q0_o3mr, h0_o3mr) - q0_sgs_tke, h0_sgs_tke = kh_adjust_down( - mc, delp, q0_sgs_tke, h0_sgs_tke - ) + q0_sgs_tke, h0_sgs_tke = kh_adjust_down(mc, delp, q0_sgs_tke, h0_sgs_tke) q0_cld, h0_cld = kh_adjust_down(mc, delp, q0_cld, h0_cld) u0, h0_u = kh_adjust_down(mc, delp, u0, h0_u) v0, h0_v = kh_adjust_down(mc, delp, v0, h0_v) w0, h0_w = kh_adjust_down(mc, delp, w0, h0_w) - total_energy, h0_total_energy = kh_adjust_energy_down( - mc, delp, static_energy, total_energy, h0_total_energy - ) + total_energy, h0_total_energy = kh_adjust_energy_down(mc, delp, static_energy, total_energy, h0_total_energy) cpm, cvm, t0, static_energy = adjust_cvm( cpm, cvm, @@ -728,9 +655,7 @@ def finalize( qcld = q0_cld -ArgSpec = collections.namedtuple( - "ArgSpec", ["arg_name", "standard_name", "units", "intent"] -) +ArgSpec = collections.namedtuple("ArgSpec", ["arg_name", "standard_name", "units", "intent"]) class DryConvectiveAdjustment: @@ -765,12 +690,8 @@ class DryConvectiveAdjustment: ArgSpec("qo3mr", "ozone_mixing_ratio", "kg/kg", intent="inout"), ArgSpec("qsgs_tke", "turbulent_kinetic_energy", "m**2/s**2", intent="inout"), ArgSpec("qcld", "cloud_fraction", "", intent="inout"), - ArgSpec( - "u_dt", "eastward_wind_tendency_due_to_physics", "m/s**2", intent="inout" - ), - ArgSpec( - "v_dt", "northward_wind_tendency_due_to_physics", "m/s**2", intent="inout" - ), + ArgSpec("u_dt", "eastward_wind_tendency_due_to_physics", "m/s**2", intent="inout"), + ArgSpec("v_dt", "northward_wind_tendency_due_to_physics", "m/s**2", intent="inout"), ) def __init__( @@ -783,15 +704,11 @@ def __init__( hydrostatic: bool, ): if hydrostatic: - raise NotImplementedError( - "DryConvectiveAdjustment (fv_subgridz): Hydrostatic is not implemented" - ) + raise NotImplementedError("DryConvectiveAdjustment (fv_subgridz): Hydrostatic is not implemented") grid_indexing = stencil_factory.grid_indexing self._k_sponge = n_sponge if self._k_sponge is not None and self._k_sponge < 3: - raise ValueError( - "DryConvectiveAdjustment (fv_subgridz): n_sponge < 3 is invalid." - ) + raise ValueError("DryConvectiveAdjustment (fv_subgridz): n_sponge < 3 is invalid.") else: self._k_sponge = grid_indexing.domain[2] if self._k_sponge < min(grid_indexing.domain[2], 24): diff --git a/pyfv3/stencils/map_single.py b/pyfv3/stencils/map_single.py index 0cab1b9f..73623d8e 100644 --- a/pyfv3/stencils/map_single.py +++ b/pyfv3/stencils/map_single.py @@ -4,7 +4,7 @@ from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import FORWARD, PARALLEL, computation, interval from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, Int, IntFieldIJ -from ndsl.stencils.basic_operations import copy +from ndsl.stencils import copy from pyfv3.stencils.remap_profile import RemapProfile diff --git a/pyfv3/stencils/remapping.py b/pyfv3/stencils/remapping.py index 223e70d4..e19fde73 100644 --- a/pyfv3/stencils/remapping.py +++ b/pyfv3/stencils/remapping.py @@ -19,7 +19,7 @@ region, ) from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, FloatFieldK -from ndsl.stencils.basic_operations import adjust_divide_stencil +from ndsl.stencils import divide_self from pyfv3._config import RemappingConfig from pyfv3.stencils import moist_cv from pyfv3.stencils.map_single import MapSingle @@ -508,8 +508,8 @@ def __init__( ), ) - self._basic_adjust_divide_stencil = stencil_factory.from_origin_domain( - adjust_divide_stencil, + self._basic_divide_self_stencil = stencil_factory.from_origin_domain( + divide_self, origin=grid_indexing.origin_compute(), domain=grid_indexing.domain_compute(), ) @@ -720,4 +720,4 @@ def __call__( ) else: # converts virtual temperature back to virtual potential temperature - self._basic_adjust_divide_stencil(pkz, pt) + self._basic_divide_self_stencil(pkz, pt) diff --git a/pyfv3/stencils/saturation_adjustment.py b/pyfv3/stencils/saturation_adjustment.py index b4e4009f..f64b9fa8 100644 --- a/pyfv3/stencils/saturation_adjustment.py +++ b/pyfv3/stencils/saturation_adjustment.py @@ -6,7 +6,7 @@ from ndsl.dsl.gt4py import function as gtfunction from ndsl.dsl.gt4py import interval, log from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ -from ndsl.stencils.basic_operations import dim +from ndsl.stencils.arithmetical_functions import dim from pyfv3._config import SatAdjustConfig from pyfv3.stencils.moist_cv import compute_pkz_func diff --git a/pyfv3/stencils/temperature_adjust.py b/pyfv3/stencils/temperature_adjust.py index 2f17bce3..b2e56ce2 100644 --- a/pyfv3/stencils/temperature_adjust.py +++ b/pyfv3/stencils/temperature_adjust.py @@ -1,7 +1,7 @@ import ndsl.constants as constants from ndsl.dsl.gt4py import PARALLEL, computation, exp, interval, log from ndsl.dsl.typing import Float, FloatField -from ndsl.stencils.basic_operations import sign +from ndsl.stencils.arithmetical_functions import sign def apply_diffusive_heating( diff --git a/pyfv3/stencils/xppm.py b/pyfv3/stencils/xppm.py index 4b9d85c1..277daa5d 100644 --- a/pyfv3/stencils/xppm.py +++ b/pyfv3/stencils/xppm.py @@ -3,7 +3,7 @@ from ndsl.dsl.gt4py import function as gtfunction from ndsl.dsl.gt4py import horizontal, interval, region from ndsl.dsl.typing import FloatField, FloatFieldIJ, Index3D -from ndsl.stencils.basic_operations import sign +from ndsl.stencils.arithmetical_functions import sign from pyfv3.stencils import ppm @@ -68,9 +68,7 @@ def get_flux(q: FloatField, courant: FloatField, al: FloatField): @gtfunction -def get_flux_ord8plus( - q: FloatField, courant: FloatField, bl: FloatField, br: FloatField -): +def get_flux_ord8plus(q: FloatField, courant: FloatField, bl: FloatField, br: FloatField): b0 = bl + br fx1 = fx1_fn(courant, br, b0, bl) return apply_flux(courant, q, fx1, 1.0) @@ -101,16 +99,14 @@ def blbr_iord8(q: FloatField, al: FloatField, dm: FloatField): def xt_dxa_edge_0_base(q, dxa): return 0.5 * ( ((2.0 * dxa + dxa[-1, 0]) * q - dxa * q[-1, 0, 0]) / (dxa[-1, 0] + dxa) - + ((2.0 * dxa[1, 0] + dxa[2, 0]) * q[1, 0, 0] - dxa[1, 0] * q[2, 0, 0]) - / (dxa[1, 0] + dxa[2, 0]) + + ((2.0 * dxa[1, 0] + dxa[2, 0]) * q[1, 0, 0] - dxa[1, 0] * q[2, 0, 0]) / (dxa[1, 0] + dxa[2, 0]) ) @gtfunction def xt_dxa_edge_1_base(q, dxa): return 0.5 * ( - ((2.0 * dxa[-1, 0] + dxa[-2, 0]) * q[-1, 0, 0] - dxa[-1, 0] * q[-2, 0, 0]) - / (dxa[-2, 0] + dxa[-1, 0]) + ((2.0 * dxa[-1, 0] + dxa[-2, 0]) * q[-1, 0, 0] - dxa[-1, 0] * q[-2, 0, 0]) / (dxa[-2, 0] + dxa[-1, 0]) + ((2.0 * dxa + dxa[1, 0]) * q - dxa * q[1, 0, 0]) / (dxa + dxa[1, 0]) ) @@ -166,13 +162,8 @@ def compute_al(q: FloatField, dxa: FloatFieldIJ): al = ppm.c1 * q[-2, 0, 0] + ppm.c2 * q[-1, 0, 0] + ppm.c3 * q with horizontal(region[i_start, :], region[i_end + 1, :]): al = 0.5 * ( - ( - (2.0 * dxa[-1, 0] + dxa[-2, 0]) * q[-1, 0, 0] - - dxa[-1, 0] * q[-2, 0, 0] - ) - / (dxa[-2, 0] + dxa[-1, 0]) - + ((2.0 * dxa[0, 0] + dxa[1, 0]) * q[0, 0, 0] - dxa[0, 0] * q[1, 0, 0]) - / (dxa[0, 0] + dxa[1, 0]) + ((2.0 * dxa[-1, 0] + dxa[-2, 0]) * q[-1, 0, 0] - dxa[-1, 0] * q[-2, 0, 0]) / (dxa[-2, 0] + dxa[-1, 0]) + + ((2.0 * dxa[0, 0] + dxa[1, 0]) * q[0, 0, 0] - dxa[0, 0] * q[1, 0, 0]) / (dxa[0, 0] + dxa[1, 0]) ) with horizontal(region[i_start + 1, :], region[i_end + 2, :]): al = ppm.c3 * q[-1, 0, 0] + ppm.c2 * q[0, 0, 0] + ppm.c1 * q[1, 0, 0] @@ -236,9 +227,7 @@ def bl_br_edges(bl, br, q, dxa, al, dm): xt_bl = xt_dxa_edge_1(q, dxa) xt_br = ppm.s11 * (q[1, 0, 0] - q) - ppm.s14 * dm_right_end + q - with horizontal( - region[i_start - 1 : i_start + 2, :], region[i_end - 1 : i_end + 2, :] - ): + with horizontal(region[i_start - 1 : i_start + 2, :], region[i_end - 1 : i_end + 2, :]): bl = xt_bl - q br = xt_br - q @@ -259,17 +248,13 @@ def compute_blbr_ord8plus(q: FloatField, dxa: FloatFieldIJ): if __INLINED(grid_type < 3): bl, br = bl_br_edges(bl, br, q, dxa, al, dm) - with horizontal( - region[i_start - 1 : i_start + 2, :], region[i_end - 1 : i_end + 2, :] - ): + with horizontal(region[i_start - 1 : i_start + 2, :], region[i_end - 1 : i_end + 2, :]): bl, br = ppm.pert_ppm_standard_constraint_fcn(q, bl, br) return bl, br -def compute_x_flux( - q: FloatField, courant: FloatField, dxa: FloatFieldIJ, xflux: FloatField -): +def compute_x_flux(q: FloatField, courant: FloatField, dxa: FloatFieldIJ, xflux: FloatField): """ Args: q (in): @@ -307,21 +292,13 @@ def __init__( # namelist.grid_type # grid.dxa if grid_type == 3 or grid_type > 4: - raise NotImplementedError( - "X Piecewise Parabolic (xppm): " - f" grid type {grid_type} not implemented. <3 or 4 available." - ) + raise NotImplementedError("X Piecewise Parabolic (xppm): " f" grid type {grid_type} not implemented. <3 or 4 available.") if abs(iord) >= 8 and iord != 8: - raise NotImplementedError( - "X Piecewise Parabolic (xppm): " - f"iord {iord} != 8 not implemented when >= 8." - ) + raise NotImplementedError("X Piecewise Parabolic (xppm): " f"iord {iord} != 8 not implemented when >= 8.") if iord < 0: - raise NotImplementedError( - f"X Piecewise Parabolic (xppm): iord {iord} < 0 not implemented." - ) + raise NotImplementedError(f"X Piecewise Parabolic (xppm): iord {iord} < 0 not implemented.") self._dxa = dxa ax_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain) @@ -367,8 +344,6 @@ def __call__( # in the Fortran version of this code, "x_advection" routines # were called "get_flux", while the routine which got the flux was called # fx1_fn. The final value was called xflux instead of q_out. - self._compute_flux_stencil( - q_in, c, self._dxa, q_mean_advected_through_x_interface - ) + self._compute_flux_stencil(q_in, c, self._dxa, q_mean_advected_through_x_interface) # bl and br are "edge perturbation values" as in equation 4.1 # of the FV3 documentation diff --git a/pyfv3/stencils/yppm.py b/pyfv3/stencils/yppm.py index 5174c883..b498b2ec 100644 --- a/pyfv3/stencils/yppm.py +++ b/pyfv3/stencils/yppm.py @@ -3,7 +3,7 @@ from ndsl.dsl.gt4py import function as gtfunction from ndsl.dsl.gt4py import horizontal, interval, region from ndsl.dsl.typing import FloatField, FloatFieldIJ, Index3D -from ndsl.stencils.basic_operations import sign +from ndsl.stencils.arithmetical_functions import sign from pyfv3.stencils import ppm @@ -68,9 +68,7 @@ def get_flux(q: FloatField, courant: FloatField, al: FloatField): @gtfunction -def get_flux_ord8plus( - q: FloatField, courant: FloatField, bl: FloatField, br: FloatField -): +def get_flux_ord8plus(q: FloatField, courant: FloatField, bl: FloatField, br: FloatField): b0 = bl + br fx1 = fx1_fn(courant, br, b0, bl) return apply_flux(courant, q, fx1, 1.0) @@ -101,16 +99,14 @@ def blbr_jord8(q: FloatField, al: FloatField, dm: FloatField): def yt_dya_edge_0_base(q, dya): return 0.5 * ( ((2.0 * dya + dya[0, -1]) * q - dya * q[0, -1, 0]) / (dya[0, -1] + dya) - + ((2.0 * dya[0, 1] + dya[0, 2]) * q[0, 1, 0] - dya[0, 1] * q[0, 2, 0]) - / (dya[0, 1] + dya[0, 2]) + + ((2.0 * dya[0, 1] + dya[0, 2]) * q[0, 1, 0] - dya[0, 1] * q[0, 2, 0]) / (dya[0, 1] + dya[0, 2]) ) @gtfunction def yt_dya_edge_1_base(q, dya): return 0.5 * ( - ((2.0 * dya[0, -1] + dya[0, -2]) * q[0, -1, 0] - dya[0, -1] * q[0, -2, 0]) - / (dya[0, -2] + dya[0, -1]) + ((2.0 * dya[0, -1] + dya[0, -2]) * q[0, -1, 0] - dya[0, -1] * q[0, -2, 0]) / (dya[0, -2] + dya[0, -1]) + ((2.0 * dya + dya[0, 1]) * q - dya * q[0, 1, 0]) / (dya + dya[0, 1]) ) @@ -166,13 +162,8 @@ def compute_al(q: FloatField, dya: FloatFieldIJ): al = ppm.c1 * q[0, -2, 0] + ppm.c2 * q[0, -1, 0] + ppm.c3 * q with horizontal(region[:, j_start], region[:, j_end + 1]): al = 0.5 * ( - ( - (2.0 * dya[0, -1] + dya[0, -2]) * q[0, -1, 0] - - dya[0, -1] * q[0, -2, 0] - ) - / (dya[0, -2] + dya[0, -1]) - + ((2.0 * dya[0, 0] + dya[0, 1]) * q[0, 0, 0] - dya[0, 0] * q[0, 1, 0]) - / (dya[0, 0] + dya[0, 1]) + ((2.0 * dya[0, -1] + dya[0, -2]) * q[0, -1, 0] - dya[0, -1] * q[0, -2, 0]) / (dya[0, -2] + dya[0, -1]) + + ((2.0 * dya[0, 0] + dya[0, 1]) * q[0, 0, 0] - dya[0, 0] * q[0, 1, 0]) / (dya[0, 0] + dya[0, 1]) ) with horizontal(region[:, j_start + 1], region[:, j_end + 2]): al = ppm.c3 * q[0, -1, 0] + ppm.c2 * q[0, 0, 0] + ppm.c1 * q[0, 1, 0] @@ -236,9 +227,7 @@ def bl_br_edges(bl, br, q, dya, al, dm): yt_bl = yt_dya_edge_1(q, dya) yt_br = ppm.s11 * (q[0, 1, 0] - q) - ppm.s14 * dm_right_end + q - with horizontal( - region[:, j_start - 1 : j_start + 2], region[:, j_end - 1 : j_end + 2] - ): + with horizontal(region[:, j_start - 1 : j_start + 2], region[:, j_end - 1 : j_end + 2]): bl = yt_bl - q br = yt_br - q @@ -259,17 +248,13 @@ def compute_blbr_ord8plus(q: FloatField, dya: FloatFieldIJ): if __INLINED(grid_type < 3): bl, br = bl_br_edges(bl, br, q, dya, al, dm) - with horizontal( - region[:, j_start - 1 : j_start + 2], region[:, j_end - 1 : j_end + 2] - ): + with horizontal(region[:, j_start - 1 : j_start + 2], region[:, j_end - 1 : j_end + 2]): bl, br = ppm.pert_ppm_standard_constraint_fcn(q, bl, br) return bl, br -def compute_y_flux( - q: FloatField, courant: FloatField, dya: FloatFieldIJ, yflux: FloatField -): +def compute_y_flux(q: FloatField, courant: FloatField, dya: FloatFieldIJ, yflux: FloatField): """ Args: q (in): @@ -307,21 +292,13 @@ def __init__( # namelist.grid_type # grid.dya if grid_type == 3 or grid_type > 4: - raise NotImplementedError( - "Y Piecewise Parabolic (yppm): " - f" grid type {grid_type} not implemented. <3 or 4 available." - ) + raise NotImplementedError("Y Piecewise Parabolic (yppm): " f" grid type {grid_type} not implemented. <3 or 4 available.") if abs(jord) >= 8 and jord != 8: - raise NotImplementedError( - "Y Piecewise Parabolic (yppm): " - f"jord {jord} != 8 not implemented when >= 8." - ) + raise NotImplementedError("Y Piecewise Parabolic (yppm): " f"jord {jord} != 8 not implemented when >= 8.") if jord < 0: - raise NotImplementedError( - f"Y Piecewise Parabolic (yppm): jord {jord} < 0 not implemented." - ) + raise NotImplementedError(f"Y Piecewise Parabolic (yppm): jord {jord} < 0 not implemented.") self._dya = dya ax_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain) @@ -367,8 +344,6 @@ def __call__( # in the Fortran version of this code, "x_advection" routines # were called "get_flux", while the routine which got the flux was called # fx1_fn. The final value was called yflux instead of q_out. - self._compute_flux_stencil( - q_in, c, self._dya, q_mean_advected_through_y_interface - ) + self._compute_flux_stencil(q_in, c, self._dya, q_mean_advected_through_y_interface) # bl and br are "edge perturbation values" as in equation 4.1 # of the FV3 documentation