Skip to content

Commit 67f1925

Browse files
authored
Await for session creation on client side (#11)
Fixes #3
1 parent 6bdaf7f commit 67f1925

File tree

2 files changed

+87
-4
lines changed
  • acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol
  • acp/src/commonMain/kotlin/com/agentclientprotocol/client

2 files changed

+87
-4
lines changed

acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import com.agentclientprotocol.common.ClientSessionOperations
1313
import com.agentclientprotocol.common.Event
1414
import com.agentclientprotocol.common.SessionParameters
1515
import com.agentclientprotocol.framework.ProtocolDriver
16+
import com.agentclientprotocol.model.AcpMethod
1617
import com.agentclientprotocol.model.ContentBlock
1718
import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION
1819
import com.agentclientprotocol.model.PermissionOption
@@ -22,9 +23,11 @@ import com.agentclientprotocol.model.PromptResponse
2223
import com.agentclientprotocol.model.RequestPermissionOutcome
2324
import com.agentclientprotocol.model.RequestPermissionResponse
2425
import com.agentclientprotocol.model.SessionId
26+
import com.agentclientprotocol.model.SessionNotification
2527
import com.agentclientprotocol.model.SessionUpdate
2628
import com.agentclientprotocol.model.StopReason
2729
import com.agentclientprotocol.model.ToolCallId
30+
import com.agentclientprotocol.protocol.invoke
2831
import kotlinx.coroutines.CancellationException
2932
import kotlinx.coroutines.CompletableDeferred
3033
import kotlinx.coroutines.awaitCancellation
@@ -39,6 +42,8 @@ import kotlinx.serialization.json.JsonElement
3942
import kotlin.test.Test
4043
import kotlin.test.assertContentEquals
4144
import kotlin.test.assertEquals
45+
import kotlin.test.assertTrue
46+
import kotlin.time.Duration.Companion.milliseconds
4247

4348
abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver by protocolDriver {
4449
@Test
@@ -591,5 +596,76 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
591596
assertEquals("Test cancellation", permissionResponseCe.message, "Cancellation exception should be propagated to agent")
592597
}
593598

599+
@Test
600+
fun `long session init on client and consequent session update should be properly handler`() = testWithProtocols { clientProtocol, agentProtocol ->
601+
val notificationDeferred = CompletableDeferred<SessionUpdate>()
602+
603+
val client = Client(protocol = clientProtocol, clientSupport = object : ClientSupport {
604+
override suspend fun createClientSession(
605+
session: ClientSession,
606+
_sessionResponseMeta: JsonElement?,
607+
): ClientSessionOperations {
608+
// long session init
609+
delay(1000.milliseconds)
610+
return object : ClientSessionOperations {
611+
override suspend fun requestPermissions(
612+
toolCall: SessionUpdate.ToolCallUpdate,
613+
permissions: List<PermissionOption>,
614+
_meta: JsonElement?,
615+
): RequestPermissionResponse {
616+
TODO()
617+
}
618+
619+
override suspend fun notify(
620+
notification: SessionUpdate,
621+
_meta: JsonElement?,
622+
) {
623+
notificationDeferred.complete(notification)
624+
}
625+
}
626+
}
627+
})
628+
629+
val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport {
630+
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
631+
return AgentInfo(clientInfo.protocolVersion)
632+
}
633+
634+
override suspend fun createSession(sessionParameters: SessionParameters): AgentSession {
635+
val id = SessionId("test-session-id")
636+
this@testWithProtocols.launch {
637+
delay(200.milliseconds)
638+
AcpMethod.ClientMethods.SessionUpdate(agentProtocol, SessionNotification(id, SessionUpdate.AvailableCommandsUpdate(listOf())))
639+
}
640+
641+
return object : AgentSession {
642+
override val sessionId: SessionId = id
643+
644+
override suspend fun prompt(
645+
content: List<ContentBlock>,
646+
_meta: JsonElement?,
647+
): Flow<Event> = flow {
648+
TODO()
649+
}
650+
}
651+
}
652+
653+
override suspend fun loadSession(
654+
sessionId: SessionId,
655+
sessionParameters: SessionParameters,
656+
): AgentSession {
657+
TODO("Not yet implemented")
658+
}
659+
})
660+
client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION))
661+
662+
val session = client.newSession(SessionParameters("/test/path", emptyList()))
663+
664+
val notification = withTimeout(5000.milliseconds) {
665+
notificationDeferred.await()
666+
}
667+
assertTrue(notification is SessionUpdate.AvailableCommandsUpdate)
668+
}
669+
594670

595671
}

acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class Client(
3939
private val handlerSideExtensions: List<HandlerSideExtension<*>> = emptyList(),
4040
private val remoteSideExtensions: List<RemoteSideExtension<*>> = emptyList(),
4141
) {
42-
private val _sessions = atomic(persistentMapOf<SessionId, ClientSessionImpl>())
42+
private val _sessions = atomic(persistentMapOf<SessionId, CompletableDeferred<ClientSessionImpl>>())
4343
private val _clientInfo = CompletableDeferred<ClientInfo>()
4444
private val _agentInfo = CompletableDeferred<AgentInfo>()
4545

@@ -123,6 +123,8 @@ public class Client(
123123
}
124124

125125
private suspend fun createSession(sessionId: SessionId, sessionParameters: SessionParameters, _meta: JsonElement?): ClientSession {
126+
val sessionDeferred = CompletableDeferred<ClientSessionImpl>()
127+
_sessions.update { it.put(sessionId, sessionDeferred) }
126128
val agentInfo = agentInfo
127129
val extensionsMap =
128130
remoteSideExtensions.filter { it.isSupported(agentInfo.capabilities) }.associateBy(keySelector = { it }) {
@@ -135,11 +137,16 @@ public class Client(
135137
val session = ClientSessionImpl(this, sessionId, sessionParameters, protocol, RemoteSideExtensionInstantiation(extensionsMap)/*, modeState, modelState*/)
136138
val sessionApi = clientSupport.createClientSession(session, _meta)
137139
session.setApi(sessionApi)
138-
_sessions.update { it.put(sessionId, session) }
140+
sessionDeferred.complete(session)
139141
return session
140142
}
141143

142-
public fun getSession(sessionId: SessionId): ClientSession? = _sessions.value[sessionId]
144+
public fun getSession(sessionId: SessionId): ClientSession {
145+
val completableDeferred = _sessions.value[sessionId] ?: error("Session $sessionId not found")
146+
if (!completableDeferred.isCompleted) error("Session $sessionId not initialized yet")
147+
@OptIn(ExperimentalCoroutinesApi::class)
148+
return completableDeferred.getCompleted()
149+
}
143150

144-
private fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl = _sessions.value[sessionId] ?: acpFail("Session $sessionId not found")
151+
private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl = (_sessions.value[sessionId] ?: acpFail("Session $sessionId not found")).await()
145152
}

0 commit comments

Comments
 (0)