diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml
index 639864b196..d68ec63487 100644
--- a/client-spark/common/pom.xml
+++ b/client-spark/common/pom.xml
@@ -89,6 +89,96 @@
net.jpountz.lz4
lz4
+
+
+ org.apache.fory
+ fory-core
+ 0.12.0
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+ provided
+
+
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ 3.2.15
+ test
+
+
+ org.scalatestplus
+ junit-4-13_${scala.binary.version}
+ 3.2.15.0
+ test
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ ${scala.maven.plugin.version}
+
+
+ scala-compile-first
+ process-resources
+
+ add-source
+ compile
+
+
+
+ scala-test-compile-first
+ process-test-resources
+
+ testCompile
+
+
+
+
+ ${scala.version}
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ compile
+
+ compile
+
+
+
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.0.2
+
+ ${project.build.directory}/surefire-reports
+ .
+ TestSuite.txt
+
+
+
+ test
+
+ test
+
+
+
+
+
+
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 6e2536caaf..0dd696ce7b 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -36,6 +36,7 @@
import org.apache.uniffle.common.config.ConfigUtils;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.shuffle.ShuffleSerializer;
public class RssSparkConfig {
@@ -110,6 +111,12 @@ public class RssSparkConfig {
.defaultValue(true)
.withDescription("indicates row based shuffle, set false when use in columnar shuffle");
+ public static final ConfigOption RSS_SHUFFLE_SERIALIZER =
+ ConfigOptions.key("rss.client.shuffle.serializer")
+ .enumType(ShuffleSerializer.class)
+ .noDefaultValue()
+ .withDescription("Shuffle serializer type");
+
public static final ConfigOption RSS_MEMORY_SPILL_ENABLED =
ConfigOptions.key("rss.client.memory.spill.enabled")
.booleanType()
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 3da4b147bb..6bbf9f8a96 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -40,9 +40,9 @@
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.ForySerializerInstance;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
-import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.RssSparkConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -58,6 +58,7 @@
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED;
public class WriteBufferManager extends MemoryConsumer {
@@ -81,7 +82,6 @@ public class WriteBufferManager extends MemoryConsumer {
private int shuffleId;
private String taskId;
private long taskAttemptId;
- private SerializerInstance instance;
private ShuffleWriteMetrics shuffleWriteMetrics;
// cache partition -> records
private Map buffers;
@@ -192,8 +192,11 @@ public WriteBufferManager(
// in columnar shuffle, the serializer here is never used
this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED);
if (isRowBased) {
- this.instance = serializer.newInstance();
- this.serializeStream = instance.serializeStream(arrayOutputStream);
+ if (rssConf.contains(RSS_SHUFFLE_SERIALIZER)) {
+ this.serializeStream = new ForySerializerInstance().serializeStream(arrayOutputStream);
+ } else {
+ this.serializeStream = serializer.newInstance().serializeStream(arrayOutputStream);
+ }
}
boolean compress =
rssConf.getBoolean(
diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java
new file mode 100644
index 0000000000..850b80e79c
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.shuffle;
+
+public enum ShuffleSerializer {
+ FORY
+}
diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala
new file mode 100644
index 0000000000..be9f34c938
--- /dev/null
+++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.fory.config.{CompatibleMode, Language}
+import org.apache.fory.io.ForyInputStream
+import org.apache.fory.{Fory, ThreadLocalFory}
+import org.apache.spark.internal.Logging
+
+import java.io.{InputStream, OutputStream, Serializable}
+import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
+@SerialVersionUID(1L)
+class ForySerializer extends org.apache.spark.serializer.Serializer
+ with Logging
+ with Serializable {
+
+ override def newInstance(): SerializerInstance = new ForySerializerInstance()
+
+ override def supportsRelocationOfSerializedObjects: Boolean = true
+
+}
+
+class ForySerializerInstance extends org.apache.spark.serializer.SerializerInstance {
+
+ private val fury = Fory.builder()
+ .withLanguage(Language.JAVA)
+ .withRefTracking(true)
+ .withCompatibleMode(CompatibleMode.SCHEMA_CONSISTENT)
+ .requireClassRegistration(false)
+ .buildThreadLocalFory()
+
+ override def serialize[T: ClassTag](t: T): ByteBuffer = {
+ val bytes = fury.serialize(t.asInstanceOf[AnyRef])
+ ByteBuffer.wrap(bytes)
+ }
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
+ fury.deserialize(bytes).asInstanceOf[T]
+ }
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
+ // Fury handles class loading internally, so we can use the standard deserialize method
+ deserialize[T](bytes)
+ }
+
+ override def serializeStream(s: OutputStream): SerializationStream = {
+ new ForySerializationStream(fury, s)
+ }
+
+ override def deserializeStream(s: InputStream): DeserializationStream = {
+ new ForyDeserializationStream(fury, s)
+ }
+}
+
+class ForySerializationStream(fury: ThreadLocalFory, outputStream: OutputStream)
+ extends org.apache.spark.serializer.SerializationStream {
+
+ private var closed = false
+
+ override def writeObject[T: ClassTag](t: T): SerializationStream = {
+ if (closed) {
+ throw new IllegalStateException("Stream is closed")
+ }
+ fury.serialize(outputStream, t)
+ this
+ }
+
+ override def flush(): Unit = {
+ if (!closed) {
+ outputStream.flush()
+ }
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ try {
+ outputStream.close()
+ } finally {
+ closed = true
+ }
+ }
+ }
+}
+
+class ForyDeserializationStream(fury: ThreadLocalFory, inputStream: InputStream)
+ extends org.apache.spark.serializer.DeserializationStream {
+
+ private var closed = false
+ private val foryStream = new ForyInputStream(inputStream)
+
+ override def readObject[T: ClassTag](): T = {
+ if (closed) {
+ throw new IllegalStateException("Stream is closed")
+ }
+ fury.deserialize(foryStream).asInstanceOf[T]
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ try {
+ foryStream.close()
+ } finally {
+ closed = true
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala
new file mode 100644
index 0000000000..650ee658a0
--- /dev/null
+++ b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.must.Matchers
+import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+class ForySerializerTest extends AnyFunSuite with Matchers {
+
+ test("ForySerializer should create new instance") {
+ val serializer = new ForySerializer()
+ val instance = serializer.newInstance()
+
+ instance should not be null
+ instance shouldBe a[ForySerializerInstance]
+ }
+
+ test("ForySerializer should support relocation of serialized objects") {
+ val serializer = new ForySerializer()
+ serializer.supportsRelocationOfSerializedObjects shouldBe true
+ }
+
+ test("ForySerializerInstance should serialize and deserialize simple primitive types") {
+ val instance = new ForySerializerInstance()
+
+ // Test String
+ testSerializeDeserialize(instance, "Hello, Fory!")
+
+ // Test Integer
+ testSerializeDeserialize(instance, Integer.valueOf(42))
+
+ // Test Long
+ testSerializeDeserialize(instance, java.lang.Long.valueOf(123456789L))
+
+ // Test Double
+ testSerializeDeserialize(instance, java.lang.Double.valueOf(3.14159))
+
+ // Test Boolean
+ testSerializeDeserialize(instance, java.lang.Boolean.valueOf(true))
+ testSerializeDeserialize(instance, java.lang.Boolean.valueOf(false))
+ }
+
+ test("ForySerializerInstance should serialize and deserialize simple collections") {
+ val instance = new ForySerializerInstance()
+
+ // Test simple Java ArrayList
+ val javaList = new java.util.ArrayList[String]()
+ javaList.add("apple")
+ javaList.add("banana")
+ javaList.add("cherry")
+ testSerializeDeserialize(instance, javaList)
+
+ // Test simple Java HashMap
+ val javaMap = new java.util.HashMap[String, String]()
+ javaMap.put("key1", "value1")
+ javaMap.put("key2", "value2")
+ testSerializeDeserialize(instance, javaMap)
+ }
+
+ test("ForySerializerInstance should serialize and deserialize simple case class") {
+ val instance = new ForySerializerInstance()
+
+ // Simple class without complex dependencies
+ val data = new SimpleData("test", 123)
+
+ testSerializeDeserialize(instance, data)
+ }
+
+ test("ForySerializationStream should serialize and deserialize simple objects") {
+ val instance = new ForySerializerInstance()
+ val baos = new ByteArrayOutputStream()
+ val stream = instance.serializeStream(baos)
+
+ // Write simple objects only
+ val objects = List("Hello", Integer.valueOf(123), java.lang.Boolean.valueOf(true))
+ objects.foreach { obj =>
+ stream.writeObject(obj)(getClassTag(obj))
+ }
+ stream.flush()
+ stream.close()
+
+ // Read back the objects
+ val bais = new ByteArrayInputStream(baos.toByteArray)
+ val deserStream = instance.deserializeStream(bais)
+
+ val readObjects = mutable.ListBuffer[Any]()
+ try {
+ while (true) {
+ readObjects += deserStream.readObject[Any]()
+ }
+ } catch {
+ case _: java.io.EOFException => // Expected when reaching end
+ }
+ deserStream.close()
+
+ readObjects.toList should contain theSameElementsInOrderAs objects
+ }
+
+ test("ForySerializationStream should handle empty stream") {
+ val instance = new ForySerializerInstance()
+ val baos = new ByteArrayOutputStream()
+ val stream = instance.serializeStream(baos)
+ stream.flush()
+ stream.close()
+
+ val bais = new ByteArrayInputStream(baos.toByteArray)
+ val deserStream = instance.deserializeStream(bais)
+
+ intercept[java.io.EOFException] {
+ deserStream.readObject[String]()
+ }
+
+ deserStream.close()
+ }
+
+ test("ForySerializationStream should handle stream operations after close") {
+ val instance = new ForySerializerInstance()
+ val baos = new ByteArrayOutputStream()
+ val stream = instance.serializeStream(baos)
+
+ stream.close()
+
+ // Writing after close should throw exception
+ intercept[IllegalStateException] {
+ stream.writeObject("test")
+ }
+ }
+
+ test("ForyDeserializationStream should handle stream operations after close") {
+ val instance = new ForySerializerInstance()
+ val bais = new ByteArrayInputStream(Array.empty[Byte])
+ val deserStream = instance.deserializeStream(bais)
+
+ deserStream.close()
+
+ // Reading after close should throw exception
+ intercept[IllegalStateException] {
+ deserStream.readObject[String]()
+ }
+ }
+
+ test("ForySerializerInstance should handle null values") {
+ val instance = new ForySerializerInstance()
+
+ val serialized = instance.serialize[AnyRef](null)
+ val deserialized = instance.deserialize[AnyRef](serialized)
+ deserialized should be(null)
+ }
+
+ test("ForySerializerInstance should handle byte arrays") {
+ val instance = new ForySerializerInstance()
+
+ val byteArray = Array[Byte](1, 2, 3, 4, 5)
+ testSerializeDeserialize(instance, byteArray)
+ }
+
+ test("ForySerializerInstance should handle large strings") {
+ val instance = new ForySerializerInstance()
+
+ // Create a large string
+ val largeString = "x" * 10000
+ testSerializeDeserialize(instance, largeString)
+ }
+
+ private def testSerializeDeserialize[T](instance: ForySerializerInstance, obj: T)(implicit ct: ClassTag[T]): Unit = {
+ val serialized = instance.serialize(obj)
+ val deserialized = instance.deserialize[T](serialized)
+
+ if (obj != null && obj.getClass.isArray) {
+ // Special handling for arrays since they don't implement equals properly
+ if (obj.isInstanceOf[Array[Byte]]) {
+ obj.asInstanceOf[Array[Byte]] should equal(deserialized.asInstanceOf[Array[Byte]])
+ } else {
+ obj.asInstanceOf[Array[_]].toList should equal(deserialized.asInstanceOf[Array[_]].toList)
+ }
+ } else {
+ deserialized should equal(obj)
+ }
+ }
+
+ private def getClassTag[T](obj: T): ClassTag[T] = {
+ if (obj == null) {
+ ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
+ } else {
+ ClassTag(obj.getClass).asInstanceOf[ClassTag[T]]
+ }
+ }
+}
+// SimpleData is a normal class with equals/hashCode overridden for test assertions
+class SimpleData(val name: String, val value: Int) {
+ override def equals(other: Any): Boolean = other match {
+ case that: SimpleData => this.name == that.name && this.value == that.value
+ case _ => false
+ }
+ override def hashCode(): Int = {
+ 31 * name.hashCode + value
+ }
+ override def toString: String = s"SimpleData($name, $value)"
+}
\ No newline at end of file
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 4113f06274..b00e02089a 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -42,6 +42,7 @@
import org.apache.spark.ShuffleDependency;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleReadMetrics;
+import org.apache.spark.serializer.ForySerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.FunctionUtils;
import org.apache.spark.shuffle.RssShuffleHandle;
@@ -68,6 +69,7 @@
import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_REORDER_MULTI_SERVERS_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER;
public class RssShuffleReader implements ShuffleReader {
private static final Logger LOG = LoggerFactory.getLogger(RssShuffleReader.class);
@@ -125,7 +127,10 @@ public RssShuffleReader(
this.numMaps = rssShuffleHandle.getNumMaps();
this.shuffleDependency = rssShuffleHandle.getDependency();
this.shuffleId = shuffleDependency.shuffleId();
- this.serializer = rssShuffleHandle.getDependency().serializer();
+ this.serializer =
+ rssConf.contains(RSS_SHUFFLE_SERIALIZER)
+ ? new ForySerializer()
+ : rssShuffleHandle.getDependency().serializer();
this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
this.basePath = basePath;
this.partitionNum = partitionNum;
@@ -307,8 +312,7 @@ class MultiPartitionIterator extends AbstractIterator> {
.retryIntervalMax(retryIntervalMax)
.rssConf(rssConf));
RssShuffleDataIterator iterator =
- new RssShuffleDataIterator<>(
- shuffleDependency.serializer(), shuffleReadClient, readMetrics, rssConf);
+ new RssShuffleDataIterator<>(serializer, shuffleReadClient, readMetrics, rssConf);
CompletionIterator, RssShuffleDataIterator> completionIterator =
CompletionIterator$.MODULE$.apply(
iterator,