Skip to content

Commit 59b8a44

Browse files
pranavdev022hvanhovell
authored andcommitted
[SPARK-54001][SQL] Optimize memory usage in session cloning with ref-counted cached local relations
### What changes were proposed in this pull request? This PR optimizes memory management for cached local relations when cloning Spark sessions by implementing reference counting instead of data replication. **Current behavior:** - When a session is cloned, cached local relation data stored in the block manager is replicated. - Each clone creates a duplicate copy of the data with a new block ID. - This causes unnecessary memory pressure. **Proposed changes:** - Implement reference counting for cached local relations during session cloning. - Retain the same block ID and data reference when cloning sessions, incrementing a ref count instead of copying - Add a hash-to-blockId mapping in ArtifactManager for efficient block lookup - Clean up blocks from block manager memory when ref count reaches zero ### Why are the changes needed? Cloning sessions is a common operation in Spark applications (e.g., for creating isolated execution contexts). The current approach of duplicating cached data can significantly increase memory footprint, especially when: - Sessions are cloned frequently - Cached relations contain large datasets - Multiple clones exist simultaneously This optimization reduces memory pressure, improves performance by avoiding unnecessary data copies. ### Does this PR introduce _any_ user-facing change? No. This is an internal optimization that improves memory efficiency without changing user-facing APIs or behavior. ### How was this patch tested? - Added unit tests to verify the reference count logic functioning. - Existing unit tests for ArtifactManager and session cloning. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52651 from pranavdev022/clone-artifactmanager-fix. Authored-by: pranavdev022 <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent cf5c7b6 commit 59b8a44

File tree

5 files changed

+129
-46
lines changed

5 files changed

+129
-46
lines changed

project/MimaExcludes.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ object MimaExcludes {
4545
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.PrimitiveKeyOpenHashMap*"),
4646

4747
// [SPARK-54041][SQL] Enable Direct Passthrough Partitioning in the DataFrame API
48-
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById")
48+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById"),
49+
50+
// [SPARK-54001][CONNECT] Replace block copying with ref-counting in ArtifactManager cloning
51+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.artifact.ArtifactManager.cachedBlockIdList")
4952
)
5053

5154
// Default exclude rules

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ object CheckConnectJvmClientCompatibility {
234234
"org.apache.spark.sql.artifact.ArtifactManager$"),
235235
ProblemFilters.exclude[MissingClassProblem](
236236
"org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"),
237+
ProblemFilters.exclude[MissingClassProblem](
238+
"org.apache.spark.sql.artifact.RefCountedCacheId"),
237239

238240
// ColumnNode conversions
239241
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession"),

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
6060
import org.apache.spark.sql.classic.{Catalog, DataFrameWriter, Dataset, MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils}
6161
import org.apache.spark.sql.classic.ClassicConversions._
6262
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
63-
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
63+
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
6464
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
6565
import org.apache.spark.sql.connect.ml.MLHandler
6666
import org.apache.spark.sql.connect.pipelines.PipelinesHandler
@@ -1330,7 +1330,9 @@ class SparkConnectPlanner(
13301330

13311331
private def transformCachedLocalRelation(rel: proto.CachedLocalRelation): LogicalPlan = {
13321332
val blockManager = session.sparkContext.env.blockManager
1333-
val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash)
1333+
val blockId = session.artifactManager.getCachedBlockId(rel.getHash).getOrElse {
1334+
throw InvalidPlanInput(s"Cannot find a cached local relation for hash: ${rel.getHash}")
1335+
}
13341336
val bytes = blockManager.getLocalBytes(blockId)
13351337
bytes
13361338
.map { blockData =>

sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ package org.apache.spark.sql.artifact
2020
import java.io.{File, IOException}
2121
import java.lang.ref.Cleaner
2222
import java.net.{URI, URL, URLClassLoader}
23-
import java.nio.ByteBuffer
2423
import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
25-
import java.util.concurrent.CopyOnWriteArrayList
26-
import java.util.concurrent.atomic.AtomicBoolean
24+
import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArrayList}
25+
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
2726

2827
import scala.collection.mutable.ListBuffer
2928
import scala.jdk.CollectionConverters._
@@ -114,7 +113,7 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
114113
}
115114
}
116115

117-
protected val cachedBlockIdList = new CopyOnWriteArrayList[CacheId]
116+
private val hashToCachedIdMap = new ConcurrentHashMap[String, RefCountedCacheId]
118117
protected val jarsList = new CopyOnWriteArrayList[Path]
119118
protected val pythonIncludeList = new CopyOnWriteArrayList[String]
120119
protected val sparkContextRelativePaths =
@@ -136,6 +135,10 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
136135
*/
137136
def getPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq
138137

