diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eda5c9dfbf6..0a9a1464df1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -96,7 +96,7 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ - Use :py:func:`dask.array.apply_gufunc` instead of :py:func:`dask.array.blockwise` in - :py:func:`xarray.apply_ufunc` when using ``dask='parallelized'``. (:pull:`4060`, :pull:`4391`) + :py:func:`xarray.apply_ufunc` when using ``dask='parallelized'``. (:pull:`4060`, :pull:`4391`, :pull:`4392`) - Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH. By `Guido Imperiale `_ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 507f12fe55e..db4a9b46e98 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -625,6 +625,26 @@ def apply_variable_ufunc( if dask_gufunc_kwargs is None: dask_gufunc_kwargs = {} + allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None) + if allow_rechunk is None: + for n, (data, core_dims) in enumerate( + zip(input_data, signature.input_core_dims) + ): + if is_duck_dask_array(data): + # core dimensions cannot span multiple chunks + for axis, dim in enumerate(core_dims, start=-len(core_dims)): + if len(data.chunks[axis]) != 1: + raise ValueError( + f"dimension {dim} on {n}th function argument to " + "apply_ufunc with dask='parallelized' consists of " + "multiple chunks, but is also a core dimension. To " + "fix, either rechunk into a single dask array chunk along " + f"this dimension, i.e., ``.chunk({dim}: -1)``, or " + "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` " + "but beware that this may significantly increase memory usage." + ) + dask_gufunc_kwargs["allow_rechunk"] = True + output_sizes = dask_gufunc_kwargs.pop("output_sizes", {}) if output_sizes: output_sizes_renamed = {} diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 7cb755b6dac..63bedfaf280 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -695,8 +695,7 @@ def check(x, y): check(data_array, 0 * data_array) check(data_array, 0 * data_array[0]) check(data_array[:, 0], 0 * data_array[0]) - with raises_regex(ValueError, "with different chunksize present"): - check(data_array, 0 * data_array.compute()) + check(data_array, 0 * data_array.compute()) @requires_dask @@ -710,7 +709,7 @@ def test_apply_dask_parallelized_errors(): with raises_regex(ValueError, "at least one input is an xarray object"): apply_ufunc(identity, array, dask="parallelized") - # formerly from _apply_blockwise, now from dask.array.apply_gufunc + # formerly from _apply_blockwise, now from apply_variable_ufunc with raises_regex(ValueError, "consists of multiple chunks"): apply_ufunc( identity,