From 66d60ff54e7c05fd1a806e3f8c78dab98c6fdba9 Mon Sep 17 00:00:00 2001 From: Sergei Sysoev Date: Wed, 10 Jul 2024 22:35:25 +0200 Subject: [PATCH 1/2] Move `ThreadContextElement` to common --- .../api/kotlinx-coroutines-core.klib.api | 4 + .../common/src/ThreadContextElement.common.kt | 82 ++++++++ .../src/internal/ThreadContext.common.kt | 53 ++++- .../jsAndWasmShared/src/CoroutineContext.kt | 104 +++++++++- .../src/internal/ThreadContext.kt | 38 +++- .../jvm/src/ThreadContextElement.kt | 79 -------- .../jvm/src/internal/ThreadContext.kt | 51 ----- .../native/src/CoroutineContext.kt | 186 +++++++++++++++++- .../native/src/internal/ThreadContext.kt | 38 +++- 9 files changed, 494 insertions(+), 141 deletions(-) create mode 100644 kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api index a820fe0b35..e8506dd670 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api @@ -159,6 +159,10 @@ abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CompletableDeferred : ko abstract fun complete(#A): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.complete|complete(1:0){}[0] abstract fun completeExceptionally(kotlin/Throwable): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.completeExceptionally|completeExceptionally(kotlin.Throwable){}[0] } +abstract interface <#A: kotlin/Any?> kotlinx.coroutines/ThreadContextElement : kotlin.coroutines/CoroutineContext.Element { // kotlinx.coroutines/ThreadContextElement|null[0] + abstract fun restoreThreadContext(kotlin.coroutines/CoroutineContext, #A) // kotlinx.coroutines/ThreadContextElement.restoreThreadContext|restoreThreadContext(kotlin.coroutines.CoroutineContext;1:0){}[0] + abstract fun updateThreadContext(kotlin.coroutines/CoroutineContext): #A // kotlinx.coroutines/ThreadContextElement.updateThreadContext|updateThreadContext(kotlin.coroutines.CoroutineContext){}[0] +} abstract interface <#A: kotlin/Throwable & kotlinx.coroutines/CopyableThrowable<#A>> kotlinx.coroutines/CopyableThrowable { // kotlinx.coroutines/CopyableThrowable|null[0] abstract fun createCopy(): #A? // kotlinx.coroutines/CopyableThrowable.createCopy|createCopy(){}[0] } diff --git a/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt new file mode 100644 index 0000000000..e5ea541c95 --- /dev/null +++ b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt @@ -0,0 +1,82 @@ +package kotlinx.coroutines + +import kotlin.coroutines.* + +/** + * Defines elements in [CoroutineContext] that are installed into thread context + * every time the coroutine with this element in the context is resumed on a thread. + * + * Implementations of this interface define a type [S] of the thread-local state that they need to store on + * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. + * + * Example usage looks like this: + * + * ``` + * // Appends "name" of a coroutine to a current thread name when coroutine is executed + * class CoroutineName(val name: String) : ThreadContextElement { + * // declare companion object for a key of this element in coroutine context + * companion object Key : CoroutineContext.Key + * + * // provide the key of the corresponding context element + * override val key: CoroutineContext.Key + * get() = Key + * + * // this is invoked before coroutine is resumed on current thread + * override fun updateThreadContext(context: CoroutineContext): String { + * val previousName = Thread.currentThread().name + * Thread.currentThread().name = "$previousName # $name" + * return previousName + * } + * + * // this is invoked after coroutine has suspended on current thread + * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + * Thread.currentThread().name = oldState + * } + * } + * + * // Usage + * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } + * ``` + * + * Every time this coroutine is resumed on a thread, UI thread name is updated to + * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when + * this coroutine suspends. + * + * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. + * + * ### Reentrancy and thread-safety + * + * Correct implementations of this interface must expect that calls to [restoreThreadContext] + * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. + * See [CopyableThreadContextElement] for advanced interleaving details. + * + * All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state + * within an element accordingly. + */ +public interface ThreadContextElement : CoroutineContext.Element { + /** + * Updates context of the current thread. + * This function is invoked before the coroutine in the specified [context] is resumed in the current thread + * when the context of the coroutine this element. + * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + */ + public fun updateThreadContext(context: CoroutineContext): S + + /** + * Restores context of the current thread. + * This function is invoked after the coroutine in the specified [context] is suspended in the current thread + * if [updateThreadContext] was previously invoked on resume of this coroutine. + * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should + * be restored in the thread-local state by this function. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + * @param oldState the value returned by the previous invocation of [updateThreadContext]. + */ + public fun restoreThreadContext(context: CoroutineContext, oldState: S) +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt index c52d35c128..03d718eae4 100644 --- a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt +++ b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt @@ -1,5 +1,56 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* +import kotlin.jvm.* -internal expect fun threadContextElements(context: CoroutineContext): Any +@JvmField +internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") + +// Used when there are >= 2 active elements in the context +@Suppress("UNCHECKED_CAST") +internal class ThreadState(@JvmField val context: CoroutineContext, n: Int) { + private val values = arrayOfNulls(n) + private val elements = arrayOfNulls>(n) + private var i = 0 + + fun append(element: ThreadContextElement<*>, value: Any?) { + values[i] = value + elements[i++] = element as ThreadContextElement + } + + fun restore(context: CoroutineContext) { + for (i in elements.indices.reversed()) { + elements[i]!!.restoreThreadContext(context, values[i]) + } + } +} + +// Counts ThreadContextElements in the context +// Any? here is Int | ThreadContextElement (when count is one) +private val countAll = + fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + if (element is ThreadContextElement<*>) { + val inCount = countOrElement as? Int ?: 1 + return if (inCount == 0) element else inCount + 1 + } + return countOrElement + } + +// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one +internal val findOne = + fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { + if (found != null) return found + return element as? ThreadContextElement<*> + } + +// Updates state for ThreadContextElements in the context using the given ThreadState +internal val updateState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + if (element is ThreadContextElement<*>) { + state.append(element, element.updateThreadContext(state.context)) + } + return state + } + +internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt index 82862ac8aa..6095f93606 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt @@ -1,7 +1,13 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.* +import kotlinx.coroutines.internal.CoroutineStackFrame +import kotlinx.coroutines.internal.NO_THREAD_ELEMENTS import kotlinx.coroutines.internal.ScopeCoroutine +import kotlinx.coroutines.internal.restoreThreadContext +import kotlinx.coroutines.internal.updateThreadContext import kotlin.coroutines.* +import kotlin.jvm.* @PublishedApi // Used from kotlinx-coroutines-test via suppress, not part of ABI internal actual val DefaultDelay: Delay @@ -18,8 +24,73 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo } // No debugging facilities on Wasm and JS -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +/** + * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. + * Used as a performance optimization to avoid stack walking where it is not necessary. + */ +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} + internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS @@ -27,5 +98,32 @@ internal actual class UndispatchedCoroutine actual constructor( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) + + private var savedContext: CoroutineContext? = null + private var savedOldValue: Any? = null + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + savedContext = context + savedOldValue = oldValue + } + + fun clearThreadContext(): Boolean { + if (savedContext == null) return false + savedContext = null + savedOldValue = null + return true + } + + override fun afterResume(state: Any?) { + savedContext?.let { context -> + restoreThreadContext(context, savedOldValue) + savedContext = null + savedOldValue = null + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } } diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt index 3f56f99d6c..a7915e43de 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt @@ -1,5 +1,41 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +// countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements +internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { + @Suppress("NAME_SHADOWING") + val countOrElement = countOrElement ?: threadContextElements(context) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + countOrElement == 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements + countOrElement is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, countOrElement), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = countOrElement as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.restore(context) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index c1898fbd65..9f52f61d78 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -3,85 +3,6 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlin.coroutines.* -/** - * Defines elements in [CoroutineContext] that are installed into thread context - * every time the coroutine with this element in the context is resumed on a thread. - * - * Implementations of this interface define a type [S] of the thread-local state that they need to store on - * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. - * - * Example usage looks like this: - * - * ``` - * // Appends "name" of a coroutine to a current thread name when coroutine is executed - * class CoroutineName(val name: String) : ThreadContextElement { - * // declare companion object for a key of this element in coroutine context - * companion object Key : CoroutineContext.Key - * - * // provide the key of the corresponding context element - * override val key: CoroutineContext.Key - * get() = Key - * - * // this is invoked before coroutine is resumed on current thread - * override fun updateThreadContext(context: CoroutineContext): String { - * val previousName = Thread.currentThread().name - * Thread.currentThread().name = "$previousName # $name" - * return previousName - * } - * - * // this is invoked after coroutine has suspended on current thread - * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { - * Thread.currentThread().name = oldState - * } - * } - * - * // Usage - * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } - * ``` - * - * Every time this coroutine is resumed on a thread, UI thread name is updated to - * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when - * this coroutine suspends. - * - * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. - * - * ### Reentrancy and thread-safety - * - * Correct implementations of this interface must expect that calls to [restoreThreadContext] - * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. - * See [CopyableThreadContextElement] for advanced interleaving details. - * - * All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state - * within an element accordingly. - */ -public interface ThreadContextElement : CoroutineContext.Element { - /** - * Updates context of the current thread. - * This function is invoked before the coroutine in the specified [context] is resumed in the current thread - * when the context of the coroutine this element. - * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - */ - public fun updateThreadContext(context: CoroutineContext): S - - /** - * Restores context of the current thread. - * This function is invoked after the coroutine in the specified [context] is suspended in the current thread - * if [updateThreadContext] was previously invoked on resume of this coroutine. - * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should - * be restored in the thread-local state by this function. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - * @param oldState the value returned by the previous invocation of [updateThreadContext]. - */ - public fun restoreThreadContext(context: CoroutineContext, oldState: S) -} - /** * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. * diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 8f21b13c25..5b876071f6 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -3,57 +3,6 @@ package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* -@JvmField -internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") - -// Used when there are >= 2 active elements in the context -@Suppress("UNCHECKED_CAST") -private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { - private val values = arrayOfNulls(n) - private val elements = arrayOfNulls>(n) - private var i = 0 - - fun append(element: ThreadContextElement<*>, value: Any?) { - values[i] = value - elements[i++] = element as ThreadContextElement - } - - fun restore(context: CoroutineContext) { - for (i in elements.indices.reversed()) { - elements[i]!!.restoreThreadContext(context, values[i]) - } - } -} - -// Counts ThreadContextElements in the context -// Any? here is Int | ThreadContextElement (when count is one) -private val countAll = - fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { - if (element is ThreadContextElement<*>) { - val inCount = countOrElement as? Int ?: 1 - return if (inCount == 0) element else inCount + 1 - } - return countOrElement - } - -// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one -private val findOne = - fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { - if (found != null) return found - return element as? ThreadContextElement<*> - } - -// Updates state for ThreadContextElements in the context using the given ThreadState -private val updateState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - if (element is ThreadContextElement<*>) { - state.append(element, element.updateThreadContext(state.context)) - } - return state - } - -internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! - // countOrElement is pre-cached in dispatched continuation // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index 3f4c8d9a01..829d97d39c 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -1,7 +1,10 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* +import kotlin.concurrent.* import kotlin.coroutines.* +import kotlin.native.concurrent.ThreadLocal +import kotlin.native.ref.* internal actual object DefaultExecutor : CoroutineDispatcher(), Delay { @@ -40,14 +43,187 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo } // No debugging facilities on native -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +/** + * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. + * Used as a performance optimization to avoid stack walking where it is not necessary. + */ +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} + internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on native -internal actual class UndispatchedCoroutine actual constructor( +internal actual class UndispatchedCoroutineactual constructor ( context: CoroutineContext, uCont: Continuation -) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) +) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { + + /** + * The state of [ThreadContextElement]s associated with the current undispatched coroutine. + * It is stored in a thread local because this coroutine can be used concurrently in suspend-resume race scenario. + * See the following, boiled down example with inlined `withContinuationContext` body: + * ``` + * val state = saveThreadContext(ctx) + * try { + * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called + * // COROUTINE_SUSPENDED is returned + * } finally { + * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread + * // and it also calls saveThreadContext and clearThreadContext + * } + * ``` + * + * Usage note: + * + * This part of the code is performance-sensitive. + * It is a well-established pattern to wrap various activities into system-specific undispatched + * `withContext` for the sake of logging, MDC, tracing etc., meaning that there exists thousands of + * undispatched coroutines. + * [ThreadLocal.set] leaves a footprint in the corresponding Thread's `ThreadLocalMap`. + * We attempt to narrow down the lifetime of this thread local as much as possible: + * - It's never accessed when we are sure there are no thread context elements + * - It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished. + */ + private val threadStateToRecover = ThreadLocal?>(this) + + /* + * Indicates that a coroutine has at least one thread context element associated with it + * and that 'threadStateToRecover' is going to be set in case of dispatchhing in order to preserve them. + * Better than nullable thread-local for easier debugging. + * + * It is used as a performance optimization to avoid 'threadStateToRecover' initialization + * (note: tl.get() initializes thread local), + * and is prone to false-positives as it is never reset: otherwise + * it may lead to logical data races between suspensions point where + * coroutine is yet being suspended in one thread while already being resumed + * in another. + */ + @Volatile + private var threadLocalIsSet = false + + init { + /* + * This is a hack for a very specific case in #2930 unless #3253 is implemented. + * 'ThreadLocalStressTest' covers this change properly. + * + * The scenario this change covers is the following: + * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function, + * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking + * `withContext(tlElement)` which creates `UndispatchedCoroutine`. + * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()` + * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both + * do thread context element tracking. + * 3) So thread locals never got chance to get properly set up via `saveThreadContext`, + * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`. + * + * Here we detect precisely this situation and properly setup context to recover later. + * + */ + if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) { + /* + * We cannot just "read" the elements as there is no such API, + * so we update-restore it immediately and use the intermediate value + * as the initial state, leveraging the fact that thread context element + * is idempotent and such situations are increasingly rare. + */ + val values = updateThreadContext(context, null) + restoreThreadContext(context, values) + saveThreadContext(context, values) + } + } + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + threadLocalIsSet = true // Specify that thread-local is touched at all + threadStateToRecover.set(context to oldValue) + } + + fun clearThreadContext(): Boolean { + return !(threadLocalIsSet && threadStateToRecover.get() == null).also { + threadStateToRecover.remove() + } + } + + override fun afterResume(state: Any?) { + if (threadLocalIsSet) { + threadStateToRecover.get()?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + } + threadStateToRecover.remove() + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } } + +private class ThreadLocal(private val key: Any) { + @Suppress("UNCHECKED_CAST") + fun get(): T? = ThreadLocalMap[key] as? T + fun set(value: T) { ThreadLocalMap[key] = value } + fun remove() { ThreadLocalMap.remove(key) } +} + +@ThreadLocal +private object ThreadLocalMap: MutableMap by mutableMapOf() diff --git a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt index 3f56f99d6c..a7915e43de 100644 --- a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt @@ -1,5 +1,41 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.* import kotlin.coroutines.* -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +// countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements +internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { + @Suppress("NAME_SHADOWING") + val countOrElement = countOrElement ?: threadContextElements(context) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + countOrElement == 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements + countOrElement is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, countOrElement), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = countOrElement as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.restore(context) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} From c40b2927fe4a92764c75769219e3609e4bc7d37f Mon Sep 17 00:00:00 2001 From: Sergei Sysoev Date: Mon, 29 Jul 2024 02:40:36 +0200 Subject: [PATCH 2/2] Add `ThreadContextElement` tests --- .../common/test/ThreadContextElementTest.kt | 85 +++++++++++ ...Test.kt => ThreadContextElementJvmTest.kt} | 97 +++++------- .../test/ThreadContextElementNativeTest.kt | 141 ++++++++++++++++++ 3 files changed, 261 insertions(+), 62 deletions(-) create mode 100644 kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt rename kotlinx-coroutines-core/jvm/test/{ThreadContextElementTest.kt => ThreadContextElementJvmTest.kt} (73%) create mode 100644 kotlinx-coroutines-core/native/test/ThreadContextElementNativeTest.kt diff --git a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt new file mode 100644 index 0000000000..3ed4c14c97 --- /dev/null +++ b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt @@ -0,0 +1,85 @@ +package kotlinx.coroutines + +import kotlinx.coroutines.testing.* +import kotlin.coroutines.* +import kotlin.test.* +import kotlinx.coroutines.internal.* + +class ThreadContextElementTest : TestBase() { + interface TestThreadContextElement : ThreadContextElement { + companion object Key : CoroutineContext.Key + } + + @Test + fun updatesAndRestores() = runTest { + expect(1) + var updateCount = 0 + var restoreCount = 0 + val threadContextElement = object : TestThreadContextElement { + override fun updateThreadContext(context: CoroutineContext): Int { + updateCount++ + return 0 + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: Int) { + restoreCount++ + } + + override val key: CoroutineContext.Key<*> + get() = TestThreadContextElement.Key + } + launch(Dispatchers.Unconfined + threadContextElement) { + assertEquals(1, updateCount) + assertEquals(0, restoreCount) + } + assertEquals(1, updateCount) + assertEquals(1, restoreCount) + finish(2) + } + + @Test + fun testUndispatched() = runTest { + val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! + val data = MyData() + val element = MyElement(data) + val job = GlobalScope.launch( + context = Dispatchers.Default + exceptionHandler + element, + start = CoroutineStart.UNDISPATCHED + ) { + assertSame(data, threadContextElementThreadLocal.get()) + yield() + assertSame(data, threadContextElementThreadLocal.get()) + } + assertNull(threadContextElementThreadLocal.get()) + job.join() + assertNull(threadContextElementThreadLocal.get()) + } +} + +internal class MyData + +// declare thread local variable holding MyData +internal val threadContextElementThreadLocal = commonThreadLocal(Symbol("ThreadContextElementTest")) + +// declare context element holding MyData +internal class MyElement(val data: MyData) : ThreadContextElement { + // declare companion object for a key of this element in coroutine context + companion object Key : CoroutineContext.Key + + // provide the key of the corresponding context element + override val key: CoroutineContext.Key + get() = Key + + // this is invoked before coroutine is resumed on current thread + override fun updateThreadContext(context: CoroutineContext): MyData? { + val oldState = threadContextElementThreadLocal.get() + threadContextElementThreadLocal.set(data) + return oldState + } + + // this is invoked after coroutine has suspended on current thread + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { + threadContextElementThreadLocal.set(oldState) + } +} + diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt similarity index 73% rename from kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt rename to kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt index 3b106c440d..94a25edee1 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt @@ -6,7 +6,7 @@ import kotlin.coroutines.* import kotlin.test.* import kotlinx.coroutines.flow.* -class ThreadContextElementTest : TestBase() { +class ThreadContextElementJvmTest : TestBase() { @Test fun testExample() = runTest { @@ -15,23 +15,23 @@ class ThreadContextElementTest : TestBase() { val mainThread = Thread.currentThread() val data = MyData() val element = MyElement(data) - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) val job = GlobalScope.launch(element + exceptionHandler) { assertTrue(mainThread != Thread.currentThread()) assertSame(element, coroutineContext[MyElement]) - assertSame(data, myThreadLocal.get()) + assertSame(data, threadContextElementThreadLocal.get()) withContext(mainDispatcher) { assertSame(mainThread, Thread.currentThread()) assertSame(element, coroutineContext[MyElement]) - assertSame(data, myThreadLocal.get()) + assertSame(data, threadContextElementThreadLocal.get()) } assertTrue(mainThread != Thread.currentThread()) assertSame(element, coroutineContext[MyElement]) - assertSame(data, myThreadLocal.get()) + assertSame(data, threadContextElementThreadLocal.get()) } - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) job.join() - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) } @Test @@ -43,13 +43,13 @@ class ThreadContextElementTest : TestBase() { context = Dispatchers.Default + exceptionHandler + element, start = CoroutineStart.UNDISPATCHED ) { - assertSame(data, myThreadLocal.get()) + assertSame(data, threadContextElementThreadLocal.get()) yield() - assertSame(data, myThreadLocal.get()) + assertSame(data, threadContextElementThreadLocal.get()) } - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) job.join() - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) } @Test @@ -58,22 +58,22 @@ class ThreadContextElementTest : TestBase() { newSingleThreadContext("withContext").use { val data = MyData() GlobalScope.async(Dispatchers.Default + MyElement(data)) { - assertSame(data, myThreadLocal.get()) + assertSame(data, threadContextElementThreadLocal.get()) expect(2) val newData = MyData() GlobalScope.async(it + MyElement(newData)) { - assertSame(newData, myThreadLocal.get()) + assertSame(newData, threadContextElementThreadLocal.get()) expect(3) }.await() withContext(it + MyElement(newData)) { - assertSame(newData, myThreadLocal.get()) + assertSame(newData, threadContextElementThreadLocal.get()) expect(4) } GlobalScope.async(it) { - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) expect(5) }.await() @@ -126,31 +126,31 @@ class ThreadContextElementTest : TestBase() { newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { withContext(it + CopyForChildCoroutineElement(MyData())) { val forBlockData = MyData() - myThreadLocal.setForBlock(forBlockData) { - assertSame(myThreadLocal.get(), forBlockData) + threadContextElementThreadLocal.setForBlock(forBlockData) { + assertSame(threadContextElementThreadLocal.get(), forBlockData) launch { - assertSame(myThreadLocal.get(), forBlockData) + assertSame(threadContextElementThreadLocal.get(), forBlockData) } launch { - assertSame(myThreadLocal.get(), forBlockData) + assertSame(threadContextElementThreadLocal.get(), forBlockData) // Modify value in child coroutine. Writes to the ThreadLocal and // the (copied) ThreadLocalElement's memory are not visible to peer or // ancestor coroutines, so this write is both threadsafe and coroutinesafe. val innerCoroutineData = MyData() - myThreadLocal.setForBlock(innerCoroutineData) { - assertSame(myThreadLocal.get(), innerCoroutineData) + threadContextElementThreadLocal.setForBlock(innerCoroutineData) { + assertSame(threadContextElementThreadLocal.get(), innerCoroutineData) } - assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored. + assertSame(threadContextElementThreadLocal.get(), forBlockData) // Asserts value was restored. } launch { val innerCoroutineData = MyData() - myThreadLocal.setForBlock(innerCoroutineData) { - assertSame(myThreadLocal.get(), innerCoroutineData) + threadContextElementThreadLocal.setForBlock(innerCoroutineData) { + assertSame(threadContextElementThreadLocal.get(), innerCoroutineData) } - assertSame(myThreadLocal.get(), forBlockData) + assertSame(threadContextElementThreadLocal.get(), forBlockData) } } - assertNull(myThreadLocal.get()) // Asserts value was restored to its origin + assertNull(threadContextElementThreadLocal.get()) // Asserts value was restored to its origin } } } @@ -193,58 +193,31 @@ class ThreadContextElementTest : TestBase() { @Test fun testThreadLocalFlowOn() = runTest { val myData = MyData() - myThreadLocal.set(myData) + threadContextElementThreadLocal.set(myData) expect(1) flow { - assertEquals(myData, myThreadLocal.get()) + assertEquals(myData, threadContextElementThreadLocal.get()) emit(1) } - .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default) + .flowOn(threadContextElementThreadLocal.asContextElement() + Dispatchers.Default) .single() - myThreadLocal.set(null) + threadContextElementThreadLocal.set(null) finish(2) } } -class MyData - -// declare thread local variable holding MyData -private val myThreadLocal = ThreadLocal() - -// declare context element holding MyData -class MyElement(val data: MyData) : ThreadContextElement { - // declare companion object for a key of this element in coroutine context - companion object Key : CoroutineContext.Key - - // provide the key of the corresponding context element - override val key: CoroutineContext.Key - get() = Key - - // this is invoked before coroutine is resumed on current thread - override fun updateThreadContext(context: CoroutineContext): MyData? { - val oldState = myThreadLocal.get() - myThreadLocal.set(data) - return oldState - } - - // this is invoked after coroutine has suspended on current thread - override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - myThreadLocal.set(oldState) - } -} - /** * A [ThreadContextElement] that implements copy semantics in [copyForChild]. */ -class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { +internal class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { companion object Key : CoroutineContext.Key override val key: CoroutineContext.Key get() = Key override fun updateThreadContext(context: CoroutineContext): MyData? { - val oldState = myThreadLocal.get() - myThreadLocal.set(data) + val oldState = threadContextElementThreadLocal.get() + threadContextElementThreadLocal.set(data) return oldState } @@ -253,7 +226,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle } override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - myThreadLocal.set(oldState) + threadContextElementThreadLocal.set(oldState) } /** @@ -268,7 +241,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle * thread and calls [restoreThreadContext]. */ override fun copyForChild(): CopyForChildCoroutineElement { - return CopyForChildCoroutineElement(myThreadLocal.get()) + return CopyForChildCoroutineElement(threadContextElementThreadLocal.get()) } } diff --git a/kotlinx-coroutines-core/native/test/ThreadContextElementNativeTest.kt b/kotlinx-coroutines-core/native/test/ThreadContextElementNativeTest.kt new file mode 100644 index 0000000000..257519e313 --- /dev/null +++ b/kotlinx-coroutines-core/native/test/ThreadContextElementNativeTest.kt @@ -0,0 +1,141 @@ +package kotlinx.coroutines + +import kotlinx.coroutines.testing.* +import kotlin.coroutines.* +import kotlin.test.* +import kotlin.native.concurrent.* + +class ThreadContextElementNativeTest : TestBase() { + + @OptIn(ObsoleteWorkersApi::class) + @Test + fun testExample() = runTest { + val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! + val mainDispatcher = coroutineContext[ContinuationInterceptor]!! + val mainThread = Worker.current.id + val data = MyData() + val element = MyElement(data) + assertNull(threadContextElementThreadLocal.get()) + val job = GlobalScope.launch(element + exceptionHandler) { + assertTrue(mainThread != Worker.current.id) + assertSame(element, coroutineContext[MyElement]) + assertSame(data, threadContextElementThreadLocal.get()) + withContext(mainDispatcher) { + assertSame(mainThread, Worker.current.id) + assertSame(element, coroutineContext[MyElement]) + assertSame(data, threadContextElementThreadLocal.get()) + } + assertTrue(mainThread != Worker.current.id) + assertSame(element, coroutineContext[MyElement]) + assertSame(data, threadContextElementThreadLocal.get()) + } + assertNull(threadContextElementThreadLocal.get()) + job.join() + assertNull(threadContextElementThreadLocal.get()) + } + + @Test + fun testUndispatched() = runTest { + val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! + val data = MyData() + val element = MyElement(data) + val job = GlobalScope.launch( + context = Dispatchers.Default + exceptionHandler + element, + start = CoroutineStart.UNDISPATCHED + ) { + assertSame(data, threadContextElementThreadLocal.get()) + yield() + assertSame(data, threadContextElementThreadLocal.get()) + } + assertNull(threadContextElementThreadLocal.get()) + job.join() + assertNull(threadContextElementThreadLocal.get()) + } + + @Test + fun testWithContext() = runTest { + expect(1) + newSingleThreadContext("withContext").use { + val data = MyData() + GlobalScope.async(Dispatchers.Default + MyElement(data)) { + assertSame(data, threadContextElementThreadLocal.get()) + expect(2) + + val newData = MyData() + GlobalScope.async(it + MyElement(newData)) { + assertSame(newData, threadContextElementThreadLocal.get()) + expect(3) + }.await() + + withContext(it + MyElement(newData)) { + assertSame(newData, threadContextElementThreadLocal.get()) + expect(4) + } + + GlobalScope.async(it) { + assertNull(threadContextElementThreadLocal.get()) + expect(5) + }.await() + + expect(6) + }.await() + } + + finish(7) + } + + @Test + fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest { + var parentElement: MyElement? = null + var inheritedElement: MyElement? = null + + newSingleThreadContext("withContext").use { + withContext(it + MyElement(MyData())) { + parentElement = coroutineContext[MyElement.Key] + launch { + inheritedElement = coroutineContext[MyElement.Key] + } + } + } + + assertSame(inheritedElement, parentElement, + "Inner and outer coroutines did not have the same object reference to a" + + " ThreadContextElement that did not override `copyForChildCoroutine()`") + } + + class JobCaptor(val capturees: ArrayList = ArrayList()) : ThreadContextElement { + + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key<*> get() = Key + + override fun updateThreadContext(context: CoroutineContext) { + capturees.add(context.job) + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) { + } + } + + @Test + fun testWithContextJobAccess() = runTest { + val captor = JobCaptor() + val manuallyCaptured = ArrayList() + runBlocking(captor) { + manuallyCaptured += coroutineContext.job + withContext(CoroutineName("undispatched")) { + manuallyCaptured += coroutineContext.job + withContext(Dispatchers.IO) { + manuallyCaptured += coroutineContext.job + } + // Context restored, captured again + manuallyCaptured += coroutineContext.job + } + // Context restored, captured again + manuallyCaptured += coroutineContext.job + } + + assertEquals(manuallyCaptured, captor.capturees) + } +} +