138+
protected[sql] def getCachedBlockId(hash: String): Option[CacheId] = {
139+
Option(hashToCachedIdMap.get(hash)).map(_.id)
140+
}
141+
139142
private def transferFile(
140143
source: Path,
141144
target: Path,
@@ -192,7 +195,14 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
192195
blockSize = tmpFile.length(),
193196
tellMaster = false)
194197
updater.save()
195-
cachedBlockIdList.add(blockId)
198+
val oldBlock = hashToCachedIdMap.put(blockId.hash, new RefCountedCacheId(blockId))
199+
if (oldBlock != null) {
200+
logWarning(
201+
log"Replacing existing cache artifact with hash ${MDC(LogKeys.BLOCK_ID, blockId)} " +
202+
log"in session ${MDC(LogKeys.SESSION_ID, session.sessionUUID)}. " +
203+
log"This may indicate duplicate artifact addition.")
204+
oldBlock.release(blockManager)
205+
}
196206
}(finallyBlock = { tmpFile.delete() })
197207
} else if (normalizedRemoteRelativePath.startsWith(s"classes${File.separator}")) {
198208
// Move class files to the right directory.
@@ -354,10 +364,27 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
354364
if (artifactPath.toFile.exists()) {
355365
Utils.copyDirectory(artifactPath.toFile, newArtifactManager.artifactPath.toFile)
356366
}
357-
val blockManager = sparkContext.env.blockManager
358-
val newBlockIds = cachedBlockIdList.asScala.map { blockId =>
359-
val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID)
360-
copyBlock(blockId, newBlockId, blockManager)
367+
368+
// Share cached blocks with the cloned session by copying the references and incrementing
369+
// their reference counts. Both the original and cloned ArtifactManager will reference the
370+
// same underlying cached data blocks. When either session releases a block, only the ref-count
371+
// decreases.
372+
// The block is removed from memory only when the ref-count reaches zero.
373+
hashToCachedIdMap.forEach { (hash: String, refCountedCacheId: RefCountedCacheId) =>
374+
try {
375+
refCountedCacheId.acquire() // Increment ref-count to prevent premature cleanup
376+
newArtifactManager.hashToCachedIdMap.put(hash, refCountedCacheId)
377+
} catch {
378+
case e: SparkRuntimeException if e.getCondition == "BLOCK_ALREADY_RELEASED" =>
379+
// The parent session was closed or this block was released during cloning.
380+
// This indicates a race condition - we cannot safely complete the clone operation.
381+
// With the ref-counting optimization, cloning is fast and this should be rare.
382+
throw new SparkRuntimeException(
383+
"INTERNAL_ERROR",
384+
Map("message" -> (s"Cannot clone ArtifactManager: cached block with hash $hash " +
385+
s"was already released. The parent session may have been closed during cloning.")),
386+
e)
387+
}
361388
}
362389

363390
// Re-register resources to SparkContext
@@ -382,7 +409,6 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
382409
}
383410
}
384411

385-
newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava)
386412
newArtifactManager.jarsList.addAll(jarsList)
387413
newArtifactManager.pythonIncludeList.addAll(pythonIncludeList)
388414
newArtifactManager.sparkContextRelativePaths.addAll(sparkContextRelativePaths)
@@ -412,10 +438,16 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
412438
// Note that this will only be run once per instance.
413439
cleanable.clean()
414440

441+
// Clean-up cached blocks.
442+
val blockManager = session.sparkContext.env.blockManager
443+
hashToCachedIdMap.values().forEach { refCountedCacheId =>
444+
refCountedCacheId.release(blockManager)
445+
}
446+
hashToCachedIdMap.clear()
447+
415448
// Clean up internal trackers
416449
jarsList.clear()
417450
pythonIncludeList.clear()
418-
cachedBlockIdList.clear()
419451
sparkContextRelativePaths.clear()
420452

421453
// Removed cached classloader
@@ -484,25 +516,6 @@ object ArtifactManager extends Logging {
484516
val JAR, FILE, ARCHIVE = Value
485517
}
486518

487-
private def copyBlock(fromId: CacheId, toId: CacheId, blockManager: BlockManager): CacheId = {
488-
require(fromId != toId)
489-
blockManager.getLocalBytes(fromId) match {
490-
case Some(blockData) =>
491-
Utils.tryWithSafeFinallyAndFailureCallbacks {
492-
val updater = blockManager.ByteBufferBlockStoreUpdater(
493-
blockId = toId,
494-
level = StorageLevel.MEMORY_AND_DISK_SER,
495-
classTag = implicitly[ClassTag[Array[Byte]]],
496-
bytes = blockData.toChunkedByteBuffer(ByteBuffer.allocate),
497-
tellMaster = false)
498-
updater.save()
499-
toId
500-
}(finallyBlock = { blockManager.releaseLock(fromId); blockData.dispose() })
501-
case None =>
502-
throw SparkException.internalError(s"Block $fromId not found in the block manager.")
503-
}
504-
}
505-
506519
// Shared cleaner instance
507520
private val cleaner: Cleaner = Cleaner.create()
508521

@@ -530,10 +543,6 @@ object ArtifactManager extends Logging {
530543
}
531544
}
532545

