diff --git a/greenplumpython/func.py b/greenplumpython/func.py index f4676f5e..dff81877 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -18,7 +18,7 @@ from greenplumpython.db import Database from greenplumpython.expr import Expr, _serialize_to_expr from greenplumpython.group import DataFrameGroupingSet -from greenplumpython.type import _serialize_to_type +from greenplumpython.type import _defined_types, _serialize_to_type_name class FunctionExpr(Expr): @@ -111,52 +111,79 @@ def apply( if grouping_col_names is not None and len(grouping_col_names) != 0 else None ) - unexpanded_dataframe = DataFrame( - " ".join( + if ( + isinstance(self._function, NormalFunction) + and self._function._language_handler == "plcontainer" + ): + return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType + _serialize_to_type_name(return_annotation, db=db, for_return=True) + input_args = self._args + if len(input_args) == 0: + raise Exception("No input data specified, please specify a DataFrame or Columns") + input_clause = ( + "*" + if (len(input_args) == 1 and isinstance(input_args[0], DataFrame)) + else ",".join([arg._serialize(db=db) for arg in input_args]) + ) + return DataFrame( + f""" + SELECT * FROM plcontainer_apply(TABLE( + SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS + {_defined_types[return_annotation.__args__[0]]._serialize(db=db)} + """, + db=db, + parents=parents, + ) + else: + unexpanded_dataframe = DataFrame( + " ".join( + [ + f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}", + ("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "", + from_clause, + group_by_clause, + ] + ), + db=db, + parents=parents, + ) + # We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a + # function returning records that contains more than one attributes + # will be called multiple times if we do + # ```sql + # SELECT (func(a, b)).* FROM t; + # ``` + # which might cause performance issue. To workaround we need to do + # ```sql + # WITH func_call AS ( + # SELECT func(a, b) AS result FROM t + # ) + # SELECT (result).* FROM func_call; + # ``` + rebased_grouping_cols = ( [ - f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}", - ("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "", - from_clause, - group_by_clause, + _serialize_to_expr(unexpanded_dataframe[name], db=db) + for name in grouping_col_names ] - ), - db=db, - parents=parents, - ) - # We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a - # function returning records that contains more than one attributes - # will be called multiple times if we do - # ```sql - # SELECT (func(a, b)).* FROM t; - # ``` - # which might cause performance issue. To workaround we need to do - # ```sql - # WITH func_call AS ( - # SELECT func(a, b) AS result FROM t - # ) - # SELECT (result).* FROM func_call; - # ``` - rebased_grouping_cols = ( - [_serialize_to_expr(unexpanded_dataframe[name], db=db) for name in grouping_col_names] - if grouping_col_names is not None - else None - ) - result_cols = ( - _serialize_to_expr(unexpanded_dataframe["*"], db=db) - if not expand - else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db) - # `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())` - if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0 - else f"({unexpanded_dataframe._name}).*" - if not expand - else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}" - ) + if grouping_col_names is not None + else None + ) + result_cols = ( + _serialize_to_expr(unexpanded_dataframe["*"], db=db) + if not expand + else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db) + # `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())` + if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0 + else f"({unexpanded_dataframe._name}).*" + if not expand + else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}" + ) - return DataFrame( - f"SELECT {result_cols} FROM {unexpanded_dataframe._name}", - db=db, - parents=[unexpanded_dataframe], - ) + return DataFrame( + f"SELECT {result_cols} FROM {unexpanded_dataframe._name}", + db=db, + parents=[unexpanded_dataframe], + ) @property def _function(self) -> "_AbstractFunction": @@ -272,12 +299,14 @@ def __init__( name: Optional[str] = None, schema: Optional[str] = None, language_handler: Literal["plpython3u"] = "plpython3u", + runtime: Optional[str] = None, ) -> None: # noqa D107 super().__init__(wrapped_func, name, schema) self._created_in_dbs: Optional[Set[Database]] = set() if wrapped_func is not None else None self._wrapped_func = wrapped_func self._language_handler = language_handler + self._runtime = runtime def unwrap(self) -> Callable[..., Any]: """Get the wrapped Python function in the database function.""" @@ -302,14 +331,18 @@ def _serialize(self, db: Database) -> str: func_sig = inspect.signature(self._wrapped_func) func_args = ",".join( [ - f'"{param.name}" {_serialize_to_type(param.annotation, db=db)}' + f'"{param.name}" {_serialize_to_type_name(param.annotation, db=db)}' for param in func_sig.parameters.values() ] ) func_arg_names = ",".join( [f"{param.name}={param.name}" for param in func_sig.parameters.values()] ) - return_type = _serialize_to_type(func_sig.return_annotation, db=db, for_return=True) + return_type = ( + _serialize_to_type_name(func_sig.return_annotation, db=db, for_return=True) + if self._language_handler != "plcontainer" + else "SETOF record" + ) func_pickled: bytes = dill.dumps(self._wrapped_func) _, func_name = self._qualified_name # Modify the AST of the wrapped function to minify dependency: (1-3) @@ -335,6 +368,7 @@ def _serialize(self, db: Database) -> str: f"CREATE FUNCTION {self._qualified_name_str} ({func_args}) " f"RETURNS {return_type} " f"AS $$\n" + f"# container: {self._runtime}\n" f"try:\n" f" return GD['{func_ast.name}']({func_arg_names})\n" f"except KeyError:\n" @@ -344,6 +378,7 @@ def _serialize(self, db: Database) -> str: f" import sys as {sys_lib_name}\n" f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n" f" raise ModuleNotFoundError\n" + f" {sys_lib_name}.modules['plpy']=plpy\n" f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n" f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n" f" except ModuleNotFoundError:\n" @@ -461,7 +496,7 @@ def _create_in_db(self, db: Database) -> None: state_param = next(param_list) args_string = ",".join( [ - f"{param.name} {_serialize_to_type(param.annotation, db=db)}" + f"{param.name} {_serialize_to_type_name(param.annotation, db=db)}" for param in param_list ] ) @@ -470,7 +505,7 @@ def _create_in_db(self, db: Database) -> None: ( f"CREATE AGGREGATE {self._qualified_name_str} ({args_string}) (\n" f" SFUNC = {self.transition_function._qualified_name_str},\n" - f" STYPE = {_serialize_to_type(state_param.annotation, db=db)}\n" + f" STYPE = {_serialize_to_type_name(state_param.annotation, db=db)}\n" f");\n" ), has_results=False, @@ -547,6 +582,7 @@ def aggregate_function(name: str, schema: Optional[str] = None) -> AggregateFunc def create_function( wrapped_func: Optional[Callable[..., Any]] = None, language_handler: Literal["plpython3u"] = "plpython3u", + runtime: Optional[str] = None, ) -> NormalFunction: """ Create a :class:`~func.NormalFunction` from the given Python function. @@ -610,8 +646,12 @@ def create_function( """ # If user needs extra parameters when creating a function if wrapped_func is None: - return functools.partial(create_function, language_handler=language_handler) - return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler) + return functools.partial( + create_function, language_handler=language_handler, runtime=runtime + ) + return NormalFunction( + wrapped_func=wrapped_func, language_handler=language_handler, runtime=runtime + ) # FIXME: Add test cases for optional parameters diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 891ff6c1..338cd599 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -94,6 +94,17 @@ def __init__( if self._modifier is not None: self._qualified_name_str += f"({self._modifier})" + def _serialize(self, db: Database) -> str: + if self._annotation is None: + raise Exception("No type annotation to serialize") + members = get_type_hints(self._annotation) + if len(members) == 0: + raise Exception(f"Failed to get annotations for type {self._annotation}") + members_str = ",\n".join( + [f"{name} {_serialize_to_type_name(type_t, db)}" for name, type_t in members.items()] + ) + return f"({members_str})" + # -- Creation of a composite type in Greenplum corresponding to the class_type given def _create_in_db(self, db: Database): # noqa: D400 @@ -115,14 +126,9 @@ def _create_in_db(self, db: Database): self._annotation, type ), "Only composite data types can be created in database." schema = "pg_temp" - members = get_type_hints(self._annotation) - if len(members) == 0: - raise Exception(f"Failed to get annotations for type {self._annotation}") - att_type_str = ",\n".join( - [f"{name} {_serialize_to_type(type_t, db)}" for name, type_t in members.items()] - ) + db._execute( - f'CREATE TYPE "{schema}"."{self._name}" AS (\n' f"{att_type_str}\n" f");", + f'CREATE TYPE "{schema}"."{self._name}" AS {self._serialize(db=db)};', has_results=False, ) self._created_in_dbs.add(db) @@ -178,7 +184,7 @@ def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = Non return DataType(name, schema=schema, modifier=modifier) -def _serialize_to_type( +def _serialize_to_type_name( annotation: Union[DataType, type], db: Database, for_return: bool = False, @@ -204,10 +210,10 @@ def _serialize_to_type( if annotation.__origin__ == list or annotation.__origin__ == List: args: Tuple[type, ...] = annotation.__args__ if for_return: - return f"SETOF {_serialize_to_type(args[0], db)}" # type: ignore - if args[0] in _defined_types: - return f"{_serialize_to_type(args[0], db)}[]" # type: ignore - raise NotImplementedError() + return f"SETOF {_serialize_to_type_name(args[0], db)}" # type: ignore + else: + return f"{_serialize_to_type_name(args[0], db)}[]" # type: ignore + raise NotImplementedError("Only list is supported as generic data type") else: if isinstance(annotation, DataType): return annotation._qualified_name_str diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py new file mode 100644 index 00000000..1a1b39e4 --- /dev/null +++ b/tests/test_plcontainer.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass + +import pytest + +import greenplumpython as gp +from tests import db + + +@dataclass +class Int: + i: int + + +@dataclass +class Pair: + i: int + j: int + + +@pytest.fixture +def t(db: gp.Database): + rows = [(i, i) for i in range(10)] + return db.create_dataframe(rows=rows, column_names=["a", "b"]) + + +@gp.create_function(language_handler="plcontainer", runtime="plc_python_example") +def add_one(x: list[Int]) -> list[Int]: + return [{"i": arg["i"] + 1} for arg in x] + + +def test_simple_func(db: gp.Database): + assert ( + len( + list( + db.create_dataframe(columns={"i": range(10)}).apply( + lambda t: add_one(t), expand=True + ) + ) + ) + == 10 + ) + + +def test_func_no_input(db: gp.Database): + + with pytest.raises(Exception) as exc_info: # no input data for func raises Exception + db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True) + assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value) + + +def test_func_column(db: gp.Database, t: gp.DataFrame): + @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") + def add(x: list[Pair]) -> list[Int]: + return [{"i": arg["i"] + arg["j"]} for arg in x] + + assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10 diff --git a/tests/test_type.py b/tests/test_type.py index c7fc2ed6..4dd12944 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -3,7 +3,7 @@ import pytest import greenplumpython as gp -from greenplumpython.type import _serialize_to_type +from greenplumpython.type import _serialize_to_type_name from tests import db @@ -76,7 +76,7 @@ class Person: _first_name: str _last_name: str - type_name = _serialize_to_type(Person, db=db) + type_name = _serialize_to_type_name(Person, db=db) assert isinstance(type_name, str) @@ -88,5 +88,5 @@ def __init__(self, _first_name: str, _last_name: str) -> None: self._last_name = _last_name with pytest.raises(Exception) as exc_info: - _serialize_to_type(Person, db=db) + _serialize_to_type_name(Person, db=db) assert "Failed to get annotations" in str(exc_info.value)