Skip to content

Commit 0eb6658

Browse files
Enable numbagg in calculation of quantiles (#8684)
* Use `numbagg.nanquantile` by default when `method=linear` and `skipna=True` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add `"None"` option to `compute_backend` * skip tests when `compute_backend == "numbagg"` * adjust regex pattern to include numbagg error message * skip test if `compute_backend == "numbagg"` and `q == -0.1` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test quantile method w/o numbagg backend * change `compute_backend` param `"None"` to `None` * add numbagg `minversion` requirement in `quantile` method * align `test_quantile_out_of_bounds` with numbagg>=0.7.2 * avoid using numbagg on pint arrays; remove exclusion from tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move numbagg nanquantiles logic to `nputils`-module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix logic related to numbagg `nanquantiles` * fix logic related to numbagg `nanquantiles` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add `whats-new` entry --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0f7a034 commit 0eb6658

File tree

8 files changed

+41
-14
lines changed

8 files changed

+41
-14
lines changed

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ New Features
2828
By `Mathias Hauser <https://github.com/mathause>`_.
2929
- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and :py:meth:`NamedArray.broadcast_to`
3030
(:pull:`8380`) By `Anderson Banihirwe <https://github.com/andersy005>`_.
31-
3231
- Xarray now defers to flox's `heuristics <https://flox.readthedocs.io/en/latest/implementation.html#heuristics>`_
3332
to set default `method` for groupby problems. This only applies to ``flox>=0.9``.
3433
By `Deepak Cherian <https://github.com/dcherian>`_.
34+
- All `quantile` methods (e.g. :py:meth:`DataArray.quantile`) now use `numbagg`
35+
for the calculation of nanquantiles (i.e., `skipna=True`) if it is installed.
36+
This is currently limited to the linear interpolation method (`method='linear'`).
37+
(:issue:`7377`, :pull:`8684`) By `Marco Wolsza <https://github.com/maawoo>`_.
3538

3639
Breaking changes
3740
~~~~~~~~~~~~~~~~

xarray/core/nputils.py

+12
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ def f(values, axis=None, **kwargs):
195195
and values.dtype.kind in "uifc"
196196
# and values.dtype.isnative
197197
and (dtype is None or np.dtype(dtype) == values.dtype)
198+
# numbagg.nanquantile only available after 0.8.0 and with linear method
199+
and (
200+
name != "nanquantile"
201+
or (
202+
pycompat.mod_version("numbagg") >= Version("0.8.0")
203+
and kwargs.get("method", "linear") == "linear"
204+
)
205+
)
198206
):
199207
import numbagg
200208