533-
// Clean up cached relations
534-
val blockManager = sparkContext.env.blockManager
535-
blockManager.removeCache(sparkSessionUUID)
536-
537546
// Clean up artifacts folder
538547
try {
539548
Utils.deleteRecursively(artifactPath.toFile)
@@ -550,3 +559,28 @@ private[artifact] case class ArtifactStateForCleanup(
550559
sparkContext: SparkContext,
551560
jobArtifactState: JobArtifactState,
552561
artifactPath: Path)
562+
563+
private class RefCountedCacheId(val id: CacheId) {
564+
private val rc = new AtomicInteger(1)
565+
566+
def acquire(): Unit = updateRc(1)
567+
568+
def release(blockManager: BlockManager): Unit = {
569+
val newRc = updateRc(-1)
570+
if (newRc == 0) {
571+
blockManager.removeBlock(id)
572+
}
573+
}
574+
575+
private def updateRc(delta: Int): Int = {
576+
rc.updateAndGet { currentRc: Int =>
577+
if (currentRc == 0) {
578+
throw new SparkRuntimeException(
579+
"BLOCK_ALREADY_RELEASED",
580+
Map("blockId" -> id.toString)
581+
)
582+
}
583+
currentRc + delta
584+
}
585+
}
586+
}

sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -503,15 +503,8 @@ class ArtifactManagerSuite extends SharedSparkSession {
503503
assert(newArtifactManager.artifactPath !== artifactManager.artifactPath)
504504

505505
// Load the cached artifact
506-
val blockManager = newSession.sparkContext.env.blockManager
507-
for (sessionId <- Seq(spark.sessionUUID, newSession.sessionUUID)) {
508-
val cacheId = CacheId(sessionId, "test")
509-
try {
510-
assert(blockManager.getLocalBytes(cacheId).get.toByteBuffer().array() === testBytes)
511-
} finally {
512-
blockManager.releaseLock(cacheId)
513-
}
514-
}
506+
assert(spark.artifactManager.getCachedBlockId("test")
507+
== newArtifactManager.getCachedBlockId("test"))
515508

516509
val allFiles = Utils.listFiles(newArtifactManager.artifactPath.toFile)
517510
assert(allFiles.size() === 3)
@@ -540,6 +533,55 @@ class ArtifactManagerSuite extends SharedSparkSession {
540533
}
541534
}
542535

536+
test("Share blocks between ArtifactManagers") {
537+
def isBlockRegistered(id: CacheId): Boolean = {
538+
sparkContext.env.blockManager.getStatus(id).isDefined
539+
}
540+
541+
def addCachedArtifact(session: SparkSession, name: String, data: String): CacheId = {
542+
val bytes = new Artifact.InMemory(data.getBytes(StandardCharsets.UTF_8))
543+
session.artifactManager.addLocalArtifacts(Artifact.newCacheArtifact(name, bytes) :: Nil)
544+
val id = CacheId(session.sessionUUID, name)
545+
assert(isBlockRegistered(id))
546+
id
547+
}
548+
549+
// Create fresh session so there is no interference with other tests.
550+
val session1 = spark.newSession()
551+
val b1 = addCachedArtifact(session1, "b1", "b_one")
552+
val b2 = addCachedArtifact(session1, "b2", "b_two")
553+
554+
// Clone, check that existing blocks are the same, add another block, clean-up, make sure
555+
// shared blocks survive and new block is cleaned.
556+
val session2 = session1.cloneSession()
557+
val b3 = addCachedArtifact(session2, "b3", "b_three")
558+
session2.artifactManager.cleanUpResourcesForTesting()
559+
assert(isBlockRegistered(b1))
560+
assert(isBlockRegistered(b2))
561+
assert(!isBlockRegistered(b3))
562+
563+
// Clone, check that existing blocks are the same, replace existing blocks, clone parent, check
564+
// that inherited blocks are removed now.
565+
val session3 = session1.cloneSession()
566+
session1.artifactManager.cleanUpResourcesForTesting()
567+
assert(isBlockRegistered(b1))
568+
assert(isBlockRegistered(b2))
569+
assert(session3.artifactManager.getCachedBlockId("b1").get == b1)
570+
assert(session3.artifactManager.getCachedBlockId("b2").get == b2)
571+
572+
val b1a = addCachedArtifact(session3, "b1", "b_one_a")
573+
val b2a = addCachedArtifact(session3, "b2", "b_two_a")
574+
assert(!isBlockRegistered(b1))
575+
assert(!isBlockRegistered(b2))
576+
assert(session3.artifactManager.getCachedBlockId("b1").get == b1a)
577+
assert(session3.artifactManager.getCachedBlockId("b2").get == b2a)
578+
579+
// Clean-up last AM. No block should be left.
580+
session3.artifactManager.cleanUpResourcesForTesting()
581+
assert(!isBlockRegistered(b1a))
582+
assert(!isBlockRegistered(b2a))
583+
}
584+
543585
test("Codegen cache should be invalid when artifacts are added - class artifact") {
544586
withTempDir { dir =>
545587
runCodegenTest("class artifact") {

0 commit comments

Comments
 (0)