Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda_package/dev-spec.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ hdf5
inpoly
libnetcdf
matplotlib-base>=3.9.0
nco
netcdf4
networkx
numpy>=2.0,<3.0
Expand Down
2 changes: 2 additions & 0 deletions conda_package/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ analyzing simulations, and in other MPAS-related workflows.

config

io

logging

transects
Expand Down
28 changes: 28 additions & 0 deletions conda_package/docs/io.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. _io:

*********
I/O Tools
*********

The :py:mod:`mpas_tools.io` module provides utilities for reading and writing
NetCDF files, especially for compatibility with MPAS mesh and data conventions.

write_netcdf
============

The :py:func:`mpas_tools.io.write_netcdf()` function writes an
``xarray.Dataset`` to a NetCDF file, ensuring MPAS compatibility (e.g.,
converting int64 to int32, handling fill values, and updating the history
attribute). It also supports writing in various NetCDF formats, including
conversion to ``NETCDF3_64BIT_DATA`` using ``ncks`` if needed.

Example usage:

.. code-block:: python

import xarray as xr
from mpas_tools.io import write_netcdf

# Create a simple dataset
ds = xr.Dataset({'foo': (('x',), [1, 2, 3])})
write_netcdf(ds, 'output.nc')
98 changes: 84 additions & 14 deletions conda_package/mpas_tools/io.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
from __future__ import absolute_import, division, print_function, \
unicode_literals
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path

import numpy
import netCDF4
from datetime import datetime
import sys
import numpy

from mpas_tools.logging import check_call

default_format = 'NETCDF3_64BIT'
default_engine = None
default_char_dim_name = 'StrLen'
default_fills = netCDF4.default_fillvals


