diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index 639864b196..d68ec63487 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -89,6 +89,96 @@ net.jpountz.lz4 lz4 + + + org.apache.fory + fory-core + 0.12.0 + + + + + org.scala-lang + scala-library + ${scala.version} + provided + + + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.15 + test + + + org.scalatestplus + junit-4-13_${scala.binary.version} + 3.2.15.0 + test + + + + + net.alchim31.maven + scala-maven-plugin + ${scala.maven.plugin.version} + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile-first + process-test-resources + + testCompile + + + + + ${scala.version} + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + compile + + compile + + + + + + + + org.scalatest + scalatest-maven-plugin + 2.0.2 + + ${project.build.directory}/surefire-reports + . + TestSuite.txt + + + + test + + test + + + + + + diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 6e2536caaf..0dd696ce7b 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -36,6 +36,7 @@ import org.apache.uniffle.common.config.ConfigUtils; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.shuffle.ShuffleSerializer; public class RssSparkConfig { @@ -110,6 +111,12 @@ public class RssSparkConfig { .defaultValue(true) .withDescription("indicates row based shuffle, set false when use in columnar shuffle"); + public static final ConfigOption RSS_SHUFFLE_SERIALIZER = + ConfigOptions.key("rss.client.shuffle.serializer") + .enumType(ShuffleSerializer.class) + .noDefaultValue() + .withDescription("Shuffle serializer type"); + public static final ConfigOption RSS_MEMORY_SPILL_ENABLED = ConfigOptions.key("rss.client.memory.spill.enabled") .booleanType() diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 3da4b147bb..6bbf9f8a96 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -40,9 +40,9 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryMode; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.ForySerializerInstance; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.RssSparkConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,6 +58,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER; import static org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED; public class WriteBufferManager extends MemoryConsumer { @@ -81,7 +82,6 @@ public class WriteBufferManager extends MemoryConsumer { private int shuffleId; private String taskId; private long taskAttemptId; - private SerializerInstance instance; private ShuffleWriteMetrics shuffleWriteMetrics; // cache partition -> records private Map buffers; @@ -192,8 +192,11 @@ public WriteBufferManager( // in columnar shuffle, the serializer here is never used this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED); if (isRowBased) { - this.instance = serializer.newInstance(); - this.serializeStream = instance.serializeStream(arrayOutputStream); + if (rssConf.contains(RSS_SHUFFLE_SERIALIZER)) { + this.serializeStream = new ForySerializerInstance().serializeStream(arrayOutputStream); + } else { + this.serializeStream = serializer.newInstance().serializeStream(arrayOutputStream); + } } boolean compress = rssConf.getBoolean( diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java new file mode 100644 index 0000000000..850b80e79c --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ShuffleSerializer.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.shuffle; + +public enum ShuffleSerializer { + FORY +} diff --git a/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala new file mode 100644 index 0000000000..be9f34c938 --- /dev/null +++ b/client-spark/common/src/main/scala/org/apache/spark/serializer/ForySerializer.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.fory.config.{CompatibleMode, Language} +import org.apache.fory.io.ForyInputStream +import org.apache.fory.{Fory, ThreadLocalFory} +import org.apache.spark.internal.Logging + +import java.io.{InputStream, OutputStream, Serializable} +import java.nio.ByteBuffer +import scala.reflect.ClassTag + +@SerialVersionUID(1L) +class ForySerializer extends org.apache.spark.serializer.Serializer + with Logging + with Serializable { + + override def newInstance(): SerializerInstance = new ForySerializerInstance() + + override def supportsRelocationOfSerializedObjects: Boolean = true + +} + +class ForySerializerInstance extends org.apache.spark.serializer.SerializerInstance { + + private val fury = Fory.builder() + .withLanguage(Language.JAVA) + .withRefTracking(true) + .withCompatibleMode(CompatibleMode.SCHEMA_CONSISTENT) + .requireClassRegistration(false) + .buildThreadLocalFory() + + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bytes = fury.serialize(t.asInstanceOf[AnyRef]) + ByteBuffer.wrap(bytes) + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + fury.deserialize(bytes).asInstanceOf[T] + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + // Fury handles class loading internally, so we can use the standard deserialize method + deserialize[T](bytes) + } + + override def serializeStream(s: OutputStream): SerializationStream = { + new ForySerializationStream(fury, s) + } + + override def deserializeStream(s: InputStream): DeserializationStream = { + new ForyDeserializationStream(fury, s) + } +} + +class ForySerializationStream(fury: ThreadLocalFory, outputStream: OutputStream) + extends org.apache.spark.serializer.SerializationStream { + + private var closed = false + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + if (closed) { + throw new IllegalStateException("Stream is closed") + } + fury.serialize(outputStream, t) + this + } + + override def flush(): Unit = { + if (!closed) { + outputStream.flush() + } + } + + override def close(): Unit = { + if (!closed) { + try { + outputStream.close() + } finally { + closed = true + } + } + } +} + +class ForyDeserializationStream(fury: ThreadLocalFory, inputStream: InputStream) + extends org.apache.spark.serializer.DeserializationStream { + + private var closed = false + private val foryStream = new ForyInputStream(inputStream) + + override def readObject[T: ClassTag](): T = { + if (closed) { + throw new IllegalStateException("Stream is closed") + } + fury.deserialize(foryStream).asInstanceOf[T] + } + + override def close(): Unit = { + if (!closed) { + try { + foryStream.close() + } finally { + closed = true + } + } + } +} \ No newline at end of file diff --git a/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala new file mode 100644 index 0000000000..650ee658a0 --- /dev/null +++ b/client-spark/common/src/test/scala/org/apache/spark/serializer/ForySerializerTest.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.must.Matchers +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import scala.collection.mutable +import scala.reflect.ClassTag + +class ForySerializerTest extends AnyFunSuite with Matchers { + + test("ForySerializer should create new instance") { + val serializer = new ForySerializer() + val instance = serializer.newInstance() + + instance should not be null + instance shouldBe a[ForySerializerInstance] + } + + test("ForySerializer should support relocation of serialized objects") { + val serializer = new ForySerializer() + serializer.supportsRelocationOfSerializedObjects shouldBe true + } + + test("ForySerializerInstance should serialize and deserialize simple primitive types") { + val instance = new ForySerializerInstance() + + // Test String + testSerializeDeserialize(instance, "Hello, Fory!") + + // Test Integer + testSerializeDeserialize(instance, Integer.valueOf(42)) + + // Test Long + testSerializeDeserialize(instance, java.lang.Long.valueOf(123456789L)) + + // Test Double + testSerializeDeserialize(instance, java.lang.Double.valueOf(3.14159)) + + // Test Boolean + testSerializeDeserialize(instance, java.lang.Boolean.valueOf(true)) + testSerializeDeserialize(instance, java.lang.Boolean.valueOf(false)) + } + + test("ForySerializerInstance should serialize and deserialize simple collections") { + val instance = new ForySerializerInstance() + + // Test simple Java ArrayList + val javaList = new java.util.ArrayList[String]() + javaList.add("apple") + javaList.add("banana") + javaList.add("cherry") + testSerializeDeserialize(instance, javaList) + + // Test simple Java HashMap + val javaMap = new java.util.HashMap[String, String]() + javaMap.put("key1", "value1") + javaMap.put("key2", "value2") + testSerializeDeserialize(instance, javaMap) + } + + test("ForySerializerInstance should serialize and deserialize simple case class") { + val instance = new ForySerializerInstance() + + // Simple class without complex dependencies + val data = new SimpleData("test", 123) + + testSerializeDeserialize(instance, data) + } + + test("ForySerializationStream should serialize and deserialize simple objects") { + val instance = new ForySerializerInstance() + val baos = new ByteArrayOutputStream() + val stream = instance.serializeStream(baos) + + // Write simple objects only + val objects = List("Hello", Integer.valueOf(123), java.lang.Boolean.valueOf(true)) + objects.foreach { obj => + stream.writeObject(obj)(getClassTag(obj)) + } + stream.flush() + stream.close() + + // Read back the objects + val bais = new ByteArrayInputStream(baos.toByteArray) + val deserStream = instance.deserializeStream(bais) + + val readObjects = mutable.ListBuffer[Any]() + try { + while (true) { + readObjects += deserStream.readObject[Any]() + } + } catch { + case _: java.io.EOFException => // Expected when reaching end + } + deserStream.close() + + readObjects.toList should contain theSameElementsInOrderAs objects + } + + test("ForySerializationStream should handle empty stream") { + val instance = new ForySerializerInstance() + val baos = new ByteArrayOutputStream() + val stream = instance.serializeStream(baos) + stream.flush() + stream.close() + + val bais = new ByteArrayInputStream(baos.toByteArray) + val deserStream = instance.deserializeStream(bais) + + intercept[java.io.EOFException] { + deserStream.readObject[String]() + } + + deserStream.close() + } + + test("ForySerializationStream should handle stream operations after close") { + val instance = new ForySerializerInstance() + val baos = new ByteArrayOutputStream() + val stream = instance.serializeStream(baos) + + stream.close() + + // Writing after close should throw exception + intercept[IllegalStateException] { + stream.writeObject("test") + } + } + + test("ForyDeserializationStream should handle stream operations after close") { + val instance = new ForySerializerInstance() + val bais = new ByteArrayInputStream(Array.empty[Byte]) + val deserStream = instance.deserializeStream(bais) + + deserStream.close() + + // Reading after close should throw exception + intercept[IllegalStateException] { + deserStream.readObject[String]() + } + } + + test("ForySerializerInstance should handle null values") { + val instance = new ForySerializerInstance() + + val serialized = instance.serialize[AnyRef](null) + val deserialized = instance.deserialize[AnyRef](serialized) + deserialized should be(null) + } + + test("ForySerializerInstance should handle byte arrays") { + val instance = new ForySerializerInstance() + + val byteArray = Array[Byte](1, 2, 3, 4, 5) + testSerializeDeserialize(instance, byteArray) + } + + test("ForySerializerInstance should handle large strings") { + val instance = new ForySerializerInstance() + + // Create a large string + val largeString = "x" * 10000 + testSerializeDeserialize(instance, largeString) + } + + private def testSerializeDeserialize[T](instance: ForySerializerInstance, obj: T)(implicit ct: ClassTag[T]): Unit = { + val serialized = instance.serialize(obj) + val deserialized = instance.deserialize[T](serialized) + + if (obj != null && obj.getClass.isArray) { + // Special handling for arrays since they don't implement equals properly + if (obj.isInstanceOf[Array[Byte]]) { + obj.asInstanceOf[Array[Byte]] should equal(deserialized.asInstanceOf[Array[Byte]]) + } else { + obj.asInstanceOf[Array[_]].toList should equal(deserialized.asInstanceOf[Array[_]].toList) + } + } else { + deserialized should equal(obj) + } + } + + private def getClassTag[T](obj: T): ClassTag[T] = { + if (obj == null) { + ClassTag.AnyRef.asInstanceOf[ClassTag[T]] + } else { + ClassTag(obj.getClass).asInstanceOf[ClassTag[T]] + } + } +} +// SimpleData is a normal class with equals/hashCode overridden for test assertions +class SimpleData(val name: String, val value: Int) { + override def equals(other: Any): Boolean = other match { + case that: SimpleData => this.name == that.name && this.value == that.value + case _ => false + } + override def hashCode(): Int = { + 31 * name.hashCode + value + } + override def toString: String = s"SimpleData($name, $value)" +} \ No newline at end of file diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 4113f06274..b00e02089a 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -42,6 +42,7 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleReadMetrics; +import org.apache.spark.serializer.ForySerializer; import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.FunctionUtils; import org.apache.spark.shuffle.RssShuffleHandle; @@ -68,6 +69,7 @@ import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_REORDER_MULTI_SERVERS_ENABLED; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; +import static org.apache.spark.shuffle.RssSparkConfig.RSS_SHUFFLE_SERIALIZER; public class RssShuffleReader implements ShuffleReader { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleReader.class); @@ -125,7 +127,10 @@ public RssShuffleReader( this.numMaps = rssShuffleHandle.getNumMaps(); this.shuffleDependency = rssShuffleHandle.getDependency(); this.shuffleId = shuffleDependency.shuffleId(); - this.serializer = rssShuffleHandle.getDependency().serializer(); + this.serializer = + rssConf.contains(RSS_SHUFFLE_SERIALIZER) + ? new ForySerializer() + : rssShuffleHandle.getDependency().serializer(); this.taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); this.basePath = basePath; this.partitionNum = partitionNum; @@ -307,8 +312,7 @@ class MultiPartitionIterator extends AbstractIterator> { .retryIntervalMax(retryIntervalMax) .rssConf(rssConf)); RssShuffleDataIterator iterator = - new RssShuffleDataIterator<>( - shuffleDependency.serializer(), shuffleReadClient, readMetrics, rssConf); + new RssShuffleDataIterator<>(serializer, shuffleReadClient, readMetrics, rssConf); CompletionIterator, RssShuffleDataIterator> completionIterator = CompletionIterator$.MODULE$.apply( iterator,