Skip to content

Commit 83e1033

Browse files
committed
[SPARK-56661] Implementing UDFWorkerManager for new UDF worker sessions
1 parent 68c0042 commit 83e1033

14 files changed

Lines changed: 920 additions & 28 deletions

udf/worker/core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
<groupId>org.scala-lang</groupId>
5252
<artifactId>scala-library</artifactId>
5353
</dependency>
54+
<dependency>
55+
<groupId>org.mockito</groupId>
56+
<artifactId>mockito-core</artifactId>
57+
<scope>test</scope>
58+
</dependency>
5459
</dependencies>
5560

5661
<build>
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.udf.worker.core
18+
19+
import java.util.{ArrayList, HashMap}
20+
21+
import org.apache.spark.annotation.Experimental
22+
import org.apache.spark.udf.worker.UDFWorkerSpecification
23+
24+
/**
25+
* :: Experimental :: Creates [[WorkerSession]] instances for a given
26+
* [[UDFWorkerSpecification]], managing [[WorkerDispatcher]] instances and
27+
* their lifecycle internally.
28+
*
29+
* Dispatchers are cached by spec (protobuf value equality) and reused across
30+
* sessions. The manager tracks the number of active sessions per dispatcher
31+
* via [[WorkerSession#addSessionCompletionListener]]. When the last session
32+
* for a dispatcher is closed, the entry is removed and
33+
* [[onAllDispatcherSessionsClosed]] is called.
34+
*
35+
* You might be wondering why the Dispatcher does not track the number of
36+
* active sessions itself. The reason is that this would create a
37+
* unavoidable race condition: Clients can provide different worker
38+
* specs. Therefore, different dispatchers may be required, which cannot all
39+
* exist for the whole Spark lifetime -> Dispatchers need to be removed/terminated
40+
* at some point. If Dispatchers were to track their active sessions themselves
41+
* and we would use this to decide on the dispatcher lifetime, it can always
42+
* happen that there are concurrent [[createSession]] requests while
43+
* the Dispatcher is being disposed off - which would create session
44+
* initialization errors and may cause Spark task/query failures.
45+
* Instead, we track the active sessions per Dispatcher globally
46+
* in this manager.
47+
*
48+
* Thread safety: a single lock guards all state -- dispatchers, active
49+
* sessions, and the stopping flag.
50+
*
51+
* Subclasses must implement [[doCreateDispatcher]] to provide a concrete
52+
* dispatcher and [[onAllDispatcherSessionsClosed]] to control dispatcher
53+
* teardown policy. Keeping a dispatcher alive after
54+
* [[onAllDispatcherSessionsClosed]] should be a conscious decision
55+
* and requires additional clean-up logic in the subclasses implemented
56+
* via [[doStop]].
57+
*/
58+
@Experimental
59+
abstract class UDFWorkerManager(
60+
workerLogger: WorkerLogger = WorkerLogger.NoOp
61+
) {
62+
63+
protected val logger: WorkerLogger =
64+
workerLogger.forClass(getClass)
65+
66+
/*
67+
* Why do we need an [[activeSessionCount]] and an [[activeSessions]]
68+
* list? [[activeSessionCount]] is per dispatcher. [[activeSessions]]
69+
* is globally and allows us to perform session cleanup on [[stop]].
70+
* Moreover, this distinction allows us to create sessions without
71+
* requiring a lock on [[lock]].
72+
*/
73+
private class DispatcherEntry(val dispatcher: WorkerDispatcher) {
74+
var activeSessionCount: Int = 0
75+
}
76+
77+
// All fields below are guarded by `lock`.
78+
private val lock = new Object
79+
private val dispatchers =
80+
new HashMap[UDFWorkerSpecification, DispatcherEntry]()
81+
private val activeSessions = new ArrayList[WorkerSession]()
82+
private var stopped = false
83+
84+
/**
85+
* Creates a [[WorkerSession]] for the given worker specification and
86+
* optional security scope.
87+
*
88+
* If a dispatcher for this spec already exists it is reused; otherwise
89+
* [[doCreateDispatcher]] is called to create one. A completion listener
90+
* is registered on the session to track when it closes.
91+
*/
92+
final def createSession(
93+
workerSpec: UDFWorkerSpecification,
94+
securityScope: Option[WorkerSecurityScope] = None
95+
): WorkerSession = {
96+
// Get the dispatcher
97+
val entry = lock.synchronized {
98+
if (stopped) {
99+
throwStopped()
100+
}
101+
getOrCreateDispatcherEntry(workerSpec)
102+
}
103+
104+
// Create a new session (potentially slow -> outside the lock).
105+
// Note: This might fail if Spark is concurrently being stopped
106+
// and the dispatcher is cleaned up. As Spark is stopping,
107+
// this failure is acceptable. On the happy path, no sessions
108+
// should try to be created while Spark is shutting down.
109+
val session = entry.dispatcher.createSession(securityScope)
110+
lock.synchronized {
111+
if (stopped) {
112+
session.close()
113+
throwStopped()
114+
}
115+
activeSessions.add(session)
116+
}
117+
118+
logger.info(s"Created session ${session.sessionId}" +
119+
s" on dispatcher ${entry.dispatcher.dispatcherId}" +
120+
s" (active: ${entry.activeSessionCount})")
121+
122+
// Register a completion listener that updates the
123+
// state when the session is canceled or closed
124+
session.addSessionCompletionListener { session =>
125+
logger.info(s"Session ${session.sessionId} terminated")
126+
lock.synchronized {
127+
if (!stopped) {
128+
activeSessions.remove(session)
129+
handleSessionTermination(workerSpec)
130+
}
131+
}
132+
}
133+
134+
session
135+
}
136+
137+
/**
138+
* Called on driver/executor shutdown. Cancels any active sessions,
139+
* closes all cached dispatchers, and resets internal state.
140+
*
141+
* Safety net -- in normal operation, sessions are closed by task
142+
* completion listeners and dispatchers are cleaned up via
143+
* [[onAllDispatcherSessionsClosed]] when their last session closes.
144+
*/
145+
final def stop(): Unit = {
146+
logger.info("UDFWorkerManager stopping" +
147+
s" (${activeSessions.size()} active sessions," +
148+
s" ${dispatchers.size()} dispatchers)")
149+
150+
lock.synchronized {
151+
stopped = true
152+
153+
// Cancel any sessions that are still active. Cancel is a
154+
// no-op if the session was already closed/cancelled.
155+
activeSessions.forEach { s =>
156+
logger.debug(s"Cancelling session ${s.sessionId}" +
157+
" during stop")
158+
s.cancel()
159+
}
160+
activeSessions.clear()
161+
162+
// Close all dispatchers we control.
163+
// When spark is stopped in a clean state
164+
// (only finished tasks), it is expected
165+
// that all dispatchers have been terminated
166+
// already. This is a safety-net.
167+
dispatchers.forEach { (_, entry) =>
168+
logger.debug(s"Closing dispatcher" +
169+
s" ${entry.dispatcher.dispatcherId} during stop")
170+
entry.dispatcher.close()
171+
}
172+
dispatchers.clear()
173+
}
174+
175+
// Perform cleanup in the sub classes
176+
doStop()
177+
logger.info("UDFWorkerManager stopped")
178+
}
179+
180+
private def throwStopped(): Nothing =
181+
throw new IllegalStateException(
182+
"UDFWorkerManager is stopped")
183+
184+
// Must be called while holding `lock`.
185+
private def handleSessionTermination(
186+
workerSpec: UDFWorkerSpecification
187+
): Unit = {
188+
val entry = dispatchers.get(workerSpec)
189+
// Note: entry == null is unexpected and should
190+
// throw here.
191+
entry.activeSessionCount -= 1
192+
if (entry.activeSessionCount == 0) {
193+
logger.info("All sessions closed for dispatcher " +
194+
s"${entry.dispatcher.dispatcherId}, removing from cache")
195+
dispatchers.remove(workerSpec)
196+
onAllDispatcherSessionsClosed(entry.dispatcher)
197+
}
198+
}
199+
200+
// Must be called while holding `lock`.
201+
private def getOrCreateDispatcherEntry(
202+
workerSpec: UDFWorkerSpecification
203+
): DispatcherEntry = {
204+
var entry = dispatchers.get(workerSpec)
205+
if (entry == null) {
206+
val dispatcher = doCreateDispatcher(workerSpec, logger)
207+
logger.info(s"Created dispatcher ${dispatcher.dispatcherId}")
208+
entry = new DispatcherEntry(dispatcher)
209+
dispatchers.put(workerSpec, entry)
210+
}
211+
entry.activeSessionCount += 1
212+
entry
213+
}
214+
215+
/**
216+
* Creates a new [[WorkerDispatcher]] for the given specification.
217+
* It is expected that creating the dispatcher
218+
* itself is not slow while creating a session might be.
219+
*/
220+
protected def doCreateDispatcher(
221+
workerSpec: UDFWorkerSpecification,
222+
logger: WorkerLogger
223+
): WorkerDispatcher
224+
225+
/**
226+
* Called when the last active session for a dispatcher is closed.
227+
* Subclasses must decide what to do with the now-idle dispatcher.
228+
* Not called during [[stop]] -- the manager cleans up dispatchers
229+
* it holds directly in that case. Cleanup of dispatchers not
230+
* provided to the manager is the responsibility of the subclass
231+
* via [[doStop]].
232+
*/
233+
protected def onAllDispatcherSessionsClosed(
234+
dispatcher: WorkerDispatcher
235+
): Unit
236+
237+
/**
238+
* Called when the executor/driver stops. Subclasses should clean
239+
* up any dispatchers/resources they hold beyond what the parent
240+
* class manages.
241+
*/
242+
protected def doStop(): Unit
243+
}

udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ import org.apache.spark.udf.worker.UDFWorkerSpecification
3131
@Experimental
3232
trait WorkerDispatcher extends AutoCloseable {
3333

34+
/** Unique identifier for this dispatcher. */
35+
val dispatcherId: String =
36+
java.util.UUID.randomUUID().toString
37+
3438
def workerSpec: UDFWorkerSpecification
3539

3640
/**

udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,43 @@ import org.apache.spark.annotation.Experimental
3636
trait WorkerLogger {
3737
def warn(msg: => String): Unit
3838
def warn(msg: => String, t: Throwable): Unit
39+
def info(msg: => String): Unit
40+
def info(msg: => String, t: Throwable): Unit
3941
def debug(msg: => String): Unit
4042
def debug(msg: => String, t: Throwable): Unit
43+
44+
/**
45+
* Returns a new [[WorkerLogger]] that prefixes every message with
46+
* `[className]`. Useful for identifying which class produced a
47+
* log line.
48+
*/
49+
def forClass(clazz: Class[_]): WorkerLogger = {
50+
val prefix = s"[${clazz.getSimpleName}] "
51+
val parent = this
52+
new WorkerLogger {
53+
override def warn(msg: => String): Unit =
54+
parent.warn(prefix + msg)
55+
override def warn(msg: => String, t: Throwable): Unit =
56+
parent.warn(prefix + msg, t)
57+
override def info(msg: => String): Unit =
58+
parent.info(prefix + msg)
59+
override def info(msg: => String, t: Throwable): Unit =
60+
parent.info(prefix + msg, t)
61+
override def debug(msg: => String): Unit =
62+
parent.debug(prefix + msg)
63+
override def debug(msg: => String, t: Throwable): Unit =
64+
parent.debug(prefix + msg, t)
65+
}
66+
}
4167
}
4268

4369
object WorkerLogger {
4470
/** Discards all messages. Default for callers that don't wire up logging. */
4571
val NoOp: WorkerLogger = new WorkerLogger {
4672
override def warn(msg: => String): Unit = ()
4773
override def warn(msg: => String, t: Throwable): Unit = ()
74+
override def info(msg: => String): Unit = ()
75+
override def info(msg: => String, t: Throwable): Unit = ()
4876
override def debug(msg: => String): Unit = ()
4977
override def debug(msg: => String, t: Throwable): Unit = ()
5078
}

0 commit comments

Comments
 (0)