|
16 | 16 | )
|
17 | 17 |
|
18 | 18 |
|
19 |
| -def _check_expanding( |
20 |
| - func, static_comp, preserve_nan=True, series=None, frame=None, nan_locs=None |
21 |
| -): |
22 |
| - |
23 |
| - series_result = func(series) |
24 |
| - assert isinstance(series_result, Series) |
25 |
| - frame_result = func(frame) |
26 |
| - assert isinstance(frame_result, DataFrame) |
27 |
| - |
28 |
| - result = func(series) |
29 |
| - tm.assert_almost_equal(result[10], static_comp(series[:11])) |
30 |
| - |
31 |
| - if preserve_nan: |
32 |
| - assert result.iloc[nan_locs].isna().all() |
33 |
| - |
34 |
| - |
35 |
| -def _check_expanding_has_min_periods(func, static_comp, has_min_periods): |
36 |
| - ser = Series(np.random.randn(50)) |
37 |
| - |
38 |
| - if has_min_periods: |
39 |
| - result = func(ser, min_periods=30) |
40 |
| - assert result[:29].isna().all() |
41 |
| - tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) |
42 |
| - |
43 |
| - # min_periods is working correctly |
44 |
| - result = func(ser, min_periods=15) |
45 |
| - assert isna(result.iloc[13]) |
46 |
| - assert notna(result.iloc[14]) |
47 |
| - |
48 |
| - ser2 = Series(np.random.randn(20)) |
49 |
| - result = func(ser2, min_periods=5) |
50 |
| - assert isna(result[3]) |
51 |
| - assert notna(result[4]) |
52 |
| - |
53 |
| - # min_periods=0 |
54 |
| - result0 = func(ser, min_periods=0) |
55 |
| - result1 = func(ser, min_periods=1) |
56 |
| - tm.assert_almost_equal(result0, result1) |
57 |
| - else: |
58 |
| - result = func(ser) |
59 |
| - tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) |
60 |
| - |
61 |
| - |
62 | 19 | def test_expanding_corr(series):
|
63 | 20 | A = series.dropna()
|
64 | 21 | B = (A + np.random.randn(len(A)))[:-5]
|
@@ -111,50 +68,106 @@ def test_expanding_corr_pairwise(frame):
|
111 | 68 | tm.assert_frame_equal(result, rolling_result)
|
112 | 69 |
|
113 | 70 |
|
114 |
| -@pytest.mark.parametrize("has_min_periods", [True, False]) |
115 | 71 | @pytest.mark.parametrize(
|
116 | 72 | "func,static_comp",
|
117 | 73 | [("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
|
118 | 74 | ids=["sum", "mean", "max", "min"],
|
119 | 75 | )
|
120 |
| -def test_expanding_func(func, static_comp, has_min_periods, series, frame, nan_locs): |
121 |
| - def expanding_func(x, min_periods=1, axis=0): |
122 |
| - exp = x.expanding(min_periods=min_periods, axis=axis) |
123 |
| - return getattr(exp, func)() |
124 |
| - |
125 |
| - _check_expanding( |
126 |
| - expanding_func, |
127 |
| - static_comp, |
128 |
| - preserve_nan=False, |
129 |
| - series=series, |
130 |
| - frame=frame, |
131 |
| - nan_locs=nan_locs, |
| 76 | +def test_expanding_func(func, static_comp, frame_or_series): |
| 77 | + data = frame_or_series(np.array(list(range(10)) + [np.nan] * 10)) |
| 78 | + result = getattr(data.expanding(min_periods=1, axis=0), func)() |
| 79 | + assert isinstance(result, frame_or_series) |
| 80 | + |
| 81 | + if frame_or_series is Series: |
| 82 | + tm.assert_almost_equal(result[10], static_comp(data[:11])) |
| 83 | + else: |
| 84 | + tm.assert_series_equal( |
| 85 | + result.iloc[10], static_comp(data[:11]), check_names=False |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +@pytest.mark.parametrize( |
| 90 | + "func,static_comp", |
| 91 | + [("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)], |
| 92 | + ids=["sum", "mean", "max", "min"], |
| 93 | +) |
| 94 | +def test_expanding_min_periods(func, static_comp): |
| 95 | + ser = Series(np.random.randn(50)) |
| 96 | + |
| 97 | + result = getattr(ser.expanding(min_periods=30, axis=0), func)() |
| 98 | + assert result[:29].isna().all() |
| 99 | + tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) |
| 100 | + |
| 101 | + # min_periods is working correctly |
| 102 | + result = getattr(ser.expanding(min_periods=15, axis=0), func)() |
| 103 | + assert isna(result.iloc[13]) |
| 104 | + assert notna(result.iloc[14]) |
| 105 | + |
| 106 | + ser2 = Series(np.random.randn(20)) |
| 107 | + result = getattr(ser2.expanding(min_periods=5, axis=0), func)() |
| 108 | + assert isna(result[3]) |
| 109 | + assert notna(result[4]) |
| 110 | + |
| 111 | + # min_periods=0 |
| 112 | + result0 = getattr(ser.expanding(min_periods=0, axis=0), func)() |
| 113 | + result1 = getattr(ser.expanding(min_periods=1, axis=0), func)() |
| 114 | + tm.assert_almost_equal(result0, result1) |
| 115 | + |
| 116 | + result = getattr(ser.expanding(min_periods=1, axis=0), func)() |
| 117 | + tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) |
| 118 | + |
| 119 | + |
| 120 | +def test_expanding_apply(engine_and_raw, frame_or_series): |
| 121 | + engine, raw = engine_and_raw |
| 122 | + data = frame_or_series(np.array(list(range(10)) + [np.nan] * 10)) |
| 123 | + result = data.expanding(min_periods=1).apply( |
| 124 | + lambda x: x.mean(), raw=raw, engine=engine |
132 | 125 | )
|
133 |
| - _check_expanding_has_min_periods(expanding_func, static_comp, has_min_periods) |
| 126 | + assert isinstance(result, frame_or_series) |
134 | 127 |
|
| 128 | + if frame_or_series is Series: |
| 129 | + tm.assert_almost_equal(result[9], np.mean(data[:11])) |
| 130 | + else: |
| 131 | + tm.assert_series_equal(result.iloc[9], np.mean(data[:11]), check_names=False) |
135 | 132 |
|
136 |
| -@pytest.mark.parametrize("has_min_periods", [True, False]) |
137 |
| -def test_expanding_apply(engine_and_raw, has_min_periods, series, frame, nan_locs): |
138 | 133 |
|
| 134 | +def test_expanding_min_periods_apply(engine_and_raw): |
139 | 135 | engine, raw = engine_and_raw
|
| 136 | + ser = Series(np.random.randn(50)) |
| 137 | + |
| 138 | + result = ser.expanding(min_periods=30).apply( |
| 139 | + lambda x: x.mean(), raw=raw, engine=engine |
| 140 | + ) |
| 141 | + assert result[:29].isna().all() |
| 142 | + tm.assert_almost_equal(result.iloc[-1], np.mean(ser[:50])) |
| 143 | + |
| 144 | + # min_periods is working correctly |
| 145 | + result = ser.expanding(min_periods=15).apply( |
| 146 | + lambda x: x.mean(), raw=raw, engine=engine |
| 147 | + ) |
| 148 | + assert isna(result.iloc[13]) |
| 149 | + assert notna(result.iloc[14]) |
| 150 | + |
| 151 | + ser2 = Series(np.random.randn(20)) |
| 152 | + result = ser2.expanding(min_periods=5).apply( |
| 153 | + lambda x: x.mean(), raw=raw, engine=engine |
| 154 | + ) |
| 155 | + assert isna(result[3]) |
| 156 | + assert notna(result[4]) |
| 157 | + |
| 158 | + # min_periods=0 |
| 159 | + result0 = ser.expanding(min_periods=0).apply( |
| 160 | + lambda x: x.mean(), raw=raw, engine=engine |
| 161 | + ) |
| 162 | + result1 = ser.expanding(min_periods=1).apply( |
| 163 | + lambda x: x.mean(), raw=raw, engine=engine |
| 164 | + ) |
| 165 | + tm.assert_almost_equal(result0, result1) |
140 | 166 |
|
141 |
| - def expanding_mean(x, min_periods=1): |
142 |
| - |
143 |
| - exp = x.expanding(min_periods=min_periods) |
144 |
| - result = exp.apply(lambda x: x.mean(), raw=raw, engine=engine) |
145 |
| - return result |
146 |
| - |
147 |
| - # TODO(jreback), needed to add preserve_nan=False |
148 |
| - # here to make this pass |
149 |
| - _check_expanding( |
150 |
| - expanding_mean, |
151 |
| - np.mean, |
152 |
| - preserve_nan=False, |
153 |
| - series=series, |
154 |
| - frame=frame, |
155 |
| - nan_locs=nan_locs, |
| 167 | + result = ser.expanding(min_periods=1).apply( |
| 168 | + lambda x: x.mean(), raw=raw, engine=engine |
156 | 169 | )
|
157 |
| - _check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods) |
| 170 | + tm.assert_almost_equal(result.iloc[-1], np.mean(ser[:50])) |
158 | 171 |
|
159 | 172 |
|
160 | 173 | @pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
|
|
0 commit comments