-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-49547][SQL][PYTHON] Add iterator of RecordBatch
API to applyInArrow
#52440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b941c9d
5fb8918
4f3c4c4
e0fe7b8
4a1fd6e
6e0a2d0
735a3a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from pyspark.util import PythonEvalType | ||
from pyspark.sql.column import Column | ||
from pyspark.sql.dataframe import DataFrame | ||
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func | ||
from pyspark.sql.streaming.state import GroupStateTimeout | ||
from pyspark.sql.streaming.stateful_processor import StatefulProcessor | ||
from pyspark.sql.types import StructType | ||
|
@@ -703,27 +704,33 @@ def applyInArrow( | |
Maps each group of the current :class:`DataFrame` using an Arrow udf and returns the result | ||
as a `DataFrame`. | ||
|
||
The function should take a `pyarrow.Table` and return another | ||
`pyarrow.Table`. Alternatively, the user can pass a function that takes | ||
a tuple of `pyarrow.Scalar` grouping key(s) and a `pyarrow.Table`. | ||
For each group, all columns are passed together as a `pyarrow.Table` | ||
to the user-function and the returned `pyarrow.Table` are combined as a | ||
:class:`DataFrame`. | ||
The function can take one of two forms: It can take a `pyarrow.Table` and return a | ||
`pyarrow.Table`, or it can take an iterator of `pyarrow.RecordBatch` and yield | ||
`pyarrow.RecordBatch`. Alternatively each form can take a tuple of `pyarrow.Scalar` | ||
as the first argument in addition to the input type above. For each group, all columns | ||
are passed together in the `pyarrow.Table` or `pyarrow.RecordBatch`, and the returned | ||
`pyarrow.Table` or iterator of `pyarrow.RecordBatch` are combined as a :class:`DataFrame`. | ||
|
||
The `schema` should be a :class:`StructType` describing the schema of the returned | ||
`pyarrow.Table`. The column labels of the returned `pyarrow.Table` must either match | ||
the field names in the defined schema if specified as strings, or match the | ||
field data types by position if not strings, e.g. integer indices. | ||
The length of the returned `pyarrow.Table` can be arbitrary. | ||
`pyarrow.Table` or `pyarrow.RecordBatch`. The column labels of the returned `pyarrow.Table` | ||
or `pyarrow.RecordBatch` must either match the field names in the defined schema if | ||
specified as strings, or match the field data types by position if not strings, e.g. | ||
integer indices. The length of the returned `pyarrow.Table` or iterator of | ||
`pyarrow.RecordBatch` can be arbitrary. | ||
|
||
.. versionadded:: 4.0.0 | ||
|
||
.. versionchanged:: 4.1.0 | ||
Added support for an iterator of `pyarrow.RecordBatch` API. | ||
|
||
Parameters | ||
---------- | ||
func : function | ||
a Python native function that takes a `pyarrow.Table` and outputs a | ||
`pyarrow.Table`, or that takes one tuple (grouping keys) and a | ||
`pyarrow.Table` and outputs a `pyarrow.Table`. | ||
a Python native function that either takes a `pyarrow.Table` and outputs a | ||
`pyarrow.Table` or takes an iterator of `pyarrow.RecordBatch` and yields | ||
`pyarrow.RecordBatch`. Additionally, each form can take a tuple of grouping keys | ||
as the first argument, with the `pyarrow.Table` or iterator of `pyarrow.RecordBatch` | ||
as the second argument. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets also add two simple examples (w/o There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added one with and without key, since the key is more relevant for this API I think |
||
schema : :class:`pyspark.sql.types.DataType` or str | ||
the return type of the `func` in PySpark. The value can be either a | ||
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||
|
@@ -752,6 +759,28 @@ def applyInArrow( | |
| 2| 1.1094003924504583| | ||
+---+-------------------+ | ||
|
||
The function can also take and return an iterator of `pyarrow.RecordBatch` using type | ||
hints. | ||
|
||
>>> df = spark.createDataFrame( | ||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], | ||
... ("id", "v")) # doctest: +SKIP | ||
>>> def sum_func( | ||
... batches: Iterator[pyarrow.RecordBatch] | ||
... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP | ||
... total = 0 | ||
... for batch in batches: | ||
... total += pc.sum(batch.column("v")).as_py() | ||
... yield pyarrow.RecordBatch.from_pydict({"v": [total]}) | ||
>>> df.groupby("id").applyInArrow( | ||
... sum_func, schema="v double").show() # doctest: +SKIP | ||
+----+ | ||
| v| | ||
+----+ | ||
| 3.0| | ||
|18.0| | ||
+----+ | ||
|
||
Alternatively, the user can pass a function that takes two arguments. | ||
In this case, the grouping key(s) will be passed as the first argument and the data will | ||
be passed as the second argument. The grouping key(s) will be passed as a tuple of Arrow | ||
|
@@ -796,11 +825,28 @@ def applyInArrow( | |
| 2| 2| 3.0| | ||
+---+-----------+----+ | ||
|
||
>>> def sum_func( | ||
... key: Tuple[pyarrow.Scalar, ...], batches: Iterator[pyarrow.RecordBatch] | ||
... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP | ||
... total = 0 | ||
... for batch in batches: | ||
... total += pc.sum(batch.column("v")).as_py() | ||
... yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": [total]}) | ||
>>> df.groupby("id").applyInArrow( | ||
... sum_func, schema="id long, v double").show() # doctest: +SKIP | ||
+---+----+ | ||
| id| v| | ||
+---+----+ | ||
| 1| 3.0| | ||
| 2|18.0| | ||
+---+----+ | ||
|
||
Notes | ||
----- | ||
This function requires a full shuffle. All the data of a group will be loaded | ||
into memory, so the user should be aware of the potential OOM risk if data is skewed | ||
and certain groups are too large to fit in memory. | ||
This function requires a full shuffle. If using the `pyarrow.Table` API, all data of a | ||
group will be loaded into memory, so the user should be aware of the potential OOM risk | ||
if data is skewed and certain groups are too large to fit in memory, and can use the | ||
iterator of `pyarrow.RecordBatch` API to mitigate this. | ||
|
||
This API is unstable, and for developers. | ||
|
||
|
@@ -813,9 +859,18 @@ def applyInArrow( | |
|
||
assert isinstance(self, GroupedData) | ||
|
||
try: | ||
# Try to infer the eval type from type hints | ||
eval_type = infer_group_arrow_eval_type_from_func(func) | ||
except Exception: | ||
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning) | ||
|
||
if eval_type is None: | ||
eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF | ||
|
||
# The usage of the pandas_udf is internal so type checking is disabled. | ||
udf = pandas_udf( | ||
func, returnType=schema, functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF | ||
func, returnType=schema, functionType=eval_type | ||
) # type: ignore[call-overload] | ||
df = self._df | ||
udf_column = udf(*[df[col] for col in df.columns]) | ||
|
Uh oh!
There was an error while loading. Please reload this page.