Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.annotation.nowarn
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

class ColumnarArrowPythonRunner(
funcs: Seq[(ChainedPythonFunctions, Long)],
Expand Down Expand Up @@ -162,7 +163,75 @@ class ColumnarArrowPythonRunner(
// For Spark 4.0. It overrides the corresponding abstract method in Writer class.
// We omitted the override keyword for compatibility consideration.
def writeNextInputToStream(dataOut: DataOutputStream): Boolean = {
writeToStreamHelper(dataOut)
writeNextInputToStreamHelper(dataOut)
}

private var nextInputRoot: VectorSchemaRoot = _
private var nextInputLoader: VectorLoader = _
private var nextInputWriter: ArrowStreamWriter = _
private var nextInputWriterClosed = false

context.addTaskCompletionListener[Unit] {
_ =>
try {
closeNextInputWriter()
} catch {
case NonFatal(_) =>
}
}

private def ensureNextInputWriter(dataOut: DataOutputStream): Unit = {
if (nextInputWriter == null) {
val arrowSchema = SparkSchemaUtil.toArrowSchema(schema, timeZoneId)
val allocator = ArrowBufferAllocators.contextInstance()
nextInputRoot = VectorSchemaRoot.create(arrowSchema, allocator)
nextInputLoader = new VectorLoader(nextInputRoot)
nextInputWriter = new ArrowStreamWriter(nextInputRoot, null, dataOut)
nextInputWriter.start()
}
}

private def closeNextInputWriter(): Unit = {
if (!nextInputWriterClosed && nextInputRoot != null) {
try {
if (nextInputWriter != null) {
nextInputWriter.end()
}
} finally {
nextInputRoot.close()
nextInputWriterClosed = true
}
}
}

private def writeNextInputToStreamHelper(dataOut: DataOutputStream): Boolean = {
ensureNextInputWriter(dataOut)
if (!inputIterator.hasNext) {
closeNextInputWriter()
// See https://issues.apache.org/jira/browse/SPARK-44705:
// Starting from Spark 4.0, we should return false once the iterator is drained out,
// otherwise Spark won't stop calling this method repeatedly.
return false
}
val nextBatch = inputIterator.next()
val cols = (0 until nextBatch.numCols).toList.map(
i =>
nextBatch
.asInstanceOf[ColumnarBatch]
.column(i)
.asInstanceOf[ArrowWritableColumnVector]
.getValueVector)
val nextRecordBatch =
SparkVectorUtil.toArrowRecordBatch(nextBatch.numRows, cols)
try {
nextInputLoader.load(nextRecordBatch)
nextInputWriter.writeBatch()
true
} finally {
if (nextRecordBatch != null) {
nextRecordBatch.close()
}
}
}

def writeToStreamHelper(dataOut: DataOutputStream): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.gluten.execution.WholeStageTransformerSuite
import org.apache.spark.SparkConf
import org.apache.spark.api.python.ColumnarArrowEvalPythonExec
import org.apache.spark.sql.IntegratedUDFTestUtils
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.util.SparkVersionUtil

Expand All @@ -36,6 +38,9 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
newTestScalarPandasUDF(name = "pyarrowUDF", returnType = Some(StringType))
private val pyarrowTestUDFLong =
newTestScalarPandasUDF(name = "pyarrowUDF", returnType = Some(LongType))
private val SQL_ARROW_BATCHED_UDF = 101
private lazy val arrowBatchedTestUDFString =
newTestArrowBatchedPythonUDF(name = "arrowBatchedUDF", returnType = Some(StringType))

override def sparkConf: SparkConf = {
super.sparkConf
Expand Down Expand Up @@ -109,6 +114,52 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
checkAnswer(df, expected)
}

testWithMinSparkVersion("arrow batched python udf over parquet scan", "4.0") {
withTempPath {
f =>
Seq(("MAIL", 1), ("RAIL", 2), ("SHIP", 3)).toDF("shipmode", "id").write.parquet(
f.getCanonicalPath)
val base = spark.read.parquet(f.getCanonicalPath)
val arrowUdfCol = arrowBatchedTestUDFString(base("shipmode")).as("shipmode_arrow")
val df = base.select(arrowUdfCol).agg(max("shipmode_arrow").as("max_shipmode"))
val expected = Seq(Tuple1("SHIP")).toDF("max_shipmode")

checkSparkPlan[ColumnarArrowEvalPythonExec](df)
checkAnswer(df, expected)
}
}

private def newTestArrowBatchedPythonUDF(
name: String,
returnType: Option[DataType] = None): UserDefinedPythonFunction = {
val regularPythonUDF = newTestPythonUDF(name, returnType)
val regularUDF = regularPythonUDF.getClass
.getMethod("udf")
.invoke(regularPythonUDF)
.asInstanceOf[UserDefinedPythonFunction]
regularUDF.copy(
regularUDF.name,
regularUDF.func,
regularUDF.dataType,
SQL_ARROW_BATCHED_UDF,
regularUDF.udfDeterministic)
}

private def newTestPythonUDF(
name: String,
returnType: Option[DataType] = None): TestPythonUDF = {
if (SparkVersionUtil.gteSpark40) {
// After https://github.com/apache/spark/pull/42864 which landed in Spark 4.0, the return
// type of the UDF must be explicitly specified when creating the UDF instance with column
// expressions as parameter.
classOf[TestPythonUDF]
.getConstructor(classOf[String], classOf[Option[DataType]])
.newInstance(name, returnType)
} else {
TestPythonUDF(name)
}
}

private def newTestScalarPandasUDF(
name: String,
returnType: Option[DataType] = None): TestScalarPandasUDF = {
Expand Down
Loading