Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ private[spark] object PythonEvalType {
val SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF = 212
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215

// Arrow UDFs
val SQL_SCALAR_ARROW_UDF = 250
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined]
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
from pyspark.sql.types import NumericType, StructType

import pyspark.sql.connect.plan as plan
Expand Down Expand Up @@ -472,13 +473,22 @@ def applyInArrow(
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

_validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
try:
# Try to infer the eval type from type hints
eval_type = infer_group_arrow_eval_type_from_func(func)
except Exception:
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)

if eval_type is None:
eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF

_validate_vectorized_udf(func, eval_type)
if isinstance(schema, str):
schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
evalType=eval_type,
)

res = DataFrame(
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from typing import (
Any,
Callable,
Iterable,
Iterator,
NewType,
Tuple,
Type,
Expand Down Expand Up @@ -59,6 +60,7 @@ PandasGroupedMapUDFTransformWithStateType = Literal[211]
PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212]
GroupedMapUDFTransformWithStateType = Literal[213]
GroupedMapUDFTransformWithStateInitStateType = Literal[214]
ArrowGroupedMapIterUDFType = Literal[215]

# Arrow UDFs
ArrowScalarUDFType = Literal[250]
Expand Down Expand Up @@ -430,10 +432,18 @@ PandasCogroupedMapFunction = Union[
Callable[[Any, DataFrameLike, DataFrameLike], DataFrameLike],
]

ArrowGroupedMapFunction = Union[
ArrowGroupedMapTableFunction = Union[
Callable[[pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
]
ArrowGroupedMapIterFunction = Union[
Callable[[Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]],
Callable[
[Tuple[pyarrow.Scalar, ...], Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]
],
]
ArrowGroupedMapFunction = Union[ArrowGroupedMapTableFunction, ArrowGroupedMapIterFunction]

ArrowCogroupedMapFunction = Union[
Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], pyarrow.Table],
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def vectorized_udf(
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
None,
]: # None means it should infer the type from type hints.
Expand Down Expand Up @@ -779,6 +780,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
]:
Expand Down
89 changes: 72 additions & 17 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.util import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.streaming.stateful_processor import StatefulProcessor
from pyspark.sql.types import StructType
Expand Down Expand Up @@ -703,27 +704,33 @@ def applyInArrow(
Maps each group of the current :class:`DataFrame` using an Arrow udf and returns the result
as a `DataFrame`.

The function should take a `pyarrow.Table` and return another
`pyarrow.Table`. Alternatively, the user can pass a function that takes
a tuple of `pyarrow.Scalar` grouping key(s) and a `pyarrow.Table`.
For each group, all columns are passed together as a `pyarrow.Table`
to the user-function and the returned `pyarrow.Table` are combined as a
:class:`DataFrame`.
The function can take one of two forms: It can take a `pyarrow.Table` and return a
`pyarrow.Table`, or it can take an iterator of `pyarrow.RecordBatch` and yield
`pyarrow.RecordBatch`. Alternatively each form can take a tuple of `pyarrow.Scalar`
as the first argument in addition to the input type above. For each group, all columns
are passed together in the `pyarrow.Table` or `pyarrow.RecordBatch`, and the returned
`pyarrow.Table` or iterator of `pyarrow.RecordBatch` are combined as a :class:`DataFrame`.

The `schema` should be a :class:`StructType` describing the schema of the returned
`pyarrow.Table`. The column labels of the returned `pyarrow.Table` must either match
the field names in the defined schema if specified as strings, or match the
field data types by position if not strings, e.g. integer indices.
The length of the returned `pyarrow.Table` can be arbitrary.
`pyarrow.Table` or `pyarrow.RecordBatch`. The column labels of the returned `pyarrow.Table`
or `pyarrow.RecordBatch` must either match the field names in the defined schema if
specified as strings, or match the field data types by position if not strings, e.g.
integer indices. The length of the returned `pyarrow.Table` or iterator of
`pyarrow.RecordBatch` can be arbitrary.

.. versionadded:: 4.0.0

.. versionchanged:: 4.1.0
Added support for an iterator of `pyarrow.RecordBatch` API.

Parameters
----------
func : function
a Python native function that takes a `pyarrow.Table` and outputs a
`pyarrow.Table`, or that takes one tuple (grouping keys) and a
`pyarrow.Table` and outputs a `pyarrow.Table`.
a Python native function that either takes a `pyarrow.Table` and outputs a
`pyarrow.Table` or takes an iterator of `pyarrow.RecordBatch` and yields
`pyarrow.RecordBatch`. Additionally, each form can take a tuple of grouping keys
as the first argument, with the `pyarrow.Table` or iterator of `pyarrow.RecordBatch`
as the second argument.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add

    .. versionchanged:: 4.1.0
        Supports iterator API ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets also add two simple examples (w/o key) in the Examples section

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added one with and without key, since the key is more relevant for this API I think

schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Expand Down Expand Up @@ -752,6 +759,28 @@ def applyInArrow(
| 2| 1.1094003924504583|
+---+-------------------+

The function can also take and return an iterator of `pyarrow.RecordBatch` using type
hints.

>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> def sum_func(
... batches: Iterator[pyarrow.RecordBatch]
... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP
... total = 0
... for batch in batches:
... total += pc.sum(batch.column("v")).as_py()
... yield pyarrow.RecordBatch.from_pydict({"v": [total]})
>>> df.groupby("id").applyInArrow(
... sum_func, schema="v double").show() # doctest: +SKIP
+----+
| v|
+----+
| 3.0|
|18.0|
+----+

Alternatively, the user can pass a function that takes two arguments.
In this case, the grouping key(s) will be passed as the first argument and the data will
be passed as the second argument. The grouping key(s) will be passed as a tuple of Arrow
Expand Down Expand Up @@ -796,11 +825,28 @@ def applyInArrow(
| 2| 2| 3.0|
+---+-----------+----+

>>> def sum_func(
... key: Tuple[pyarrow.Scalar, ...], batches: Iterator[pyarrow.RecordBatch]
... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP
... total = 0
... for batch in batches:
... total += pc.sum(batch.column("v")).as_py()
... yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": [total]})
>>> df.groupby("id").applyInArrow(
... sum_func, schema="id long, v double").show() # doctest: +SKIP
+---+----+
| id| v|
+---+----+
| 1| 3.0|
| 2|18.0|
+---+----+

Notes
-----
This function requires a full shuffle. All the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
This function requires a full shuffle. If using the `pyarrow.Table` API, all data of a
group will be loaded into memory, so the user should be aware of the potential OOM risk
if data is skewed and certain groups are too large to fit in memory, and can use the
iterator of `pyarrow.RecordBatch` API to mitigate this.

This API is unstable, and for developers.

Expand All @@ -813,9 +859,18 @@ def applyInArrow(

assert isinstance(self, GroupedData)

try:
# Try to infer the eval type from type hints
eval_type = infer_group_arrow_eval_type_from_func(func)
except Exception:
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)

if eval_type is None:
eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF

# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
func, returnType=schema, functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
func, returnType=schema, functionType=eval_type
) # type: ignore[call-overload]
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
Expand Down
19 changes: 11 additions & 8 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from decimal import Decimal
from itertools import groupby
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Iterator, Optional

import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
Expand Down Expand Up @@ -1116,19 +1116,22 @@ def load_stream(self, stream):
"""
import pyarrow as pa

def process_group(batches: "Iterator[pa.RecordBatch]"):
for batch in batches:
struct = batch.column(0)
yield pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))

dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 1:
structs = [
batch.column(0) for batch in ArrowStreamSerializer.load_stream(self, stream)
]
yield [
pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))
for struct in structs
]
batch_iter = process_group(ArrowStreamSerializer.load_stream(self, stream))
yield batch_iter
# Make sure the batches are fully iterated before getting the next group
for _ in batch_iter:
pass

elif dataframes_in_group != 0:
raise PySparkValueError(
Expand Down
91 changes: 91 additions & 0 deletions python/pyspark/sql/pandas/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
ArrowScalarUDFType,
ArrowScalarIterUDFType,
ArrowGroupedAggUDFType,
ArrowGroupedMapIterUDFType,
ArrowGroupedMapUDFType,
ArrowGroupedMapFunction,
)


Expand Down Expand Up @@ -303,6 +306,94 @@ def infer_eval_type_for_udf( # type: ignore[no-untyped-def]
return None


def infer_group_arrow_eval_type(
sig: Signature,
type_hints: Dict[str, Any],
) -> Optional[Union["ArrowGroupedMapUDFType", "ArrowGroupedMapIterUDFType"]]:
from pyspark.sql.pandas.functions import PythonEvalType

require_minimum_pyarrow_version()

import pyarrow as pa

annotations = {}
for param in sig.parameters.values():
if param.annotation is not param.empty:
annotations[param.name] = type_hints.get(param.name, param.annotation)

# Check if all arguments have type hints
parameters_sig = [
annotations[parameter] for parameter in sig.parameters if parameter in annotations
]
if len(parameters_sig) != len(sig.parameters):
raise PySparkValueError(
errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
messageParameters={"target": "all parameters", "sig": str(sig)},
)

# Check if the return has a type hint
return_annotation = type_hints.get("return", sig.return_annotation)
if sig.empty is return_annotation:
raise PySparkValueError(
errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
messageParameters={"target": "the return type", "sig": str(sig)},
)

# Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
is_iterator_batch = (
len(parameters_sig) == 1
and check_iterator_annotation( # Iterator
parameters_sig[0],
parameter_check_func=lambda t: t == pa.RecordBatch,
)
and check_iterator_annotation(
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
)
)
# Tuple[pa.Scalar, ...], Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
is_iterator_batch_with_keys = (
len(parameters_sig) == 2
and check_iterator_annotation( # Iterator
parameters_sig[1],
parameter_check_func=lambda t: t == pa.RecordBatch,
)
and check_iterator_annotation(
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
)
)

if is_iterator_batch or is_iterator_batch_with_keys:
return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF

# pa.Table -> pa.Table
is_table = (
len(parameters_sig) == 1 and parameters_sig[0] == pa.Table and return_annotation == pa.Table
)
# Tuple[pa.Scalar, ...], pa.Table -> pa.Table
is_table_with_keys = (
len(parameters_sig) == 2 and parameters_sig[1] == pa.Table and return_annotation == pa.Table
)
if is_table or is_table_with_keys:
return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF

return None


def infer_group_arrow_eval_type_from_func(
f: "ArrowGroupedMapFunction",
) -> Optional[Union["ArrowGroupedMapUDFType", "ArrowGroupedMapIterUDFType"]]:
argspec = getfullargspec(f)
if len(argspec.annotations) > 0:
try:
type_hints = get_type_hints(f)
except NameError:
type_hints = {}

return infer_group_arrow_eval_type(signature(f), type_hints)
else:
return None


def check_tuple_annotation(
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = None
) -> bool:
Expand Down
Loading