From 40fc15257cf9d52b6adfcc969dc699d09cbada94 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Aug 2025 16:05:07 +0800 Subject: [PATCH 1/8] [#2596] feat(spark): Introduce fory serializer --- client-spark/extension/pom.xml | 39 ++++ .../spark/serializer/ForySerializer.scala | 173 +++++++++++++++ .../spark/serializer/ForySerializerTest.scala | 201 ++++++++++++++++++ 3 files changed, 413 insertions(+) create mode 100644 client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala create mode 100644 client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala diff --git a/client-spark/extension/pom.xml b/client-spark/extension/pom.xml index ff0511eba9..7ca13eebcd 100644 --- a/client-spark/extension/pom.xml +++ b/client-spark/extension/pom.xml @@ -68,6 +68,25 @@ rss-client-spark-common ${project.version} + + org.apache.fory + fory-core + 0.12.0 + + + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.15 + test + + + org.scalatestplus + junit-4-13_${scala.binary.version} + 3.2.15.0 + test + @@ -88,6 +107,26 @@ ${scala.version} + + + + org.scalatest + scalatest-maven-plugin + 2.0.2 + + ${project.build.directory}/surefire-reports + . + TestSuite.txt + + + + test + + test + + + + diff --git a/client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala new file mode 100644 index 0000000000..8bc7f9e9d0 --- /dev/null +++ b/client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -0,0 +1,173 @@ +package org.apache.spark.serializer + +import org.apache.spark.internal.Logging +import org.apache.fory.Fory +import org.apache.fory.config.{CompatibleMode, Language} + +import java.io.{InputStream, OutputStream, Serializable} +import java.nio.ByteBuffer +import scala.reflect.ClassTag + +class ForySerializer extends org.apache.spark.serializer.Serializer + with Logging + with Serializable { + + override def newInstance(): SerializerInstance = new ForySerializerInstance() + + override def supportsRelocationOfSerializedObjects: Boolean = true + +} + +class ForySerializerInstance extends org.apache.spark.serializer.SerializerInstance { + + // Thread-local Fury instance for thread safety + private val fury: ThreadLocal[Fory] = ThreadLocal.withInitial(() => { + val f = Fory.builder() + .withLanguage(Language.JAVA) + .withRefTracking(true) + .withCompatibleMode(CompatibleMode.COMPATIBLE) + .requireClassRegistration(false) + .build() + f + }) + + private def createFuryInstance(): Fory = { + Fory.builder() + .withLanguage(Language.JAVA) + .withRefTracking(true) + .withCompatibleMode(CompatibleMode.COMPATIBLE) + .requireClassRegistration(false) + .build() + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bytes = fury.get().serialize(t.asInstanceOf[AnyRef]) + ByteBuffer.wrap(bytes) + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + val array = if (bytes.hasArray) { + val offset = bytes.arrayOffset() + bytes.position() + val length = bytes.remaining() + java.util.Arrays.copyOfRange(bytes.array(), offset, offset + length) + } else { + val array = new Array[Byte](bytes.remaining()) + bytes.get(array) + array + } + fury.get().deserialize(array).asInstanceOf[T] + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + // Fury handles class loading internally, so we can use the standard deserialize method + deserialize[T](bytes) + } + + override def serializeStream(s: OutputStream): SerializationStream = { + new ForySerializationStream(fury.get(), s) + } + + override def deserializeStream(s: InputStream): DeserializationStream = { + new ForyDeserializationStream(fury.get(), s) + } +} + +class ForySerializationStream(fury: Fory, outputStream: OutputStream) + extends org.apache.spark.serializer.SerializationStream { + + private val out = outputStream + private var closed = false + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + if (closed) { + throw new IllegalStateException("Stream is closed") + } + + val bytes = fury.serialize(t.asInstanceOf[AnyRef]) + // Write length first, then data + writeInt(bytes.length) + out.write(bytes) + this + } + + private def writeInt(value: Int): Unit = { + out.write((value >>> 24) & 0xFF) + out.write((value >>> 16) & 0xFF) + out.write((value >>> 8) & 0xFF) + out.write(value & 0xFF) + } + + override def flush(): Unit = { + if (!closed) { + out.flush() + } + } + + override def close(): Unit = { + if (!closed) { + try { + out.close() + } finally { + closed = true + } + } + } +} + +class ForyDeserializationStream(fury: Fory, inputStream: InputStream) + extends org.apache.spark.serializer.DeserializationStream { + + private val in = inputStream + private var closed = false + + override def readObject[T: ClassTag](): T = { + if (closed) { + throw new IllegalStateException("Stream is closed") + } + + try { + val length = readInt() + if (length < 0) { + throw new java.io.EOFException("Reached end of stream") + } + + val bytes = new Array[Byte](length) + var bytesRead = 0 + while (bytesRead < length) { + val read = in.read(bytes, bytesRead, length - bytesRead) + if (read == -1) { + throw new java.io.EOFException("Unexpected end of stream") + } + bytesRead += read + } + + fury.deserialize(bytes).asInstanceOf[T] + } catch { + case _: java.io.EOFException => + throw new java.io.EOFException("Reached end of stream") + } + } + + private def readInt(): Int = { + val b1 = in.read() + val b2 = in.read() + val b3 = in.read() + val b4 = in.read() + + if ((b1 | b2 | b3 | b4) < 0) { + throw new java.io.EOFException() + } + + (b1 << 24) + (b2 << 16) + (b3 << 8) + b4 + } + + override def close(): Unit = { + if (!closed) { + try { + in.close() + } finally { + closed = true + } + } + } +} \ No newline at end of file diff --git a/client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala b/client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala new file mode 100644 index 0000000000..1ebcec618f --- /dev/null +++ b/client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala @@ -0,0 +1,201 @@ +package org.apache.spark.serializer + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import scala.collection.mutable +import scala.reflect.ClassTag + +class ForySerializerTest extends AnyFunSuite with Matchers { + + test("ForySerializer should create new instance") { + val serializer = new ForySerializer() + val instance = serializer.newInstance() + + instance should not be null + instance shouldBe a[ForySerializerInstance] + } + + test("ForySerializer should support relocation of serialized objects") { + val serializer = new ForySerializer() + serializer.supportsRelocationOfSerializedObjects shouldBe true + } + + test("ForySerializerInstance should serialize and deserialize simple primitive types") { + val instance = new ForySerializerInstance() + + // Test String + testSerializeDeserialize(instance, "Hello, Fory!") + + // Test Integer + testSerializeDeserialize(instance, Integer.valueOf(42)) + + // Test Long + testSerializeDeserialize(instance, java.lang.Long.valueOf(123456789L)) + + // Test Double + testSerializeDeserialize(instance, java.lang.Double.valueOf(3.14159)) + + // Test Boolean + testSerializeDeserialize(instance, java.lang.Boolean.valueOf(true)) + testSerializeDeserialize(instance, java.lang.Boolean.valueOf(false)) + } + + test("ForySerializerInstance should serialize and deserialize simple collections") { + val instance = new ForySerializerInstance() + + // Test simple Java ArrayList + val javaList = new java.util.ArrayList[String]() + javaList.add("apple") + javaList.add("banana") + javaList.add("cherry") + testSerializeDeserialize(instance, javaList) + + // Test simple Java HashMap + val javaMap = new java.util.HashMap[String, String]() + javaMap.put("key1", "value1") + javaMap.put("key2", "value2") + testSerializeDeserialize(instance, javaMap) + } + + test("ForySerializerInstance should serialize and deserialize simple case class") { + val instance = new ForySerializerInstance() + + // Simple class without complex dependencies + val data = new SimpleData("test", 123) + + testSerializeDeserialize(instance, data) + } + + test("ForySerializationStream should serialize and deserialize simple objects") { + val instance = new ForySerializerInstance() + val baos = new ByteArrayOutputStream() + val stream = instance.serializeStream(baos) + + // Write simple objects only + val objects = List("Hello", Integer.valueOf(123), java.lang.Boolean.valueOf(true)) + objects.foreach { obj => + stream.writeObject(obj)(getClassTag(obj)) + } + stream.flush() + stream.close() + + // Read back the objects + val bais = new ByteArrayInputStream(baos.toByteArray) + val deserStream = instance.deserializeStream(bais) + + val readObjects = mutable.ListBuffer[Any]() + try { + while (true) { + readObjects += deserStream.readObject[Any]() + } + } catch { + case _: java.io.EOFException => // Expected when reaching end + } + deserStream.close() + + readObjects.toList should contain theSameElementsInOrderAs objects + } + + test("ForySerializationStream should handle empty stream") { + val instance = new ForySerializerInstance() + val baos = new ByteArrayOutputStream() + val stream = instance.serializeStream(baos) + stream.flush() + stream.close() + + val bais = new ByteArrayInputStream(baos.toByteArray) + val deserStream = instance.deserializeStream(bais) + + intercept[java.io.EOFException] { + deserStream.readObject[String]() + } + + deserStream.close() + } + + test("ForySerializationStream should handle stream operations after close") { + val instance = new ForySerializerInstance() + val baos = new ByteArrayOutputStream() + val stream = instance.serializeStream(baos) + + stream.close() + + // Writing after close should throw exception + intercept[IllegalStateException] { + stream.writeObject("test") + } + } + + test("ForyDeserializationStream should handle stream operations after close") { + val instance = new ForySerializerInstance() + val bais = new ByteArrayInputStream(Array.empty[Byte]) + val deserStream = instance.deserializeStream(bais) + + deserStream.close() + + // Reading after close should throw exception + intercept[IllegalStateException] { + deserStream.readObject[String]() + } + } + + test("ForySerializerInstance should handle null values") { + val instance = new ForySerializerInstance() + + val serialized = instance.serialize[AnyRef](null) + val deserialized = instance.deserialize[AnyRef](serialized) + deserialized should be(null) + } + + test("ForySerializerInstance should handle byte arrays") { + val instance = new ForySerializerInstance() + + val byteArray = Array[Byte](1, 2, 3, 4, 5) + testSerializeDeserialize(instance, byteArray) + } + + test("ForySerializerInstance should handle large strings") { + val instance = new ForySerializerInstance() + + // Create a large string + val largeString = "x" * 10000 + testSerializeDeserialize(instance, largeString) + } + + private def testSerializeDeserialize[T](instance: ForySerializerInstance, obj: T)(implicit ct: ClassTag[T]): Unit = { + val serialized = instance.serialize(obj) + val deserialized = instance.deserialize[T](serialized) + + if (obj != null && obj.getClass.isArray) { + // Special handling for arrays since they don't implement equals properly + if (obj.isInstanceOf[Array[Byte]]) { + obj.asInstanceOf[Array[Byte]] should equal(deserialized.asInstanceOf[Array[Byte]]) + } else { + obj.asInstanceOf[Array[_]].toList should equal(deserialized.asInstanceOf[Array[_]].toList) + } + } else { + deserialized should equal(obj) + } + } + + private def getClassTag[T](obj: T): ClassTag[T] = { + if (obj == null) { + ClassTag.AnyRef.asInstanceOf[ClassTag[T]] + } else { + ClassTag(obj.getClass).asInstanceOf[ClassTag[T]] + } + } +} +// SimpleData is a normal class with equals/hashCode overridden for test assertions +class SimpleData(val name: String, val value: Int) { + override def equals(other: Any): Boolean = other match { + case that: SimpleData => this.name == that.name && this.value == that.value + case _ => false + } + override def hashCode(): Int = { + 31 * name.hashCode + value + } + override def toString: String = s"SimpleData($name, $value)" +} \ No newline at end of file From 9931e5a11033973e00a98080cb73d6ccae522c11 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Aug 2025 16:12:38 +0800 Subject: [PATCH 2/8] reorg --- client-spark/common/pom.xml | 60 +++++++++++++++++++ .../spark/serializer/ForySerializer.scala | 2 +- .../spark/serializer/ForySerializerTest.scala | 3 +- client-spark/extension/pom.xml | 39 ------------ 4 files changed, 63 insertions(+), 41 deletions(-) rename client-spark/{extension => common}/src/main/scala/org/apache/spark/serializer/ForySerializer.scala (100%) rename client-spark/{extension => common}/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala (98%) diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index 639864b196..2589fd02e8 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -89,6 +89,66 @@ net.jpountz.lz4 lz4 + + + org.apache.fory + fory-core + 0.12.0 + + + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.15 + test + + + org.scalatestplus + junit-4-13_${scala.binary.version} + 3.2.15.0 + test + + + + + net.alchim31.maven + scala-maven-plugin + ${scala.maven.plugin.version} + + + + compile + testCompile + + + + + ${scala.version} + + + + + + org.scalatest + scalatest-maven-plugin + 2.0.2 + + ${project.build.directory}/surefire-reports + . + TestSuite.txt + + + + test + + test + + + + + + diff --git a/client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala similarity index 100% rename from client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala rename to client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala index 8bc7f9e9d0..7dee991403 100644 --- a/client-spark/extension/src/main/scala/org/apache/spark/serializer/ForySerializer.scala +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -1,8 +1,8 @@ package org.apache.spark.serializer -import org.apache.spark.internal.Logging import org.apache.fory.Fory import org.apache.fory.config.{CompatibleMode, Language} +import org.apache.spark.internal.Logging import java.io.{InputStream, OutputStream, Serializable} import java.nio.ByteBuffer diff --git a/client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala similarity index 98% rename from client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala rename to client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala index 1ebcec618f..cb2206a94f 100644 --- a/client-spark/extension/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala +++ b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala @@ -1,7 +1,8 @@ package org.apache.spark.serializer import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers +import org.scalatest.matchers.must.Matchers +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable diff --git a/client-spark/extension/pom.xml b/client-spark/extension/pom.xml index 7ca13eebcd..ff0511eba9 100644 --- a/client-spark/extension/pom.xml +++ b/client-spark/extension/pom.xml @@ -68,25 +68,6 @@ rss-client-spark-common ${project.version} - - org.apache.fory - fory-core - 0.12.0 - - - - - org.scalatest - scalatest_${scala.binary.version} - 3.2.15 - test - - - org.scalatestplus - junit-4-13_${scala.binary.version} - 3.2.15.0 - test - @@ -107,26 +88,6 @@ ${scala.version} - - - - org.scalatest - scalatest-maven-plugin - 2.0.2 - - ${project.build.directory}/surefire-reports - . - TestSuite.txt - - - - test - - test - - - - From 793fe550fdb7a2d577b4dcafaa3782f277bd0d92 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Aug 2025 17:06:49 +0800 Subject: [PATCH 3/8] add option --- client-spark/common/pom.xml | 30 +++++++++++++++++++ .../apache/spark/shuffle/RssSparkConfig.java | 7 +++++ .../shuffle/writer/WriteBufferManager.java | 11 ++++--- .../uniffle/shuffle/ShuffleSerializer.java | 22 ++++++++++++++ .../spark/serializer/ForySerializer.scala | 9 ------ .../shuffle/reader/RssShuffleReader.java | 7 ++++- 6 files changed, 72 insertions(+), 14 deletions(-) create mode 100644 client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index 2589fd02e8..d68ec63487 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -95,6 +95,14 @@ fory-core 0.12.0 + + + + org.scala-lang + scala-library + ${scala.version} + provided + @@ -119,8 +127,17 @@ ${scala.maven.plugin.version} + scala-compile-first + process-resources + add-source compile + + + + scala-test-compile-first + process-test-resources + testCompile @@ -129,6 +146,19 @@ ${scala.version} + + + org.apache.maven.plugins + maven-compiler-plugin + + + compile + + compile + + + + diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 6e2536caaf..0dd696ce7b 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -36,6 +36,7 @@ import org.apache.uniffle.common.config.ConfigUtils; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.shuffle.ShuffleSerializer; public class RssSparkConfig { @@ -110,6 +111,12 @@ public class RssSparkConfig { .defaultValue(true) .withDescription("indicates row based shuffle, set false when use in columnar shuffle"); + public static final ConfigOption RSS_SHUFFLE_SERIALIZER = + ConfigOptions.key("rss.client.shuffle.serializer") + .enumType(ShuffleSerializer.class) + .noDefaultValue() + .withDescription("Shuffle serializer type"); + public static final ConfigOption RSS_MEMORY_SPILL_ENABLED = ConfigOptions.key("rss.client.memory.spill.enabled") .booleanType() diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 3da4b147bb..6bbf9f8a96 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -40,9 +40,9 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryMode; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.ForySerializerInstance; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.RssSparkConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,6 +58,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER; import static org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED; public class WriteBufferManager extends MemoryConsumer { @@ -81,7 +82,6 @@ public class WriteBufferManager extends MemoryConsumer { private int shuffleId; private String taskId; private long taskAttemptId; - private SerializerInstance instance; private ShuffleWriteMetrics shuffleWriteMetrics; // cache partition -> records private Map buffers; @@ -192,8 +192,11 @@ public WriteBufferManager( // in columnar shuffle, the serializer here is never used this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED); if (isRowBased) { - this.instance = serializer.newInstance(); - this.serializeStream = instance.serializeStream(arrayOutputStream); + if (rssConf.contains(RSS_SHUFFLE_SERIALIZER)) { + this.serializeStream = new ForySerializerInstance().serializeStream(arrayOutputStream); + } else { + this.serializeStream = serializer.newInstance().serializeStream(arrayOutputStream); + } } boolean compress = rssConf.getBoolean( diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java new file mode 100644 index 0000000000..850b80e79c --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.shuffle; + +public enum ShuffleSerializer { + FORY +} diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala index 7dee991403..b5b04f7cf9 100644 --- a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -31,15 +31,6 @@ class ForySerializerInstance extends org.apache.spark.serializer.SerializerInsta f }) - private def createFuryInstance(): Fory = { - Fory.builder() - .withLanguage(Language.JAVA) - .withRefTracking(true) - .withCompatibleMode(CompatibleMode.COMPATIBLE) - .requireClassRegistration(false) - .build() - } - override def serialize[T: ClassTag](t: T): ByteBuffer = { val bytes = fury.get().serialize(t.asInstanceOf[AnyRef]) ByteBuffer.wrap(bytes) diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 4113f06274..daffc28274 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -42,6 +42,7 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleReadMetrics; +import org.apache.spark.serializer.ForySerializer; import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.FunctionUtils; import org.apache.spark.shuffle.RssShuffleHandle; @@ -68,6 +69,7 @@ import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_REORDER_MULTI_SERVERS_ENABLED; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER; public class RssShuffleReader implements ShuffleReader { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleReader.class); @@ -125,7 +127,10 @@ public RssShuffleReader( this.numMaps = rssShuffleHandle.getNumMaps(); this.shuffleDependency = rssShuffleHandle.getDependency(); this.shuffleId = shuffleDependency.shuffleId(); - this.serializer = rssShuffleHandle.getDependency().serializer(); + this.serializer = + rssConf.contains(RSS_SHUFFLE_SERIALIZER) + ? new ForySerializer() + : rssShuffleHandle.getDependency().serializer(); this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); this.basePath = basePath; this.partitionNum = partitionNum; From 9d44b311ace9aed89ae6fa72e1f09a820b70953e Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Aug 2025 17:17:00 +0800 Subject: [PATCH 4/8] fix reader --- .../java/org/apache/spark/shuffle/reader/RssShuffleReader.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index daffc28274..b00e02089a 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -312,8 +312,7 @@ class MultiPartitionIterator extends AbstractIterator> { .retryIntervalMax(retryIntervalMax) .rssConf(rssConf)); RssShuffleDataIterator iterator = - new RssShuffleDataIterator<>( - shuffleDependency.serializer(), shuffleReadClient, readMetrics, rssConf); + new RssShuffleDataIterator<>(serializer, shuffleReadClient, readMetrics, rssConf); CompletionIterator, RssShuffleDataIterator> completionIterator = CompletionIterator$.MODULE$.apply( iterator, From 32f2ad5172ed007407af464d9aadbc1eae81e3b4 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Aug 2025 17:17:33 +0800 Subject: [PATCH 5/8] add liencese header --- .../spark/serializer/ForySerializer.scala | 17 +++++++++++++++++ .../spark/serializer/ForySerializerTest.scala | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala index b5b04f7cf9..a926b0c4cf 100644 --- a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.serializer import org.apache.fory.Fory diff --git a/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala index cb2206a94f..650ee658a0 100644 --- a/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala +++ b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.serializer import org.scalatest.funsuite.AnyFunSuite From 277d84de4e131f392ffbefc6f27cad42489fc3ab Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Aug 2025 17:31:23 +0800 Subject: [PATCH 6/8] use internal thread local fory --- .../spark/serializer/ForySerializer.scala | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala index a926b0c4cf..5674da6853 100644 --- a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import org.apache.fory.Fory +import org.apache.fory.{Fory, ThreadLocalFory} import org.apache.fory.config.{CompatibleMode, Language} import org.apache.spark.internal.Logging @@ -37,19 +37,15 @@ class ForySerializer extends org.apache.spark.serializer.Serializer class ForySerializerInstance extends org.apache.spark.serializer.SerializerInstance { - // Thread-local Fury instance for thread safety - private val fury: ThreadLocal[Fory] = ThreadLocal.withInitial(() => { - val f = Fory.builder() - .withLanguage(Language.JAVA) - .withRefTracking(true) - .withCompatibleMode(CompatibleMode.COMPATIBLE) - .requireClassRegistration(false) - .build() - f - }) + private val fury = Fory.builder() + .withLanguage(Language.JAVA) + .withRefTracking(true) + .withCompatibleMode(CompatibleMode.SCHEMA_CONSISTENT) + .requireClassRegistration(false) + .buildThreadLocalFory(); override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bytes = fury.get().serialize(t.asInstanceOf[AnyRef]) + val bytes = fury.serialize(t.asInstanceOf[AnyRef]) ByteBuffer.wrap(bytes) } @@ -63,7 +59,7 @@ class ForySerializerInstance extends org.apache.spark.serializer.SerializerInsta bytes.get(array) array } - fury.get().deserialize(array).asInstanceOf[T] + fury.deserialize(array).asInstanceOf[T] } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { @@ -72,15 +68,15 @@ class ForySerializerInstance extends org.apache.spark.serializer.SerializerInsta } override def serializeStream(s: OutputStream): SerializationStream = { - new ForySerializationStream(fury.get(), s) + new ForySerializationStream(fury, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new ForyDeserializationStream(fury.get(), s) + new ForyDeserializationStream(fury, s) } } -class ForySerializationStream(fury: Fory, outputStream: OutputStream) +class ForySerializationStream(fury: ThreadLocalFory, outputStream: OutputStream) extends org.apache.spark.serializer.SerializationStream { private val out = outputStream @@ -122,7 +118,7 @@ class ForySerializationStream(fury: Fory, outputStream: OutputStream) } } -class ForyDeserializationStream(fury: Fory, inputStream: InputStream) +class ForyDeserializationStream(fury: ThreadLocalFory, inputStream: InputStream) extends org.apache.spark.serializer.DeserializationStream { private val in = inputStream @@ -132,7 +128,7 @@ class ForyDeserializationStream(fury: Fory, inputStream: InputStream) if (closed) { throw new IllegalStateException("Stream is closed") } - + try { val length = readInt() if (length < 0) { @@ -161,11 +157,11 @@ class ForyDeserializationStream(fury: Fory, inputStream: InputStream) val b2 = in.read() val b3 = in.read() val b4 = in.read() - + if ((b1 | b2 | b3 | b4) < 0) { throw new java.io.EOFException() } - + (b1 << 24) + (b2 << 16) + (b3 << 8) + b4 } From de3b1f65d97e9d877345f4318979e92ecbac30ad Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 2 Sep 2025 17:55:51 +0800 Subject: [PATCH 7/8] simplify ser --- .../spark/serializer/ForySerializer.scala | 74 +++---------------- 1 file changed, 10 insertions(+), 64 deletions(-) diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala index 5674da6853..6945bb5d10 100644 --- a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -17,8 +17,9 @@ package org.apache.spark.serializer -import org.apache.fory.{Fory, ThreadLocalFory} import org.apache.fory.config.{CompatibleMode, Language} +import org.apache.fory.io.ForyInputStream +import org.apache.fory.{Fory, ThreadLocalFory} import org.apache.spark.internal.Logging import java.io.{InputStream, OutputStream, Serializable} @@ -42,7 +43,7 @@ class ForySerializerInstance extends org.apache.spark.serializer.SerializerInsta .withRefTracking(true) .withCompatibleMode(CompatibleMode.SCHEMA_CONSISTENT) .requireClassRegistration(false) - .buildThreadLocalFory(); + .buildThreadLocalFory() override def serialize[T: ClassTag](t: T): ByteBuffer = { val bytes = fury.serialize(t.asInstanceOf[AnyRef]) @@ -50,16 +51,7 @@ class ForySerializerInstance extends org.apache.spark.serializer.SerializerInsta } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - val array = if (bytes.hasArray) { - val offset = bytes.arrayOffset() + bytes.position() - val length = bytes.remaining() - java.util.Arrays.copyOfRange(bytes.array(), offset, offset + length) - } else { - val array = new Array[Byte](bytes.remaining()) - bytes.get(array) - array - } - fury.deserialize(array).asInstanceOf[T] + fury.deserialize(bytes).asInstanceOf[T] } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { @@ -79,38 +71,26 @@ class ForySerializerInstance extends org.apache.spark.serializer.SerializerInsta class ForySerializationStream(fury: ThreadLocalFory, outputStream: OutputStream) extends org.apache.spark.serializer.SerializationStream { - private val out = outputStream private var closed = false override def writeObject[T: ClassTag](t: T): SerializationStream = { if (closed) { throw new IllegalStateException("Stream is closed") } - - val bytes = fury.serialize(t.asInstanceOf[AnyRef]) - // Write length first, then data - writeInt(bytes.length) - out.write(bytes) + fury.serialize(outputStream, t) this } - private def writeInt(value: Int): Unit = { - out.write((value >>> 24) & 0xFF) - out.write((value >>> 16) & 0xFF) - out.write((value >>> 8) & 0xFF) - out.write(value & 0xFF) - } - override def flush(): Unit = { if (!closed) { - out.flush() + outputStream.flush() } } override def close(): Unit = { if (!closed) { try { - out.close() + outputStream.close() } finally { closed = true } @@ -121,54 +101,20 @@ class ForySerializationStream(fury: ThreadLocalFory, outputStream: OutputStream) class ForyDeserializationStream(fury: ThreadLocalFory, inputStream: InputStream) extends org.apache.spark.serializer.DeserializationStream { - private val in = inputStream private var closed = false + private val foryStream = new ForyInputStream(inputStream) override def readObject[T: ClassTag](): T = { if (closed) { throw new IllegalStateException("Stream is closed") } - - try { - val length = readInt() - if (length < 0) { - throw new java.io.EOFException("Reached end of stream") - } - - val bytes = new Array[Byte](length) - var bytesRead = 0 - while (bytesRead < length) { - val read = in.read(bytes, bytesRead, length - bytesRead) - if (read == -1) { - throw new java.io.EOFException("Unexpected end of stream") - } - bytesRead += read - } - - fury.deserialize(bytes).asInstanceOf[T] - } catch { - case _: java.io.EOFException => - throw new java.io.EOFException("Reached end of stream") - } - } - - private def readInt(): Int = { - val b1 = in.read() - val b2 = in.read() - val b3 = in.read() - val b4 = in.read() - - if ((b1 | b2 | b3 | b4) < 0) { - throw new java.io.EOFException() - } - - (b1 << 24) + (b2 << 16) + (b3 << 8) + b4 + fury.deserialize(foryStream).asInstanceOf[T] } override def close(): Unit = { if (!closed) { try { - in.close() + foryStream.close() } finally { closed = true } From c2a7d46a45a560f1197704c2b731a63e07d2a557 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 2 Sep 2025 17:58:32 +0800 Subject: [PATCH 8/8] fix --- .../main/scala/org/apache/spark/serializer/ForySerializer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala index 6945bb5d10..be9f34c938 100644 --- a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -26,6 +26,7 @@ import java.io.{InputStream, OutputStream, Serializable} import java.nio.ByteBuffer import scala.reflect.ClassTag +@SerialVersionUID(1L) class ForySerializer extends org.apache.spark.serializer.Serializer with Logging with Serializable {