Skip to content

Commit

Permalink
Merge pull request #476 from machow/misc-fixes
Browse files Browse the repository at this point in the history
Misc fixes
  • Loading branch information
machow authored Sep 19, 2023
2 parents 4099f48 + 0c4d103 commit 99a3682
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 18 deletions.
8 changes: 5 additions & 3 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2545,7 +2545,7 @@ def _extract_gdf(__data, *args, **kwargs):

# tbl ----

from siuba.siu._databackend import SqlaEngine
from siuba.siu._databackend import SqlaEngine, PlDataFrame, PdDataFrame

@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def tbl(src, *args, **kwargs):
Expand Down Expand Up @@ -2613,8 +2613,10 @@ def _tbl_sqla(src: SqlaEngine, table_name, columns=None):

# TODO: once we subclass LazyTbl per dialect (e.g. duckdb), we can move out
# this dialect specific logic.
if src.dialect.name == "duckdb" and isinstance(columns, pd.DataFrame):
src.execute("register", (table_name, columns))
if src.dialect.name == "duckdb" and isinstance(columns, (PdDataFrame, PlDataFrame)):
with src.begin() as conn:
conn.exec_driver_sql("register", (table_name, columns))

return LazyTbl(src, table_name)

return LazyTbl(src, table_name, columns=columns)
Expand Down
5 changes: 3 additions & 2 deletions siuba/experimental/pivot/pivot_wide.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,9 @@ def pivot_wider_spec(
# validate names and move id vars to columns ----
# note: in pandas 1.5+ we can use the allow_duplicates option to reset, even
# when index and column names overlap. for now, repair names, rename, then reset.
unique_names = vec_as_names([*id_vars, *wide.columns], repair="unique")
repaired_names = vec_as_names([*id_vars, *wide.columns], repair=names_repair)
_all_raw_names = list(map(str, [*id_vars, *wide.columns]))
unique_names = vec_as_names(_all_raw_names, repair="unique")
repaired_names = vec_as_names(_all_raw_names, repair=names_repair)

uniq_id_vars = unique_names[:len(id_vars)]
uniq_val_vars = unique_names[len(id_vars):]
Expand Down
3 changes: 2 additions & 1 deletion siuba/experimental/pivot/sql_pivot_wide.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def _pivot_wider_spec(

wide_id_cols = [sel_cols[id_] for id_ in id_vars]

repaired_names = vec_as_names([*id_vars, *spec[".name"]], repair=names_repair)
_all_raw_names = list(map(str, [*id_vars, *spec[".name"]]))
repaired_names = vec_as_names(_all_raw_names, repair=names_repair)
labeled_cols = [
col.label(name) for name, col in
zip(repaired_names, [*wide_id_cols, *wide_name_cols])
Expand Down
4 changes: 4 additions & 0 deletions siuba/siu/_databackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,9 @@ def __subclasshook__(cls, subclass):
# Implementations -------------------------------------------------------------

class SqlaEngine(AbstractBackend): pass
class PlDataFrame(AbstractBackend): pass
class PdDataFrame(AbstractBackend): pass

SqlaEngine.register_backend("sqlalchemy.engine", "Connectable")
PlDataFrame.register_backend("polars", "DataFrame")
PdDataFrame.register_backend("pandas", "DataFrame")
38 changes: 32 additions & 6 deletions siuba/siu/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __rshift__(self, x):
stripped = strip_symbolic(x)

if isinstance(stripped, Call):
return self._construct_pipe(MetaArg("_"), self, x)
return self._construct_pipe(self, x)

raise TypeError()

Expand Down Expand Up @@ -308,9 +308,28 @@ def obj_name(self):

return None

@classmethod
def _construct_pipe(cls, *args):
return PipeCall(*args)
@staticmethod
def _construct_pipe(lhs, rhs):
if isinstance(lhs, PipeCall):
lh_args = lhs.args

# ensure we don't keep adding MetaArg to the left when
# combining two pipes
if lh_args and isinstance(lh_args[0], MetaArg):
lh_args = lh_args[1:]
else:
lh_args = [lhs]

if isinstance(rhs, PipeCall):
rh_args = rhs.args

# similar to above, but for rh args
if rh_args and isinstance(rh_args[0], MetaArg):
rh_args = rh_args[1:]
else:
rh_args = [rhs]

return PipeCall(MetaArg("_"), *lh_args, *rh_args)


class Lazy(Call):
Expand Down Expand Up @@ -674,8 +693,15 @@ class PipeCall(Call):
"""

def __init__(self, func, *args, **kwargs):
self.func = "__siu_pipe_call__"
self.args = (func, *args)
if isinstance(func, str) and func == "__siu_pipe_call__":
# it was a mistake to make func the first parameter to Call
# but basically we need to catch when it is passed, so
# we can ignore it
self.func = func
self.args = args
else:
self.func = "__siu_pipe_call__"
self.args = (func, *args)
if kwargs:
raise ValueError("Keyword arguments are not allowed.")
self.kwargs = {}
Expand Down
2 changes: 1 addition & 1 deletion siuba/siu/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __rshift__(self, x):

if isinstance(stripped, Call):
lhs_call = self.__source
return Call._construct_pipe(MetaArg("_"), lhs_call, stripped)
return self.__class__(Call._construct_pipe(lhs_call, stripped))
# strip_symbolic(self)(x)
# x is a symbolic
raise NotImplementedError("Symbolic may only be used on right-hand side of >> operator.")
Expand Down
4 changes: 3 additions & 1 deletion siuba/sql/across.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def _across_sql_cols(
lazy_tbl = ctx_verb_data.get()
window = ctx_verb_window.get()

column_names = list(__data.keys())

name_template = _get_name_template(fns, names)
selected_cols = var_select(__data, *var_create(cols), data=__data)
selected_cols = var_select(column_names, *var_create(cols), data=__data)

fns_map = _across_setup_fns(fns)

Expand Down
15 changes: 15 additions & 0 deletions siuba/sql/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ def sql_func_capitalize(_, col):
rest = fn.right(col, fn.length(col) - 1)
return sql.functions.concat(first_char, rest)

def sql_str_cat(_, col, others=None, sep=None, na_rep=None, join=None):
if sep is not None:
raise NotImplementedError("sep argument not supported for sql cat")

if na_rep is not None:
raise NotImplementedError("na_rep argument not supported for sql cat")

if join is not None:
raise NotImplementedError("join argument not supported for sql cat")

if isinstance(others, (list, tuple)):
raise NotImplementedError("others argument must be a single column for sql cat")

return sql.functions.concat(col, others)

# Numpy ufuncs ----------------------------------------------------------------
# symbolic objects have a generic dispatch for when _.__array_ufunc__ is called,
Expand Down Expand Up @@ -252,6 +266,7 @@ def req_bool(f):
**{
# TODO: check generality of trim functions, since MYSQL overrides
"str.capitalize" : sql_func_capitalize,
"str.cat" : sql_str_cat,
#"str.center" :,
#"str.contains" :,
#"str.count" :,
Expand Down
12 changes: 11 additions & 1 deletion siuba/sql/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
sql_not_impl,
# wiring up translator
extend_base,
SqlTranslator
SqlTranslator,
convert_literal
)

from .postgresql import (
Expand Down Expand Up @@ -44,6 +45,15 @@ def returns_int(func_names):
f_annotated = wrap_annotate(f_concrete, result_type="int")
generic.register(DuckdbColumn, f_annotated)

# Literal Conversions =========================================================

@convert_literal.register
def _cl_duckdb(codata: DuckdbColumn, lit):
from sqlalchemy.dialects.postgresql import array
if isinstance(lit, list):
return array(lit)

return sql.literal(lit)

# Translations ================================================================

Expand Down
10 changes: 8 additions & 2 deletions siuba/sql/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def wrapper(*args, **kwargs):
# Translator =================================================================

from siuba.ops.translate import create_pandas_translator
from functools import singledispatch


def extend_base(cls, **kwargs):
Expand Down Expand Up @@ -323,7 +324,7 @@ def shape_call(
# verbs that can use strings as accessors, like group_by, or
# arrange, need to convert those strings into a getitem call
return str_to_getitem_call(call)
elif isinstance(call, sql.elements.ColumnClause):
elif isinstance(call, (sql.elements.ClauseElement)):
return Lazy(call)
elif callable(call):
#TODO: should not happen here
Expand All @@ -332,7 +333,8 @@ def shape_call(
else:
# verbs that use literal strings, need to convert them to a call
# that returns a sqlalchemy "literal" object
return Lazy(sql.literal(call))
_lit = convert_literal(self.window.dispatch_cls(), call)
return Lazy(_lit)

# raise informative error message if missing translation
try:
Expand Down Expand Up @@ -367,3 +369,7 @@ def from_mappings(WinCls, AggCls):
aggregate = create_pandas_translator(ALL_OPS, AggCls, sql.elements.ClauseElement)
)


@singledispatch
def convert_literal(codata, lit):
return sql.literal(lit)
12 changes: 11 additions & 1 deletion siuba/tests/test_siu_dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import pytest

from siuba.siu.calls import PipeCall
from siuba.siu.dispatchers import call
from siuba.siu import _
from siuba.siu import _, strip_symbolic, Symbolic

# TODO: direct test of lazy elements
# TODO: NSECall - no map subcalls


def test_siu_pipe_call_is_flat():
pipe_expr = _ >> _.a >> _.b
pipe_call = strip_symbolic(pipe_expr)

assert isinstance(pipe_expr, Symbolic)
assert isinstance(pipe_call, PipeCall)
assert len(pipe_call.args) == 4

def test_siu_call_no_args():
assert 1 >> call(range) == range(1)

Expand Down

0 comments on commit 99a3682

Please sign in to comment.