From cba4a2c4bbcd5a320c25d8f35d5ad7478b60b213 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 17 Sep 2023 16:03:19 -0500 Subject: [PATCH 1/9] fix: coerce column names in pivot funcs to strings --- siuba/experimental/pivot/pivot_wide.py | 5 +++-- siuba/experimental/pivot/sql_pivot_wide.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/siuba/experimental/pivot/pivot_wide.py b/siuba/experimental/pivot/pivot_wide.py index a65d93bb..f5f9f9c9 100644 --- a/siuba/experimental/pivot/pivot_wide.py +++ b/siuba/experimental/pivot/pivot_wide.py @@ -403,8 +403,9 @@ def pivot_wider_spec( # validate names and move id vars to columns ---- # note: in pandas 1.5+ we can use the allow_duplicates option to reset, even # when index and column names overlap. for now, repair names, rename, then reset. - unique_names = vec_as_names([*id_vars, *wide.columns], repair="unique") - repaired_names = vec_as_names([*id_vars, *wide.columns], repair=names_repair) + _all_raw_names = list(map(str, [*id_vars, *wide.columns])) + unique_names = vec_as_names(_all_raw_names, repair="unique") + repaired_names = vec_as_names(_all_raw_names, repair=names_repair) uniq_id_vars = unique_names[:len(id_vars)] uniq_val_vars = unique_names[len(id_vars):] diff --git a/siuba/experimental/pivot/sql_pivot_wide.py b/siuba/experimental/pivot/sql_pivot_wide.py index 1406134b..eadba7a8 100644 --- a/siuba/experimental/pivot/sql_pivot_wide.py +++ b/siuba/experimental/pivot/sql_pivot_wide.py @@ -131,7 +131,8 @@ def _pivot_wider_spec( wide_id_cols = [sel_cols[id_] for id_ in id_vars] - repaired_names = vec_as_names([*id_vars, *spec[".name"]], repair=names_repair) + _all_raw_names = list(map(str, [*id_vars, *spec[".name"]])) + repaired_names = vec_as_names(_all_raw_names, repair=names_repair) labeled_cols = [ col.label(name) for name, col in zip(repaired_names, [*wide_id_cols, *wide_name_cols]) From 7073cbbff1b49f74615642edf0c0c809dd7b4ba9 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 17 Sep 2023 16:04:05 -0500 Subject: [PATCH 2/9] feat: tbl with duckdb now accept polars DataFrame --- siuba/dply/verbs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 9ff7545a..ac815e30 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -2545,7 +2545,7 @@ def _extract_gdf(__data, *args, **kwargs): # tbl ---- -from siuba.siu._databackend import SqlaEngine +from siuba.siu._databackend import SqlaEngine, PlDataFrame, PdDataFrame @singledispatch2((pd.DataFrame, DataFrameGroupBy)) def tbl(src, *args, **kwargs): @@ -2614,7 +2614,7 @@ def _tbl_sqla(src: SqlaEngine, table_name, columns=None): # TODO: once we subclass LazyTbl per dialect (e.g. duckdb), we can move out # this dialect specific logic. - if src.dialect.name == "duckdb" and isinstance(columns, pd.DataFrame): + if src.dialect.name == "duckdb" and isinstance(columns, (PdDataFrame, PlDataFrame)): src.execute("register", (table_name, columns)) return LazyTbl(src, table_name) From 4c9b430f634656c647e7eeecd998bf16d860fc6f Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 17 Sep 2023 16:14:19 -0500 Subject: [PATCH 3/9] dev: add polars DataFrame to databackend --- siuba/siu/_databackend.py | 4 ++++ siuba/siu/calls.py | 27 ++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/siuba/siu/_databackend.py b/siuba/siu/_databackend.py index 25db0b6b..8375bb61 100644 --- a/siuba/siu/_databackend.py +++ b/siuba/siu/_databackend.py @@ -48,5 +48,9 @@ def __subclasshook__(cls, subclass): # Implementations ------------------------------------------------------------- class SqlaEngine(AbstractBackend): pass +class PlDataFrame(AbstractBackend): pass +class PdDataFrame(AbstractBackend): pass SqlaEngine.register_backend("sqlalchemy.engine", "Connectable") +PlDataFrame.register_backend("polars", "DataFrame") +PdDataFrame.register_backend("pandas", "DataFrame") diff --git a/siuba/siu/calls.py b/siuba/siu/calls.py index e66224e8..7fb6fe39 100644 --- a/siuba/siu/calls.py +++ b/siuba/siu/calls.py @@ -308,9 +308,19 @@ def obj_name(self): return None - @classmethod - def _construct_pipe(cls, *args): - return PipeCall(*args) + @staticmethod + def _construct_pipe(meta, lhs, rhs): + if isinstance(lhs, PipeCall): + lh_args = lhs.args + else: + lh_args = [lhs] + + if isinstance(rhs, PipeCall): + rh_args = rhs.args + else: + rh_args = [rhs] + + return PipeCall(meta, *lh_args, *rh_args) class Lazy(Call): @@ -674,8 +684,15 @@ class PipeCall(Call): """ def __init__(self, func, *args, **kwargs): - self.func = "__siu_pipe_call__" - self.args = (func, *args) + if func == "__siu_pipe_call__": + # it was a mistake to make func the first parameter to Call + # but basically we need to catch when it is passed, so + # we can ignore it + self.func = func + self.args = args + else: + self.func = "__siu_pipe_call__" + self.args = (func, *args) if kwargs: raise ValueError("Keyword arguments are not allowed.") self.kwargs = {} From 0bf8ea4dbed7459aef5be25dda2c5d8aa5712492 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 17 Sep 2023 16:26:16 -0500 Subject: [PATCH 4/9] fix: do not put extra MetaArg in the pipe call --- siuba/siu/calls.py | 6 +++--- siuba/siu/symbolic.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/siuba/siu/calls.py b/siuba/siu/calls.py index 7fb6fe39..e429b181 100644 --- a/siuba/siu/calls.py +++ b/siuba/siu/calls.py @@ -201,7 +201,7 @@ def __rshift__(self, x): stripped = strip_symbolic(x) if isinstance(stripped, Call): - return self._construct_pipe(MetaArg("_"), self, x) + return self._construct_pipe(self, x) raise TypeError() @@ -309,7 +309,7 @@ def obj_name(self): return None @staticmethod - def _construct_pipe(meta, lhs, rhs): + def _construct_pipe(lhs, rhs): if isinstance(lhs, PipeCall): lh_args = lhs.args else: @@ -320,7 +320,7 @@ def _construct_pipe(meta, lhs, rhs): else: rh_args = [rhs] - return PipeCall(meta, *lh_args, *rh_args) + return PipeCall(*lh_args, *rh_args) class Lazy(Call): diff --git a/siuba/siu/symbolic.py b/siuba/siu/symbolic.py index 4a710d6e..3960a389 100644 --- a/siuba/siu/symbolic.py +++ b/siuba/siu/symbolic.py @@ -85,7 +85,7 @@ def __rshift__(self, x): if isinstance(stripped, Call): lhs_call = self.__source - return Call._construct_pipe(MetaArg("_"), lhs_call, stripped) + return self.__class__(Call._construct_pipe(lhs_call, stripped)) # strip_symbolic(self)(x) # x is a symbolic raise NotImplementedError("Symbolic may only be used on right-hand side of >> operator.") From cab45ee41ae966fcfee4bfcd72a44a6904328006 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 17 Sep 2023 16:27:00 -0500 Subject: [PATCH 5/9] tests: test pipe call is flat, number of steps --- siuba/tests/test_siu_dispatchers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/siuba/tests/test_siu_dispatchers.py b/siuba/tests/test_siu_dispatchers.py index f2aa1ad7..de14f93a 100644 --- a/siuba/tests/test_siu_dispatchers.py +++ b/siuba/tests/test_siu_dispatchers.py @@ -1,11 +1,21 @@ import pytest +from siuba.siu.calls import PipeCall from siuba.siu.dispatchers import call -from siuba.siu import _ +from siuba.siu import _, strip_symbolic, Symbolic # TODO: direct test of lazy elements # TODO: NSECall - no map subcalls + +def test_siu_pipe_call_is_flat(): + pipe_expr = _ >> _.a >> _.b + pipe_call = strip_symbolic(pipe_expr) + + assert isinstance(pipe_expr, Symbolic) + assert isinstance(pipe_call, PipeCall) + assert len(pipe_call.args) == 3 + def test_siu_call_no_args(): assert 1 >> call(range) == range(1) From 5752fb189fe086c7337d7b7cb68b8278c33dd822 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 17 Sep 2023 16:28:21 -0500 Subject: [PATCH 6/9] fix(sql): explicitly get list of column names --- siuba/sql/across.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/siuba/sql/across.py b/siuba/sql/across.py index 59a2ba5e..58908067 100644 --- a/siuba/sql/across.py +++ b/siuba/sql/across.py @@ -50,8 +50,10 @@ def _across_sql_cols( lazy_tbl = ctx_verb_data.get() window = ctx_verb_window.get() + column_names = list(__data.keys()) + name_template = _get_name_template(fns, names) - selected_cols = var_select(__data, *var_create(cols), data=__data) + selected_cols = var_select(column_names, *var_create(cols), data=__data) fns_map = _across_setup_fns(fns) From adc4c18755b5ed2ea7277f3104955de32e0308a8 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 18 Sep 2023 10:49:35 -0500 Subject: [PATCH 7/9] fix(siu): ensure the first arg to a pipe is MetaArg again --- siuba/siu/calls.py | 13 +++++++++++-- siuba/tests/test_siu_dispatchers.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/siuba/siu/calls.py b/siuba/siu/calls.py index e429b181..f32fd94b 100644 --- a/siuba/siu/calls.py +++ b/siuba/siu/calls.py @@ -312,15 +312,24 @@ def obj_name(self): def _construct_pipe(lhs, rhs): if isinstance(lhs, PipeCall): lh_args = lhs.args + + # ensure we don't keep adding MetaArg to the left when + # combining two pipes + if lh_args and isinstance(lh_args[0], MetaArg): + lh_args = lh_args[1:] else: lh_args = [lhs] if isinstance(rhs, PipeCall): rh_args = rhs.args + + # similar to above, but for rh args + if rh_args and isinstance(rh_args[0], MetaArg): + rh_args = rh_args[1:] else: rh_args = [rhs] - return PipeCall(*lh_args, *rh_args) + return PipeCall(MetaArg("_"), *lh_args, *rh_args) class Lazy(Call): @@ -684,7 +693,7 @@ class PipeCall(Call): """ def __init__(self, func, *args, **kwargs): - if func == "__siu_pipe_call__": + if isinstance(func, str) and func == "__siu_pipe_call__": # it was a mistake to make func the first parameter to Call # but basically we need to catch when it is passed, so # we can ignore it diff --git a/siuba/tests/test_siu_dispatchers.py b/siuba/tests/test_siu_dispatchers.py index de14f93a..8e4d70b8 100644 --- a/siuba/tests/test_siu_dispatchers.py +++ b/siuba/tests/test_siu_dispatchers.py @@ -14,7 +14,7 @@ def test_siu_pipe_call_is_flat(): assert isinstance(pipe_expr, Symbolic) assert isinstance(pipe_call, PipeCall) - assert len(pipe_call.args) == 3 + assert len(pipe_call.args) == 4 def test_siu_call_no_args(): assert 1 >> call(range) == range(1) From 4f69352067cf91021813af3379c085ae8626d74c Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 19 Sep 2023 10:24:24 -0500 Subject: [PATCH 8/9] feat: convert_literal for handling literals to sql --- siuba/dply/verbs.py | 4 +++- siuba/sql/dialects/duckdb.py | 12 +++++++++++- siuba/sql/translate.py | 10 ++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index a2b8e8cc..cda82792 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -2614,7 +2614,9 @@ def _tbl_sqla(src: SqlaEngine, table_name, columns=None): # TODO: once we subclass LazyTbl per dialect (e.g. duckdb), we can move out # this dialect specific logic. if src.dialect.name == "duckdb" and isinstance(columns, (PdDataFrame, PlDataFrame)): - src.execute("register", (table_name, columns)) + with src.begin() as conn: + conn.exec_driver_sql("register", (table_name, columns)) + return LazyTbl(src, table_name) return LazyTbl(src, table_name, columns=columns) diff --git a/siuba/sql/dialects/duckdb.py b/siuba/sql/dialects/duckdb.py index 8f0bc25c..eb73f476 100644 --- a/siuba/sql/dialects/duckdb.py +++ b/siuba/sql/dialects/duckdb.py @@ -13,7 +13,8 @@ sql_not_impl, # wiring up translator extend_base, - SqlTranslator + SqlTranslator, + convert_literal ) from .postgresql import ( @@ -44,6 +45,15 @@ def returns_int(func_names): f_annotated = wrap_annotate(f_concrete, result_type="int") generic.register(DuckdbColumn, f_annotated) +# Literal Conversions ========================================================= + +@convert_literal.register +def _cl_duckdb(codata: DuckdbColumn, lit): + from sqlalchemy.dialects.postgresql import array + if isinstance(lit, list): + return array(lit) + + return sql.literal(lit) # Translations ================================================================ diff --git a/siuba/sql/translate.py b/siuba/sql/translate.py index 62348727..79ab2a2d 100644 --- a/siuba/sql/translate.py +++ b/siuba/sql/translate.py @@ -252,6 +252,7 @@ def wrapper(*args, **kwargs): # Translator ================================================================= from siuba.ops.translate import create_pandas_translator +from functools import singledispatch def extend_base(cls, **kwargs): @@ -323,7 +324,7 @@ def shape_call( # verbs that can use strings as accessors, like group_by, or # arrange, need to convert those strings into a getitem call return str_to_getitem_call(call) - elif isinstance(call, sql.elements.ColumnClause): + elif isinstance(call, (sql.elements.ClauseElement)): return Lazy(call) elif callable(call): #TODO: should not happen here @@ -332,7 +333,8 @@ def shape_call( else: # verbs that use literal strings, need to convert them to a call # that returns a sqlalchemy "literal" object - return Lazy(sql.literal(call)) + _lit = convert_literal(self.window.dispatch_cls(), call) + return Lazy(_lit) # raise informative error message if missing translation try: @@ -367,3 +369,7 @@ def from_mappings(WinCls, AggCls): aggregate = create_pandas_translator(ALL_OPS, AggCls, sql.elements.ClauseElement) ) + +@singledispatch +def convert_literal(codata, lit): + return sql.literal(lit) From 0c4d10351baeb699904bcd316fb8c2d857220502 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 19 Sep 2023 15:34:15 -0500 Subject: [PATCH 9/9] feat(sql): support limited str.cat --- siuba/sql/dialects/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/siuba/sql/dialects/base.py b/siuba/sql/dialects/base.py index 47535adb..c73732a7 100644 --- a/siuba/sql/dialects/base.py +++ b/siuba/sql/dialects/base.py @@ -127,6 +127,20 @@ def sql_func_capitalize(_, col): rest = fn.right(col, fn.length(col) - 1) return sql.functions.concat(first_char, rest) +def sql_str_cat(_, col, others=None, sep=None, na_rep=None, join=None): + if sep is not None: + raise NotImplementedError("sep argument not supported for sql cat") + + if na_rep is not None: + raise NotImplementedError("na_rep argument not supported for sql cat") + + if join is not None: + raise NotImplementedError("join argument not supported for sql cat") + + if isinstance(others, (list, tuple)): + raise NotImplementedError("others argument must be a single column for sql cat") + + return sql.functions.concat(col, others) # Numpy ufuncs ---------------------------------------------------------------- # symbolic objects have a generic dispatch for when _.__array_ufunc__ is called, @@ -252,6 +266,7 @@ def req_bool(f): **{ # TODO: check generality of trim functions, since MYSQL overrides "str.capitalize" : sql_func_capitalize, + "str.cat" : sql_str_cat, #"str.center" :, #"str.contains" :, #"str.count" :,