diff --git a/data_cube_utilities/dc_mosaic.py b/data_cube_utilities/dc_mosaic.py index cc277316..b47a058d 100644 --- a/data_cube_utilities/dc_mosaic.py +++ b/data_cube_utilities/dc_mosaic.py @@ -19,6 +19,7 @@ # License for the specific language governing permissions and limitations # under the License. +import dask.array import numpy as np import xarray as xr from collections import OrderedDict @@ -164,6 +165,38 @@ def create_mean_mosaic(dataset_in, clean_mask=None, no_data=-9999, dtype=None, * dataset_out = restore_or_convert_dtypes(dtype, band_list, dataset_in_dtypes, dataset_out, no_data) return dataset_out +# Based on a gist by Andrew Hicks: https://gist.github.com/andrewdhicks/d89849997453cdfad6fa568816ca7160 + +def median(array, dim, keep_attrs=False, skipna=False, **kwargs): + """ Runs a median on an dask-backed xarray. + + This function does not scale! + It will rechunk along the given dimension, so make sure + your other chunk sizes are small enough that it + will fit into memory. + + :param array: An xarray.DataArray or xarray.Dataset wrapping one or more dask arrays + :type array: xarray.DataArray or xarray.Dataset + :param dim: The name of the dim in array to calculate the median + :type dim: str + """ + if type(array) is xr.Dataset: + return array.apply(median, dim=dim, keep_attrs=keep_attrs, skipna=skipna, **kwargs) + + if not hasattr(array.data, 'dask'): + return array.median(dim, keep_attrs=keep_attrs, skipna=skipna, **kwargs) + + array = array.chunk({dim:-1}) + axis = array.dims.index(dim) + median_func = np.nanmedian if skipna else np.median + blocks = dask.array.map_blocks(median_func, array.data, dtype=array.dtype, drop_axis=axis, axis=axis, **kwargs) + + new_coords={k: v for k, v in array.coords.items() if k != dim and dim not in v.dims} + new_dims = tuple(d for d in array.dims if d != dim) + new_attrs = array.attrs if keep_attrs else None + + return xr.DataArray(blocks, coords=new_coords, dims=new_dims, attrs=new_attrs) + def create_median_mosaic(dataset_in, clean_mask=None, no_data=-9999, dtype=None, **kwargs): """ @@ -205,7 +238,7 @@ def create_median_mosaic(dataset_in, clean_mask=None, no_data=-9999, dtype=None, # Mask out clouds and Landsat 7 scan lines. dataset_in = dataset_in.where((dataset_in != no_data) & (clean_mask)) - dataset_out = dataset_in.median(dim='time', skipna=True, keep_attrs=False) + dataset_out = median(dataset_in, dim='time', skipna=True, keep_attrs=False) # Handle datatype conversions. dataset_out = restore_or_convert_dtypes(dtype, band_list, dataset_in_dtypes, dataset_out, no_data) @@ -478,6 +511,8 @@ def create_hdmedians_multiple_band_mosaic(dataset_in, dtype=None, intermediate_product=None, operation="median", + x_coord='longitude', + y_coord='latitude', **kwargs): """ Calculates the geomedian or geomedoid using a multi-band processing method. @@ -497,6 +532,8 @@ def create_hdmedians_multiple_band_mosaic(dataset_in, A string denoting a Python datatype name (e.g. int, float) or a NumPy dtype (e.g. np.int16, np.float32) to convert the data to. operation: str in ['median', 'medoid'] + x_coord, y_coord: str + Names of DataArrays in `dataset_in` to use as x and y coordinates. Returns ------- @@ -546,8 +583,8 @@ def create_hdmedians_multiple_band_mosaic(dataset_in, for index, value in enumerate(band_list) } dataset_out = xr.Dataset(output_dict, - coords={'latitude': dataset_in['latitude'], - 'longitude': dataset_in['longitude']}, + coords={y_coord: dataset_in[y_coord], + x_coord: dataset_in[x_coord]}, attrs=dataset_in.attrs) dataset_out = restore_or_convert_dtypes(dtype, band_list, dataset_in_dtypes, dataset_out, no_data) return dataset_out