From b1e1768f657a07d64b7dc0fadf37445f20401499 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Wed, 16 Nov 2022 12:44:05 -0500 Subject: [PATCH] fix: summarize raising error when a grouping col is all NA (or mostly NA) (#459) * fix(pandas): summarize works with na group cols, preserves keys * tests: correctly skip unneeded duckdb test --- siuba/dply/verbs.py | 19 ++++++++++-- siuba/tests/test_sql_misc.py | 11 ++----- siuba/tests/test_verb_mutate.py | 11 +++++++ siuba/tests/test_verb_summarize.py | 48 ++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 12 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index f138df8d..9ff7545a 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -154,7 +154,7 @@ def _mutate_cols(__data, args, kwargs): def _make_groupby_safe(gdf): - return gdf.obj.groupby(gdf.grouper, group_keys=False) + return gdf.obj.groupby(gdf.grouper, group_keys=False, dropna=False) MSG_TYPE_ERROR = "The first argument to {func} must be one of: {types}" @@ -363,9 +363,9 @@ def group_by(__data, *args, add = False, **kwargs): # ensures group levels are recalculated if varname was in transmute groupings[varname] = varname - return tmp_df.groupby(list(groupings.values())) + return tmp_df.groupby(list(groupings.values()), dropna=False, group_keys=True) - return tmp_df.groupby(by = by_vars) + return tmp_df.groupby(by = by_vars, dropna=False, group_keys=True) @singledispatch2((pd.DataFrame, DataFrameGroupBy)) @@ -563,6 +563,19 @@ def summarize(__data, *args, **kwargs): @summarize.register(DataFrameGroupBy) def _summarize(__data, *args, **kwargs): + if __data.dropna or not __data.group_keys: + warnings.warn( + f"Grouped data passed to summarize must have dropna=False and group_keys=True." + " Regrouping with these arguments set." + ) + + if __data.grouper.dropna: + # will need to recalculate groupings, otherwise it ignores dropna + group_cols = [ping.name for ping in __data.grouper.groupings] + else: + group_cols = __data.grouper.groupings + __data = __data.obj.groupby(group_cols, dropna=False, group_keys=True) + df_summarize = summarize.registry[pd.DataFrame] df = __data.apply(df_summarize, *args, **kwargs) diff --git a/siuba/tests/test_sql_misc.py b/siuba/tests/test_sql_misc.py index c168e4f1..5501cf3c 100644 --- a/siuba/tests/test_sql_misc.py +++ b/siuba/tests/test_sql_misc.py @@ -31,17 +31,10 @@ def test_raw_sql_mutate_grouped(backend, df): ) -@pytest.mark.skip_backend("snowflake") # supported by snowflake +@pytest.mark.skip_backend("snowflake", "duckdb") # they support this behavior @backend_sql def test_raw_sql_mutate_refer_previous_raise_dberror(backend, skip_backend, df): - # Note: unlikely will be able to support this case. Normally we analyze - if backend.name == "duckdb": - # duckdb dialect re-raises the engines exception, which is RuntimeError - # the expression to know whether we need to create a subquery. - import duckdb - exc = duckdb.BinderException - else: - exc = sqlalchemy.exc.DatabaseError + exc = sqlalchemy.exc.DatabaseError with pytest.raises(exc): assert_equal_query( diff --git a/siuba/tests/test_verb_mutate.py b/siuba/tests/test_verb_mutate.py index 6a4359f6..9b8cc0ea 100644 --- a/siuba/tests/test_verb_mutate.py +++ b/siuba/tests/test_verb_mutate.py @@ -75,6 +75,17 @@ def test_mutate_reassign_all_cols_keeps_rowsize(dfs): data_frame(a = [1,1,1], b = [2,2,2]) ) + +def test_mutate_grouped_pandas_no_dropna(): + src = data_frame(x = [1, 2], g = [None, None]) + + assert_equal_query( + src, + group_by(_.g) >> mutate(res = _.x + 1), + data_frame(x = [1, 2], g = [None, None], res = [2, 3]) + ) + + @backend_sql def test_mutate_window_funcs(backend): data = data_frame(idx = range(0, 4), x = range(1, 5), g = [1,1,2,2]) diff --git a/siuba/tests/test_verb_summarize.py b/siuba/tests/test_verb_summarize.py index d32ff4ed..89700def 100644 --- a/siuba/tests/test_verb_summarize.py +++ b/siuba/tests/test_verb_summarize.py @@ -3,6 +3,8 @@ https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R """ + +import numpy as np from siuba import _, mutate, select, group_by, summarize, filter, show_query, arrange from siuba.dply.vector import row_number, n @@ -47,6 +49,52 @@ def test_summarize_after_mutate_cuml_win(backend, df_float): ) +def test_summarize_keeps_na_grouping_cols(backend): + df = data_frame(x = [1, 2, 3], g = [None, None, None]) + src = backend.load_df(df) + + if backend.name == "pandas": + missing = np.nan + else: + missing = None + + assert_equal_query( + src, + group_by(_.g) >> summarize(res = _.x.min()), + data_frame(g = [missing], res = [1]) + ) + + +def test_summarize_regroups_group_keys(): + df = data_frame(x = [1, 2, 3], g = [None, None, None]) + + # bad group_keys choice + g_df = df.groupby("g", group_keys=False, dropna=False) + + with pytest.warns(UserWarning, match="group_keys=True"): + + assert_equal_query( + g_df, + summarize(res = _.x.min()), + data_frame(g = [np.nan], res = [1]) + ) + + +def test_summarize_regroups_dropna(): + df = data_frame(x = [1, 2, 3], g = [None, None, None]) + + # bad dropna choice + g_df = df.groupby("g", group_keys=True, dropna=True) + + with pytest.warns(UserWarning, match="dropna=False"): + + assert_equal_query( + g_df, + summarize(res = _.x.min()), + data_frame(g = [np.nan], res = [1]) + ) + + @backend_sql def test_summarize_keeps_group_vars(backend, gdf): q = gdf >> summarize(n = n(_))