Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions conda_package/mpas_tools/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
default_engine = None
default_char_dim_name = 'StrLen'
default_fills = netCDF4.default_fillvals
default_nchar = 64


def write_netcdf(
Expand All @@ -23,6 +24,7 @@ def write_netcdf(
engine=None,
char_dim_name=None,
logger=None,
nchar=None,
):
"""
Write an xarray.Dataset to a file with NetCDF4 fill values and the given
Expand All @@ -31,9 +33,9 @@ def write_netcdf(

Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case
because xarray output with this format is not performant. First, the file
is written in `NETCDF4` format, which supports larger files and variables.
Then, the `ncks` command is used to convert the file to the
`NETCDF3_64BIT_DATA` format.
is written in ``NETCDF4`` format, which supports larger files and
variables. Then, the ``ncks`` command is used to convert the file to the
``NETCDF3_64BIT_DATA`` format.

Note: All int64 variables are automatically converted to int32 for MPAS
compatibility.
Expand Down Expand Up @@ -63,15 +65,19 @@ def write_netcdf(
``mpas_tools.io.default_engine``

char_dim_name : str, optional
The name of the dimension used for character strings, or None to let
xarray figure this out. Default is
The name of the dimension used for character strings. Default is
``mpas_tools.io.default_char_dim_name``, which can be modified but
which defaults to ``'StrLen'``

nchar : int, optional
The number of characters to use for string variables. If None, the
default is ``mpas_tools.io.default_nchar``, which can be modified but
which defaults to 64.

logger : logging.Logger, optional
A logger to write messages to write the output of `ncks` conversion
calls to. If None, `ncks` output is suppressed. This is only
relevant if `format` is 'NETCDF3_64BIT_DATA'
A logger to write messages to write the output of ``ncks`` conversion
calls to. If None, ``ncks`` output is suppressed. This is only
relevant if ``format`` is 'NETCDF3_64BIT_DATA'
""" # noqa: E501
if format is None:
format = default_format
Expand All @@ -85,31 +91,43 @@ def write_netcdf(
if char_dim_name is None:
char_dim_name = default_char_dim_name

# Convert int64 variables to int32 for MPAS compatibility
for var in list(ds.data_vars.keys()) + list(ds.coords.keys()):
if ds[var].dtype == numpy.int64:
attrs = ds[var].attrs.copy()
ds[var] = ds[var].astype(numpy.int32)
ds[var].attrs = attrs
if nchar is None:
nchar = default_nchar

numpyFillValues = {}
for fillType in fillValues:
# drop string fill values
if not fillType.startswith('S'):
numpyFillValues[numpy.dtype(fillType)] = fillValues[fillType]

encodingDict = {}
variableNames = list(ds.data_vars.keys()) + list(ds.coords.keys())
for variableName in variableNames:
isNumeric = numpy.issubdtype(ds[variableName].dtype, numpy.number)
if isNumeric and numpy.any(numpy.isnan(ds[variableName])):
dtype = ds[variableName].dtype
for fillType in fillValues:
if dtype == numpy.dtype(fillType):
encodingDict[variableName] = {
'_FillValue': fillValues[fillType]
}
break
else:
encodingDict[variableName] = {'_FillValue': None}

isString = numpy.issubdtype(ds[variableName].dtype, numpy.bytes_)
if isString and char_dim_name is not None:
encodingDict[variableName] = {'char_dim_name': char_dim_name}
var = ds[variableName]
encodingDict[variableName] = {}
dtype = var.dtype

# Convert int64 variables to int32 for MPAS compatibility
if dtype == numpy.int64:
encodingDict[variableName]['dtype'] = 'int32'

# add fill values
if dtype in numpyFillValues:
if numpy.any(numpy.isnan(var)):
# only add fill values if they're needed
fill = numpyFillValues[dtype]
else:
fill = None
encodingDict[variableName]['_FillValue'] = fill

isString = numpy.issubdtype(dtype, numpy.bytes_) or numpy.issubdtype(
dtype, numpy.str_
)
if isString:
# set the encoding for string variables
encodingDict[variableName].update(
{'dtype': f'|S{nchar}', 'char_dim_name': char_dim_name}
)

update_history(ds)

Expand Down
54 changes: 21 additions & 33 deletions conda_package/mpas_tools/mesh/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from shapely.strtree import STRtree

from mpas_tools.cime.constants import constants
from mpas_tools.io import write_netcdf
from mpas_tools.io import default_nchar, write_netcdf
from mpas_tools.logging import LoggingContext
from mpas_tools.parallel import create_pool
from mpas_tools.transects import (
Expand Down Expand Up @@ -100,7 +100,7 @@ def compute_mpas_region_masks(

# create shapely geometry for lon and lat
points = [shapely.geometry.Point(x, y) for x, y in zip(lon, lat)]
regionNames, masks, properties, nchar = _compute_region_masks(
regionNames, masks, properties = _compute_region_masks(
fcMask,
points,
logger,
Expand Down Expand Up @@ -133,7 +133,6 @@ def compute_mpas_region_masks(
ds=dsMasks,
properties=properties,
dim='nRegions',
nchar=nchar,
)

if logger is not None:
Expand Down Expand Up @@ -339,17 +338,15 @@ def compute_mpas_transect_masks(
polygons, nPolygons, duplicatePolygons = _get_polygons(
dsMesh, maskType
)
transectNames, masks, properties, nchar, shapes = (
_compute_transect_masks(
fcMask,
polygons,
logger,
pool,
chunkSize,
showProgress,
subdivisionResolution,
earthRadius,
)
transectNames, masks, properties, shapes = _compute_transect_masks(
fcMask,
polygons,
logger,
pool,
chunkSize,
showProgress,
subdivisionResolution,
earthRadius,
)

if logger is not None:
Expand Down Expand Up @@ -393,7 +390,6 @@ def compute_mpas_transect_masks(
ds=dsMasks,
properties=properties,
dim='nTransects',
nchar=nchar,
)

if logger is not None:
Expand Down Expand Up @@ -723,7 +719,7 @@ def compute_lon_lat_region_masks(

# create shapely geometry for lon and lat
points = [shapely.geometry.Point(x, y) for x, y in zip(Lon, Lat)]
regionNames, masks, properties, nchar = _compute_region_masks(
regionNames, masks, properties = _compute_region_masks(
fcMask,
points,
logger,
Expand Down Expand Up @@ -757,7 +753,6 @@ def compute_lon_lat_region_masks(
ds=dsMasks,
properties=properties,
dim='nRegions',
nchar=nchar,
)

if logger is not None:
Expand Down Expand Up @@ -959,7 +954,7 @@ def compute_projection_grid_region_masks(
points = [
shapely.geometry.Point(x, y) for x, y in zip(lon.ravel(), lat.ravel())
]
regionNames, masks, properties, nchar = _compute_region_masks(
regionNames, masks, properties = _compute_region_masks(
fcMask,
points,
logger,
Expand Down Expand Up @@ -990,7 +985,6 @@ def compute_projection_grid_region_masks(
ds=dsMasks,
properties=properties,
dim='nRegions',
nchar=nchar,
)

if logger is not None:
Expand Down Expand Up @@ -1171,10 +1165,11 @@ def _compute_mask_from_shapes(
return mask


def _add_properties(ds, properties, dim, nchar):
def _add_properties(ds, properties, dim):
"""
Add properties to the dataset from a dictionary of properties
"""
nchar = default_nchar
for name, prop_list in properties.items():
if name not in ds:
if isinstance(prop_list[0], str):
Expand All @@ -1186,7 +1181,7 @@ def _add_properties(ds, properties, dim, nchar):
for index, value in enumerate(prop_list):
ds[name][index] = value
else:
ds[name] = ((dim,), properties[prop_list])
ds[name] = ((dim,), prop_list)


def _get_region_names_and_properties(fc):
Expand All @@ -1208,19 +1203,16 @@ def _get_region_names_and_properties(fc):
propertyNames.add(propertyName)

properties = {}
nchar = 0
for propertyName in propertyNames:
properties[propertyName] = []
for feature in fc.features:
if propertyName in feature['properties']:
propertyVal = feature['properties'][propertyName]
properties[propertyName].append(propertyVal)
if isinstance(propertyVal, str):
nchar = max(nchar, len(propertyVal))
else:
properties[propertyName].append('')

return regionNames, properties, nchar
return regionNames, properties


def _compute_region_masks(
Expand All @@ -1231,7 +1223,7 @@ def _compute_region_masks(
a set of regions.
"""

regionNames, properties, nchar = _get_region_names_and_properties(fcMask)
regionNames, properties = _get_region_names_and_properties(fcMask)

masks = []

Expand All @@ -1253,11 +1245,9 @@ def _compute_region_masks(
showProgress=showProgress,
)

nchar = max(nchar, len(name))

masks.append(mask)

return regionNames, masks, properties, nchar
return regionNames, masks, properties


def _contains(shapes, points):
Expand Down Expand Up @@ -1355,7 +1345,7 @@ def _compute_transect_masks(
a set of transects.
"""

transectNames, properties, nchar = _get_region_names_and_properties(fcMask)
transectNames, properties = _get_region_names_and_properties(fcMask)

masks = []
shapes = []
Expand Down Expand Up @@ -1405,12 +1395,10 @@ def _compute_transect_masks(
showProgress=showProgress,
)

nchar = max(nchar, len(name))

masks.append(mask)
shapes.append(shape)

return transectNames, masks, properties, nchar, shapes
return transectNames, masks, properties, shapes


def _intersects(shape, polygons):
Expand Down
Loading
Loading