Skip to content

Commit fa515d2

Browse files
phoflcrusaderky
andauthored
Implement cumulative aggregation (dask#433)
Co-authored-by: crusaderky <[email protected]>
1 parent 9f76576 commit fa515d2

File tree

5 files changed

+176
-0
lines changed

5 files changed

+176
-0
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ API Coverage
7979
- `combine_first`
8080
- `copy`
8181
- `count`
82+
- `cummax`
83+
- `cummin`
84+
- `cumprod`
85+
- `cumsum`
8286
- `dask`
8387
- `div`
8488
- `divide`
@@ -169,6 +173,10 @@ API Coverage
169173
- `combine_first`
170174
- `copy`
171175
- `count`
176+
- `cummax`
177+
- `cummin`
178+
- `cumprod`
179+
- `cumsum`
172180
- `dask`
173181
- `div`
174182
- `divide`
@@ -204,6 +212,7 @@ API Coverage
204212
- `partitions`
205213
- `pow`
206214
- `prod`
215+
- `product`
207216
- `radd`
208217
- `rdiv`
209218
- `rename`
@@ -295,6 +304,7 @@ API Coverage
295304
- `mean`
296305
- `median`
297306
- `min`
307+
- `nunique`
298308
- `prod`
299309
- `shift`
300310
- `size`

dask_expr/_collection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,18 @@ def align(self, other, join="outer", fill_value=None):
765765
def nunique_approx(self):
766766
return new_collection(self.expr.nunique_approx())
767767

768+
def cumsum(self, skipna=True):
769+
return new_collection(self.expr.cumsum(skipna=skipna))
770+
771+
def cumprod(self, skipna=True):
772+
return new_collection(self.expr.cumprod(skipna=skipna))
773+
774+
def cummax(self, skipna=True):
775+
return new_collection(self.expr.cummax(skipna=skipna))
776+
777+
def cummin(self, skipna=True):
778+
return new_collection(self.expr.cummin(skipna=skipna))
779+
768780
def memory_usage_per_partition(self, index=True, deep=False):
769781
return new_collection(self.expr.memory_usage_per_partition(index, deep))
770782

dask_expr/_cumulative.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import functools
2+
3+
from dask.dataframe import methods
4+
from dask.utils import M
5+
6+
from dask_expr._expr import Blockwise, Expr, Projection
7+
8+
9+
class CumulativeAggregations(Expr):
10+
_parameters = ["frame", "axis", "skipna"]
11+
_defaults = {"axis": None}
12+
13+
chunk_operation = None
14+
aggregate_operation = None
15+
16+
def _divisions(self):
17+
return self.frame._divisions()
18+
19+
@functools.cached_property
20+
def _meta(self):
21+
return self.frame._meta
22+
23+
def _lower(self):
24+
chunks = CumulativeBlockwise(
25+
self.frame, self.axis, self.skipna, self.chunk_operation
26+
)
27+
chunks_last = TakeLast(chunks, self.skipna)
28+
return CumulativeFinalize(chunks, chunks_last, self.aggregate_operation)
29+
30+
def _simplify_up(self, parent):
31+
if isinstance(parent, Projection):
32+
return type(self)(self.frame[parent.operand("columns")], *self.operands[1:])
33+
34+
35+
class CumulativeBlockwise(Blockwise):
36+
_parameters = ["frame", "axis", "skipna", "operation"]
37+
_defaults = {"skipna": True, "axis": None}
38+
_projection_passthrough = True
39+
40+
@functools.cached_property
41+
def _meta(self):
42+
return self.frame._meta
43+
44+
@functools.cached_property
45+
def operation(self):
46+
return self.operand("operation")
47+
48+
@functools.cached_property
49+
def _args(self) -> list:
50+
return self.operands[:-1]
51+
52+
53+
class TakeLast(Blockwise):
54+
_parameters = ["frame", "skipna"]
55+
_projection_passthrough = True
56+
57+
@staticmethod
58+
def operation(a, skipna=True):
59+
if skipna:
60+
a = a.bfill()
61+
return a.tail(n=1).squeeze()
62+
63+
64+
class CumulativeFinalize(Expr):
65+
_parameters = ["frame", "previous_partitions", "aggregator"]
66+
67+
def _divisions(self):
68+
return self.frame._divisions()
69+
70+
@functools.cached_property
71+
def _meta(self):
72+
return self.frame._meta
73+
74+
def _layer(self) -> dict:
75+
dsk = {}
76+
frame, previous_partitions = self.frame, self.previous_partitions
77+
dsk[(self._name, 0)] = (frame._name, 0)
78+
79+
intermediate_name = self._name + "-intermediate"
80+
for i in range(1, self.frame.npartitions):
81+
if i == 1:
82+
dsk[(intermediate_name, i)] = (previous_partitions._name, i - 1)
83+
else:
84+
# aggregate with previous cumulation results
85+
dsk[(intermediate_name, i)] = (
86+
methods._cum_aggregate_apply,
87+
self.aggregator,
88+
(intermediate_name, i - 1),
89+
(previous_partitions._name, i - 1),
90+
)
91+
dsk[(self._name, i)] = (
92+
self.aggregator,
93+
(self.frame._name, i),
94+
(intermediate_name, i),
95+
)
96+
return dsk
97+
98+
99+
class CumSum(CumulativeAggregations):
100+
chunk_operation = M.cumsum
101+
aggregate_operation = staticmethod(methods.cumsum_aggregate)
102+
103+
104+
class CumProd(CumulativeAggregations):
105+
chunk_operation = M.cumprod
106+
aggregate_operation = staticmethod(methods.cumprod_aggregate)
107+
108+
109+
class CumMax(CumulativeAggregations):
110+
chunk_operation = M.cummax
111+
aggregate_operation = staticmethod(methods.cummax_aggregate)
112+
113+
114+
class CumMin(CumulativeAggregations):
115+
chunk_operation = M.cummin
116+
aggregate_operation = staticmethod(methods.cummin_aggregate)

dask_expr/_expr.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,26 @@ def min(self, skipna=True, numeric_only=False, min_count=0):
361361
def count(self, numeric_only=False):
362362
return Count(self, numeric_only)
363363

364+
def cumsum(self, skipna=True):
365+
from dask_expr._cumulative import CumSum
366+
367+
return CumSum(self, skipna=skipna)
368+
369+
def cumprod(self, skipna=True):
370+
from dask_expr._cumulative import CumProd
371+
372+
return CumProd(self, skipna=skipna)
373+
374+
def cummax(self, skipna=True):
375+
from dask_expr._cumulative import CumMax
376+
377+
return CumMax(self, skipna=skipna)
378+
379+
def cummin(self, skipna=True):
380+
from dask_expr._cumulative import CumMin
381+
382+
return CumMin(self, skipna=skipna)
383+
364384
def abs(self):
365385
return Abs(self)
366386

dask_expr/tests/test_collection.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,24 @@ def test_std_kwargs(axis, skipna, ddof):
242242
)
243243

244244

245+
@pytest.mark.parametrize("func", ["cumsum", "cumprod", "cummin", "cummax"])
246+
def test_cumulative_methods(df, pdf, func):
247+
assert_eq(getattr(df, func)(), getattr(pdf, func)(), check_dtype=False)
248+
assert_eq(getattr(df.x, func)(), getattr(pdf.x, func)())
249+
250+
q = getattr(df, func)()["x"]
251+
assert q.simplify()._name == getattr(df.x, func)()
252+
253+
pdf.loc[slice(None, None, 2), "x"] = np.nan
254+
df = from_pandas(pdf, npartitions=10)
255+
assert_eq(
256+
getattr(df, func)(skipna=False),
257+
getattr(pdf, func)(skipna=False),
258+
check_dtype=False,
259+
)
260+
assert_eq(getattr(df.x, func)(skipna=False), getattr(pdf.x, func)(skipna=False))
261+
262+
245263
@xfail_gpu("nbytes not supported by cudf")
246264
def test_nbytes(pdf, df):
247265
with pytest.raises(NotImplementedError, match="nbytes is not implemented"):

0 commit comments

Comments
 (0)