diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 600e1c8..dd8d473 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -126,15 +126,35 @@ def __init__(self, owner, path, metadata, **kwargs): def write(self, data, *args, check = False, **kwargs): template = kwargs.pop("template") step = kwargs.pop("step") + + if data is None: + return + import xarray as xr + + if isinstance(template, ekd.readers.grib.codes.GribField): + xarray_obj: xr.Dataset = template.to_xarray() + attrs = xarray_obj.attrs - xarray_obj: xr.DataArray = template.to_xarray() - xarray_obj.data = data - xarray_obj = xarray_obj.assign_coords(step = step) + xarray_obj = xarray_obj[list(xarray_obj.data_vars)[0]] + xarray_obj.attrs.update(attrs) + + xarray_obj.data = data + + if 'levtype' in attrs and 'levelist' in attrs: + xarray_obj = xarray_obj.assign_coords({attrs['levtype']: attrs['levelist']}) + + else: + xarray_obj: xr.DataArray = template.to_xarray() + xarray_obj.data = data + if 'pl' in xarray_obj.coords: xarray_obj = xarray_obj.expand_dims('pl') if 'ml' in xarray_obj.coords: xarray_obj = xarray_obj.expand_dims('ml') + + xarray_obj = xarray_obj.assign_coords(step = step) + xarray_obj.attrs.pop('_earthkit', None) self._outputs[step].append(xarray_obj)