Skip to content

Commit a525405

Browse files
fujiisoupshoyer
authored andcommitted
Fixes centerized rolling with bottleneck. Also, fixed rolling with an integer dask array. (#2122)
1 parent d63001c commit a525405

File tree

4 files changed

+44
-10
lines changed

4 files changed

+44
-10
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ Enhancements
4949
Bug fixes
5050
~~~~~~~~~
5151

52+
- Fixed a bug in `rolling` with bottleneck. Also, fixed a bug in rolling an
53+
integer dask array. (:issue:`21133`)
54+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
5255
- Fixed a bug where `keep_attrs=True` flag was neglected if
5356
:py:func:`apply_func` was used with :py:class:`Variable`. (:issue:`2114`)
5457
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

xarray/core/dask_array_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
from . import nputils
6+
from . import dtypes
67

78
try:
89
import dask.array as da
@@ -12,12 +13,14 @@
1213

1314
def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
1415
'''wrapper to apply bottleneck moving window funcs on dask arrays'''
16+
dtype, fill_value = dtypes.maybe_promote(a.dtype)
17+
a = a.astype(dtype)
1518
# inputs for ghost
1619
if axis < 0:
1720
axis = a.ndim + axis
1821
depth = {d: 0 for d in range(a.ndim)}
1922
depth[axis] = window - 1
20-
boundary = {d: np.nan for d in range(a.ndim)}
23+
boundary = {d: fill_value for d in range(a.ndim)}
2124
# create ghosted arrays
2225
ag = da.ghost.ghost(a, depth=depth, boundary=boundary)
2326
# apply rolling func

xarray/core/rolling.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,26 @@ def wrapped_func(self, **kwargs):
285285

286286
padded = self.obj.variable
287287
if self.center:
288-
shift = (-self.window // 2) + 1
289-
290288
if (LooseVersion(np.__version__) < LooseVersion('1.13') and
291289
self.obj.dtype.kind == 'b'):
292290
# with numpy < 1.13 bottleneck cannot handle np.nan-Boolean
293291
# mixed array correctly. We cast boolean array to float.
294292
padded = padded.astype(float)
293+
294+
if isinstance(padded.data, dask_array_type):
295+
# Workaround to make the padded chunk size is larger than
296+
# self.window-1
297+
shift = - (self.window - 1)
298+
offset = -shift - self.window // 2
299+
valid = (slice(None), ) * axis + (
300+
slice(offset, offset + self.obj.shape[axis]), )
301+
else:
302+
shift = (-self.window // 2) + 1
303+
valid = (slice(None), ) * axis + (slice(-shift, None), )
295304
padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)})
296-
valid = (slice(None), ) * axis + (slice(-shift, None), )
297305

298306
if isinstance(padded.data, dask_array_type):
299-
values = dask_rolling_wrapper(func, self.obj.data,
307+
values = dask_rolling_wrapper(func, padded,
300308
window=self.window,
301309
min_count=min_count,
302310
axis=axis)

xarray/tests/test_dataarray.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3439,23 +3439,43 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods):
34393439
assert_equal(actual, da['time'])
34403440

34413441

3442-
@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'min', 'max',
3443-
'median'))
3442+
@pytest.mark.parametrize('name', ('mean', 'count'))
34443443
@pytest.mark.parametrize('center', (True, False, None))
34453444
@pytest.mark.parametrize('min_periods', (1, None))
3446-
def test_rolling_wrapped_bottleneck_dask(da_dask, name, center, min_periods):
3445+
@pytest.mark.parametrize('window', (7, 8))
3446+
def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window):
34473447
pytest.importorskip('dask.array')
34483448
# dask version
3449-
rolling_obj = da_dask.rolling(time=7, min_periods=min_periods)
3449+
rolling_obj = da_dask.rolling(time=window, min_periods=min_periods,
3450+
center=center)
34503451
actual = getattr(rolling_obj, name)().load()
34513452
# numpy version
3452-
rolling_obj = da_dask.load().rolling(time=7, min_periods=min_periods)
3453+
rolling_obj = da_dask.load().rolling(time=window, min_periods=min_periods,
3454+
center=center)
34533455
expected = getattr(rolling_obj, name)()
34543456

34553457
# using all-close because rolling over ghost cells introduces some
34563458
# precision errors
34573459
assert_allclose(actual, expected)
34583460

3461+
# with zero chunked array GH:2113
3462+
rolling_obj = da_dask.chunk().rolling(time=window, min_periods=min_periods,
3463+
center=center)
3464+
actual = getattr(rolling_obj, name)().load()
3465+
assert_allclose(actual, expected)
3466+
3467+
3468+
@pytest.mark.parametrize('center', (True, None))
3469+
def test_rolling_wrapped_dask_nochunk(center):
3470+
# GH:2113
3471+
pytest.importorskip('dask.array')
3472+
3473+
da_day_clim = xr.DataArray(np.arange(1, 367),
3474+
coords=[np.arange(1, 367)], dims='dayofyear')
3475+
expected = da_day_clim.rolling(dayofyear=31, center=center).mean()
3476+
actual = da_day_clim.chunk().rolling(dayofyear=31, center=center).mean()
3477+
assert_allclose(actual, expected)
3478+
34593479

34603480
@pytest.mark.parametrize('center', (True, False))
34613481
@pytest.mark.parametrize('min_periods', (None, 1, 2, 3))

0 commit comments

Comments
 (0)