@@ -206,6 +214,9 @@ def f(values, axis=None, **kwargs):
206214
# to ddof=1 above.
207215
if pycompat.mod_version("numbagg") < Version("0.7.0"):
208216
kwargs.pop("ddof", None)
217+
if name == "nanquantile":
218+
kwargs["quantiles"] = kwargs.pop("q")
219+
kwargs.pop("method", None)
209220
return nba_func(values, axis=axis, **kwargs)
210221
if (
211222
_BOTTLENECK_AVAILABLE
@@ -285,3 +296,4 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
285296
nancumprod = _create_method("nancumprod")
286297
nanargmin = _create_method("nanargmin")
287298
nanargmax = _create_method("nanargmax")
299+
nanquantile = _create_method("nanquantile")

xarray/core/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,7 @@ def quantile(
19921992
method = interpolation
19931993

19941994
if skipna or (skipna is None and self.dtype.kind in "cfO"):
1995-
_quantile_func = np.nanquantile
1995+
_quantile_func = nputils.nanquantile
19961996
else:
19971997
_quantile_func = np.quantile
19981998

xarray/tests/conftest.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ def backend(request):
1414
return request.param
1515

1616

17-
@pytest.fixture(params=["numbagg", "bottleneck"])
17+
@pytest.fixture(params=["numbagg", "bottleneck", None])
1818
def compute_backend(request):
19-
if request.param == "bottleneck":
19+
if request.param is None:
20+
options = dict(use_bottleneck=False, use_numbagg=False)
21+
elif request.param == "bottleneck":
2022
options = dict(use_bottleneck=True, use_numbagg=False)
2123
elif request.param == "numbagg":
2224
options = dict(use_bottleneck=False, use_numbagg=True)

xarray/tests/test_dataarray.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2888,12 +2888,13 @@ def test_reduce_out(self) -> None:
28882888
with pytest.raises(TypeError):
28892889
orig.mean(out=np.ones(orig.shape))
28902890

2891+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
28912892
@pytest.mark.parametrize("skipna", [True, False, None])
28922893
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
28932894
@pytest.mark.parametrize(
28942895
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
28952896
)
2896-
def test_quantile(self, q, axis, dim, skipna) -> None:
2897+
def test_quantile(self, q, axis, dim, skipna, compute_backend) -> None:
28972898
va = self.va.copy(deep=True)
28982899
va[0, 0] = np.nan
28992900

xarray/tests/test_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5612,9 +5612,10 @@ def test_reduce_keepdims(self) -> None:
56125612
)
56135613
assert_identical(expected, actual)
56145614

5615+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
56155616
@pytest.mark.parametrize("skipna", [True, False, None])
56165617
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
5617-
def test_quantile(self, q, skipna) -> None:
5618+
def test_quantile(self, q, skipna, compute_backend) -> None:
56185619
ds = create_test_data(seed=123)
56195620
ds.var1.data[0, 0] = np.nan
56205621

@@ -5635,8 +5636,9 @@ def test_quantile(self, q, skipna) -> None:
56355636
assert "dim3" in ds_quantile.dims
56365637
assert all(d not in ds_quantile.dims for d in dim)
56375638

5639+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
56385640
@pytest.mark.parametrize("skipna", [True, False])
5639-
def test_quantile_skipna(self, skipna) -> None:
5641+
def test_quantile_skipna(self, skipna, compute_backend) -> None:
56405642
q = 0.1
56415643
dim = "time"
56425644
ds = Dataset({"a": ([dim], np.arange(0, 11))})

xarray/tests/test_units.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,7 @@ def test_squeeze(self, dim, dtype):
20142014
assert_units_equal(expected, actual)
20152015
assert_identical(expected, actual)
20162016

2017+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
20172018
@pytest.mark.parametrize(
20182019
"func",
20192020
(
@@ -2035,7 +2036,7 @@ def test_squeeze(self, dim, dtype):
20352036
),
20362037
ids=repr,
20372038
)
2038-
def test_computation(self, func, dtype):
2039+
def test_computation(self, func, dtype, compute_backend):
20392040
base_unit = unit_registry.m
20402041
array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit
20412042
variable = xr.Variable(("x", "y"), array)
@@ -3767,6 +3768,7 @@ def test_differentiate_integrate(self, func, variant, dtype):
37673768
assert_units_equal(expected, actual)
37683769
assert_identical(expected, actual)
37693770

3771+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
37703772
@pytest.mark.parametrize(
37713773
"variant",
37723774
(
@@ -3787,7 +3789,7 @@ def test_differentiate_integrate(self, func, variant, dtype):
37873789
),
37883790
ids=repr,
37893791
)
3790-
def test_computation(self, func, variant, dtype):
3792+
def test_computation(self, func, variant, dtype, compute_backend):
37913793
unit = unit_registry.m
37923794

37933795
variants = {
@@ -3893,6 +3895,7 @@ def test_resample(self, dtype):
38933895
assert_units_equal(expected, actual)
38943896
assert_identical(expected, actual)
38953897

3898+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
38963899
@pytest.mark.parametrize(
38973900
"variant",
38983901
(
@@ -3913,7 +3916,7 @@ def test_resample(self, dtype):
39133916
),
39143917
ids=repr,
39153918
)
3916-
def test_grouped_operations(self, func, variant, dtype):
3919+
def test_grouped_operations(self, func, variant, dtype, compute_backend):
39173920
unit = unit_registry.m
39183921

39193922
variants = {
@@ -5250,6 +5253,7 @@ def test_interp_reindex_like_indexing(self, func, unit, error, dtype):
52505253
assert_units_equal(expected, actual)
52515254
assert_equal(expected, actual)
52525255

5256+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
52535257
@pytest.mark.parametrize(
52545258
"func",
52555259
(
@@ -5272,7 +5276,7 @@ def test_interp_reindex_like_indexing(self, func, unit, error, dtype):
52725276
"coords",
52735277
),
52745278
)
5275-
def test_computation(self, func, variant, dtype):
5279+
def test_computation(self, func, variant, dtype, compute_backend):
52765280
variants = {
52775281
"data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
52785282
"dims": ((1, 1), unit_registry.m, 1),
@@ -5404,6 +5408,7 @@ def test_resample(self, variant, dtype):
54045408
assert_units_equal(expected, actual)
54055409
assert_equal(expected, actual)
54065410

5411+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
54075412
@pytest.mark.parametrize(
54085413
"func",
54095414
(
@@ -5425,7 +5430,7 @@ def test_resample(self, variant, dtype):
54255430
"coords",
54265431
),
54275432
)
5428-
def test_grouped_operations(self, func, variant, dtype):
5433+
def test_grouped_operations(self, func, variant, dtype, compute_backend):
54295434
variants = {
54305435
"data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
54315436
"dims": ((1, 1), unit_registry.m, 1),

xarray/tests/test_variable.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1842,13 +1842,15 @@ def test_quantile_chunked_dim_error(self):
18421842
with pytest.raises(ValueError, match=r"consists of multiple chunks"):
18431843
v.quantile(0.5, dim="x")
18441844

1845+
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
18451846
@pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]])
1846-
def test_quantile_out_of_bounds(self, q):
1847+
def test_quantile_out_of_bounds(self, q, compute_backend):
18471848
v = Variable(["x", "y"], self.d)
18481849

18491850
# escape special characters
18501851
with pytest.raises(
1851-
ValueError, match=r"Quantiles must be in the range \[0, 1\]"
1852+
ValueError,
1853+
match=r"(Q|q)uantiles must be in the range \[0, 1\]",
18521854
):
18531855
v.quantile(q, dim="x")
18541856

0 commit comments

Comments
 (0)