diff --git a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala index 6bb2f50e1f8..b4257d76f4b 100644 --- a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala +++ b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala @@ -46,6 +46,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.nowarn import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal class ColumnarArrowPythonRunner( funcs: Seq[(ChainedPythonFunctions, Long)], @@ -162,7 +163,75 @@ class ColumnarArrowPythonRunner( // For Spark 4.0. It overrides the corresponding abstract method in Writer class. // We omitted the override keyword for compatibility consideration. def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - writeToStreamHelper(dataOut) + writeNextInputToStreamHelper(dataOut) + } + + private var nextInputRoot: VectorSchemaRoot = _ + private var nextInputLoader: VectorLoader = _ + private var nextInputWriter: ArrowStreamWriter = _ + private var nextInputWriterClosed = false + + context.addTaskCompletionListener[Unit] { + _ => + try { + closeNextInputWriter() + } catch { + case NonFatal(_) => + } + } + + private def ensureNextInputWriter(dataOut: DataOutputStream): Unit = { + if (nextInputWriter == null) { + val arrowSchema = SparkSchemaUtil.toArrowSchema(schema, timeZoneId) + val allocator = ArrowBufferAllocators.contextInstance() + nextInputRoot = VectorSchemaRoot.create(arrowSchema, allocator) + nextInputLoader = new VectorLoader(nextInputRoot) + nextInputWriter = new ArrowStreamWriter(nextInputRoot, null, dataOut) + nextInputWriter.start() + } + } + + private def closeNextInputWriter(): Unit = { + if (!nextInputWriterClosed && nextInputRoot != null) { + try { + if (nextInputWriter != null) { + nextInputWriter.end() + } + } finally { + nextInputRoot.close() + nextInputWriterClosed = true + } + } + } + + private def writeNextInputToStreamHelper(dataOut: DataOutputStream): Boolean = { + ensureNextInputWriter(dataOut) + if (!inputIterator.hasNext) { + closeNextInputWriter() + // See https://issues.apache.org/jira/browse/SPARK-44705: + // Starting from Spark 4.0, we should return false once the iterator is drained out, + // otherwise Spark won't stop calling this method repeatedly. + return false + } + val nextBatch = inputIterator.next() + val cols = (0 until nextBatch.numCols).toList.map( + i => + nextBatch + .asInstanceOf[ColumnarBatch] + .column(i) + .asInstanceOf[ArrowWritableColumnVector] + .getValueVector) + val nextRecordBatch = + SparkVectorUtil.toArrowRecordBatch(nextBatch.numRows, cols) + try { + nextInputLoader.load(nextRecordBatch) + nextInputWriter.writeBatch() + true + } finally { + if (nextRecordBatch != null) { + nextRecordBatch.close() + } + } } def writeToStreamHelper(dataOut: DataOutputStream): Boolean = { diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala index 7747bd21937..b664ba437a6 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala @@ -21,6 +21,8 @@ import org.apache.gluten.execution.WholeStageTransformerSuite import org.apache.spark.SparkConf import org.apache.spark.api.python.ColumnarArrowEvalPythonExec import org.apache.spark.sql.IntegratedUDFTestUtils +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.functions.max import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.util.SparkVersionUtil @@ -36,6 +38,9 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { newTestScalarPandasUDF(name = "pyarrowUDF", returnType = Some(StringType)) private val pyarrowTestUDFLong = newTestScalarPandasUDF(name = "pyarrowUDF", returnType = Some(LongType)) + private val SQL_ARROW_BATCHED_UDF = 101 + private lazy val arrowBatchedTestUDFString = + newTestArrowBatchedPythonUDF(name = "arrowBatchedUDF", returnType = Some(StringType)) override def sparkConf: SparkConf = { super.sparkConf @@ -109,6 +114,52 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { checkAnswer(df, expected) } + testWithMinSparkVersion("arrow batched python udf over parquet scan", "4.0") { + withTempPath { + f => + Seq(("MAIL", 1), ("RAIL", 2), ("SHIP", 3)).toDF("shipmode", "id").write.parquet( + f.getCanonicalPath) + val base = spark.read.parquet(f.getCanonicalPath) + val arrowUdfCol = arrowBatchedTestUDFString(base("shipmode")).as("shipmode_arrow") + val df = base.select(arrowUdfCol).agg(max("shipmode_arrow").as("max_shipmode")) + val expected = Seq(Tuple1("SHIP")).toDF("max_shipmode") + + checkSparkPlan[ColumnarArrowEvalPythonExec](df) + checkAnswer(df, expected) + } + } + + private def newTestArrowBatchedPythonUDF( + name: String, + returnType: Option[DataType] = None): UserDefinedPythonFunction = { + val regularPythonUDF = newTestPythonUDF(name, returnType) + val regularUDF = regularPythonUDF.getClass + .getMethod("udf") + .invoke(regularPythonUDF) + .asInstanceOf[UserDefinedPythonFunction] + regularUDF.copy( + regularUDF.name, + regularUDF.func, + regularUDF.dataType, + SQL_ARROW_BATCHED_UDF, + regularUDF.udfDeterministic) + } + + private def newTestPythonUDF( + name: String, + returnType: Option[DataType] = None): TestPythonUDF = { + if (SparkVersionUtil.gteSpark40) { + // After https://github.com/apache/spark/pull/42864 which landed in Spark 4.0, the return + // type of the UDF must be explicitly specified when creating the UDF instance with column + // expressions as parameter. + classOf[TestPythonUDF] + .getConstructor(classOf[String], classOf[Option[DataType]]) + .newInstance(name, returnType) + } else { + TestPythonUDF(name) + } + } + private def newTestScalarPandasUDF( name: String, returnType: Option[DataType] = None): TestScalarPandasUDF = {