Skip to content

Commit

Permalink
Represent Subchain as a List<Continuation> to simplify stack gluing
Browse files Browse the repository at this point in the history
  • Loading branch information
kyay10 committed Jun 8, 2024
1 parent 8db7b99 commit ed66ec0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
46 changes: 40 additions & 6 deletions library/src/commonMain/kotlin/cloning.kt
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
import kotlin.coroutines.Continuation
import kotlin.jvm.JvmInline

internal expect val Continuation<*>.isCompilerGenerated: Boolean
internal expect val Continuation<*>.completion: Continuation<*>
internal expect fun <T> Continuation<T>.copy(completion: Continuation<*>): Continuation<T>

internal fun <T, R> Continuation<T>.collectSubchain(prompt: Prompt<R>): Subchain<T, R> = Subchain(buildList {
this@collectSubchain.forEach {
add(it)
if (it is Hole<*> && it.prompt == prompt) return@buildList
}
})

internal inline fun Continuation<*>.forEach(block: (Continuation<*>) -> Unit) {
var current: Continuation<*> = this
while (true) {
block(current)
current = when (current) {
in CompilerGenerated -> current.completion
is CopyableContinuation -> current.completion
else -> error("Continuation $this is not see-through, so its stack can't be traversed")
}
}
}

// list is a list of continuations from the current continuation to the hole
// The last element is the hole itself, the first element is the current continuation
@Suppress("UNCHECKED_CAST")
internal fun <T, R> Continuation<T>.clone(prompt: Prompt<R>, replacement: Continuation<R>): Continuation<T> =
when {
this is Hole<*> && this.prompt == prompt -> replacement as Continuation<T>
isCompilerGenerated -> copy(completion.clone(prompt, replacement))
this is CopyableContinuation<T> -> copy(completion.clone(prompt, replacement))
else -> error("Continuation $this is not cloneable, but $prompt has not been found in the chain.")
@JvmInline
internal value class Subchain<T, R>(private val list: List<Continuation<*>>) {
fun replace(replacement: Continuation<R>): Continuation<T> {
var result: Continuation<*> = replacement
for (i in list.lastIndex - 1 downTo 0) result = when (val cont = list[i]) {
in CompilerGenerated -> cont.copy(result)
is CopyableContinuation -> cont.copy(result)
else -> error("Continuation $this is not cloneable")
}
return result as Continuation<T>
}

val hole: Hole<R> get() = list.last() as Hole<R>
}

internal object CompilerGenerated {
operator fun contains(cont: Continuation<*>): Boolean = cont.isCompilerGenerated
}

internal interface CopyableContinuation<T> : Continuation<T> {
val completion: Continuation<*>
fun copy(completion: Continuation<*>): CopyableContinuation<T>
Expand Down
18 changes: 10 additions & 8 deletions library/src/commonMain/kotlin/reset.kt
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import kotlinx.coroutines.CancellationException
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
import kotlin.jvm.JvmInline

@Target(AnnotationTarget.CLASS, AnnotationTarget.TYPE, AnnotationTarget.FUNCTION, AnnotationTarget.PROPERTY)
@DslMarker
public annotation class ResetDsl

public data class SubCont<in T, out R> internal constructor(
private val ekFragment: Continuation<T>,
private val prompt: Prompt<R>,
@JvmInline
public value class SubCont<in T, out R> internal constructor(
private val subchain: Subchain<T, R>,
) {
private val prompt get() = subchain.hole.prompt
private fun composedWith(
k: Continuation<R>, isDelimiting: Boolean, extraContext: CoroutineContext, rewindHandler: RewindHandler?
) = ekFragment.clone(prompt, Hole(k, prompt.takeIf { isDelimiting }, extraContext, rewindHandler))
) = subchain.replace(Hole(k, prompt.takeIf { isDelimiting }, extraContext, rewindHandler))

@ResetDsl
public suspend fun pushSubContWith(
Expand Down Expand Up @@ -91,8 +93,6 @@ internal data class Hole<T>(
val extraContext = rewindHandler?.onRewind(extraContext, completion.context) ?: extraContext
return copy(completion = completion as Continuation<T>, extraContext = extraContext)
}

internal fun withoutDelimiter(): Continuation<T> = completion
}

public fun CoroutineContext.promptParentContext(prompt: Prompt<*>): CoroutineContext? =
Expand All @@ -102,14 +102,16 @@ public fun CoroutineContext.promptContext(prompt: Prompt<*>): CoroutineContext?

private fun <T> CoroutineContext.holeFor(prompt: Prompt<T>, deleteDelimiter: Boolean): Continuation<T> {
val hole = this[prompt] ?: error("Prompt $prompt not set")
return if (deleteDelimiter) hole.withoutDelimiter() else hole
return if (deleteDelimiter) hole.completion else hole
}

@ResetDsl
public suspend fun <T, R> Prompt<R>.takeSubCont(
deleteDelimiter: Boolean = true, body: suspend (SubCont<T, R>) -> R
): T = suspendCoroutineUnintercepted { k ->
body.startCoroutine(SubCont(k, this), k.context.holeFor(this, deleteDelimiter))
val subchain = k.collectSubchain(this)
val hole = subchain.hole
body.startCoroutine(SubCont(subchain), if(deleteDelimiter) hole.completion else hole)
}

@Suppress("UNCHECKED_CAST")
Expand Down

0 comments on commit ed66ec0

Please sign in to comment.