Skip to content

Commit 085a22d

Browse files
authored
CLN: tests/window/* (#37926)
1 parent a170e97 commit 085a22d

File tree

5 files changed

+102
-94
lines changed

5 files changed

+102
-94
lines changed

pandas/tests/window/conftest.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,18 @@ def win_types_special(request):
3535

3636

3737
@pytest.fixture(
38-
params=["sum", "mean", "median", "max", "min", "var", "std", "kurt", "skew"]
38+
params=[
39+
"sum",
40+
"mean",
41+
"median",
42+
"max",
43+
"min",
44+
"var",
45+
"std",
46+
"kurt",
47+
"skew",
48+
"count",
49+
]
3950
)
4051
def arithmetic_win_operators(request):
4152
return request.param

pandas/tests/window/moments/test_moments_consistency_expanding.py

+88-75
Original file line numberDiff line numberDiff line change
@@ -16,49 +16,6 @@
1616
)
1717

1818

19-
def _check_expanding(
20-
func, static_comp, preserve_nan=True, series=None, frame=None, nan_locs=None
21-
):
22-
23-
series_result = func(series)
24-
assert isinstance(series_result, Series)
25-
frame_result = func(frame)
26-
assert isinstance(frame_result, DataFrame)
27-
28-
result = func(series)
29-
tm.assert_almost_equal(result[10], static_comp(series[:11]))
30-
31-
if preserve_nan:
32-
assert result.iloc[nan_locs].isna().all()
33-
34-
35-
def _check_expanding_has_min_periods(func, static_comp, has_min_periods):
36-
ser = Series(np.random.randn(50))
37-
38-
if has_min_periods:
39-
result = func(ser, min_periods=30)
40-
assert result[:29].isna().all()
41-
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
42-
43-
# min_periods is working correctly
44-
result = func(ser, min_periods=15)
45-
assert isna(result.iloc[13])
46-
assert notna(result.iloc[14])
47-
48-
ser2 = Series(np.random.randn(20))
49-
result = func(ser2, min_periods=5)
50-
assert isna(result[3])
51-
assert notna(result[4])
52-
53-
# min_periods=0
54-
result0 = func(ser, min_periods=0)
55-
result1 = func(ser, min_periods=1)
56-
tm.assert_almost_equal(result0, result1)
57-
else:
58-
result = func(ser)
59-
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
60-
61-
6219
def test_expanding_corr(series):
6320
A = series.dropna()
6421
B = (A + np.random.randn(len(A)))[:-5]
@@ -111,50 +68,106 @@ def test_expanding_corr_pairwise(frame):
11168
tm.assert_frame_equal(result, rolling_result)
11269

11370

114-
@pytest.mark.parametrize("has_min_periods", [True, False])
11571
@pytest.mark.parametrize(
11672
"func,static_comp",
11773
[("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
11874
ids=["sum", "mean", "max", "min"],
11975
)
120-
def test_expanding_func(func, static_comp, has_min_periods, series, frame, nan_locs):
121-
def expanding_func(x, min_periods=1, axis=0):
122-
exp = x.expanding(min_periods=min_periods, axis=axis)
123-
return getattr(exp, func)()
124-
125-
_check_expanding(
126-
expanding_func,
127-
static_comp,
128-
preserve_nan=False,
129-
series=series,
130-
frame=frame,
131-
nan_locs=nan_locs,
76+
def test_expanding_func(func, static_comp, frame_or_series):
77+
data = frame_or_series(np.array(list(range(10)) + [np.nan] * 10))
78+
result = getattr(data.expanding(min_periods=1, axis=0), func)()
79+
assert isinstance(result, frame_or_series)
80+
81+
if frame_or_series is Series:
82+
tm.assert_almost_equal(result[10], static_comp(data[:11]))
83+
else:
84+
tm.assert_series_equal(
85+
result.iloc[10], static_comp(data[:11]), check_names=False
86+
)
87+
88+
89+
@pytest.mark.parametrize(
90+
"func,static_comp",
91+
[("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
92+
ids=["sum", "mean", "max", "min"],
93+
)
94+
def test_expanding_min_periods(func, static_comp):
95+
ser = Series(np.random.randn(50))
96+
97+
result = getattr(ser.expanding(min_periods=30, axis=0), func)()
98+
assert result[:29].isna().all()
99+
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
100+
101+
# min_periods is working correctly
102+
result = getattr(ser.expanding(min_periods=15, axis=0), func)()
103+
assert isna(result.iloc[13])
104+
assert notna(result.iloc[14])
105+
106+
ser2 = Series(np.random.randn(20))
107+
result = getattr(ser2.expanding(min_periods=5, axis=0), func)()
108+
assert isna(result[3])
109+
assert notna(result[4])
110+
111+
# min_periods=0
112+
result0 = getattr(ser.expanding(min_periods=0, axis=0), func)()
113+
result1 = getattr(ser.expanding(min_periods=1, axis=0), func)()
114+
tm.assert_almost_equal(result0, result1)
115+
116+
result = getattr(ser.expanding(min_periods=1, axis=0), func)()
117+
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
118+
119+
120+
def test_expanding_apply(engine_and_raw, frame_or_series):
121+
engine, raw = engine_and_raw
122+
data = frame_or_series(np.array(list(range(10)) + [np.nan] * 10))
123+
result = data.expanding(min_periods=1).apply(
124+
lambda x: x.mean(), raw=raw, engine=engine
132125
)
133-
_check_expanding_has_min_periods(expanding_func, static_comp, has_min_periods)
126+
assert isinstance(result, frame_or_series)
134127

128+
if frame_or_series is Series:
129+
tm.assert_almost_equal(result[9], np.mean(data[:11]))
130+
else:
131+
tm.assert_series_equal(result.iloc[9], np.mean(data[:11]), check_names=False)
135132

136-
@pytest.mark.parametrize("has_min_periods", [True, False])
137-
def test_expanding_apply(engine_and_raw, has_min_periods, series, frame, nan_locs):
138133

134+
def test_expanding_min_periods_apply(engine_and_raw):
139135
engine, raw = engine_and_raw
136+
ser = Series(np.random.randn(50))
137+
138+
result = ser.expanding(min_periods=30).apply(
139+
lambda x: x.mean(), raw=raw, engine=engine
140+
)
141+
assert result[:29].isna().all()
142+
tm.assert_almost_equal(result.iloc[-1], np.mean(ser[:50]))
143+
144+
# min_periods is working correctly
145+
result = ser.expanding(min_periods=15).apply(
146+
lambda x: x.mean(), raw=raw, engine=engine
147+
)
148+
assert isna(result.iloc[13])
149+
assert notna(result.iloc[14])
150+
151+
ser2 = Series(np.random.randn(20))
152+
result = ser2.expanding(min_periods=5).apply(
153+
lambda x: x.mean(), raw=raw, engine=engine
154+
)
155+
assert isna(result[3])
156+
assert notna(result[4])
157+
158+
# min_periods=0
159+
result0 = ser.expanding(min_periods=0).apply(
160+
lambda x: x.mean(), raw=raw, engine=engine
161+
)
162+
result1 = ser.expanding(min_periods=1).apply(
163+
lambda x: x.mean(), raw=raw, engine=engine
164+
)
165+
tm.assert_almost_equal(result0, result1)
140166

141-
def expanding_mean(x, min_periods=1):
142-
143-
exp = x.expanding(min_periods=min_periods)
144-
result = exp.apply(lambda x: x.mean(), raw=raw, engine=engine)
145-
return result
146-
147-
# TODO(jreback), needed to add preserve_nan=False
148-
# here to make this pass
149-
_check_expanding(
150-
expanding_mean,
151-
np.mean,
152-
preserve_nan=False,
153-
series=series,
154-
frame=frame,
155-
nan_locs=nan_locs,
167+
result = ser.expanding(min_periods=1).apply(
168+
lambda x: x.mean(), raw=raw, engine=engine
156169
)
157-
_check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods)
170+
tm.assert_almost_equal(result.iloc[-1], np.mean(ser[:50]))
158171

159172

160173
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])

pandas/tests/window/test_rolling.py

-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_numpy_compat(method):
122122
getattr(r, method)(dtype=np.float64)
123123

124124

125-
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
126125
def test_closed_fixed(closed, arithmetic_win_operators):
127126
# GH 34315
128127
func_name = arithmetic_win_operators

pandas/tests/window/test_timeseries_window.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -621,23 +621,8 @@ def test_all(self, f):
621621
expected = er.quantile(0.5)
622622
tm.assert_frame_equal(result, expected)
623623

624-
@pytest.mark.parametrize(
625-
"f",
626-
[
627-
"sum",
628-
"mean",
629-
"count",
630-
"median",
631-
"std",
632-
"var",
633-
"kurt",
634-
"skew",
635-
"min",
636-
"max",
637-
],
638-
)
639-
def test_all2(self, f):
640-
624+
def test_all2(self, arithmetic_win_operators):
625+
f = arithmetic_win_operators
641626
# more sophisticated comparison of integer vs.
642627
# time-based windowing
643628
df = DataFrame(

0 commit comments

Comments
 (0)