diff --git a/udf/worker/core/pom.xml b/udf/worker/core/pom.xml
index 69088d284365f..0c7ea371df3c5 100644
--- a/udf/worker/core/pom.xml
+++ b/udf/worker/core/pom.xml
@@ -51,6 +51,11 @@
org.scala-lang
scala-library
+
+ org.mockito
+ mockito-core
+ test
+
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala
new file mode 100644
index 0000000000000..422f70944bae6
--- /dev/null
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+
+/**
+ * :: Experimental ::
+ * Creates [[WorkerDispatcher]] instances and controls their
+ * lifecycle after all sessions have closed.
+ *
+ * Implementations are passed to [[UDFDispatcherManager]] which
+ * handles caching, session tracking, and shutdown.
+ */
+@Experimental
+trait UDFDispatcherFactory {
+
+ /**
+ * Creates a new [[WorkerDispatcher]] for the given specification.
+ * It is expected that creating the dispatcher
+ * itself is not slow while creating a session might be.
+ */
+ def createDispatcher(
+ workerSpec: UDFWorkerSpecification,
+ logger: WorkerLogger): WorkerDispatcher
+
+ /**
+ * Called when the last active session for a dispatcher is closed.
+ * Implementations must decide what to do with the now-idle
+ * dispatcher: close it immediately, schedule idle-timeout
+ * eviction, etc.
+ * Not called during [[UDFDispatcherManager#stop]] -- the manager
+ * cleans up dispatchers it holds directly in that case.
+ */
+ def onAllDispatcherSessionsClosed(
+ dispatcher: WorkerDispatcher): Unit
+
+ /**
+ * Called when the executor/driver stops. Implementations should
+ * clean up any dispatchers/resources they hold beyond what the
+ * [[UDFDispatcherManager]] manages.
+ */
+ def onStop(): Unit
+}
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala
new file mode 100644
index 0000000000000..d182d99b9bcbf
--- /dev/null
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import java.util.{ArrayList, HashMap}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+
+/**
+ * :: Experimental :: Creates [[WorkerSession]] instances for a given
+ * [[UDFWorkerSpecification]], managing [[WorkerDispatcher]] instances and
+ * their lifecycle internally.
+ *
+ * Dispatchers are cached by spec (protobuf value equality) and reused across
+ * sessions. The manager tracks the number of active sessions per dispatcher
+ * via [[WorkerSession#addSessionCompletionListener]]. When the last session
+ * for a dispatcher is closed, the entry is removed and
+ * [[UDFDispatcherFactory#onAllDispatcherSessionsClosed]] is called.
+ *
+ * You might be wondering why the Dispatcher does not track the number of
+ * active sessions itself. The reason is that this would create a
+ * unavoidable race condition: Clients can provide different worker
+ * specs. Therefore, different dispatchers may be required, which cannot all
+ * exist for the whole Spark lifetime -> Dispatchers need to be removed/terminated
+ * at some point. If Dispatchers were to track their active sessions themselves
+ * and we would use this to decide on the dispatcher lifetime, it can always
+ * happen that there are concurrent [[createSession]] requests while
+ * the Dispatcher is being disposed off - which would create session
+ * initialization errors and may cause Spark task/query failures.
+ * Instead, we track the active sessions per Dispatcher globally
+ * in this manager.
+ *
+ * Thread safety: a single lock guards all state -- dispatchers, active
+ * sessions, and the stopping flag.
+ */
+@Experimental
+class UDFDispatcherManager(
+ private val dispatcherFactory: UDFDispatcherFactory,
+ workerLogger: WorkerLogger = WorkerLogger.NoOp
+) {
+
+ private val logger: WorkerLogger =
+ workerLogger.forClass(getClass)
+
+ /*
+ * Why do we need an [[activeSessionCount]] and an [[activeSessions]]
+ * list? [[activeSessionCount]] is per dispatcher. [[activeSessions]]
+ * is globally and allows us to perform session cleanup on [[stop]].
+ * Moreover, this distinction allows us to create sessions without
+ * requiring a lock on [[lock]].
+ */
+ private class DispatcherEntry(val dispatcher: WorkerDispatcher) {
+ var activeSessionCount: Int = 0
+ }
+
+ // All fields below are guarded by `lock`.
+ private val lock = new Object
+ private val dispatchers =
+ new HashMap[UDFWorkerSpecification, DispatcherEntry]()
+ private val activeSessions = new ArrayList[WorkerSession]()
+ private var stopped = false
+
+ /**
+ * Creates a [[WorkerSession]] for the given worker specification and
+ * optional security scope.
+ *
+ * If a dispatcher for this spec already exists it is reused; otherwise
+ * [[UDFDispatcherFactory#createDispatcher]] is called to create one.
+ * A completion listener is registered on the session to track when
+ * it closes.
+ */
+ final def createSession(
+ workerSpec: UDFWorkerSpecification,
+ securityScope: Option[WorkerSecurityScope] = None
+ ): WorkerSession = {
+ // Get the dispatcher
+ val entry = lock.synchronized {
+ if (stopped) {
+ throwStopped()
+ }
+ getOrCreateDispatcherEntry(workerSpec)
+ }
+
+ // Create a new session (potentially slow -> outside the lock).
+ // Note: This might fail if Spark is concurrently being stopped
+ // and the dispatcher is cleaned up. As Spark is stopping,
+ // this failure is acceptable. On the happy path, no sessions
+ // should try to be created while Spark is shutting down.
+ val session = entry.dispatcher.createSession(securityScope)
+ lock.synchronized {
+ if (stopped) {
+ session.close()
+ throwStopped()
+ }
+ activeSessions.add(session)
+ }
+
+ logger.info(s"Created session ${session.sessionId}" +
+ s" on dispatcher ${entry.dispatcher.dispatcherId}" +
+ s" (active: ${entry.activeSessionCount})")
+
+ // Register a completion listener that updates the
+ // state when the session is canceled or closed
+ session.addSessionCompletionListener { session =>
+ logger.info(s"Session ${session.sessionId} terminated")
+ lock.synchronized {
+ if (!stopped) {
+ activeSessions.remove(session)
+ handleSessionTermination(workerSpec)
+ }
+ }
+ }
+
+ session
+ }
+
+ /**
+ * Called on driver/executor shutdown. Cancels any active sessions,
+ * closes all cached dispatchers, and resets internal state.
+ *
+ * Safety net -- in normal operation, sessions are closed
+ * by the physical Spark operators and dispatchers are cleaned up via
+ * [[UDFDispatcherFactory#onAllDispatcherSessionsClosed]] when their
+ * last session closes.
+ */
+ final def stop(): Unit = {
+ logger.info("UDFDispatcherManager stopping" +
+ s" (${activeSessions.size()} active sessions," +
+ s" ${dispatchers.size()} dispatchers)")
+
+ lock.synchronized {
+ stopped = true
+
+ // Cancel any sessions that are still active. Cancel is a
+ // no-op if the session was already closed/cancelled.
+ activeSessions.forEach { session =>
+ logger.debug(s"Cancelling session ${session.sessionId}" +
+ " during stop")
+ session.cancel()
+ }
+ activeSessions.clear()
+
+ // Close all dispatchers we control.
+ // When spark is stopped in a clean state
+ // (only finished tasks), it is expected
+ // that all dispatchers have been terminated
+ // already. This is a safety-net.
+ dispatchers.forEach { (_, entry) =>
+ logger.debug(s"Closing dispatcher" +
+ s" ${entry.dispatcher.dispatcherId} during stop")
+ entry.dispatcher.close()
+ }
+ dispatchers.clear()
+ }
+
+ // Perform cleanup in the factory
+ dispatcherFactory.onStop()
+ logger.info("UDFDispatcherManager stopped")
+ }
+
+ private def throwStopped(): Nothing =
+ throw new IllegalStateException(
+ "UDFDispatcherManager is stopped")
+
+ // Must be called while holding `lock`.
+ private def handleSessionTermination(
+ workerSpec: UDFWorkerSpecification
+ ): Unit = {
+ val entry = dispatchers.get(workerSpec)
+ // Note: entry == null is unexpected and should
+ // throw here.
+ entry.activeSessionCount -= 1
+ if (entry.activeSessionCount == 0) {
+ logger.info("All sessions closed for dispatcher " +
+ s"${entry.dispatcher.dispatcherId}, removing from cache")
+ dispatchers.remove(workerSpec)
+ dispatcherFactory.onAllDispatcherSessionsClosed(
+ entry.dispatcher)
+ }
+ }
+
+ // Must be called while holding `lock`.
+ private def getOrCreateDispatcherEntry(
+ workerSpec: UDFWorkerSpecification
+ ): DispatcherEntry = {
+ var entry = dispatchers.get(workerSpec)
+ if (entry == null) {
+ val dispatcher = dispatcherFactory.createDispatcher(
+ workerSpec, logger)
+ logger.info(s"Created dispatcher ${dispatcher.dispatcherId}")
+ entry = new DispatcherEntry(dispatcher)
+ dispatchers.put(workerSpec, entry)
+ }
+ entry.activeSessionCount += 1
+ entry
+ }
+}
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala
index 008cfc2993a09..f75c8a2171200 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala
@@ -31,6 +31,10 @@ import org.apache.spark.udf.worker.UDFWorkerSpecification
@Experimental
trait WorkerDispatcher extends AutoCloseable {
+ /** Unique identifier for this dispatcher. */
+ val dispatcherId: String =
+ java.util.UUID.randomUUID().toString
+
def workerSpec: UDFWorkerSpecification
/**
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
index a8f135f688908..9b6646fb6fb5b 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
@@ -36,8 +36,34 @@ import org.apache.spark.annotation.Experimental
trait WorkerLogger {
def warn(msg: => String): Unit
def warn(msg: => String, t: Throwable): Unit
+ def info(msg: => String): Unit
+ def info(msg: => String, t: Throwable): Unit
def debug(msg: => String): Unit
def debug(msg: => String, t: Throwable): Unit
+
+ /**
+ * Returns a new [[WorkerLogger]] that prefixes every message with
+ * `[className]`. Useful for identifying which class produced a
+ * log line.
+ */
+ def forClass(clazz: Class[_]): WorkerLogger = {
+ val prefix = s"[${clazz.getSimpleName}] "
+ val parent = this
+ new WorkerLogger {
+ override def warn(msg: => String): Unit =
+ parent.warn(prefix + msg)
+ override def warn(msg: => String, t: Throwable): Unit =
+ parent.warn(prefix + msg, t)
+ override def info(msg: => String): Unit =
+ parent.info(prefix + msg)
+ override def info(msg: => String, t: Throwable): Unit =
+ parent.info(prefix + msg, t)
+ override def debug(msg: => String): Unit =
+ parent.debug(prefix + msg)
+ override def debug(msg: => String, t: Throwable): Unit =
+ parent.debug(prefix + msg, t)
+ }
+ }
}
object WorkerLogger {
@@ -45,6 +71,8 @@ object WorkerLogger {
val NoOp: WorkerLogger = new WorkerLogger {
override def warn(msg: => String): Unit = ()
override def warn(msg: => String, t: Throwable): Unit = ()
+ override def info(msg: => String): Unit = ()
+ override def info(msg: => String, t: Throwable): Unit = ()
override def debug(msg: => String): Unit = ()
override def debug(msg: => String, t: Throwable): Unit = ()
}
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala
index f4c4091688c94..55690aeb1972e 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.udf.worker.core
+import java.util.ArrayList
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.annotation.Experimental
@@ -76,16 +77,66 @@ case class InitMessage(
* - [[close]] must always be called (use try-finally).
* - [[cancel]] may be called at any time to abort execution.
*
- * The lifecycle is enforced here: [[init]] and [[process]] are `final`
- * and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards.
+ * The lifecycle is enforced here: [[init]], [[process]], [[cancel]],
+ * and [[close]] are `final` and delegate to [[doInit]] / [[doProcess]] /
+ * [[doCancel]], and [[doClose]] after AtomicBoolean guards.
* Subclasses implement the protocol-specific work and do not re-check
* the contract.
+ *
+ * Completion listeners registered via [[addSessionCompletionListener]]
+ * are fired exactly once, after [[doClose]] or [[doCancel]]
+ * (whichever runs first). Note that the completion listener can
+ * be executed in a completely separate thread from the thread who
+ * registered the listener.
*/
@Experimental
-abstract class WorkerSession extends AutoCloseable {
+abstract class WorkerSession(
+ workerLogger: WorkerLogger
+) extends AutoCloseable {
+
+ protected val logger: WorkerLogger =
+ workerLogger.forClass(getClass)
+
+ /** Unique identifier for this session. */
+ val sessionId: String = java.util.UUID.randomUUID().toString
private val initialized = new AtomicBoolean(false)
private val processed = new AtomicBoolean(false)
+ private val closed = new AtomicBoolean(false)
+
+ // Guards `completionListeners`, and `completionListenersFired`
+ // to ensure that a listener added after close is fired
+ // immediately and exactly once.
+ private val listenerLock = new Object
+ private var completionListenersFired = false
+ private val completionListeners =
+ new ArrayList[WorkerSession => Unit]()
+
+ /**
+ * Registers a closure to be invoked when this session completes
+ * (via [[close]] or [[cancel]]). Listeners fire exactly once, in
+ * registration order. If the session is already closed when
+ * registering, the listener is fired immediately.
+ */
+ final def addSessionCompletionListener(
+ f: WorkerSession => Unit): Unit = {
+ listenerLock.synchronized {
+ if (completionListenersFired) {
+ // Listeners from the list were already fired
+ // -> Invoke immediately.
+ f(this)
+ } else {
+ completionListeners.add(f)
+ }
+ }
+ }
+
+ private def fireCompletionListeners(): Unit = {
+ listenerLock.synchronized {
+ completionListenersFired = true
+ completionListeners.forEach(_(this))
+ }
+ }
/**
* Initializes the UDF execution. Must be called exactly once before
@@ -100,6 +151,7 @@ abstract class WorkerSession extends AutoCloseable {
if (!initialized.compareAndSet(false, true)) {
throw new IllegalStateException("init has already been called on this session")
}
+ logger.info(s"Session $sessionId: init")
doInit(message)
}
@@ -117,22 +169,18 @@ abstract class WorkerSession extends AutoCloseable {
* @param input iterator of raw input data batches (e.g., Arrow IPC)
* @return iterator of raw result data batches
*/
- final def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = {
+ final def process(
+ input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = {
if (!initialized.get()) {
throw new IllegalStateException("process called before init")
}
if (!processed.compareAndSet(false, true)) {
throw new IllegalStateException("process has already been called on this session")
}
+ logger.info(s"Session $sessionId: process started")
doProcess(input)
}
- /** Subclass hook for [[init]]. Called once, after the guard. */
- protected def doInit(message: InitMessage): Unit
-
- /** Subclass hook for [[process]]. Called at most once, after the guard. */
- protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]]
-
/**
* Requests cancellation of the current UDF execution.
*
@@ -141,8 +189,34 @@ abstract class WorkerSession extends AutoCloseable {
* task interruption thread). It may be invoked at any point after
* [[init]] and should be a no-op if execution has already finished.
*/
- def cancel(): Unit
+ final def cancel(): Unit = {
+ // TODO [SPARK-55278]: Implement correct cancellation/finish semantics
+ // according to the worker_spec.proto.
+ if (closed.compareAndSet(false, true)) {
+ logger.info(s"Session $sessionId: cancel")
+ doCancel()
+ fireCompletionListeners()
+ }
+ }
/** Closes this session and releases resources. */
- override def close(): Unit
+ override final def close(): Unit = {
+ if (closed.compareAndSet(false, true)) {
+ logger.info(s"Session $sessionId: close")
+ doClose()
+ fireCompletionListeners()
+ }
+ }
+
+ /** Subclass hook for [[init]]. Called once, after the guard. */
+ protected def doInit(message: InitMessage): Unit
+
+ /** Subclass hook for [[process]]. Called at most once, after the guard. */
+ protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]]
+
+ /** Subclass hook for [[cancel]]. Called once, after the guard. */
+ protected def doCancel(): Unit
+
+ /** Subclass hook for [[close]]. Called at most once, after the guard. */
+ protected def doClose(): Unit
}
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala
index 8da0354187e4f..7456a8a5b9bae 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala
@@ -37,7 +37,7 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POL
@Experimental
abstract class DirectUnixSocketWorkerDispatcher(
workerSpec: UDFWorkerSpecification,
- logger: WorkerLogger = WorkerLogger.NoOp)
+ logger: WorkerLogger)
extends DirectWorkerDispatcher(workerSpec, logger) {
// Removed explicitly in closeTransport(). deleteOnExit is avoided because
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
index afaf23791d80f..8925912b85851 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
@@ -62,9 +62,12 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR
@Experimental
abstract class DirectWorkerDispatcher(
override val workerSpec: UDFWorkerSpecification,
- protected val logger: WorkerLogger = WorkerLogger.NoOp)
+ workerLogger: WorkerLogger)
extends WorkerDispatcher {
+ protected val logger: WorkerLogger =
+ workerLogger.forClass(getClass)
+
// TODO: Connection pooling -- reuse idle workers across sessions.
// TODO: Security scope isolation -- partition pool by WorkerSecurityScope.
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala
index f4b5c1df63193..495950f3e7048 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala
@@ -51,7 +51,7 @@ class DirectWorkerProcess(
val id: String,
private[direct] val artifacts: WorkerArtifacts,
val gracefulTimeoutMs: Long,
- protected val logger: WorkerLogger = WorkerLogger.NoOp,
+ protected val logger: WorkerLogger,
private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ())
extends AutoCloseable {
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala
index 7cdc5329350e3..78b10c84a34e1 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala
@@ -19,7 +19,7 @@ package org.apache.spark.udf.worker.core.direct
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.annotation.Experimental
-import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession}
+import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger, WorkerSession}
/**
* :: Experimental ::
@@ -41,14 +41,15 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession}
*/
@Experimental
abstract class DirectWorkerSession(
- private[core] val workerProcess: DirectWorkerProcess) extends WorkerSession {
-
+ private[core] val workerProcess: DirectWorkerProcess,
+ workerLogger: WorkerLogger)
+ extends WorkerSession(workerLogger) {
private val released = new AtomicBoolean(false)
/** The connection to the worker for this session. */
def connection: WorkerConnection = workerProcess.connection
- override def close(): Unit = {
+ override protected def doClose(): Unit = {
if (released.compareAndSet(false, true)) {
workerProcess.releaseSession()
}
diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
index 60f5e2211b702..b0b73ed70a825 100644
--- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
+++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
@@ -55,8 +55,9 @@ class SocketFileConnection(socketPath: String)
* no-op). Tracking the thread-safety contract in the docstring on
* [[org.apache.spark.udf.worker.core.WorkerSession.cancel]].
*/
-class StubWorkerSession(
- workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) {
+private class StubWorkerSession(
+ workerProcess: DirectWorkerProcess)
+ extends DirectWorkerSession(workerProcess, WorkerLogger.NoOp) {
override protected def doInit(message: InitMessage): Unit = {}
@@ -64,7 +65,9 @@ class StubWorkerSession(
input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
Iterator.empty
- override def cancel(): Unit = {}
+ override protected def doCancel(): Unit = {}
+
+ override protected def doClose(): Unit = {}
}
/**
@@ -73,7 +76,7 @@ class StubWorkerSession(
* implementation.
*/
class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification)
- extends DirectUnixSocketWorkerDispatcher(spec) {
+ extends DirectUnixSocketWorkerDispatcher(spec, WorkerLogger.NoOp) {
override protected def createConnection(
socketPath: String): UnixSocketWorkerConnection =
@@ -362,7 +365,8 @@ class DirectWorkerDispatcherSuite
val releaseLatch = new java.util.concurrent.CountDownLatch(1)
val capturedWorkers =
new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]()
- val racing = new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) {
+ val racing = new DirectUnixSocketWorkerDispatcher(
+ specWithRunner(defaultRunner), WorkerLogger.NoOp) {
override protected def createConnection(
socketPath: String): UnixSocketWorkerConnection =
new SocketFileConnection(socketPath)
@@ -538,7 +542,7 @@ class DirectWorkerDispatcherSuite
// worker must be terminated rather than leaked until dispatcher.close().
var capturedWorker: DirectWorkerProcess = null
val failingDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) {
+ new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner), WorkerLogger.NoOp) {
override protected def createConnection(
socketPath: String): UnixSocketWorkerConnection =
new SocketFileConnection(socketPath)
@@ -594,7 +598,7 @@ class DirectWorkerDispatcherSuite
test("socket file is cleaned up when createConnection throws") {
val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]()
val failingDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) {
+ new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner), WorkerLogger.NoOp) {
override protected def createConnection(
socketPath: String): UnixSocketWorkerConnection = {
capturedSocketPaths.add(socketPath)
@@ -760,7 +764,7 @@ class DirectWorkerDispatcherSuite
.addCommand("sleep 30").build()
val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build()
val shortTimeoutDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) {
+ new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), WorkerLogger.NoOp) {
override protected def callableTimeoutMs: Long = 500L
override protected def createConnection(
socketPath: String): UnixSocketWorkerConnection =
@@ -897,7 +901,7 @@ class DirectWorkerDispatcherSuite
s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build())
.build()
val timeoutDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) {
+ new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), WorkerLogger.NoOp) {
override protected def callableTimeoutMs: Long = 500L
override protected def createConnection(
socketPath: String): UnixSocketWorkerConnection =
diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala
new file mode 100644
index 0000000000000..c5af112a54f5e
--- /dev/null
+++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{mock, verify, when}
+import org.mockito.invocation.InvocationOnMock
+// scalastyle:off funsuite
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.udf.worker._
+
+/**
+ * A minimal [[WorkerSession]] that tracks cancel/close calls.
+ */
+private class NoOpWorkerSession
+ extends WorkerSession(WorkerLogger.NoOp) {
+ var cancelled: Boolean = false
+ override protected def doInit(msg: InitMessage): Unit = {}
+ override protected def doProcess(
+ input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
+ Iterator.empty
+ override protected def doCancel(): Unit = { cancelled = true }
+ override protected def doClose(): Unit = {}
+}
+
+/**
+ * Holds a test [[UDFDispatcherManager]] and all observable state,
+ * so tests can assert on whichever fields they care about.
+ */
+private case class TestManagerFixture(
+ manager: UDFDispatcherManager,
+ createdDispatchers: ArrayBuffer[WorkerDispatcher],
+ closedDispatchers: ArrayBuffer[WorkerDispatcher],
+ doStopCalls: ArrayBuffer[Boolean],
+ createdSessions: ArrayBuffer[NoOpWorkerSession])
+
+private object TestManagerFixture {
+ def apply(): TestManagerFixture = {
+ val createdDispatchers = ArrayBuffer[WorkerDispatcher]()
+ val closedDispatchers = ArrayBuffer[WorkerDispatcher]()
+ val doStopCalls = ArrayBuffer[Boolean]()
+ val createdSessions = ArrayBuffer[NoOpWorkerSession]()
+
+ val factory = new UDFDispatcherFactory {
+ override def createDispatcher(
+ workerSpec: UDFWorkerSpecification,
+ logger: WorkerLogger): WorkerDispatcher = {
+ val dispatcher = mock(classOf[WorkerDispatcher])
+ when(dispatcher.createSession(
+ any[Option[WorkerSecurityScope]]))
+ .thenAnswer((_: InvocationOnMock) => {
+ val session = new NoOpWorkerSession()
+ createdSessions += session
+ session
+ })
+ createdDispatchers += dispatcher
+ dispatcher
+ }
+ override def onAllDispatcherSessionsClosed(
+ dispatcher: WorkerDispatcher): Unit = {
+ closedDispatchers += dispatcher
+ }
+ override def onStop(): Unit = {
+ doStopCalls += true
+ }
+ }
+ val manager = new UDFDispatcherManager(factory)
+
+ TestManagerFixture(
+ manager, createdDispatchers, closedDispatchers,
+ doStopCalls, createdSessions)
+ }
+}
+
+class UDFDispatcherManagerSuite
+ extends AnyFunSuite { // scalastyle:ignore funsuite
+
+ private def makeSpec(command: String): UDFWorkerSpecification = {
+ val callable = ProcessCallable.newBuilder()
+ callable.addCommand(command)
+ val caps = WorkerCapabilities.newBuilder()
+ .addSupportedDataFormats(UDFWorkerDataFormat.ARROW)
+ .addSupportedCommunicationPatterns(
+ UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
+ val conn = WorkerConnectionSpec.newBuilder()
+ .setTcp(LocalTcpConnection.newBuilder())
+ val props = UDFWorkerProperties.newBuilder()
+ .setConnection(conn)
+ val direct = DirectWorker.newBuilder()
+ .setRunner(callable).setProperties(props)
+ UDFWorkerSpecification.newBuilder()
+ .setEnvironment(WorkerEnvironment.newBuilder())
+ .setCapabilities(caps).setDirect(direct).build()
+ }
+
+ test("Same spec reuses the same dispatcher") {
+ val fixture = TestManagerFixture()
+ val spec = makeSpec("worker.bin")
+
+ val s1 = fixture.manager.createSession(spec)
+ val s2 = fixture.manager.createSession(spec)
+
+ assert(fixture.createdDispatchers.size === 1)
+ verify(
+ fixture.createdDispatchers.head,
+ org.mockito.Mockito.times(2))
+ .createSession(any[Option[WorkerSecurityScope]])
+
+ s1.close()
+ s2.close()
+ }
+
+ test("Structurally equal specs reuse the same dispatcher") {
+ val fixture = TestManagerFixture()
+ val spec1 = makeSpec("worker.bin")
+ val spec2 = makeSpec("worker.bin")
+ assert(spec1 ne spec2)
+ assert(spec1 == spec2)
+
+ val s1 = fixture.manager.createSession(spec1)
+ val s2 = fixture.manager.createSession(spec2)
+
+ assert(fixture.createdDispatchers.size === 1)
+
+ s1.close()
+ s2.close()
+ }
+
+ test("Different specs create different dispatchers") {
+ val fixture = TestManagerFixture()
+
+ val sA = fixture.manager.createSession(makeSpec("worker-a.bin"))
+ val sB = fixture.manager.createSession(makeSpec("worker-b.bin"))
+
+ assert(fixture.createdDispatchers.size === 2)
+ assert(
+ fixture.createdDispatchers(0) ne
+ fixture.createdDispatchers(1))
+
+ sA.close()
+ sB.close()
+ }
+
+ test("onAllDispatcherSessionsClosed when last session closes") {
+ val fixture = TestManagerFixture()
+ val spec = makeSpec("worker.bin")
+
+ val s1 = fixture.manager.createSession(spec)
+ val s2 = fixture.manager.createSession(spec)
+
+ s1.close()
+ assert(fixture.closedDispatchers.isEmpty)
+
+ s2.close()
+ assert(fixture.closedDispatchers.size === 1)
+ assert(
+ fixture.closedDispatchers.head eq
+ fixture.createdDispatchers.head)
+ }
+
+ test("onAllDispatcherSessionsClosed not called while sessions remain") {
+ val fixture = TestManagerFixture()
+ val spec = makeSpec("worker.bin")
+
+ val s1 = fixture.manager.createSession(spec)
+ val s2 = fixture.manager.createSession(spec)
+ val s3 = fixture.manager.createSession(spec)
+
+ s1.close()
+ s2.close()
+ assert(fixture.closedDispatchers.isEmpty)
+
+ s3.close()
+ assert(fixture.closedDispatchers.size === 1)
+ }
+
+ test("New dispatcher after all sessions closed") {
+ val fixture = TestManagerFixture()
+ val spec = makeSpec("worker.bin")
+
+ val s1 = fixture.manager.createSession(spec)
+ val s2 = fixture.manager.createSession(spec)
+ assert(fixture.createdDispatchers.size === 1)
+
+ s1.close()
+ s2.close()
+ assert(fixture.closedDispatchers.size === 1)
+
+ val s3 = fixture.manager.createSession(spec)
+ assert(fixture.createdDispatchers.size === 2)
+ assert(
+ fixture.createdDispatchers(0) ne
+ fixture.createdDispatchers(1))
+
+ s3.close()
+ assert(fixture.closedDispatchers.size === 2)
+ }
+
+ test("Stop closes all cached dispatchers") {
+ val fixture = TestManagerFixture()
+
+ fixture.manager.createSession(makeSpec("worker-a.bin"))
+ fixture.manager.createSession(makeSpec("worker-b.bin"))
+ fixture.manager.stop()
+
+ fixture.createdDispatchers.foreach(d => verify(d).close())
+ }
+
+ test("Stop calls doStop for subclass cleanup") {
+ val fixture = TestManagerFixture()
+
+ fixture.manager.createSession(makeSpec("worker.bin"))
+ assert(fixture.doStopCalls.isEmpty)
+
+ fixture.manager.stop()
+ assert(fixture.doStopCalls.size === 1)
+ }
+
+ test("Stop cancels active sessions that were not closed") {
+ val fixture = TestManagerFixture()
+ val spec = makeSpec("worker.bin")
+
+ val s1 = fixture.manager.createSession(spec)
+ fixture.manager.createSession(spec)
+ s1.close()
+
+ assert(fixture.createdSessions.size === 2)
+ assert(!fixture.createdSessions(0).cancelled)
+ assert(!fixture.createdSessions(1).cancelled)
+
+ fixture.manager.stop()
+
+ assert(fixture.createdSessions(1).cancelled)
+ }
+
+ test("Stop does not cancel already closed sessions") {
+ val fixture = TestManagerFixture()
+ val spec = makeSpec("worker.bin")
+
+ val s1 = fixture.manager.createSession(spec)
+ s1.close()
+
+ fixture.manager.stop()
+
+ assert(!fixture.createdSessions(0).cancelled)
+ }
+
+ test("createSession throws after stop") {
+ val fixture = TestManagerFixture()
+ fixture.manager.stop()
+
+ intercept[IllegalStateException] {
+ fixture.manager.createSession(makeSpec("worker.bin"))
+ }
+ }
+}
diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerSessionSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerSessionSuite.scala
new file mode 100644
index 0000000000000..b63af22af65e3
--- /dev/null
+++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerSessionSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import java.util.concurrent.atomic.AtomicInteger
+
+// scalastyle:off funsuite
+import org.scalatest.funsuite.AnyFunSuite
+
+private class TestWorkerSession
+ extends WorkerSession(WorkerLogger.NoOp) {
+ override protected def doInit(msg: InitMessage): Unit = {}
+ override protected def doProcess(
+ input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
+ Iterator.empty
+ override protected def doCancel(): Unit = {}
+ override protected def doClose(): Unit = {}
+}
+
+class WorkerSessionSuite
+ extends AnyFunSuite { // scalastyle:ignore funsuite
+
+ test("Completion listener fires on close") {
+ val session = new TestWorkerSession()
+ val count = new AtomicInteger(0)
+ session.addSessionCompletionListener(_ => count.incrementAndGet())
+
+ session.close()
+ assert(count.get() === 1)
+ }
+
+ test("Completion listener fires on cancel") {
+ val session = new TestWorkerSession()
+ val count = new AtomicInteger(0)
+ session.addSessionCompletionListener(_ => count.incrementAndGet())
+
+ session.cancel()
+ assert(count.get() === 1)
+ }
+
+ test("Completion listener fires exactly once on close then cancel") {
+ val session = new TestWorkerSession()
+ val count = new AtomicInteger(0)
+ session.addSessionCompletionListener(_ => count.incrementAndGet())
+
+ session.close()
+ session.cancel()
+ assert(count.get() === 1)
+ }
+
+ test("Listener added after close fires immediately") {
+ val session = new TestWorkerSession()
+ session.close()
+
+ val count = new AtomicInteger(0)
+ session.addSessionCompletionListener(_ => count.incrementAndGet())
+ assert(count.get() === 1)
+ }
+
+ test("Listener added after cancel fires immediately") {
+ val session = new TestWorkerSession()
+ session.cancel()
+
+ val count = new AtomicInteger(0)
+ session.addSessionCompletionListener(_ => count.incrementAndGet())
+ assert(count.get() === 1)
+ }
+
+ test("Multiple listeners all fire exactly once") {
+ val session = new TestWorkerSession()
+ val count1 = new AtomicInteger(0)
+ val count2 = new AtomicInteger(0)
+ val count3 = new AtomicInteger(0)
+ session.addSessionCompletionListener(_ => count1.incrementAndGet())
+ session.addSessionCompletionListener(_ => count2.incrementAndGet())
+
+ session.close()
+
+ // Add a third after close
+ session.addSessionCompletionListener(_ => count3.incrementAndGet())
+
+ assert(count1.get() === 1)
+ assert(count2.get() === 1)
+ assert(count3.get() === 1)
+ }
+
+ test("Listener receives the correct session instance") {
+ val session = new TestWorkerSession()
+ var received: WorkerSession = null
+ session.addSessionCompletionListener(s => received = s)
+
+ session.close()
+ assert(received eq session)
+ }
+}