diff --git a/pyproject.toml b/pyproject.toml index 7bf13af..e946f10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,3 +78,4 @@ opendata = "ai_models.inputs.opendata:OpenDataInput" [project.entry-points."ai_models.output"] file = "ai_models.outputs:FileOutput" none = "ai_models.outputs:NoneOutput" +netcdf = "ai_models.outputs:NetCDFOutput" diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 01eda32..395cb25 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -542,6 +542,10 @@ def write_input_fields( if ignore is None: ignore = [] + if all(map(lambda x: x is None, fields.metadata("shortName", default = None))): + LOG.warning("Could not find 'shortName' in metadata. Are you using a grib input? Skipping writing input fields") + return + with self.timer("Writing step 0"): for field in fields: if field.metadata("shortName") in ignore: diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 2e9a616..600e1c8 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -9,6 +9,7 @@ import logging import warnings from functools import cached_property +from collections import defaultdict import earthkit.data as ekd import entrypoints @@ -110,6 +111,40 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) LOG.info("Writing results to %s", self.path) +class NetCDFOutput(Output): + def __init__(self, owner, path, metadata, **kwargs): + metadata.setdefault("stream", "oper") + metadata.setdefault("expver", owner.expver) + metadata.setdefault("class", "ml") + + self.path = path + self.owner = owner + self.metadata = metadata + + self._outputs = defaultdict(list) + + def write(self, data, *args, check = False, **kwargs): + template = kwargs.pop("template") + step = kwargs.pop("step") + import xarray as xr + + xarray_obj: xr.DataArray = template.to_xarray() + xarray_obj.data = data + xarray_obj = xarray_obj.assign_coords(step = step) + if 'pl' in xarray_obj.coords: + xarray_obj = xarray_obj.expand_dims('pl') + if 'ml' in xarray_obj.coords: + xarray_obj = xarray_obj.expand_dims('ml') + + self._outputs[step].append(xarray_obj) + + def flush(self, *args, **kwargs): + import xarray as xr + + output = xr.concat(map(xr.merge, self._outputs.values()), dim = 'step') + output.attrs.update(self.metadata) + output.to_netcdf(self.path) + class NoneOutput(Output): def __init__(self, *args, **kwargs):