Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.agentclientprotocol.common.ClientSessionOperations
import com.agentclientprotocol.common.Event
import com.agentclientprotocol.common.SessionParameters
import com.agentclientprotocol.framework.ProtocolDriver
import com.agentclientprotocol.model.AcpMethod
import com.agentclientprotocol.model.ContentBlock
import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION
import com.agentclientprotocol.model.PermissionOption
Expand All @@ -22,9 +23,11 @@ import com.agentclientprotocol.model.PromptResponse
import com.agentclientprotocol.model.RequestPermissionOutcome
import com.agentclientprotocol.model.RequestPermissionResponse
import com.agentclientprotocol.model.SessionId
import com.agentclientprotocol.model.SessionNotification
import com.agentclientprotocol.model.SessionUpdate
import com.agentclientprotocol.model.StopReason
import com.agentclientprotocol.model.ToolCallId
import com.agentclientprotocol.protocol.invoke
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.awaitCancellation
Expand All @@ -39,6 +42,8 @@ import kotlinx.serialization.json.JsonElement
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.milliseconds

abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver by protocolDriver {
@Test
Expand Down Expand Up @@ -591,5 +596,76 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
assertEquals("Test cancellation", permissionResponseCe.message, "Cancellation exception should be propagated to agent")
}

@Test
fun `long session init on client and consequent session update should be properly handler`() = testWithProtocols { clientProtocol, agentProtocol ->
val notificationDeferred = CompletableDeferred<SessionUpdate>()

val client = Client(protocol = clientProtocol, clientSupport = object : ClientSupport {
override suspend fun createClientSession(
session: ClientSession,
_sessionResponseMeta: JsonElement?,
): ClientSessionOperations {
// long session init
delay(1000.milliseconds)
return object : ClientSessionOperations {
override suspend fun requestPermissions(
toolCall: SessionUpdate.ToolCallUpdate,
permissions: List<PermissionOption>,
_meta: JsonElement?,
): RequestPermissionResponse {
TODO()
}

override suspend fun notify(
notification: SessionUpdate,
_meta: JsonElement?,
) {
notificationDeferred.complete(notification)
}
}
}
})

val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport {
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
return AgentInfo(clientInfo.protocolVersion)
}

override suspend fun createSession(sessionParameters: SessionParameters): AgentSession {
val id = SessionId("test-session-id")
[email protected] {
delay(200.milliseconds)
AcpMethod.ClientMethods.SessionUpdate(agentProtocol, SessionNotification(id, SessionUpdate.AvailableCommandsUpdate(listOf())))
}

return object : AgentSession {
override val sessionId: SessionId = id

override suspend fun prompt(
content: List<ContentBlock>,
_meta: JsonElement?,
): Flow<Event> = flow {
TODO()
}
}
}

override suspend fun loadSession(
sessionId: SessionId,
sessionParameters: SessionParameters,
): AgentSession {
TODO("Not yet implemented")
}
})
client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION))

val session = client.newSession(SessionParameters("/test/path", emptyList()))

val notification = withTimeout(5000.milliseconds) {
notificationDeferred.await()
}
assertTrue(notification is SessionUpdate.AvailableCommandsUpdate)
}


}
15 changes: 11 additions & 4 deletions acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class Client(
private val handlerSideExtensions: List<HandlerSideExtension<*>> = emptyList(),
private val remoteSideExtensions: List<RemoteSideExtension<*>> = emptyList(),
) {
private val _sessions = atomic(persistentMapOf<SessionId, ClientSessionImpl>())
private val _sessions = atomic(persistentMapOf<SessionId, CompletableDeferred<ClientSessionImpl>>())
private val _clientInfo = CompletableDeferred<ClientInfo>()
private val _agentInfo = CompletableDeferred<AgentInfo>()

Expand Down Expand Up @@ -123,6 +123,8 @@ public class Client(
}

private suspend fun createSession(sessionId: SessionId, sessionParameters: SessionParameters, _meta: JsonElement?): ClientSession {
val sessionDeferred = CompletableDeferred<ClientSessionImpl>()
_sessions.update { it.put(sessionId, sessionDeferred) }
val agentInfo = agentInfo
val extensionsMap =
remoteSideExtensions.filter { it.isSupported(agentInfo.capabilities) }.associateBy(keySelector = { it }) {
Expand All @@ -135,11 +137,16 @@ public class Client(
val session = ClientSessionImpl(this, sessionId, sessionParameters, protocol, RemoteSideExtensionInstantiation(extensionsMap)/*, modeState, modelState*/)
val sessionApi = clientSupport.createClientSession(session, _meta)
session.setApi(sessionApi)
_sessions.update { it.put(sessionId, session) }
sessionDeferred.complete(session)
return session
}

public fun getSession(sessionId: SessionId): ClientSession? = _sessions.value[sessionId]
public fun getSession(sessionId: SessionId): ClientSession {
val completableDeferred = _sessions.value[sessionId] ?: error("Session $sessionId not found")
if (!completableDeferred.isCompleted) error("Session $sessionId not initialized yet")
@OptIn(ExperimentalCoroutinesApi::class)
return completableDeferred.getCompleted()
}

private fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl = _sessions.value[sessionId] ?: acpFail("Session $sessionId not found")
private suspend fun getSessionOrThrow(sessionId: SessionId): ClientSessionImpl = (_sessions.value[sessionId] ?: acpFail("Session $sessionId not found")).await()
}
Loading