diff --git a/udf/worker/core/pom.xml b/udf/worker/core/pom.xml index 69088d284365f..0c7ea371df3c5 100644 --- a/udf/worker/core/pom.xml +++ b/udf/worker/core/pom.xml @@ -51,6 +51,11 @@ org.scala-lang scala-library + + org.mockito + mockito-core + test + diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala new file mode 100644 index 0000000000000..422f70944bae6 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.UDFWorkerSpecification + +/** + * :: Experimental :: + * Creates [[WorkerDispatcher]] instances and controls their + * lifecycle after all sessions have closed. + * + * Implementations are passed to [[UDFDispatcherManager]] which + * handles caching, session tracking, and shutdown. + */ +@Experimental +trait UDFDispatcherFactory { + + /** + * Creates a new [[WorkerDispatcher]] for the given specification. + * It is expected that creating the dispatcher + * itself is not slow while creating a session might be. + */ + def createDispatcher( + workerSpec: UDFWorkerSpecification, + logger: WorkerLogger): WorkerDispatcher + + /** + * Called when the last active session for a dispatcher is closed. + * Implementations must decide what to do with the now-idle + * dispatcher: close it immediately, schedule idle-timeout + * eviction, etc. + * Not called during [[UDFDispatcherManager#stop]] -- the manager + * cleans up dispatchers it holds directly in that case. + */ + def onAllDispatcherSessionsClosed( + dispatcher: WorkerDispatcher): Unit + + /** + * Called when the executor/driver stops. Implementations should + * clean up any dispatchers/resources they hold beyond what the + * [[UDFDispatcherManager]] manages. + */ + def onStop(): Unit +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala new file mode 100644 index 0000000000000..d182d99b9bcbf --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.util.{ArrayList, HashMap} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.UDFWorkerSpecification + +/** + * :: Experimental :: Creates [[WorkerSession]] instances for a given + * [[UDFWorkerSpecification]], managing [[WorkerDispatcher]] instances and + * their lifecycle internally. + * + * Dispatchers are cached by spec (protobuf value equality) and reused across + * sessions. The manager tracks the number of active sessions per dispatcher + * via [[WorkerSession#addSessionCompletionListener]]. When the last session + * for a dispatcher is closed, the entry is removed and + * [[UDFDispatcherFactory#onAllDispatcherSessionsClosed]] is called. + * + * You might be wondering why the Dispatcher does not track the number of + * active sessions itself. The reason is that this would create a + * unavoidable race condition: Clients can provide different worker + * specs. Therefore, different dispatchers may be required, which cannot all + * exist for the whole Spark lifetime -> Dispatchers need to be removed/terminated + * at some point. If Dispatchers were to track their active sessions themselves + * and we would use this to decide on the dispatcher lifetime, it can always + * happen that there are concurrent [[createSession]] requests while + * the Dispatcher is being disposed off - which would create session + * initialization errors and may cause Spark task/query failures. + * Instead, we track the active sessions per Dispatcher globally + * in this manager. + * + * Thread safety: a single lock guards all state -- dispatchers, active + * sessions, and the stopping flag. + */ +@Experimental +class UDFDispatcherManager( + private val dispatcherFactory: UDFDispatcherFactory, + workerLogger: WorkerLogger = WorkerLogger.NoOp +) { + + private val logger: WorkerLogger = + workerLogger.forClass(getClass) + + /* + * Why do we need an [[activeSessionCount]] and an [[activeSessions]] + * list? [[activeSessionCount]] is per dispatcher. [[activeSessions]] + * is globally and allows us to perform session cleanup on [[stop]]. + * Moreover, this distinction allows us to create sessions without + * requiring a lock on [[lock]]. + */ + private class DispatcherEntry(val dispatcher: WorkerDispatcher) { + var activeSessionCount: Int = 0 + } + + // All fields below are guarded by `lock`. + private val lock = new Object + private val dispatchers = + new HashMap[UDFWorkerSpecification, DispatcherEntry]() + private val activeSessions = new ArrayList[WorkerSession]() + private var stopped = false + + /** + * Creates a [[WorkerSession]] for the given worker specification and + * optional security scope. + * + * If a dispatcher for this spec already exists it is reused; otherwise + * [[UDFDispatcherFactory#createDispatcher]] is called to create one. + * A completion listener is registered on the session to track when + * it closes. + */ + final def createSession( + workerSpec: UDFWorkerSpecification, + securityScope: Option[WorkerSecurityScope] = None + ): WorkerSession = { + // Get the dispatcher + val entry = lock.synchronized { + if (stopped) { + throwStopped() + } + getOrCreateDispatcherEntry(workerSpec) + } + + // Create a new session (potentially slow -> outside the lock). + // Note: This might fail if Spark is concurrently being stopped + // and the dispatcher is cleaned up. As Spark is stopping, + // this failure is acceptable. On the happy path, no sessions + // should try to be created while Spark is shutting down. + val session = entry.dispatcher.createSession(securityScope) + lock.synchronized { + if (stopped) { + session.close() + throwStopped() + } + activeSessions.add(session) + } + + logger.info(s"Created session ${session.sessionId}" + + s" on dispatcher ${entry.dispatcher.dispatcherId}" + + s" (active: ${entry.activeSessionCount})") + + // Register a completion listener that updates the + // state when the session is canceled or closed + session.addSessionCompletionListener { session => + logger.info(s"Session ${session.sessionId} terminated") + lock.synchronized { + if (!stopped) { + activeSessions.remove(session) + handleSessionTermination(workerSpec) + } + } + } + + session + } + + /** + * Called on driver/executor shutdown. Cancels any active sessions, + * closes all cached dispatchers, and resets internal state. + * + * Safety net -- in normal operation, sessions are closed + * by the physical Spark operators and dispatchers are cleaned up via + * [[UDFDispatcherFactory#onAllDispatcherSessionsClosed]] when their + * last session closes. + */ + final def stop(): Unit = { + logger.info("UDFDispatcherManager stopping" + + s" (${activeSessions.size()} active sessions," + + s" ${dispatchers.size()} dispatchers)") + + lock.synchronized { + stopped = true + + // Cancel any sessions that are still active. Cancel is a + // no-op if the session was already closed/cancelled. + activeSessions.forEach { session => + logger.debug(s"Cancelling session ${session.sessionId}" + + " during stop") + session.cancel() + } + activeSessions.clear() + + // Close all dispatchers we control. + // When spark is stopped in a clean state + // (only finished tasks), it is expected + // that all dispatchers have been terminated + // already. This is a safety-net. + dispatchers.forEach { (_, entry) => + logger.debug(s"Closing dispatcher" + + s" ${entry.dispatcher.dispatcherId} during stop") + entry.dispatcher.close() + } + dispatchers.clear() + } + + // Perform cleanup in the factory + dispatcherFactory.onStop() + logger.info("UDFDispatcherManager stopped") + } + + private def throwStopped(): Nothing = + throw new IllegalStateException( + "UDFDispatcherManager is stopped") + + // Must be called while holding `lock`. + private def handleSessionTermination( + workerSpec: UDFWorkerSpecification + ): Unit = { + val entry = dispatchers.get(workerSpec) + // Note: entry == null is unexpected and should + // throw here. + entry.activeSessionCount -= 1 + if (entry.activeSessionCount == 0) { + logger.info("All sessions closed for dispatcher " + + s"${entry.dispatcher.dispatcherId}, removing from cache") + dispatchers.remove(workerSpec) + dispatcherFactory.onAllDispatcherSessionsClosed( + entry.dispatcher) + } + } + + // Must be called while holding `lock`. + private def getOrCreateDispatcherEntry( + workerSpec: UDFWorkerSpecification + ): DispatcherEntry = { + var entry = dispatchers.get(workerSpec) + if (entry == null) { + val dispatcher = dispatcherFactory.createDispatcher( + workerSpec, logger) + logger.info(s"Created dispatcher ${dispatcher.dispatcherId}") + entry = new DispatcherEntry(dispatcher) + dispatchers.put(workerSpec, entry) + } + entry.activeSessionCount += 1 + entry + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala index 008cfc2993a09..f75c8a2171200 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala @@ -31,6 +31,10 @@ import org.apache.spark.udf.worker.UDFWorkerSpecification @Experimental trait WorkerDispatcher extends AutoCloseable { + /** Unique identifier for this dispatcher. */ + val dispatcherId: String = + java.util.UUID.randomUUID().toString + def workerSpec: UDFWorkerSpecification /** diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala index a8f135f688908..9b6646fb6fb5b 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala @@ -36,8 +36,34 @@ import org.apache.spark.annotation.Experimental trait WorkerLogger { def warn(msg: => String): Unit def warn(msg: => String, t: Throwable): Unit + def info(msg: => String): Unit + def info(msg: => String, t: Throwable): Unit def debug(msg: => String): Unit def debug(msg: => String, t: Throwable): Unit + + /** + * Returns a new [[WorkerLogger]] that prefixes every message with + * `[className]`. Useful for identifying which class produced a + * log line. + */ + def forClass(clazz: Class[_]): WorkerLogger = { + val prefix = s"[${clazz.getSimpleName}] " + val parent = this + new WorkerLogger { + override def warn(msg: => String): Unit = + parent.warn(prefix + msg) + override def warn(msg: => String, t: Throwable): Unit = + parent.warn(prefix + msg, t) + override def info(msg: => String): Unit = + parent.info(prefix + msg) + override def info(msg: => String, t: Throwable): Unit = + parent.info(prefix + msg, t) + override def debug(msg: => String): Unit = + parent.debug(prefix + msg) + override def debug(msg: => String, t: Throwable): Unit = + parent.debug(prefix + msg, t) + } + } } object WorkerLogger { @@ -45,6 +71,8 @@ object WorkerLogger { val NoOp: WorkerLogger = new WorkerLogger { override def warn(msg: => String): Unit = () override def warn(msg: => String, t: Throwable): Unit = () + override def info(msg: => String): Unit = () + override def info(msg: => String, t: Throwable): Unit = () override def debug(msg: => String): Unit = () override def debug(msg: => String, t: Throwable): Unit = () } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala index f4c4091688c94..55690aeb1972e 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.udf.worker.core +import java.util.ArrayList import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.annotation.Experimental @@ -76,16 +77,66 @@ case class InitMessage( * - [[close]] must always be called (use try-finally). * - [[cancel]] may be called at any time to abort execution. * - * The lifecycle is enforced here: [[init]] and [[process]] are `final` - * and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards. + * The lifecycle is enforced here: [[init]], [[process]], [[cancel]], + * and [[close]] are `final` and delegate to [[doInit]] / [[doProcess]] / + * [[doCancel]], and [[doClose]] after AtomicBoolean guards. * Subclasses implement the protocol-specific work and do not re-check * the contract. + * + * Completion listeners registered via [[addSessionCompletionListener]] + * are fired exactly once, after [[doClose]] or [[doCancel]] + * (whichever runs first). Note that the completion listener can + * be executed in a completely separate thread from the thread who + * registered the listener. */ @Experimental -abstract class WorkerSession extends AutoCloseable { +abstract class WorkerSession( + workerLogger: WorkerLogger +) extends AutoCloseable { + + protected val logger: WorkerLogger = + workerLogger.forClass(getClass) + + /** Unique identifier for this session. */ + val sessionId: String = java.util.UUID.randomUUID().toString private val initialized = new AtomicBoolean(false) private val processed = new AtomicBoolean(false) + private val closed = new AtomicBoolean(false) + + // Guards `completionListeners`, and `completionListenersFired` + // to ensure that a listener added after close is fired + // immediately and exactly once. + private val listenerLock = new Object + private var completionListenersFired = false + private val completionListeners = + new ArrayList[WorkerSession => Unit]() + + /** + * Registers a closure to be invoked when this session completes + * (via [[close]] or [[cancel]]). Listeners fire exactly once, in + * registration order. If the session is already closed when + * registering, the listener is fired immediately. + */ + final def addSessionCompletionListener( + f: WorkerSession => Unit): Unit = { + listenerLock.synchronized { + if (completionListenersFired) { + // Listeners from the list were already fired + // -> Invoke immediately. + f(this) + } else { + completionListeners.add(f) + } + } + } + + private def fireCompletionListeners(): Unit = { + listenerLock.synchronized { + completionListenersFired = true + completionListeners.forEach(_(this)) + } + } /** * Initializes the UDF execution. Must be called exactly once before @@ -100,6 +151,7 @@ abstract class WorkerSession extends AutoCloseable { if (!initialized.compareAndSet(false, true)) { throw new IllegalStateException("init has already been called on this session") } + logger.info(s"Session $sessionId: init") doInit(message) } @@ -117,22 +169,18 @@ abstract class WorkerSession extends AutoCloseable { * @param input iterator of raw input data batches (e.g., Arrow IPC) * @return iterator of raw result data batches */ - final def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = { + final def process( + input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = { if (!initialized.get()) { throw new IllegalStateException("process called before init") } if (!processed.compareAndSet(false, true)) { throw new IllegalStateException("process has already been called on this session") } + logger.info(s"Session $sessionId: process started") doProcess(input) } - /** Subclass hook for [[init]]. Called once, after the guard. */ - protected def doInit(message: InitMessage): Unit - - /** Subclass hook for [[process]]. Called at most once, after the guard. */ - protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] - /** * Requests cancellation of the current UDF execution. * @@ -141,8 +189,34 @@ abstract class WorkerSession extends AutoCloseable { * task interruption thread). It may be invoked at any point after * [[init]] and should be a no-op if execution has already finished. */ - def cancel(): Unit + final def cancel(): Unit = { + // TODO [SPARK-55278]: Implement correct cancellation/finish semantics + // according to the worker_spec.proto. + if (closed.compareAndSet(false, true)) { + logger.info(s"Session $sessionId: cancel") + doCancel() + fireCompletionListeners() + } + } /** Closes this session and releases resources. */ - override def close(): Unit + override final def close(): Unit = { + if (closed.compareAndSet(false, true)) { + logger.info(s"Session $sessionId: close") + doClose() + fireCompletionListeners() + } + } + + /** Subclass hook for [[init]]. Called once, after the guard. */ + protected def doInit(message: InitMessage): Unit + + /** Subclass hook for [[process]]. Called at most once, after the guard. */ + protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] + + /** Subclass hook for [[cancel]]. Called once, after the guard. */ + protected def doCancel(): Unit + + /** Subclass hook for [[close]]. Called at most once, after the guard. */ + protected def doClose(): Unit } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala index 8da0354187e4f..7456a8a5b9bae 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala @@ -37,7 +37,7 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POL @Experimental abstract class DirectUnixSocketWorkerDispatcher( workerSpec: UDFWorkerSpecification, - logger: WorkerLogger = WorkerLogger.NoOp) + logger: WorkerLogger) extends DirectWorkerDispatcher(workerSpec, logger) { // Removed explicitly in closeTransport(). deleteOnExit is avoided because diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index afaf23791d80f..8925912b85851 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -62,9 +62,12 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR @Experimental abstract class DirectWorkerDispatcher( override val workerSpec: UDFWorkerSpecification, - protected val logger: WorkerLogger = WorkerLogger.NoOp) + workerLogger: WorkerLogger) extends WorkerDispatcher { + protected val logger: WorkerLogger = + workerLogger.forClass(getClass) + // TODO: Connection pooling -- reuse idle workers across sessions. // TODO: Security scope isolation -- partition pool by WorkerSecurityScope. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index f4b5c1df63193..495950f3e7048 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -51,7 +51,7 @@ class DirectWorkerProcess( val id: String, private[direct] val artifacts: WorkerArtifacts, val gracefulTimeoutMs: Long, - protected val logger: WorkerLogger = WorkerLogger.NoOp, + protected val logger: WorkerLogger, private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ()) extends AutoCloseable { diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala index 7cdc5329350e3..78b10c84a34e1 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.udf.worker.core.direct import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.annotation.Experimental -import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession} +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger, WorkerSession} /** * :: Experimental :: @@ -41,14 +41,15 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession} */ @Experimental abstract class DirectWorkerSession( - private[core] val workerProcess: DirectWorkerProcess) extends WorkerSession { - + private[core] val workerProcess: DirectWorkerProcess, + workerLogger: WorkerLogger) + extends WorkerSession(workerLogger) { private val released = new AtomicBoolean(false) /** The connection to the worker for this session. */ def connection: WorkerConnection = workerProcess.connection - override def close(): Unit = { + override protected def doClose(): Unit = { if (released.compareAndSet(false, true)) { workerProcess.releaseSession() } diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 60f5e2211b702..b0b73ed70a825 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -55,8 +55,9 @@ class SocketFileConnection(socketPath: String) * no-op). Tracking the thread-safety contract in the docstring on * [[org.apache.spark.udf.worker.core.WorkerSession.cancel]]. */ -class StubWorkerSession( - workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) { +private class StubWorkerSession( + workerProcess: DirectWorkerProcess) + extends DirectWorkerSession(workerProcess, WorkerLogger.NoOp) { override protected def doInit(message: InitMessage): Unit = {} @@ -64,7 +65,9 @@ class StubWorkerSession( input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = Iterator.empty - override def cancel(): Unit = {} + override protected def doCancel(): Unit = {} + + override protected def doClose(): Unit = {} } /** @@ -73,7 +76,7 @@ class StubWorkerSession( * implementation. */ class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification) - extends DirectUnixSocketWorkerDispatcher(spec) { + extends DirectUnixSocketWorkerDispatcher(spec, WorkerLogger.NoOp) { override protected def createConnection( socketPath: String): UnixSocketWorkerConnection = @@ -362,7 +365,8 @@ class DirectWorkerDispatcherSuite val releaseLatch = new java.util.concurrent.CountDownLatch(1) val capturedWorkers = new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]() - val racing = new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + val racing = new DirectUnixSocketWorkerDispatcher( + specWithRunner(defaultRunner), WorkerLogger.NoOp) { override protected def createConnection( socketPath: String): UnixSocketWorkerConnection = new SocketFileConnection(socketPath) @@ -538,7 +542,7 @@ class DirectWorkerDispatcherSuite // worker must be terminated rather than leaked until dispatcher.close(). var capturedWorker: DirectWorkerProcess = null val failingDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner), WorkerLogger.NoOp) { override protected def createConnection( socketPath: String): UnixSocketWorkerConnection = new SocketFileConnection(socketPath) @@ -594,7 +598,7 @@ class DirectWorkerDispatcherSuite test("socket file is cleaned up when createConnection throws") { val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]() val failingDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner), WorkerLogger.NoOp) { override protected def createConnection( socketPath: String): UnixSocketWorkerConnection = { capturedSocketPaths.add(socketPath) @@ -760,7 +764,7 @@ class DirectWorkerDispatcherSuite .addCommand("sleep 30").build() val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build() val shortTimeoutDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), WorkerLogger.NoOp) { override protected def callableTimeoutMs: Long = 500L override protected def createConnection( socketPath: String): UnixSocketWorkerConnection = @@ -897,7 +901,7 @@ class DirectWorkerDispatcherSuite s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build()) .build() val timeoutDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), WorkerLogger.NoOp) { override protected def callableTimeoutMs: Long = 500L override protected def createConnection( socketPath: String): UnixSocketWorkerConnection = diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala new file mode 100644 index 0000000000000..c5af112a54f5e --- /dev/null +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import scala.collection.mutable.ArrayBuffer + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, verify, when} +import org.mockito.invocation.InvocationOnMock +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.udf.worker._ + +/** + * A minimal [[WorkerSession]] that tracks cancel/close calls. + */ +private class NoOpWorkerSession + extends WorkerSession(WorkerLogger.NoOp) { + var cancelled: Boolean = false + override protected def doInit(msg: InitMessage): Unit = {} + override protected def doProcess( + input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = + Iterator.empty + override protected def doCancel(): Unit = { cancelled = true } + override protected def doClose(): Unit = {} +} + +/** + * Holds a test [[UDFDispatcherManager]] and all observable state, + * so tests can assert on whichever fields they care about. + */ +private case class TestManagerFixture( + manager: UDFDispatcherManager, + createdDispatchers: ArrayBuffer[WorkerDispatcher], + closedDispatchers: ArrayBuffer[WorkerDispatcher], + doStopCalls: ArrayBuffer[Boolean], + createdSessions: ArrayBuffer[NoOpWorkerSession]) + +private object TestManagerFixture { + def apply(): TestManagerFixture = { + val createdDispatchers = ArrayBuffer[WorkerDispatcher]() + val closedDispatchers = ArrayBuffer[WorkerDispatcher]() + val doStopCalls = ArrayBuffer[Boolean]() + val createdSessions = ArrayBuffer[NoOpWorkerSession]() + + val factory = new UDFDispatcherFactory { + override def createDispatcher( + workerSpec: UDFWorkerSpecification, + logger: WorkerLogger): WorkerDispatcher = { + val dispatcher = mock(classOf[WorkerDispatcher]) + when(dispatcher.createSession( + any[Option[WorkerSecurityScope]])) + .thenAnswer((_: InvocationOnMock) => { + val session = new NoOpWorkerSession() + createdSessions += session + session + }) + createdDispatchers += dispatcher + dispatcher + } + override def onAllDispatcherSessionsClosed( + dispatcher: WorkerDispatcher): Unit = { + closedDispatchers += dispatcher + } + override def onStop(): Unit = { + doStopCalls += true + } + } + val manager = new UDFDispatcherManager(factory) + + TestManagerFixture( + manager, createdDispatchers, closedDispatchers, + doStopCalls, createdSessions) + } +} + +class UDFDispatcherManagerSuite + extends AnyFunSuite { // scalastyle:ignore funsuite + + private def makeSpec(command: String): UDFWorkerSpecification = { + val callable = ProcessCallable.newBuilder() + callable.addCommand(command) + val caps = WorkerCapabilities.newBuilder() + .addSupportedDataFormats(UDFWorkerDataFormat.ARROW) + .addSupportedCommunicationPatterns( + UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING) + val conn = WorkerConnectionSpec.newBuilder() + .setTcp(LocalTcpConnection.newBuilder()) + val props = UDFWorkerProperties.newBuilder() + .setConnection(conn) + val direct = DirectWorker.newBuilder() + .setRunner(callable).setProperties(props) + UDFWorkerSpecification.newBuilder() + .setEnvironment(WorkerEnvironment.newBuilder()) + .setCapabilities(caps).setDirect(direct).build() + } + + test("Same spec reuses the same dispatcher") { + val fixture = TestManagerFixture() + val spec = makeSpec("worker.bin") + + val s1 = fixture.manager.createSession(spec) + val s2 = fixture.manager.createSession(spec) + + assert(fixture.createdDispatchers.size === 1) + verify( + fixture.createdDispatchers.head, + org.mockito.Mockito.times(2)) + .createSession(any[Option[WorkerSecurityScope]]) + + s1.close() + s2.close() + } + + test("Structurally equal specs reuse the same dispatcher") { + val fixture = TestManagerFixture() + val spec1 = makeSpec("worker.bin") + val spec2 = makeSpec("worker.bin") + assert(spec1 ne spec2) + assert(spec1 == spec2) + + val s1 = fixture.manager.createSession(spec1) + val s2 = fixture.manager.createSession(spec2) + + assert(fixture.createdDispatchers.size === 1) + + s1.close() + s2.close() + } + + test("Different specs create different dispatchers") { + val fixture = TestManagerFixture() + + val sA = fixture.manager.createSession(makeSpec("worker-a.bin")) + val sB = fixture.manager.createSession(makeSpec("worker-b.bin")) + + assert(fixture.createdDispatchers.size === 2) + assert( + fixture.createdDispatchers(0) ne + fixture.createdDispatchers(1)) + + sA.close() + sB.close() + } + + test("onAllDispatcherSessionsClosed when last session closes") { + val fixture = TestManagerFixture() + val spec = makeSpec("worker.bin") + + val s1 = fixture.manager.createSession(spec) + val s2 = fixture.manager.createSession(spec) + + s1.close() + assert(fixture.closedDispatchers.isEmpty) + + s2.close() + assert(fixture.closedDispatchers.size === 1) + assert( + fixture.closedDispatchers.head eq + fixture.createdDispatchers.head) + } + + test("onAllDispatcherSessionsClosed not called while sessions remain") { + val fixture = TestManagerFixture() + val spec = makeSpec("worker.bin") + + val s1 = fixture.manager.createSession(spec) + val s2 = fixture.manager.createSession(spec) + val s3 = fixture.manager.createSession(spec) + + s1.close() + s2.close() + assert(fixture.closedDispatchers.isEmpty) + + s3.close() + assert(fixture.closedDispatchers.size === 1) + } + + test("New dispatcher after all sessions closed") { + val fixture = TestManagerFixture() + val spec = makeSpec("worker.bin") + + val s1 = fixture.manager.createSession(spec) + val s2 = fixture.manager.createSession(spec) + assert(fixture.createdDispatchers.size === 1) + + s1.close() + s2.close() + assert(fixture.closedDispatchers.size === 1) + + val s3 = fixture.manager.createSession(spec) + assert(fixture.createdDispatchers.size === 2) + assert( + fixture.createdDispatchers(0) ne + fixture.createdDispatchers(1)) + + s3.close() + assert(fixture.closedDispatchers.size === 2) + } + + test("Stop closes all cached dispatchers") { + val fixture = TestManagerFixture() + + fixture.manager.createSession(makeSpec("worker-a.bin")) + fixture.manager.createSession(makeSpec("worker-b.bin")) + fixture.manager.stop() + + fixture.createdDispatchers.foreach(d => verify(d).close()) + } + + test("Stop calls doStop for subclass cleanup") { + val fixture = TestManagerFixture() + + fixture.manager.createSession(makeSpec("worker.bin")) + assert(fixture.doStopCalls.isEmpty) + + fixture.manager.stop() + assert(fixture.doStopCalls.size === 1) + } + + test("Stop cancels active sessions that were not closed") { + val fixture = TestManagerFixture() + val spec = makeSpec("worker.bin") + + val s1 = fixture.manager.createSession(spec) + fixture.manager.createSession(spec) + s1.close() + + assert(fixture.createdSessions.size === 2) + assert(!fixture.createdSessions(0).cancelled) + assert(!fixture.createdSessions(1).cancelled) + + fixture.manager.stop() + + assert(fixture.createdSessions(1).cancelled) + } + + test("Stop does not cancel already closed sessions") { + val fixture = TestManagerFixture() + val spec = makeSpec("worker.bin") + + val s1 = fixture.manager.createSession(spec) + s1.close() + + fixture.manager.stop() + + assert(!fixture.createdSessions(0).cancelled) + } + + test("createSession throws after stop") { + val fixture = TestManagerFixture() + fixture.manager.stop() + + intercept[IllegalStateException] { + fixture.manager.createSession(makeSpec("worker.bin")) + } + } +} diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerSessionSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerSessionSuite.scala new file mode 100644 index 0000000000000..b63af22af65e3 --- /dev/null +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerSessionSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.util.concurrent.atomic.AtomicInteger + +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite + +private class TestWorkerSession + extends WorkerSession(WorkerLogger.NoOp) { + override protected def doInit(msg: InitMessage): Unit = {} + override protected def doProcess( + input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = + Iterator.empty + override protected def doCancel(): Unit = {} + override protected def doClose(): Unit = {} +} + +class WorkerSessionSuite + extends AnyFunSuite { // scalastyle:ignore funsuite + + test("Completion listener fires on close") { + val session = new TestWorkerSession() + val count = new AtomicInteger(0) + session.addSessionCompletionListener(_ => count.incrementAndGet()) + + session.close() + assert(count.get() === 1) + } + + test("Completion listener fires on cancel") { + val session = new TestWorkerSession() + val count = new AtomicInteger(0) + session.addSessionCompletionListener(_ => count.incrementAndGet()) + + session.cancel() + assert(count.get() === 1) + } + + test("Completion listener fires exactly once on close then cancel") { + val session = new TestWorkerSession() + val count = new AtomicInteger(0) + session.addSessionCompletionListener(_ => count.incrementAndGet()) + + session.close() + session.cancel() + assert(count.get() === 1) + } + + test("Listener added after close fires immediately") { + val session = new TestWorkerSession() + session.close() + + val count = new AtomicInteger(0) + session.addSessionCompletionListener(_ => count.incrementAndGet()) + assert(count.get() === 1) + } + + test("Listener added after cancel fires immediately") { + val session = new TestWorkerSession() + session.cancel() + + val count = new AtomicInteger(0) + session.addSessionCompletionListener(_ => count.incrementAndGet()) + assert(count.get() === 1) + } + + test("Multiple listeners all fire exactly once") { + val session = new TestWorkerSession() + val count1 = new AtomicInteger(0) + val count2 = new AtomicInteger(0) + val count3 = new AtomicInteger(0) + session.addSessionCompletionListener(_ => count1.incrementAndGet()) + session.addSessionCompletionListener(_ => count2.incrementAndGet()) + + session.close() + + // Add a third after close + session.addSessionCompletionListener(_ => count3.incrementAndGet()) + + assert(count1.get() === 1) + assert(count2.get() === 1) + assert(count3.get() === 1) + } + + test("Listener receives the correct session instance") { + val session = new TestWorkerSession() + var received: WorkerSession = null + session.addSessionCompletionListener(s => received = s) + + session.close() + assert(received eq session) + } +}