@@ -20,10 +20,9 @@ package org.apache.spark.sql.artifact
2020import java .io .{File , IOException }
2121import java .lang .ref .Cleaner
2222import java .net .{URI , URL , URLClassLoader }
23- import java .nio .ByteBuffer
2423import java .nio .file .{CopyOption , Files , Path , Paths , StandardCopyOption }
25- import java .util .concurrent .CopyOnWriteArrayList
26- import java .util .concurrent .atomic .AtomicBoolean
24+ import java .util .concurrent .{ ConcurrentHashMap , CopyOnWriteArrayList }
25+ import java .util .concurrent .atomic .{ AtomicBoolean , AtomicInteger }
2726
2827import scala .collection .mutable .ListBuffer
2928import scala .jdk .CollectionConverters ._
@@ -114,7 +113,7 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
114113 }
115114 }
116115
117- protected val cachedBlockIdList = new CopyOnWriteArrayList [ CacheId ]
116+ private val hashToCachedIdMap = new ConcurrentHashMap [ String , RefCountedCacheId ]
118117 protected val jarsList = new CopyOnWriteArrayList [Path ]
119118 protected val pythonIncludeList = new CopyOnWriteArrayList [String ]
120119 protected val sparkContextRelativePaths =
@@ -136,6 +135,10 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
136135 */
137136 def getPythonIncludes : Seq [String ] = pythonIncludeList.asScala.toSeq
138137
138+ protected [sql] def getCachedBlockId (hash : String ): Option [CacheId ] = {
139+ Option (hashToCachedIdMap.get(hash)).map(_.id)
140+ }
141+
139142 private def transferFile (
140143 source : Path ,
141144 target : Path ,
@@ -192,7 +195,14 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
192195 blockSize = tmpFile.length(),
193196 tellMaster = false )
194197 updater.save()
195- cachedBlockIdList.add(blockId)
198+ val oldBlock = hashToCachedIdMap.put(blockId.hash, new RefCountedCacheId (blockId))
199+ if (oldBlock != null ) {
200+ logWarning(
201+ log " Replacing existing cache artifact with hash ${MDC (LogKeys .BLOCK_ID , blockId)} " +
202+ log " in session ${MDC (LogKeys .SESSION_ID , session.sessionUUID)}. " +
203+ log " This may indicate duplicate artifact addition. " )
204+ oldBlock.release(blockManager)
205+ }
196206 }(finallyBlock = { tmpFile.delete() })
197207 } else if (normalizedRemoteRelativePath.startsWith(s " classes ${File .separator}" )) {
198208 // Move class files to the right directory.
@@ -354,10 +364,27 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
354364 if (artifactPath.toFile.exists()) {
355365 Utils .copyDirectory(artifactPath.toFile, newArtifactManager.artifactPath.toFile)
356366 }
357- val blockManager = sparkContext.env.blockManager
358- val newBlockIds = cachedBlockIdList.asScala.map { blockId =>
359- val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID)
360- copyBlock(blockId, newBlockId, blockManager)
367+
368+ // Share cached blocks with the cloned session by copying the references and incrementing
369+ // their reference counts. Both the original and cloned ArtifactManager will reference the
370+ // same underlying cached data blocks. When either session releases a block, only the ref-count
371+ // decreases.
372+ // The block is removed from memory only when the ref-count reaches zero.
373+ hashToCachedIdMap.forEach { (hash : String , refCountedCacheId : RefCountedCacheId ) =>
374+ try {
375+ refCountedCacheId.acquire() // Increment ref-count to prevent premature cleanup
376+ newArtifactManager.hashToCachedIdMap.put(hash, refCountedCacheId)
377+ } catch {
378+ case e : SparkRuntimeException if e.getCondition == " BLOCK_ALREADY_RELEASED" =>
379+ // The parent session was closed or this block was released during cloning.
380+ // This indicates a race condition - we cannot safely complete the clone operation.
381+ // With the ref-counting optimization, cloning is fast and this should be rare.
382+ throw new SparkRuntimeException (
383+ " INTERNAL_ERROR" ,
384+ Map (" message" -> (s " Cannot clone ArtifactManager: cached block with hash $hash " +
385+ s " was already released. The parent session may have been closed during cloning. " )),
386+ e)
387+ }
361388 }
362389
363390 // Re-register resources to SparkContext
@@ -382,7 +409,6 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
382409 }
383410 }
384411
385- newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava)
386412 newArtifactManager.jarsList.addAll(jarsList)
387413 newArtifactManager.pythonIncludeList.addAll(pythonIncludeList)
388414 newArtifactManager.sparkContextRelativePaths.addAll(sparkContextRelativePaths)
@@ -412,10 +438,16 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
412438 // Note that this will only be run once per instance.
413439 cleanable.clean()
414440
441+ // Clean-up cached blocks.
442+ val blockManager = session.sparkContext.env.blockManager
443+ hashToCachedIdMap.values().forEach { refCountedCacheId =>
444+ refCountedCacheId.release(blockManager)
445+ }
446+ hashToCachedIdMap.clear()
447+
415448 // Clean up internal trackers
416449 jarsList.clear()
417450 pythonIncludeList.clear()
418- cachedBlockIdList.clear()
419451 sparkContextRelativePaths.clear()
420452
421453 // Removed cached classloader
@@ -484,25 +516,6 @@ object ArtifactManager extends Logging {
484516 val JAR, FILE, ARCHIVE = Value
485517 }
486518
487- private def copyBlock (fromId : CacheId , toId : CacheId , blockManager : BlockManager ): CacheId = {
488- require(fromId != toId)
489- blockManager.getLocalBytes(fromId) match {
490- case Some (blockData) =>
491- Utils .tryWithSafeFinallyAndFailureCallbacks {
492- val updater = blockManager.ByteBufferBlockStoreUpdater (
493- blockId = toId,
494- level = StorageLevel .MEMORY_AND_DISK_SER ,
495- classTag = implicitly[ClassTag [Array [Byte ]]],
496- bytes = blockData.toChunkedByteBuffer(ByteBuffer .allocate),
497- tellMaster = false )
498- updater.save()
499- toId
500- }(finallyBlock = { blockManager.releaseLock(fromId); blockData.dispose() })
501- case None =>
502- throw SparkException .internalError(s " Block $fromId not found in the block manager. " )
503- }
504- }
505-
506519 // Shared cleaner instance
507520 private val cleaner : Cleaner = Cleaner .create()
508521
@@ -530,10 +543,6 @@ object ArtifactManager extends Logging {
530543 }
531544 }
532545
533- // Clean up cached relations
534- val blockManager = sparkContext.env.blockManager
535- blockManager.removeCache(sparkSessionUUID)
536-
537546 // Clean up artifacts folder
538547 try {
539548 Utils .deleteRecursively(artifactPath.toFile)
@@ -550,3 +559,28 @@ private[artifact] case class ArtifactStateForCleanup(
550559 sparkContext : SparkContext ,
551560 jobArtifactState : JobArtifactState ,
552561 artifactPath : Path )
562+
563+ private class RefCountedCacheId (val id : CacheId ) {
564+ private val rc = new AtomicInteger (1 )
565+
566+ def acquire (): Unit = updateRc(1 )
567+
568+ def release (blockManager : BlockManager ): Unit = {
569+ val newRc = updateRc(- 1 )
570+ if (newRc == 0 ) {
571+ blockManager.removeBlock(id)
572+ }
573+ }
574+
575+ private def updateRc (delta : Int ): Int = {
576+ rc.updateAndGet { currentRc : Int =>
577+ if (currentRc == 0 ) {
578+ throw new SparkRuntimeException (
579+ " BLOCK_ALREADY_RELEASED" ,
580+ Map (" blockId" -> id.toString)
581+ )
582+ }
583+ currentRc + delta
584+ }
585+ }
586+ }
0 commit comments