def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
char_dim_name=None):
def write_netcdf(
ds,
fileName,
fillValues=None,
format=None,
engine=None,
char_dim_name=None,
logger=None,
):
"""
Write an xarray.Dataset to a file with NetCDF4 fill values and the given
name of the string dimension. Also adds the time and command-line to the
history attribute.

Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case
because xarray output with this format is not performant. First, the file
is written in `NETCDF4` format, which supports larger files and variables.
Then, the `ncks` command is used to convert the file to the
`NETCDF3_64BIT_DATA` format.

Note: All int64 variables are automatically converted to int32 for MPAS
compatibility.

Parameters
----------
ds : xarray.Dataset
Expand Down Expand Up @@ -50,7 +68,11 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
``mpas_tools.io.default_char_dim_name``, which can be modified but
which defaults to ``'StrLen'``

"""
logger : logging.Logger, optional
A logger to write messages to write the output of `ncks` conversion
calls to. If None, `ncks` output is suppressed. This is only
relevant if `format` is 'NETCDF3_64BIT_DATA'
""" # noqa: E501
if format is None:
format = default_format

Expand All @@ -63,6 +85,13 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
if char_dim_name is None:
char_dim_name = default_char_dim_name

# Convert int64 variables to int32 for MPAS compatibility
for var in list(ds.data_vars.keys()) + list(ds.coords.keys()):
if ds[var].dtype == numpy.int64:
attrs = ds[var].attrs.copy()
ds[var] = ds[var].astype(numpy.int32)
ds[var].attrs = attrs

encodingDict = {}
variableNames = list(ds.data_vars.keys()) + list(ds.coords.keys())
for variableName in variableNames:
Expand All @@ -71,8 +100,9 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
dtype = ds[variableName].dtype
for fillType in fillValues:
if dtype == numpy.dtype(fillType):
encodingDict[variableName] = \
{'_FillValue': fillValues[fillType]}
encodingDict[variableName] = {
'_FillValue': fillValues[fillType]
}
break
else:
encodingDict[variableName] = {'_FillValue': None}
Expand All @@ -88,14 +118,54 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
# reading Time otherwise
ds.encoding['unlimited_dims'] = {'Time'}

ds.to_netcdf(fileName, encoding=encodingDict, format=format, engine=engine)
# for performance, we have to handle this as a special case
convert = format == 'NETCDF3_64BIT_DATA'

if convert:
out_path = Path(fileName)
out_filename = (
out_path.parent / f'_tmp_{out_path.stem}.netcdf4{out_path.suffix}'
)
format = 'NETCDF4'
if engine == 'scipy':
# that's not going to work
engine = 'netcdf4'
else:
out_filename = fileName

ds.to_netcdf(
out_filename, encoding=encodingDict, format=format, engine=engine
)

if convert:
args = [
'ncks',
'-O',
'-5',
out_filename,
fileName,
]
if logger is None:
subprocess.run(
args,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
else:
check_call(args, logger=logger)
Comment thread
xylar marked this conversation as resolved.
# delete the temporary NETCDF4 file
os.remove(out_filename)


def update_history(ds):
'''Add or append history to attributes of a data set'''
"""Add or append history to attributes of a data set"""

thiscommand = datetime.now().strftime("%a %b %d %H:%M:%S %Y") + ": " + \
" ".join(sys.argv[:])
thiscommand = (
datetime.now().strftime('%a %b %d %H:%M:%S %Y')
+ ': '
+ ' '.join(sys.argv[:])
)
if 'history' in ds.attrs:
newhist = '\n'.join([thiscommand, ds.attrs['history']])
else:
Expand Down
42 changes: 31 additions & 11 deletions conda_package/mpas_tools/mesh/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,13 @@ def entry_point_compute_mpas_region_masks():
subdivisionThreshold=args.subdivision,
)

write_netcdf(
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
)
write_netcdf(
dsMasks,
args.mask_file_name,
format=args.format,
engine=args.engine,
logger=logger,
)


def compute_mpas_transect_masks(
Expand Down Expand Up @@ -516,9 +520,13 @@ def entry_point_compute_mpas_transect_masks():
addEdgeSign=args.add_edge_sign,
)

write_netcdf(
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
)
write_netcdf(
dsMasks,
args.mask_file_name,
format=args.format,
engine=args.engine,
logger=logger,
)


def compute_mpas_flood_fill_mask(
Expand Down Expand Up @@ -641,9 +649,13 @@ def entry_point_compute_mpas_flood_fill_mask():
dsMesh=dsMesh, fcSeed=fcSeed, logger=logger
)

write_netcdf(
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
)
write_netcdf(
dsMasks,
args.mask_file_name,
format=args.format,
engine=args.engine,
logger=logger,
)


def compute_lon_lat_region_masks(
Expand Down Expand Up @@ -868,7 +880,11 @@ def entry_point_compute_lon_lat_region_masks():
)

write_netcdf(
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
dsMasks,
args.mask_file_name,
format=args.format,
engine=args.engine,
logger=logger,
)


Expand Down Expand Up @@ -1101,7 +1117,11 @@ def entry_point_compute_projection_grid_region_masks():
)

write_netcdf(
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
dsMasks,
args.mask_file_name,
format=args.format,
engine=args.engine,
logger=logger,
)


Expand Down
1 change: 1 addition & 0 deletions conda_package/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ requirements:
- networkx
- netcdf-fortran
- matplotlib-base >=3.9.0
- nco
- netcdf4
- numpy >=2.0,<3.0
- progressbar2
Expand Down
102 changes: 102 additions & 0 deletions conda_package/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import subprocess

import numpy as np
import pytest
import xarray as xr

from mpas_tools.io import write_netcdf

from .util import get_test_data_file

TEST_MESH = get_test_data_file('mesh.QU.1920km.151026.nc')


@pytest.mark.skipif(
not os.path.exists(TEST_MESH), reason='Test mesh not available'
)
def test_write_netcdf_basic(tmp_path):
ds = xr.open_dataset(TEST_MESH)
out_file = tmp_path / 'test_basic.nc'
write_netcdf(ds, str(out_file))
ds2 = xr.open_dataset(out_file)
# Should have same dimensions and variables
assert set(ds.dims) == set(ds2.dims)
for var in ds.data_vars:
assert var in ds2.data_vars
ds2.close()


@pytest.mark.skipif(
not os.path.exists(TEST_MESH), reason='Test mesh not available'
)
def test_write_netcdf_cdf5_format(tmp_path):
ds = xr.open_dataset(TEST_MESH)
out_file = tmp_path / 'test_cdf5.nc'
write_netcdf(ds, str(out_file), format='NETCDF3_64BIT_DATA')
# Use ncdump -k to check format
result = subprocess.run(
['ncdump', '-k', str(out_file)],
capture_output=True,
text=True,
check=True,
)
# Should be cdf5 for NETCDF3_64BIT_DATA
assert result.stdout.strip() == 'cdf5'
# Check that the temporary file was deleted
tmp_file = (
out_file.parent / f'_tmp_{out_file.stem}.netcdf4{out_file.suffix}'
)
assert not os.path.exists(tmp_file)


def test_write_netcdf_int64_conversion_and_attr(tmp_path):
# Create a dataset with int64 variable and an attribute
arr = np.array([1, 2, 3], dtype=np.int64)
ds = xr.Dataset({'foo': (('x',), arr)})
ds['foo'].attrs['myattr'] = 'testattr'
out_file = tmp_path / 'test_int64.nc'
write_netcdf(ds, str(out_file))
ds2 = xr.open_dataset(out_file)
# Should be int32, not int64
assert ds2['foo'].dtype == np.int32
# Attribute should be preserved
assert ds2['foo'].attrs['myattr'] == 'testattr'
ds2.close()


def test_write_netcdf_fill_value(tmp_path):
# Test that NaN values are written with correct fill value
arr = np.array([1.0, np.nan, 3.0], dtype=np.float32)
ds = xr.Dataset({'bar': (('x',), arr)})
out_file = tmp_path / 'test_fill.nc'
write_netcdf(ds, str(out_file))
ds2 = xr.open_dataset(out_file)
# The second value should be the default fill value for float32
fill_value = ds2['bar'].encoding.get('_FillValue', None)
assert fill_value is not None
assert np.isnan(ds2['bar'].values[1])
ds2.close()


def test_write_netcdf_string_dim_name(tmp_path):
# Test that custom char_dim_name is used in encoding
arr = np.array([b'abc', b'def'])
ds = xr.Dataset({'baz': (('x',), arr)})
out_file = tmp_path / 'test_strdim.nc'
write_netcdf(ds, str(out_file), char_dim_name='CustomStrLen')
ds2 = xr.open_dataset(out_file)
# Should have the variable and correct shape
assert 'baz' in ds2.variables
ds2.close()
arr = np.array([1, 2, 3], dtype=np.int64)
ds = xr.Dataset({'foo': (('x',), arr)})
ds['foo'].attrs['myattr'] = 'testattr'
out_file = tmp_path / 'test_int64.nc'
write_netcdf(ds, str(out_file))
ds2 = xr.open_dataset(out_file)
# Should be int32, not int64
assert ds2['foo'].dtype == np.int32
# Attribute should be preserved
assert ds2['foo'].attrs['myattr'] == 'testattr'
ds2.close()
Loading