diff --git a/conda_package/dev-spec.txt b/conda_package/dev-spec.txt index 56755df14..90e46adc6 100644 --- a/conda_package/dev-spec.txt +++ b/conda_package/dev-spec.txt @@ -12,6 +12,7 @@ hdf5 inpoly libnetcdf matplotlib-base>=3.9.0 +nco netcdf4 networkx numpy>=2.0,<3.0 diff --git a/conda_package/docs/index.rst b/conda_package/docs/index.rst index 02f775a4a..fa692db9c 100644 --- a/conda_package/docs/index.rst +++ b/conda_package/docs/index.rst @@ -27,6 +27,8 @@ analyzing simulations, and in other MPAS-related workflows. config + io + logging transects diff --git a/conda_package/docs/io.rst b/conda_package/docs/io.rst new file mode 100644 index 000000000..12d412c35 --- /dev/null +++ b/conda_package/docs/io.rst @@ -0,0 +1,28 @@ +.. _io: + +********* +I/O Tools +********* + +The :py:mod:`mpas_tools.io` module provides utilities for reading and writing +NetCDF files, especially for compatibility with MPAS mesh and data conventions. + +write_netcdf +============ + +The :py:func:`mpas_tools.io.write_netcdf()` function writes an +``xarray.Dataset`` to a NetCDF file, ensuring MPAS compatibility (e.g., +converting int64 to int32, handling fill values, and updating the history +attribute). It also supports writing in various NetCDF formats, including +conversion to ``NETCDF3_64BIT_DATA`` using ``ncks`` if needed. + +Example usage: + +.. code-block:: python + + import xarray as xr + from mpas_tools.io import write_netcdf + + # Create a simple dataset + ds = xr.Dataset({'foo': (('x',), [1, 2, 3])}) + write_netcdf(ds, 'output.nc') diff --git a/conda_package/mpas_tools/io.py b/conda_package/mpas_tools/io.py index af300b281..d9e5133d1 100644 --- a/conda_package/mpas_tools/io.py +++ b/conda_package/mpas_tools/io.py @@ -1,11 +1,13 @@ -from __future__ import absolute_import, division, print_function, \ - unicode_literals +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path -import numpy import netCDF4 -from datetime import datetime -import sys +import numpy +from mpas_tools.logging import check_call default_format = 'NETCDF3_64BIT' default_engine = None @@ -13,13 +15,29 @@ default_fills = netCDF4.default_fillvals -def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None, - char_dim_name=None): +def write_netcdf( + ds, + fileName, + fillValues=None, + format=None, + engine=None, + char_dim_name=None, + logger=None, +): """ Write an xarray.Dataset to a file with NetCDF4 fill values and the given name of the string dimension. Also adds the time and command-line to the history attribute. + Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case + because xarray output with this format is not performant. First, the file + is written in `NETCDF4` format, which supports larger files and variables. + Then, the `ncks` command is used to convert the file to the + `NETCDF3_64BIT_DATA` format. + + Note: All int64 variables are automatically converted to int32 for MPAS + compatibility. + Parameters ---------- ds : xarray.Dataset @@ -50,7 +68,11 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None, ``mpas_tools.io.default_char_dim_name``, which can be modified but which defaults to ``'StrLen'`` - """ + logger : logging.Logger, optional + A logger to write messages to write the output of `ncks` conversion + calls to. If None, `ncks` output is suppressed. This is only + relevant if `format` is 'NETCDF3_64BIT_DATA' + """ # noqa: E501 if format is None: format = default_format @@ -63,6 +85,13 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None, if char_dim_name is None: char_dim_name = default_char_dim_name + # Convert int64 variables to int32 for MPAS compatibility + for var in list(ds.data_vars.keys()) + list(ds.coords.keys()): + if ds[var].dtype == numpy.int64: + attrs = ds[var].attrs.copy() + ds[var] = ds[var].astype(numpy.int32) + ds[var].attrs = attrs + encodingDict = {} variableNames = list(ds.data_vars.keys()) + list(ds.coords.keys()) for variableName in variableNames: @@ -71,8 +100,9 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None, dtype = ds[variableName].dtype for fillType in fillValues: if dtype == numpy.dtype(fillType): - encodingDict[variableName] = \ - {'_FillValue': fillValues[fillType]} + encodingDict[variableName] = { + '_FillValue': fillValues[fillType] + } break else: encodingDict[variableName] = {'_FillValue': None} @@ -88,14 +118,54 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None, # reading Time otherwise ds.encoding['unlimited_dims'] = {'Time'} - ds.to_netcdf(fileName, encoding=encodingDict, format=format, engine=engine) + # for performance, we have to handle this as a special case + convert = format == 'NETCDF3_64BIT_DATA' + + if convert: + out_path = Path(fileName) + out_filename = ( + out_path.parent / f'_tmp_{out_path.stem}.netcdf4{out_path.suffix}' + ) + format = 'NETCDF4' + if engine == 'scipy': + # that's not going to work + engine = 'netcdf4' + else: + out_filename = fileName + + ds.to_netcdf( + out_filename, encoding=encodingDict, format=format, engine=engine + ) + + if convert: + args = [ + 'ncks', + '-O', + '-5', + out_filename, + fileName, + ] + if logger is None: + subprocess.run( + args, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + else: + check_call(args, logger=logger) + # delete the temporary NETCDF4 file + os.remove(out_filename) def update_history(ds): - '''Add or append history to attributes of a data set''' + """Add or append history to attributes of a data set""" - thiscommand = datetime.now().strftime("%a %b %d %H:%M:%S %Y") + ": " + \ - " ".join(sys.argv[:]) + thiscommand = ( + datetime.now().strftime('%a %b %d %H:%M:%S %Y') + + ': ' + + ' '.join(sys.argv[:]) + ) if 'history' in ds.attrs: newhist = '\n'.join([thiscommand, ds.attrs['history']]) else: diff --git a/conda_package/mpas_tools/mesh/mask.py b/conda_package/mpas_tools/mesh/mask.py index d82a4c8da..b30e64f87 100644 --- a/conda_package/mpas_tools/mesh/mask.py +++ b/conda_package/mpas_tools/mesh/mask.py @@ -251,9 +251,13 @@ def entry_point_compute_mpas_region_masks(): subdivisionThreshold=args.subdivision, ) - write_netcdf( - dsMasks, args.mask_file_name, format=args.format, engine=args.engine - ) + write_netcdf( + dsMasks, + args.mask_file_name, + format=args.format, + engine=args.engine, + logger=logger, + ) def compute_mpas_transect_masks( @@ -516,9 +520,13 @@ def entry_point_compute_mpas_transect_masks(): addEdgeSign=args.add_edge_sign, ) - write_netcdf( - dsMasks, args.mask_file_name, format=args.format, engine=args.engine - ) + write_netcdf( + dsMasks, + args.mask_file_name, + format=args.format, + engine=args.engine, + logger=logger, + ) def compute_mpas_flood_fill_mask( @@ -641,9 +649,13 @@ def entry_point_compute_mpas_flood_fill_mask(): dsMesh=dsMesh, fcSeed=fcSeed, logger=logger ) - write_netcdf( - dsMasks, args.mask_file_name, format=args.format, engine=args.engine - ) + write_netcdf( + dsMasks, + args.mask_file_name, + format=args.format, + engine=args.engine, + logger=logger, + ) def compute_lon_lat_region_masks( @@ -868,7 +880,11 @@ def entry_point_compute_lon_lat_region_masks(): ) write_netcdf( - dsMasks, args.mask_file_name, format=args.format, engine=args.engine + dsMasks, + args.mask_file_name, + format=args.format, + engine=args.engine, + logger=logger, ) @@ -1101,7 +1117,11 @@ def entry_point_compute_projection_grid_region_masks(): ) write_netcdf( - dsMasks, args.mask_file_name, format=args.format, engine=args.engine + dsMasks, + args.mask_file_name, + format=args.format, + engine=args.engine, + logger=logger, ) diff --git a/conda_package/recipe/meta.yaml b/conda_package/recipe/meta.yaml index b0d0b372e..f4aaa6f7c 100644 --- a/conda_package/recipe/meta.yaml +++ b/conda_package/recipe/meta.yaml @@ -42,6 +42,7 @@ requirements: - networkx - netcdf-fortran - matplotlib-base >=3.9.0 + - nco - netcdf4 - numpy >=2.0,<3.0 - progressbar2 diff --git a/conda_package/tests/test_io.py b/conda_package/tests/test_io.py new file mode 100644 index 000000000..d6b58b359 --- /dev/null +++ b/conda_package/tests/test_io.py @@ -0,0 +1,102 @@ +import os +import subprocess + +import numpy as np +import pytest +import xarray as xr + +from mpas_tools.io import write_netcdf + +from .util import get_test_data_file + +TEST_MESH = get_test_data_file('mesh.QU.1920km.151026.nc') + + +@pytest.mark.skipif( + not os.path.exists(TEST_MESH), reason='Test mesh not available' +) +def test_write_netcdf_basic(tmp_path): + ds = xr.open_dataset(TEST_MESH) + out_file = tmp_path / 'test_basic.nc' + write_netcdf(ds, str(out_file)) + ds2 = xr.open_dataset(out_file) + # Should have same dimensions and variables + assert set(ds.dims) == set(ds2.dims) + for var in ds.data_vars: + assert var in ds2.data_vars + ds2.close() + + +@pytest.mark.skipif( + not os.path.exists(TEST_MESH), reason='Test mesh not available' +) +def test_write_netcdf_cdf5_format(tmp_path): + ds = xr.open_dataset(TEST_MESH) + out_file = tmp_path / 'test_cdf5.nc' + write_netcdf(ds, str(out_file), format='NETCDF3_64BIT_DATA') + # Use ncdump -k to check format + result = subprocess.run( + ['ncdump', '-k', str(out_file)], + capture_output=True, + text=True, + check=True, + ) + # Should be cdf5 for NETCDF3_64BIT_DATA + assert result.stdout.strip() == 'cdf5' + # Check that the temporary file was deleted + tmp_file = ( + out_file.parent / f'_tmp_{out_file.stem}.netcdf4{out_file.suffix}' + ) + assert not os.path.exists(tmp_file) + + +def test_write_netcdf_int64_conversion_and_attr(tmp_path): + # Create a dataset with int64 variable and an attribute + arr = np.array([1, 2, 3], dtype=np.int64) + ds = xr.Dataset({'foo': (('x',), arr)}) + ds['foo'].attrs['myattr'] = 'testattr' + out_file = tmp_path / 'test_int64.nc' + write_netcdf(ds, str(out_file)) + ds2 = xr.open_dataset(out_file) + # Should be int32, not int64 + assert ds2['foo'].dtype == np.int32 + # Attribute should be preserved + assert ds2['foo'].attrs['myattr'] == 'testattr' + ds2.close() + + +def test_write_netcdf_fill_value(tmp_path): + # Test that NaN values are written with correct fill value + arr = np.array([1.0, np.nan, 3.0], dtype=np.float32) + ds = xr.Dataset({'bar': (('x',), arr)}) + out_file = tmp_path / 'test_fill.nc' + write_netcdf(ds, str(out_file)) + ds2 = xr.open_dataset(out_file) + # The second value should be the default fill value for float32 + fill_value = ds2['bar'].encoding.get('_FillValue', None) + assert fill_value is not None + assert np.isnan(ds2['bar'].values[1]) + ds2.close() + + +def test_write_netcdf_string_dim_name(tmp_path): + # Test that custom char_dim_name is used in encoding + arr = np.array([b'abc', b'def']) + ds = xr.Dataset({'baz': (('x',), arr)}) + out_file = tmp_path / 'test_strdim.nc' + write_netcdf(ds, str(out_file), char_dim_name='CustomStrLen') + ds2 = xr.open_dataset(out_file) + # Should have the variable and correct shape + assert 'baz' in ds2.variables + ds2.close() + arr = np.array([1, 2, 3], dtype=np.int64) + ds = xr.Dataset({'foo': (('x',), arr)}) + ds['foo'].attrs['myattr'] = 'testattr' + out_file = tmp_path / 'test_int64.nc' + write_netcdf(ds, str(out_file)) + ds2 = xr.open_dataset(out_file) + # Should be int32, not int64 + assert ds2['foo'].dtype == np.int32 + # Attribute should be preserved + assert ds2['foo'].attrs['myattr'] == 'testattr' + ds2.close()