diff --git a/acp-ktor-client/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorClientExtensions.kt b/acp-ktor-client/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorClientExtensions.kt index 5f83fdc..e5faa4d 100644 --- a/acp-ktor-client/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorClientExtensions.kt +++ b/acp-ktor-client/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorClientExtensions.kt @@ -4,14 +4,18 @@ import com.agentclientprotocol.protocol.Protocol import com.agentclientprotocol.protocol.ProtocolOptions import io.ktor.client.* import io.ktor.client.plugins.websocket.* -import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.* /** * Create a new [Protocol] on a websocket via [HttpClient]. * * The protocol should be started manually by the calling site. */ -public suspend fun HttpClient.acpProtocolOnClientWebSocket(url: String = ACP_PATH, protocolOptions: ProtocolOptions, requestBuilder: HttpRequestBuilder.() -> Unit = {}): Protocol { +public suspend fun HttpClient.acpProtocolOnClientWebSocket( + url: String = ACP_PATH, + protocolOptions: ProtocolOptions, + requestBuilder: HttpRequestBuilder.() -> Unit = {} +): Protocol { val webSocketSession = webSocketSession(urlString = url, block = requestBuilder) val webSocketTransport = WebSocketTransport(parentScope = webSocketSession, wss = webSocketSession) val protocol = Protocol(parentScope = webSocketSession, transport = webSocketTransport, options = protocolOptions) diff --git a/acp-ktor-server/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorServerExtensions.kt b/acp-ktor-server/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorServerExtensions.kt index e859ebf..b60c4ec 100644 --- a/acp-ktor-server/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorServerExtensions.kt +++ b/acp-ktor-server/src/commonMain/kotlin/com/agentclientprotocol/transport/AcpKtorServerExtensions.kt @@ -14,7 +14,11 @@ import kotlinx.coroutines.awaitCancellation * When [block] exits the websocket connection is closed. */ @KtorDsl -public fun Route.acpProtocolOnServerWebSocket(path: String = ACP_PATH, protocolOptions: ProtocolOptions, block: suspend (Protocol) -> Unit) { +public fun Route.acpProtocolOnServerWebSocket( + path: String = ACP_PATH, + protocolOptions: ProtocolOptions, + block: suspend (Protocol) -> Unit +) { webSocket(path) { val webSocketTransport = WebSocketTransport(parentScope = this, wss = this) val protocol = Protocol(parentScope = this, transport = webSocketTransport, options = protocolOptions) @@ -29,7 +33,12 @@ public fun Route.acpProtocolOnServerWebSocket(path: String = ACP_PATH, protocolO * When [block] exits the websocket connection is closed. */ @KtorDsl -public fun Application.acpProtocolOnServerWebSocket(path: String = ACP_PATH, protocolOptions: ProtocolOptions, withAuth: (Route.(Route.() -> Unit) -> Unit)?, block: suspend (Protocol) -> Unit) { +public fun Application.acpProtocolOnServerWebSocket( + path: String = ACP_PATH, + protocolOptions: ProtocolOptions, + withAuth: (Route.(Route.() -> Unit) -> Unit)?, + block: suspend (Protocol) -> Unit +) { routing { if (withAuth != null) { withAuth { diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/FeaturesTest.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/FeaturesTest.kt index d80f757..c0db6e3 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/FeaturesTest.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/FeaturesTest.kt @@ -1,12 +1,12 @@ package com.agentclientprotocol -import com.agentclientprotocol.agent.Agent -import com.agentclientprotocol.agent.AgentInfo -import com.agentclientprotocol.agent.AgentSession -import com.agentclientprotocol.agent.AgentSupport -import com.agentclientprotocol.agent.client -import com.agentclientprotocol.client.* -import com.agentclientprotocol.common.* +import com.agentclientprotocol.agent.* +import com.agentclientprotocol.client.Client +import com.agentclientprotocol.client.ClientInfo +import com.agentclientprotocol.common.ClientSessionOperations +import com.agentclientprotocol.common.Event +import com.agentclientprotocol.common.FileSystemOperations +import com.agentclientprotocol.common.SessionCreationParameters import com.agentclientprotocol.framework.ProtocolDriver import com.agentclientprotocol.model.* import kotlinx.coroutines.currentCoroutineContext @@ -15,9 +15,12 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.flow import kotlinx.serialization.json.JsonElement -import kotlin.test.* +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertTrue -open class TestClientSessionOperations(): ClientSessionOperations { +open class TestClientSessionOperations : ClientSessionOperations { override suspend fun requestPermissions( toolCall: SessionUpdate.ToolCallUpdate, permissions: List, @@ -42,7 +45,10 @@ open class TestAgentSession(override val sessionId: SessionId = SessionId("test- } } -class TestAgentSupport(val capabilities: AgentCapabilities = AgentCapabilities(), val createSessionFunc: suspend (SessionCreationParameters) -> AgentSession) : AgentSupport { +class TestAgentSupport( + val capabilities: AgentCapabilities = AgentCapabilities(), + val createSessionFunc: suspend (SessionCreationParameters) -> AgentSession +) : AgentSupport { val agentInitialized = kotlinx.coroutines.CompletableDeferred() override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { @@ -121,7 +127,10 @@ abstract class FeaturesTest(protocolDriver: ProtocolDriver) : ProtocolDriver by fun `change mode from client`() = testWithProtocols { clientProtocol, agentProtocol -> val client = Client(protocol = clientProtocol) - val modes = listOf(SessionMode(SessionModeId("ask"), "Ask mode", "Only conversations"), SessionMode(SessionModeId("Code"), "Coding mode", "Writes code")) + val modes = listOf( + SessionMode(SessionModeId("ask"), "Ask mode", "Only conversations"), + SessionMode(SessionModeId("Code"), "Coding mode", "Writes code") + ) val agentSupport = TestAgentSupport { object : TestAgentSession() { @@ -191,7 +200,6 @@ abstract class FeaturesTest(protocolDriver: ProtocolDriver) : ProtocolDriver by } - // @Test // fun `call agent extension from client`(): TestResult = testWithProtocols { clientProtocol, agentProtocol -> // diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/ProtocolTest.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/ProtocolTest.kt index 8298996..54ff95c 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/ProtocolTest.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/ProtocolTest.kt @@ -4,37 +4,22 @@ import com.agentclientprotocol.framework.ProtocolDriver import com.agentclientprotocol.model.AcpMethod import com.agentclientprotocol.model.AcpRequest import com.agentclientprotocol.model.AcpResponse -import com.agentclientprotocol.protocol.AcpExpectedError -import com.agentclientprotocol.protocol.JsonRpcException -import com.agentclientprotocol.protocol.acpFail -import com.agentclientprotocol.protocol.sendRequest -import com.agentclientprotocol.protocol.setRequestHandler +import com.agentclientprotocol.protocol.* import com.agentclientprotocol.rpc.JsonRpcErrorCode -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.NonCancellable -import kotlinx.coroutines.TimeoutCancellationException -import kotlinx.coroutines.awaitCancellation -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch +import kotlinx.coroutines.* import kotlinx.coroutines.test.TestResult -import kotlinx.coroutines.withContext -import kotlinx.coroutines.withTimeout -import kotlinx.coroutines.withTimeoutOrNull import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.put import kotlin.coroutines.cancellation.CancellationException -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue -import kotlin.test.fail +import kotlin.test.* import kotlin.time.Duration.Companion.milliseconds import kotlin.time.measureTimedValue @Serializable data class TestRequest(val message: String, override val _meta: JsonElement? = null) : AcpRequest + @Serializable data class TestResponse(val message: String, override val _meta: JsonElement? = null) : AcpResponse @@ -42,7 +27,11 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by val cancellationMessage = "Cancelled from test" companion object { - object TestMethod : AcpMethod.AcpRequestResponseMethod("test/testRequest", TestRequest.serializer(), TestResponse.serializer()) + object TestMethod : AcpMethod.AcpRequestResponseMethod( + "test/testRequest", + TestRequest.serializer(), + TestResponse.serializer() + ) } @Test @@ -63,8 +52,7 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by agentProtocol.setRequestHandler(TestMethod) { request -> try { awaitCancellation() - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { agentCeDeferred.complete(ce) throw ce } @@ -72,20 +60,21 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by launch { delay(500) - clientProtocol.cancelPendingOutgoingRequests(kotlinx.coroutines.CancellationException(cancellationMessage)) + clientProtocol.cancelPendingOutgoingRequests( + kotlinx.coroutines.CancellationException( + cancellationMessage + ) + ) } try { val response = withTimeout(2000) { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) } - } - catch (te: TimeoutCancellationException) { + } catch (te: TimeoutCancellationException) { fail("Request should be cancelled explicitly and not timed out") - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { // expected assertEquals(cancellationMessage, ce.message, "Cancellation exception should be propagated to client") - } - catch (e: Exception) { + } catch (e: Exception) { fail("Unexpected exception: ${e.message}", e) } val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } @@ -95,139 +84,139 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by } @Test - fun `request cancelled from client by coroutine cancel should be cancelled on agent`() = testWithProtocols { clientProtocol, agentProtocol -> - val agentCeDeferred = CompletableDeferred() - agentProtocol.setRequestHandler(TestMethod) { request -> - try { - awaitCancellation() - } - catch (ce: CancellationException) { - agentCeDeferred.complete(ce) - throw ce + fun `request cancelled from client by coroutine cancel should be cancelled on agent`() = + testWithProtocols { clientProtocol, agentProtocol -> + val agentCeDeferred = CompletableDeferred() + agentProtocol.setRequestHandler(TestMethod) { request -> + try { + awaitCancellation() + } catch (ce: CancellationException) { + agentCeDeferred.complete(ce) + throw ce + } } - } - val requestJob = launch { - clientProtocol.sendRequest(TestMethod, TestRequest("Test")) - } + val requestJob = launch { + clientProtocol.sendRequest(TestMethod, TestRequest("Test")) + } - delay(500) - requestJob.cancel(kotlinx.coroutines.CancellationException(cancellationMessage)) + delay(500) + requestJob.cancel(kotlinx.coroutines.CancellationException(cancellationMessage)) - val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } - assertNotNull(agentCe, "Cancellation exception should be propagated to agent") - assertEquals(cancellationMessage, agentCe.message, "Cancellation exception should be propagated to agent") - } + val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } + assertNotNull(agentCe, "Cancellation exception should be propagated to agent") + assertEquals(cancellationMessage, agentCe.message, "Cancellation exception should be propagated to agent") + } @Test - fun `request cancelled from client by coroutine cancel should wait for graceful cancellation`() = testWithProtocols { clientProtocol, agentProtocol -> - val agentCeDeferred = CompletableDeferred() - agentProtocol.setRequestHandler(TestMethod) { request -> - try { - awaitCancellation() - } - catch (ce: CancellationException) { - withContext(NonCancellable) { - // Wait for graceful cancellation - delay(900) // less than protocol graceful cancellation timeout - agentCeDeferred.complete(ce) + fun `request cancelled from client by coroutine cancel should wait for graceful cancellation`() = + testWithProtocols { clientProtocol, agentProtocol -> + val agentCeDeferred = CompletableDeferred() + agentProtocol.setRequestHandler(TestMethod) { request -> + try { + awaitCancellation() + } catch (ce: CancellationException) { + withContext(NonCancellable) { + // Wait for graceful cancellation + delay(900) // less than protocol graceful cancellation timeout + agentCeDeferred.complete(ce) + } + throw ce } - throw ce } - } - val clientRequestCeDeferred = CompletableDeferred() - val requestJob = launch { - try { - clientProtocol.sendRequest(TestMethod, TestRequest("Test")) - } - catch (ce: CancellationException) { - clientRequestCeDeferred.complete(ce) - throw ce + val clientRequestCeDeferred = CompletableDeferred() + val requestJob = launch { + try { + clientProtocol.sendRequest(TestMethod, TestRequest("Test")) + } catch (ce: CancellationException) { + clientRequestCeDeferred.complete(ce) + throw ce + } } - } - delay(500) - requestJob.cancel(kotlinx.coroutines.CancellationException(cancellationMessage)) - - withTimeout(5000) { - val cancellationException = measureTimedValue { clientRequestCeDeferred.await() } - assertEquals(cancellationMessage, cancellationException.value.message, "Cancellation exception should be propagated to client") - assertTrue(cancellationException.duration > 900.milliseconds, "Graceful cancellation should be performed") + delay(500) + requestJob.cancel(kotlinx.coroutines.CancellationException(cancellationMessage)) + + withTimeout(5000) { + val cancellationException = measureTimedValue { clientRequestCeDeferred.await() } + assertEquals( + cancellationMessage, + cancellationException.value.message, + "Cancellation exception should be propagated to client" + ) + assertTrue( + cancellationException.duration > 900.milliseconds, + "Graceful cancellation should be performed" + ) + } } - } @Test - fun `request cancelled from agent by cancelPendingIncomingRequests should be cancelled on client`() = testWithProtocols { clientProtocol, agentProtocol -> - val agentCeDeferred = CompletableDeferred() - agentProtocol.setRequestHandler(TestMethod) { request -> - try { - awaitCancellation() + fun `request cancelled from agent by cancelPendingIncomingRequests should be cancelled on client`() = + testWithProtocols { clientProtocol, agentProtocol -> + val agentCeDeferred = CompletableDeferred() + agentProtocol.setRequestHandler(TestMethod) { request -> + try { + awaitCancellation() + } catch (ce: CancellationException) { + agentCeDeferred.complete(ce) + throw ce + } } - catch (ce: CancellationException) { - agentCeDeferred.complete(ce) - throw ce + + launch { + delay(500) + agentProtocol.cancelPendingIncomingRequests(kotlinx.coroutines.CancellationException(cancellationMessage)) } - } - launch { - delay(500) - agentProtocol.cancelPendingIncomingRequests(kotlinx.coroutines.CancellationException(cancellationMessage)) - } + try { + val response = withTimeout(1000) { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) } + } catch (te: TimeoutCancellationException) { + fail("Request should be cancelled explicitly and not timed out") + } catch (ce: CancellationException) { + //expected + assertEquals(cancellationMessage, ce.message, "Cancellation exception should be propagated to client") + } catch (e: Exception) { + fail("Unexpected exception: ${e.message}", e) + } - try { - val response = withTimeout(1000) { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) } - } - catch (te: TimeoutCancellationException) { - fail("Request should be cancelled explicitly and not timed out") - } - catch (ce: CancellationException) { - //expected - assertEquals(cancellationMessage, ce.message, "Cancellation exception should be propagated to client") - } - catch (e: Exception) { - fail("Unexpected exception: ${e.message}", e) + val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } + assertNotNull(agentCe, "Cancellation exception should be propagated to agent") + assertEquals(cancellationMessage, agentCe.message, "Cancellation exception should be propagated to agent") } - val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } - assertNotNull(agentCe, "Cancellation exception should be propagated to agent") - assertEquals(cancellationMessage, agentCe.message, "Cancellation exception should be propagated to agent") - } - @Test - fun `request cancelled from agent by throwing CE should be cancelled on client`() = testWithProtocols { clientProtocol, agentProtocol -> - val agentCeDeferred = CompletableDeferred() - agentProtocol.setRequestHandler(TestMethod) { request -> - try { - delay(500) - throw kotlinx.coroutines.CancellationException(cancellationMessage) + fun `request cancelled from agent by throwing CE should be cancelled on client`() = + testWithProtocols { clientProtocol, agentProtocol -> + val agentCeDeferred = CompletableDeferred() + agentProtocol.setRequestHandler(TestMethod) { request -> + try { + delay(500) + throw kotlinx.coroutines.CancellationException(cancellationMessage) + } catch (ce: CancellationException) { + agentCeDeferred.complete(ce) + throw ce + } } - catch (ce: CancellationException) { - agentCeDeferred.complete(ce) - throw ce + + try { + val response = withTimeout(1000) { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) } + } catch (te: TimeoutCancellationException) { + fail("Request should be cancelled explicitly and not timed out") + } catch (ce: CancellationException) { + //expected + assertEquals(cancellationMessage, ce.message, "Cancellation exception should be propagated to client") + } catch (e: Exception) { + fail("Unexpected exception: ${e.message}", e) } - } - try { - val response = withTimeout(1000) { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) } - } - catch (te: TimeoutCancellationException) { - fail("Request should be cancelled explicitly and not timed out") - } - catch (ce: CancellationException) { - //expected - assertEquals(cancellationMessage, ce.message, "Cancellation exception should be propagated to client") - } - catch (e: Exception) { - fail("Unexpected exception: ${e.message}", e) + val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } + assertNotNull(agentCe, "Cancellation exception should be propagated to agent") + assertEquals(cancellationMessage, agentCe.message, "Cancellation exception should be propagated to agent") } - val agentCe = withTimeoutOrNull(1000) { agentCeDeferred.await() } - assertNotNull(agentCe, "Cancellation exception should be propagated to agent") - assertEquals(cancellationMessage, agentCe.message, "Cancellation exception should be propagated to agent") - } - @Test fun `error is propagated to client (INTERNAL_ERROR)`() = testWithProtocols { clientProtocol, agentProtocol -> val errorMessage = "Test error from handler" @@ -238,8 +227,7 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by try { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) fail("Expected exception to be thrown") - } - catch (e: JsonRpcException) { + } catch (e: JsonRpcException) { assertEquals(errorMessage, e.message, "Error message should be propagated to client") assertEquals(JsonRpcErrorCode.INTERNAL_ERROR.code, e.code, "Error code should be INTERNAL_ERROR") } @@ -255,8 +243,7 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by try { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) fail("Expected JsonRpcException to be thrown") - } - catch (e: AcpExpectedError) { + } catch (e: AcpExpectedError) { assertEquals(errorMessage, e.message, "Error message should be propagated to client") } } @@ -273,11 +260,9 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by put("invalidField", "not a valid TestRequest") }) fail("Expected JsonRpcException to be thrown") - } - catch (e: SerializationException) { + } catch (e: SerializationException) { // expected - } - catch (e: Exception) { + } catch (e: Exception) { fail("Unexpected exception: ${e.message}", e) } } @@ -288,8 +273,7 @@ abstract class ProtocolTest(protocolDriver: ProtocolDriver) : ProtocolDriver by try { clientProtocol.sendRequest(TestMethod, TestRequest("Test")) fail("Expected JsonRpcException to be thrown") - } - catch (e: JsonRpcException) { + } catch (e: JsonRpcException) { assertEquals(JsonRpcErrorCode.METHOD_NOT_FOUND.code, e.code, "Error code should be METHOD_NOT_FOUND") } } diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt index ebad6c9..d1d958d 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt @@ -1,42 +1,18 @@ package com.agentclientprotocol -import com.agentclientprotocol.agent.Agent -import com.agentclientprotocol.agent.AgentInfo -import com.agentclientprotocol.agent.AgentSession -import com.agentclientprotocol.agent.AgentSupport -import com.agentclientprotocol.agent.client +import com.agentclientprotocol.agent.* import com.agentclientprotocol.client.Client import com.agentclientprotocol.client.ClientInfo import com.agentclientprotocol.common.ClientSessionOperations import com.agentclientprotocol.common.Event import com.agentclientprotocol.common.SessionCreationParameters import com.agentclientprotocol.framework.ProtocolDriver -import com.agentclientprotocol.model.AcpMethod -import com.agentclientprotocol.model.ContentBlock -import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION -import com.agentclientprotocol.model.PermissionOption -import com.agentclientprotocol.model.PermissionOptionId -import com.agentclientprotocol.model.PermissionOptionKind -import com.agentclientprotocol.model.PromptResponse -import com.agentclientprotocol.model.RequestPermissionOutcome -import com.agentclientprotocol.model.RequestPermissionResponse -import com.agentclientprotocol.model.SessionId -import com.agentclientprotocol.model.SessionNotification -import com.agentclientprotocol.model.SessionUpdate -import com.agentclientprotocol.model.StopReason -import com.agentclientprotocol.model.ToolCallId +import com.agentclientprotocol.model.* import com.agentclientprotocol.protocol.invoke -import io.github.oshai.kotlinlogging.KotlinLogging.logger -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.awaitCancellation -import kotlinx.coroutines.currentCoroutineContext -import kotlinx.coroutines.delay +import kotlinx.coroutines.* import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.launch -import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonElement import kotlin.test.Test import kotlin.test.assertContentEquals @@ -88,7 +64,15 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver content: List, _meta: JsonElement?, ): Flow = flow { - emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text(sessionParameters.cwd)))) + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text( + sessionParameters.cwd + ) + ) + ) + ) delay(100) for (block in content.filterIsInstance()) { emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(block))) @@ -157,7 +141,15 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver content: List, _meta: JsonElement?, ): Flow = flow { - emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text(sessionParameters.cwd)))) + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text( + sessionParameters.cwd + ) + ) + ) + ) emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 1")))) emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 2")))) emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("text 3")))) @@ -200,12 +192,13 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver newSession.prompt(listOf()).collect { event -> when (event) { is Event.PromptResponseEvent -> { - println( "Received prompt response: ${event.response}" ) + println("Received prompt response: ${event.response}") result = event.response responses.add(event.response.stopReason.toString()) } + is Event.SessionUpdateEvent -> { - println( "Received session update: ${(event.update as SessionUpdate.AgentMessageChunk).content}" ) + println("Received session update: ${(event.update as SessionUpdate.AgentMessageChunk).content}") responses.add(((event.update as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text) } } @@ -236,12 +229,10 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver // forever wait try { awaitCancellation() - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { agentSideCeDeferred.complete(ce) throw ce - } - catch (e: Exception) { + } catch (e: Exception) { agentSideCeDeferred.completeExceptionally(e) throw e } @@ -306,12 +297,10 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver // forever wait try { awaitCancellation() - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { agentSideCeDeferred.complete(ce) throw ce - } - catch (e: Exception) { + } catch (e: Exception) { agentSideCeDeferred.completeExceptionally(e) throw e } @@ -327,23 +316,24 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver } }) client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) - val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> object : ClientSessionOperations { - override suspend fun requestPermissions( - toolCall: SessionUpdate.ToolCallUpdate, - permissions: List, - _meta: JsonElement?, - ): RequestPermissionResponse { - TODO("Not yet implemented") - } + val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + TODO("Not yet implemented") + } - override suspend fun notify( - notification: SessionUpdate, - _meta: JsonElement?, - ) { - TODO("Not yet implemented") + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + TODO("Not yet implemented") + } } } - } val promptJob = launch { session.prompt(listOf(ContentBlock.Text("Test message"))).collect() } @@ -424,6 +414,7 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver when (event) { is Event.PromptResponseEvent -> { } + is Event.SessionUpdateEvent -> responses.add(((event.update as SessionUpdate.AgentMessageChunk).content as ContentBlock.Text).text) } @@ -433,253 +424,265 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver } @Test - fun `permission request should be cancelled by prompt cancellation on client`() = testWithProtocols { clientProtocol, agentProtocol -> - val permissionResponseCeDeferred = CompletableDeferred() - val client = Client(protocol = clientProtocol) - - val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { - override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { - return AgentInfo(clientInfo.protocolVersion) - } - - override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { - return object : AgentSession { - override val sessionId: SessionId = SessionId("test-session-id") + fun `permission request should be cancelled by prompt cancellation on client`() = + testWithProtocols { clientProtocol, agentProtocol -> + val permissionResponseCeDeferred = CompletableDeferred() + val client = Client(protocol = clientProtocol) + + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } - override suspend fun prompt( - content: List, - _meta: JsonElement?, - ): Flow = flow { - try { - val permissionResponse = currentCoroutineContext().client.requestPermissions( - SessionUpdate.ToolCallUpdate(toolCallId = ToolCallId("tool-id")), listOf( - PermissionOption( - optionId = PermissionOptionId("approve"), - name = "Approve", - kind = PermissionOptionKind.ALLOW_ONCE - ), - PermissionOption( - optionId = PermissionOptionId("reject"), - name = "Reject", - kind = PermissionOptionKind.REJECT_ONCE + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + return object : AgentSession { + override val sessionId: SessionId = SessionId("test-session-id") + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = flow { + try { + val permissionResponse = currentCoroutineContext().client.requestPermissions( + SessionUpdate.ToolCallUpdate(toolCallId = ToolCallId("tool-id")), listOf( + PermissionOption( + optionId = PermissionOptionId("approve"), + name = "Approve", + kind = PermissionOptionKind.ALLOW_ONCE + ), + PermissionOption( + optionId = PermissionOptionId("reject"), + name = "Reject", + kind = PermissionOptionKind.REJECT_ONCE + ) ) ) - ) - emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("Permission response: ${permissionResponse.outcome}")))) - } - catch (ce: CancellationException) { - println("Client cancellation exception caught") - throw ce + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("Permission response: ${permissionResponse.outcome}")))) + } catch (ce: CancellationException) { + println("Client cancellation exception caught") + throw ce + } } } } - } - override suspend fun loadSession( - sessionId: SessionId, - sessionParameters: SessionCreationParameters, - ): AgentSession { - TODO("Not yet implemented") - } - }) - client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) - val responses = mutableListOf() - val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> - object : ClientSessionOperations { - override suspend fun requestPermissions( - toolCall: SessionUpdate.ToolCallUpdate, - permissions: List, - _meta: JsonElement?, - ): RequestPermissionResponse { - try { - // wait forever - awaitCancellation() - } catch (ce: CancellationException) { - permissionResponseCeDeferred.complete(ce) - throw ce - } catch (e: Exception) { - permissionResponseCeDeferred.completeExceptionally(e) - throw e - } + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") } + }) + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + val responses = mutableListOf() + val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + try { + // wait forever + awaitCancellation() + } catch (ce: CancellationException) { + permissionResponseCeDeferred.complete(ce) + throw ce + } catch (e: Exception) { + permissionResponseCeDeferred.completeExceptionally(e) + throw e + } + } - override suspend fun notify( - notification: SessionUpdate, - _meta: JsonElement?, - ) { - TODO("Not yet implemented") + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + TODO("Not yet implemented") + } } } - } - val promptJob = launch { - session.prompt(listOf(ContentBlock.Text("Test message"))).collect() - } + val promptJob = launch { + session.prompt(listOf(ContentBlock.Text("Test message"))).collect() + } - delay(500) - promptJob.cancel(CancellationException("Test cancellation")) + delay(500) + promptJob.cancel(CancellationException("Test cancellation")) // val permissionResponseCe = withTimeout(100000) { permissionResponseCeDeferred.await() } - val permissionResponseCe = permissionResponseCeDeferred.await() - assertEquals("Test cancellation", permissionResponseCe.message, "Cancellation exception should be propagated to agent") - } + val permissionResponseCe = permissionResponseCeDeferred.await() + assertEquals( + "Test cancellation", + permissionResponseCe.message, + "Cancellation exception should be propagated to agent" + ) + } @Test - fun `permission request should be cancelled by prompt cancellation on client and wait for graceful cancellation`() = testWithProtocols { clientProtocol, agentProtocol -> - val permissionResponseCeDeferred = CompletableDeferred() - val client = Client(protocol = clientProtocol) - - val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { - override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { - return AgentInfo(clientInfo.protocolVersion) - } - - override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { - return object : AgentSession { - override val sessionId: SessionId = SessionId("test-session-id") + fun `permission request should be cancelled by prompt cancellation on client and wait for graceful cancellation`() = + testWithProtocols { clientProtocol, agentProtocol -> + val permissionResponseCeDeferred = CompletableDeferred() + val client = Client(protocol = clientProtocol) + + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } - override suspend fun prompt( - content: List, - _meta: JsonElement?, - ): Flow = flow { - try { - val permissionResponse = currentCoroutineContext().client.requestPermissions( - SessionUpdate.ToolCallUpdate(toolCallId = ToolCallId("tool-id")), listOf( - PermissionOption( - optionId = PermissionOptionId("approve"), - name = "Approve", - kind = PermissionOptionKind.ALLOW_ONCE - ), - PermissionOption( - optionId = PermissionOptionId("reject"), - name = "Reject", - kind = PermissionOptionKind.REJECT_ONCE + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + return object : AgentSession { + override val sessionId: SessionId = SessionId("test-session-id") + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = flow { + try { + val permissionResponse = currentCoroutineContext().client.requestPermissions( + SessionUpdate.ToolCallUpdate(toolCallId = ToolCallId("tool-id")), listOf( + PermissionOption( + optionId = PermissionOptionId("approve"), + name = "Approve", + kind = PermissionOptionKind.ALLOW_ONCE + ), + PermissionOption( + optionId = PermissionOptionId("reject"), + name = "Reject", + kind = PermissionOptionKind.REJECT_ONCE + ) ) ) - ) - emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("Permission response: ${permissionResponse.outcome}")))) - } - catch (ce: CancellationException) { - println("Client cancellation exception caught") - throw ce + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(ContentBlock.Text("Permission response: ${permissionResponse.outcome}")))) + } catch (ce: CancellationException) { + println("Client cancellation exception caught") + throw ce + } } } } - } - override suspend fun loadSession( - sessionId: SessionId, - sessionParameters: SessionCreationParameters, - ): AgentSession { - TODO("Not yet implemented") - } - }) - client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) - val responses = mutableListOf() - val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> - object : ClientSessionOperations { - override suspend fun requestPermissions( - toolCall: SessionUpdate.ToolCallUpdate, - permissions: List, - _meta: JsonElement?, - ): RequestPermissionResponse { - try { - // wait forever - awaitCancellation() - } catch (ce: CancellationException) { - permissionResponseCeDeferred.complete(ce) - throw ce - } catch (e: Exception) { - permissionResponseCeDeferred.completeExceptionally(e) - throw e - } + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") } + }) + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + val responses = mutableListOf() + val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + try { + // wait forever + awaitCancellation() + } catch (ce: CancellationException) { + permissionResponseCeDeferred.complete(ce) + throw ce + } catch (e: Exception) { + permissionResponseCeDeferred.completeExceptionally(e) + throw e + } + } - override suspend fun notify( - notification: SessionUpdate, - _meta: JsonElement?, - ) { - TODO("Not yet implemented") + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + TODO("Not yet implemented") + } } } - } - val promptJob = launch { - session.prompt(listOf(ContentBlock.Text("Test message"))).collect() - } + val promptJob = launch { + session.prompt(listOf(ContentBlock.Text("Test message"))).collect() + } - delay(500) - promptJob.cancel(CancellationException("Test cancellation")) + delay(500) + promptJob.cancel(CancellationException("Test cancellation")) // val permissionResponseCe = withTimeout(100000) { permissionResponseCeDeferred.await() } - val permissionResponseCe = permissionResponseCeDeferred.await() - assertEquals("Test cancellation", permissionResponseCe.message, "Cancellation exception should be propagated to agent") - } + val permissionResponseCe = permissionResponseCeDeferred.await() + assertEquals( + "Test cancellation", + permissionResponseCe.message, + "Cancellation exception should be propagated to agent" + ) + } @Test - fun `long session init on client and consequent session update should be properly handler`() = testWithProtocols { clientProtocol, agentProtocol -> - val notificationDeferred = CompletableDeferred() + fun `long session init on client and consequent session update should be properly handler`() = + testWithProtocols { clientProtocol, agentProtocol -> + val notificationDeferred = CompletableDeferred() - val client = Client(protocol = clientProtocol) + val client = Client(protocol = clientProtocol) - val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { - override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { - return AgentInfo(clientInfo.protocolVersion) - } - - override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { - val id = SessionId("test-session-id") - this@testWithProtocols.launch { - delay(200.milliseconds) - AcpMethod.ClientMethods.SessionUpdate(agentProtocol, SessionNotification(id, SessionUpdate.AvailableCommandsUpdate(listOf()))) + val agent = Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) } - return object : AgentSession { - override val sessionId: SessionId = id + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + val id = SessionId("test-session-id") + this@testWithProtocols.launch { + delay(200.milliseconds) + AcpMethod.ClientMethods.SessionUpdate( + agentProtocol, + SessionNotification(id, SessionUpdate.AvailableCommandsUpdate(listOf())) + ) + } - override suspend fun prompt( - content: List, - _meta: JsonElement?, - ): Flow = flow { - TODO() + return object : AgentSession { + override val sessionId: SessionId = id + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = flow { + TODO() + } } } - } - override suspend fun loadSession( - sessionId: SessionId, - sessionParameters: SessionCreationParameters, - ): AgentSession { - TODO("Not yet implemented") - } - }) - client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) - - val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> - // long session init - delay(1000.milliseconds) - return@newSession object : ClientSessionOperations { - override suspend fun requestPermissions( - toolCall: SessionUpdate.ToolCallUpdate, - permissions: List, - _meta: JsonElement?, - ): RequestPermissionResponse { - TODO() + override suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + ): AgentSession { + TODO("Not yet implemented") } + }) + client.initialize(ClientInfo(protocolVersion = LATEST_PROTOCOL_VERSION)) + + val session = client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + // long session init + delay(1000.milliseconds) + return@newSession object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + TODO() + } - override suspend fun notify( - notification: SessionUpdate, - _meta: JsonElement?, - ) { - notificationDeferred.complete(notification) + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + notificationDeferred.complete(notification) + } } } - } - val notification = withTimeout(5000.milliseconds) { - notificationDeferred.await() + val notification = withTimeout(5000.milliseconds) { + notificationDeferred.await() + } + assertTrue(notification is SessionUpdate.AvailableCommandsUpdate) } - assertTrue(notification is SessionUpdate.AvailableCommandsUpdate) - } } \ No newline at end of file diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/framework/WebSocketKtorProtocolDriver.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/framework/WebSocketKtorProtocolDriver.kt index 5338cd1..56cb8e5 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/framework/WebSocketKtorProtocolDriver.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/framework/WebSocketKtorProtocolDriver.kt @@ -27,7 +27,8 @@ class WebSocketKtorProtocolDriver : ProtocolDriver { val httpClient = createClient { install(io.ktor.client.plugins.websocket.WebSockets.Plugin) } - val clientProtocol = httpClient.acpProtocolOnClientWebSocket("acp", ProtocolOptions(protocolDebugName = "client protocol")) + val clientProtocol = + httpClient.acpProtocolOnClientWebSocket("acp", ProtocolOptions(protocolDebugName = "client protocol")) val agentProtocol = agentProtocolDeferred.await() agentProtocol.start() clientProtocol.start() diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/utils.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/utils.kt index 3b56580..b7d8040 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/utils.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/utils.kt @@ -25,6 +25,7 @@ suspend fun ClientSession.promptToList(message: String): List { is Event.PromptResponseEvent -> { it.response.stopReason.toString() } + is Event.SessionUpdateEvent -> { when (val update = it.update) { is SessionUpdate.AgentMessageChunk -> update.content.render() diff --git a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioProtocolTest.kt b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioProtocolTest.kt index 0f75fca..88811c1 100644 --- a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioProtocolTest.kt +++ b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioProtocolTest.kt @@ -2,5 +2,4 @@ package com.agentclientprotocol import com.agentclientprotocol.framework.StdioProtocolDriver -class StdioProtocolTest : ProtocolTest(StdioProtocolDriver()) { -} \ No newline at end of file +class StdioProtocolTest : ProtocolTest(StdioProtocolDriver()) \ No newline at end of file diff --git a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioSimpleAgentTest.kt b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioSimpleAgentTest.kt index 6e2c374..cb2c253 100644 --- a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioSimpleAgentTest.kt +++ b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/StdioSimpleAgentTest.kt @@ -2,5 +2,4 @@ package com.agentclientprotocol import com.agentclientprotocol.framework.StdioProtocolDriver -class StdioSimpleAgentTest : SimpleAgentTest(StdioProtocolDriver()) { -} \ No newline at end of file +class StdioSimpleAgentTest : SimpleAgentTest(StdioProtocolDriver()) \ No newline at end of file diff --git a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketProtocolTest.kt b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketProtocolTest.kt index 6dc5060..0384270 100644 --- a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketProtocolTest.kt +++ b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketProtocolTest.kt @@ -2,5 +2,4 @@ package com.agentclientprotocol import com.agentclientprotocol.framework.WebSocketKtorProtocolDriver -class WebSocketProtocolTest : ProtocolTest(WebSocketKtorProtocolDriver()) { -} \ No newline at end of file +class WebSocketProtocolTest : ProtocolTest(WebSocketKtorProtocolDriver()) \ No newline at end of file diff --git a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketSimpleAgentTest.kt b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketSimpleAgentTest.kt index ec455c5..2d2c19e 100644 --- a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketSimpleAgentTest.kt +++ b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/WebSocketSimpleAgentTest.kt @@ -2,5 +2,4 @@ package com.agentclientprotocol import com.agentclientprotocol.framework.WebSocketKtorProtocolDriver -class WebSocketSimpleAgentTest : SimpleAgentTest(WebSocketKtorProtocolDriver()) { -} \ No newline at end of file +class WebSocketSimpleAgentTest : SimpleAgentTest(WebSocketKtorProtocolDriver()) \ No newline at end of file diff --git a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/framework/StdioProtocolDriver.kt b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/framework/StdioProtocolDriver.kt index 064a37c..5692c86 100644 --- a/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/framework/StdioProtocolDriver.kt +++ b/acp-ktor-test/src/jvmTest/kotlin/com/agentclientprotocol/framework/StdioProtocolDriver.kt @@ -3,11 +3,7 @@ package com.agentclientprotocol.framework import com.agentclientprotocol.protocol.Protocol import com.agentclientprotocol.protocol.ProtocolOptions import com.agentclientprotocol.transport.StdioTransport -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.* import kotlinx.coroutines.test.TestResult import kotlinx.io.asSink import kotlinx.io.asSource @@ -37,8 +33,10 @@ class StdioProtocolDriver : ProtocolDriver { "agent" ) - val clientProtocol = Protocol(this, clientTransport, options = ProtocolOptions(protocolDebugName = "client protocol")) - val agentProtocol = Protocol(this, agentTransport, options = ProtocolOptions(protocolDebugName = "agent protocol")) + val clientProtocol = + Protocol(this, clientTransport, options = ProtocolOptions(protocolDebugName = "client protocol")) + val agentProtocol = + Protocol(this, agentTransport, options = ProtocolOptions(protocolDebugName = "agent protocol")) clientProtocol.start() agentProtocol.start() diff --git a/acp-ktor/src/commonMain/kotlin/com/agentclientprotocol/transport/WebSocketTransport.kt b/acp-ktor/src/commonMain/kotlin/com/agentclientprotocol/transport/WebSocketTransport.kt index 77cef5a..f7d88b5 100644 --- a/acp-ktor/src/commonMain/kotlin/com/agentclientprotocol/transport/WebSocketTransport.kt +++ b/acp-ktor/src/commonMain/kotlin/com/agentclientprotocol/transport/WebSocketTransport.kt @@ -1,8 +1,8 @@ package com.agentclientprotocol.transport import com.agentclientprotocol.rpc.ACPJson -import com.agentclientprotocol.rpc.decodeJsonRpcMessage import com.agentclientprotocol.rpc.JsonRpcMessage +import com.agentclientprotocol.rpc.decodeJsonRpcMessage import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.websocket.* import kotlinx.coroutines.CoroutineScope @@ -18,7 +18,8 @@ private val logger = KotlinLogging.logger {} public const val ACP_PATH: String = "acp" -public class WebSocketTransport(private val parentScope: CoroutineScope, private val wss: WebSocketSession) : BaseTransport() { +public class WebSocketTransport(private val parentScope: CoroutineScope, private val wss: WebSocketSession) : + BaseTransport() { private val scope = CoroutineScope(parentScope.coroutineContext + SupervisorJob(parentScope.coroutineContext[Job])) private val sendChannel = Channel(Channel.UNLIMITED) @@ -29,7 +30,7 @@ public class WebSocketTransport(private val parentScope: CoroutineScope, private val jsonText = try { ACPJson.encodeToString(message) } catch (e: SerializationException) { - logger.trace(e) { "Failed to serialize message: ${message}" } + logger.trace(e) { "Failed to serialize message: $message" } fireError(e) continue } @@ -41,12 +42,10 @@ public class WebSocketTransport(private val parentScope: CoroutineScope, private logger.trace { "No more messages in channel, closing connection" } wss.close() wss.flush() - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { logger.trace(ce) { "Send job cancelled" } wss.close(CloseReason(CloseReason.Codes.NORMAL, "Cancelled")) - } - catch (e: Throwable) { + } catch (e: Throwable) { logger.trace(e) { "Failed to send message to channel" } fireError(e) wss.close(CloseReason(CloseReason.Codes.INTERNAL_ERROR, e.message ?: "Internal error")) @@ -75,16 +74,13 @@ public class WebSocketTransport(private val parentScope: CoroutineScope, private } } } - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { logger.trace(ce) { "Receive job cancelled" } wss.close(CloseReason(CloseReason.Codes.NORMAL, "Cancelled")) - } - catch (e: Throwable) { + } catch (e: Throwable) { logger.trace(e) { "Failed to receive message from channel" } fireError(e) - } - finally { + } finally { close() } logger.trace { "Exiting read job..." } diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/annotations/UnstableApi.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/annotations/UnstableApi.kt index 9caf2b5..410a88f 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/annotations/UnstableApi.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/annotations/UnstableApi.kt @@ -5,4 +5,4 @@ package com.agentclientprotocol.annotations level = RequiresOptIn.Level.WARNING ) @Retention(AnnotationRetention.BINARY) -public annotation class UnstableApi() \ No newline at end of file +public annotation class UnstableApi \ No newline at end of file diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/AcpCreatedSessionResponse.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/AcpCreatedSessionResponse.kt index 91dcccf..32d224c 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/AcpCreatedSessionResponse.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/AcpCreatedSessionResponse.kt @@ -4,6 +4,7 @@ import com.agentclientprotocol.annotations.UnstableApi public interface AcpCreatedSessionResponse : AcpWithMeta { public val modes: SessionModeState? + @UnstableApi public val models: SessionModelState? } \ No newline at end of file diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Capabilities.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Capabilities.kt index a06c3e7..6f05f4f 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Capabilities.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Capabilities.kt @@ -1,4 +1,3 @@ -@file:Suppress("unused") @file:OptIn(ExperimentalSerializationApi::class) package com.agentclientprotocol.model diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Content.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Content.kt index 05ab01b..1eda1f2 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Content.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Content.kt @@ -1,4 +1,3 @@ -@file:Suppress("unused") @file:OptIn(ExperimentalSerializationApi::class) package com.agentclientprotocol.model @@ -21,7 +20,7 @@ import kotlinx.serialization.json.JsonElement @JsonClassDiscriminator("type") public sealed class ContentBlock : AcpWithMeta { public abstract val annotations: Annotations? - + /** * Plain text content * diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Methods.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Methods.kt index e39fb2a..32edd21 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Methods.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Methods.kt @@ -1,11 +1,8 @@ -@file:Suppress("unused") - package com.agentclientprotocol.model import com.agentclientprotocol.annotations.UnstableApi import com.agentclientprotocol.rpc.MethodName import kotlinx.serialization.KSerializer -import kotlinx.serialization.builtins.serializer /** * Base interface for ACP method enums. @@ -14,60 +11,143 @@ import kotlinx.serialization.builtins.serializer */ public open class AcpMethod(public val methodName: MethodName) { - public open class AcpRequestResponseMethod( + public open class AcpRequestResponseMethod( method: String, public val requestSerializer: KSerializer, public val responseSerializer: KSerializer ) : AcpMethod(MethodName(method)) - public open class AcpSessionRequestResponseMethod(method: String, - requestSerializer: KSerializer, - responseSerializer: KSerializer + public open class AcpSessionRequestResponseMethod( + method: String, + requestSerializer: KSerializer, + responseSerializer: KSerializer ) : AcpRequestResponseMethod(method, requestSerializer, responseSerializer) where TRequest : AcpRequest, TRequest : AcpWithSessionId - public open class AcpNotificationMethod( + public open class AcpNotificationMethod( method: String, public val serializer: KSerializer, ) : AcpMethod(MethodName(method)) - public open class AcpSessionNotificationMethod(method: String, - serializer: KSerializer + public open class AcpSessionNotificationMethod( + method: String, + serializer: KSerializer ) : AcpNotificationMethod(method, serializer) where TNotification : AcpNotification, TNotification : AcpWithSessionId public object MetaMethods { - public object CancelRequest : AcpNotificationMethod("\$/cancelRequest", CancelRequestNotification.serializer()) + public object CancelRequest : + AcpNotificationMethod("\$/cancelRequest", CancelRequestNotification.serializer()) } public object AgentMethods { // Agent-side operations (methods that agents can call on clients) - public object Initialize : AcpRequestResponseMethod("initialize", InitializeRequest.serializer(), InitializeResponse.serializer()) - public object Authenticate : AcpRequestResponseMethod("authenticate", AuthenticateRequest.serializer(), AuthenticateResponse.serializer()) - public object SessionNew : AcpRequestResponseMethod("session/new", NewSessionRequest.serializer(), NewSessionResponse.serializer()) - public object SessionLoad : AcpRequestResponseMethod("session/load", LoadSessionRequest.serializer(), LoadSessionResponse.serializer()) + public object Initialize : AcpRequestResponseMethod( + "initialize", + InitializeRequest.serializer(), + InitializeResponse.serializer() + ) + + public object Authenticate : AcpRequestResponseMethod( + "authenticate", + AuthenticateRequest.serializer(), + AuthenticateResponse.serializer() + ) + + public object SessionNew : AcpRequestResponseMethod( + "session/new", + NewSessionRequest.serializer(), + NewSessionResponse.serializer() + ) + + public object SessionLoad : AcpRequestResponseMethod( + "session/load", + LoadSessionRequest.serializer(), + LoadSessionResponse.serializer() + ) // session specific - public object SessionPrompt : AcpSessionRequestResponseMethod("session/prompt", PromptRequest.serializer(), PromptResponse.serializer()) - public object SessionCancel : AcpSessionNotificationMethod("session/cancel", CancelNotification.serializer()) - public object SessionSetMode : AcpSessionRequestResponseMethod("session/set_mode", SetSessionModeRequest.serializer(), SetSessionModeResponse.serializer()) + public object SessionPrompt : AcpSessionRequestResponseMethod( + "session/prompt", + PromptRequest.serializer(), + PromptResponse.serializer() + ) + + public object SessionCancel : + AcpSessionNotificationMethod("session/cancel", CancelNotification.serializer()) + + public object SessionSetMode : AcpSessionRequestResponseMethod( + "session/set_mode", + SetSessionModeRequest.serializer(), + SetSessionModeResponse.serializer() + ) + @UnstableApi - public object SessionSetModel : AcpSessionRequestResponseMethod("session/set_model", SetSessionModelRequest.serializer(), SetSessionModelResponse.serializer()) + public object SessionSetModel : + AcpSessionRequestResponseMethod( + "session/set_model", + SetSessionModelRequest.serializer(), + SetSessionModelResponse.serializer() + ) } public object ClientMethods { // Client-side operations (methods that clients can call on agents) - public object SessionRequestPermission : AcpSessionRequestResponseMethod("session/request_permission", RequestPermissionRequest.serializer(), RequestPermissionResponse.serializer()) - public object SessionUpdate : AcpSessionNotificationMethod("session/update", SessionNotification.serializer()) + public object SessionRequestPermission : + AcpSessionRequestResponseMethod( + "session/request_permission", + RequestPermissionRequest.serializer(), + RequestPermissionResponse.serializer() + ) + + public object SessionUpdate : + AcpSessionNotificationMethod("session/update", SessionNotification.serializer()) // extensions - public object FsReadTextFile : AcpSessionRequestResponseMethod("fs/read_text_file", ReadTextFileRequest.serializer(), ReadTextFileResponse.serializer()) - public object FsWriteTextFile : AcpSessionRequestResponseMethod("fs/write_text_file", WriteTextFileRequest.serializer(), WriteTextFileResponse.serializer()) - public object TerminalCreate : AcpSessionRequestResponseMethod("terminal/create", CreateTerminalRequest.serializer(), CreateTerminalResponse.serializer()) - public object TerminalOutput : AcpSessionRequestResponseMethod("terminal/output", TerminalOutputRequest.serializer(), TerminalOutputResponse.serializer()) - public object TerminalRelease : AcpSessionRequestResponseMethod("terminal/release", ReleaseTerminalRequest.serializer(), ReleaseTerminalResponse.serializer()) - public object TerminalWaitForExit : AcpSessionRequestResponseMethod("terminal/wait_for_exit", WaitForTerminalExitRequest.serializer(), WaitForTerminalExitResponse.serializer()) - public object TerminalKill : AcpSessionRequestResponseMethod("terminal/kill", KillTerminalCommandRequest.serializer(), KillTerminalCommandResponse.serializer()) + public object FsReadTextFile : AcpSessionRequestResponseMethod( + "fs/read_text_file", + ReadTextFileRequest.serializer(), + ReadTextFileResponse.serializer() + ) + + public object FsWriteTextFile : AcpSessionRequestResponseMethod( + "fs/write_text_file", + WriteTextFileRequest.serializer(), + WriteTextFileResponse.serializer() + ) + + public object TerminalCreate : AcpSessionRequestResponseMethod( + "terminal/create", + CreateTerminalRequest.serializer(), + CreateTerminalResponse.serializer() + ) + + public object TerminalOutput : AcpSessionRequestResponseMethod( + "terminal/output", + TerminalOutputRequest.serializer(), + TerminalOutputResponse.serializer() + ) + + public object TerminalRelease : + AcpSessionRequestResponseMethod( + "terminal/release", + ReleaseTerminalRequest.serializer(), + ReleaseTerminalResponse.serializer() + ) + + public object TerminalWaitForExit : + AcpSessionRequestResponseMethod( + "terminal/wait_for_exit", + WaitForTerminalExitRequest.serializer(), + WaitForTerminalExitResponse.serializer() + ) + + public object TerminalKill : + AcpSessionRequestResponseMethod( + "terminal/kill", + KillTerminalCommandRequest.serializer(), + KillTerminalCommandResponse.serializer() + ) } diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Plan.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Plan.kt index 1dc7d92..7865eed 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Plan.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Plan.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.model import kotlinx.serialization.SerialName @@ -16,9 +14,14 @@ import kotlinx.serialization.json.JsonElement */ @Serializable public enum class PlanEntryPriority { - @SerialName("high") HIGH, - @SerialName("medium") MEDIUM, - @SerialName("low") LOW + @SerialName("high") + HIGH, + + @SerialName("medium") + MEDIUM, + + @SerialName("low") + LOW } /** @@ -30,9 +33,14 @@ public enum class PlanEntryPriority { */ @Serializable public enum class PlanEntryStatus { - @SerialName("pending") PENDING, - @SerialName("in_progress") IN_PROGRESS, - @SerialName("completed") COMPLETED + @SerialName("pending") + PENDING, + + @SerialName("in_progress") + IN_PROGRESS, + + @SerialName("completed") + COMPLETED } /** diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Requests.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Requests.kt index fd530d5..0c198a1 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Requests.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Requests.kt @@ -1,4 +1,3 @@ -@file:Suppress("unused") @file:OptIn(ExperimentalSerializationApi::class) package com.agentclientprotocol.model @@ -53,7 +52,7 @@ public data class HttpHeader( @Serializable public sealed class McpServer { public abstract val name: String - + /** * Stdio transport configuration * @@ -67,7 +66,7 @@ public sealed class McpServer { val args: List, val env: List ) : McpServer() - + /** * HTTP transport configuration * @@ -80,7 +79,7 @@ public sealed class McpServer { val url: String, val headers: List ) : McpServer() - + /** * SSE transport configuration * @@ -102,11 +101,20 @@ public sealed class McpServer { */ @Serializable public enum class StopReason { - @SerialName("end_turn") END_TURN, - @SerialName("max_tokens") MAX_TOKENS, - @SerialName("max_turn_requests") MAX_TURN_REQUESTS, - @SerialName("refusal") REFUSAL, - @SerialName("cancelled") CANCELLED + @SerialName("end_turn") + END_TURN, + + @SerialName("max_tokens") + MAX_TOKENS, + + @SerialName("max_turn_requests") + MAX_TURN_REQUESTS, + + @SerialName("refusal") + REFUSAL, + + @SerialName("cancelled") + CANCELLED } /** @@ -116,10 +124,17 @@ public enum class StopReason { */ @Serializable public enum class PermissionOptionKind { - @SerialName("allow_once") ALLOW_ONCE, - @SerialName("allow_always") ALLOW_ALWAYS, - @SerialName("reject_once") REJECT_ONCE, - @SerialName("reject_always") REJECT_ALWAYS + @SerialName("allow_once") + ALLOW_ONCE, + + @SerialName("allow_always") + ALLOW_ALWAYS, + + @SerialName("reject_once") + REJECT_ONCE, + + @SerialName("reject_always") + REJECT_ALWAYS } /** diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/SessionUpdate.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/SessionUpdate.kt index 7d08430..0051617 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/SessionUpdate.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/SessionUpdate.kt @@ -1,4 +1,3 @@ -@file:Suppress("unused") @file:OptIn(ExperimentalSerializationApi::class) package com.agentclientprotocol.model diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Terminal.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Terminal.kt index 3dc7ade..8070bb6 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Terminal.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Terminal.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.model import kotlinx.serialization.Serializable diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/ToolCall.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/ToolCall.kt index 5b2a9c4..1a339b2 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/ToolCall.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/ToolCall.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.model import kotlinx.serialization.SerialName @@ -16,16 +14,35 @@ import kotlinx.serialization.json.JsonElement */ @Serializable public enum class ToolKind { - @SerialName("read") READ, - @SerialName("edit") EDIT, - @SerialName("delete") DELETE, - @SerialName("move") MOVE, - @SerialName("search") SEARCH, - @SerialName("execute") EXECUTE, - @SerialName("think") THINK, - @SerialName("fetch") FETCH, - @SerialName("switch_mode") SWITCH_MODE, - @SerialName("other") OTHER + @SerialName("read") + READ, + + @SerialName("edit") + EDIT, + + @SerialName("delete") + DELETE, + + @SerialName("move") + MOVE, + + @SerialName("search") + SEARCH, + + @SerialName("execute") + EXECUTE, + + @SerialName("think") + THINK, + + @SerialName("fetch") + FETCH, + + @SerialName("switch_mode") + SWITCH_MODE, + + @SerialName("other") + OTHER } /** @@ -37,10 +54,17 @@ public enum class ToolKind { */ @Serializable public enum class ToolCallStatus { - @SerialName("pending") PENDING, - @SerialName("in_progress") IN_PROGRESS, - @SerialName("completed") COMPLETED, - @SerialName("failed") FAILED + @SerialName("pending") + PENDING, + + @SerialName("in_progress") + IN_PROGRESS, + + @SerialName("completed") + COMPLETED, + + @SerialName("failed") + FAILED } /** diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Types.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Types.kt index 6d74cb5..c35e053 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Types.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/model/Types.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.model import com.agentclientprotocol.annotations.UnstableApi @@ -93,8 +91,11 @@ public value class ModelId(public val value: String) { */ @Serializable public enum class Role { - @SerialName("assistant") ASSISTANT, - @SerialName("user") USER + @SerialName("assistant") + ASSISTANT, + + @SerialName("user") + USER } /** diff --git a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/rpc/JsonRpc.kt b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/rpc/JsonRpc.kt index fec6d94..0a7803b 100644 --- a/acp-model/src/commonMain/kotlin/com/agentclientprotocol/rpc/JsonRpc.kt +++ b/acp-model/src/commonMain/kotlin/com/agentclientprotocol/rpc/JsonRpc.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.rpc import kotlinx.serialization.ExperimentalSerializationApi diff --git a/acp-model/src/commonTest/kotlin/com/agentclientprotocol/rpc/JsonDecodeTest.kt b/acp-model/src/commonTest/kotlin/com/agentclientprotocol/rpc/JsonDecodeTest.kt index 0398cf9..0798024 100644 --- a/acp-model/src/commonTest/kotlin/com/agentclientprotocol/rpc/JsonDecodeTest.kt +++ b/acp-model/src/commonTest/kotlin/com/agentclientprotocol/rpc/JsonDecodeTest.kt @@ -52,7 +52,8 @@ class JsonDecodeTest { @Test fun testDecodeError() { try { - decodeJsonRpcMessage(""" + decodeJsonRpcMessage( + """ asdfasdfas "outcome": { "outcome": "selected", @@ -60,9 +61,11 @@ class JsonDecodeTest { } } } - """.trimIndent()) + """.trimIndent() + ) fail("Exception expected") - } catch (_: SerializationException) {} + } catch (_: SerializationException) { + } } } \ No newline at end of file diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt index 3a33082..03c9bc5 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt @@ -37,7 +37,7 @@ private val logger = KotlinLogging.logger {} public class Agent( public val protocol: Protocol, private val agentSupport: AgentSupport - ) { +) { internal class SessionWrapper( val agent: Agent, @@ -46,6 +46,7 @@ public class Agent( val protocol: Protocol ) { private class PromptSession(val currentRequestId: RequestId) + private val _activePrompt = atomic(null) internal suspend fun executeWithSession(block: suspend () -> T): T { @@ -56,7 +57,11 @@ public class Agent( suspend fun prompt(content: List, _meta: JsonElement? = null): PromptResponse { val currentRpcRequest = currentCoroutineContext().jsonRpcRequest - if (!_activePrompt.compareAndSet(null, PromptSession(currentRpcRequest.id))) error("There is already active prompt execution") + if (!_activePrompt.compareAndSet( + null, + PromptSession(currentRpcRequest.id) + ) + ) error("There is already active prompt execution") try { var response: PromptResponse? = null @@ -76,12 +81,10 @@ public class Agent( } return response ?: PromptResponse(StopReason.END_TURN) - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { logger.trace(ce) { "Prompt job cancelled" } return PromptResponse(StopReason.CANCELLED) - } - finally { + } finally { _activePrompt.getAndSet(null) } } @@ -120,12 +123,19 @@ public class Agent( private fun setHandlers(protocol: Protocol) { // Set up request handlers for incoming client requests protocol.setRequestHandler(AcpMethod.AgentMethods.Initialize) { params: InitializeRequest -> - val clientInfo = ClientInfo(params.protocolVersion, params.clientCapabilities, params.clientInfo, params._meta) + val clientInfo = + ClientInfo(params.protocolVersion, params.clientCapabilities, params.clientInfo, params._meta) _clientInfo.complete(clientInfo) val agentInfo = agentSupport.initialize(clientInfo) // see https://agentclientprotocol.com/protocol/initialization#version-negotiation val negotiatedVersion = min(params.protocolVersion, agentInfo.protocolVersion) - return@setRequestHandler InitializeResponse(negotiatedVersion, agentInfo.capabilities, agentInfo.authMethods, agentInfo.implementation, agentInfo._meta) + return@setRequestHandler InitializeResponse( + negotiatedVersion, + agentInfo.capabilities, + agentInfo.authMethods, + agentInfo.implementation, + agentInfo._meta + ) } protocol.setRequestHandler(AcpMethod.AgentMethods.Authenticate) { params: AuthenticateRequest -> @@ -145,7 +155,8 @@ public class Agent( protocol.setRequestHandler(AcpMethod.AgentMethods.SessionLoad) { params: LoadSessionRequest -> val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta) - val session = createSession(sessionParameters) { agentSupport.loadSession(params.sessionId, sessionParameters) } + val session = + createSession(sessionParameters) { agentSupport.loadSession(params.sessionId, sessionParameters) } return@setRequestHandler LoadSessionResponse( // maybe unify result of these two methods to have sessionId in both // sessionId = session.sessionId, @@ -181,7 +192,10 @@ public class Agent( } } - private suspend fun createSession(sessionParameters: SessionCreationParameters, sessionFactory: suspend (SessionCreationParameters) -> AgentSession): AgentSession { + private suspend fun createSession( + sessionParameters: SessionCreationParameters, + sessionFactory: suspend (SessionCreationParameters) -> AgentSession + ): AgentSession { val session = sessionFactory(sessionParameters) val clientInfo = getClientInfoOrThrow() @@ -198,11 +212,13 @@ public class Agent( return session } - private fun getSessionOrThrow(sessionId: SessionId): SessionWrapper = _sessions.value[sessionId] ?: acpFail("Session $sessionId not found") + private fun getSessionOrThrow(sessionId: SessionId): SessionWrapper = + _sessions.value[sessionId] ?: acpFail("Session $sessionId not found") } -internal class SessionWrapperContextElement(val sessionWrapper: Agent.SessionWrapper) : AbstractCoroutineContextElement(Key) { +internal class SessionWrapperContextElement(val sessionWrapper: Agent.SessionWrapper) : + AbstractCoroutineContextElement(Key) { object Key : CoroutineContext.Key } @@ -210,6 +226,7 @@ internal fun Agent.SessionWrapper.asContextElement() = SessionWrapperContextElem public val CoroutineContext.agent: Agent get() = this[SessionWrapperContextElement.Key]?.sessionWrapper?.agent ?: error("No agent data found in context") + /** * Returns client info associated with the current protocol. Throws an exception if the agent is still not initialized from the client side. */ @@ -220,5 +237,6 @@ public val CoroutineContext.clientInfo: ClientInfo * Returns a remote client connected to the counterpart via the current protocol */ public val CoroutineContext.client: ClientSessionOperations - get() = this[SessionWrapperContextElement.Key]?.sessionWrapper?.clientOperations ?: error("No remote client found in context") + get() = this[SessionWrapperContextElement.Key]?.sessionWrapper?.clientOperations + ?: error("No remote client found in context") diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentInfo.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentInfo.kt index 92ef361..22e097e 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentInfo.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentInfo.kt @@ -1,10 +1,6 @@ package com.agentclientprotocol.agent -import com.agentclientprotocol.model.AgentCapabilities -import com.agentclientprotocol.model.AuthMethod -import com.agentclientprotocol.model.Implementation -import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION -import com.agentclientprotocol.model.ProtocolVersion +import com.agentclientprotocol.model.* import kotlinx.serialization.json.JsonElement public class AgentInfo( diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt index 4bd99b0..2f45cc8 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt @@ -9,7 +9,9 @@ import kotlinx.serialization.json.JsonElement public interface AgentSupport { public suspend fun initialize(clientInfo: ClientInfo): AgentInfo - public suspend fun authenticate(methodId: AuthMethodId, _meta: JsonElement?): AuthenticateResponse = AuthenticateResponse() + public suspend fun authenticate(methodId: AuthMethodId, _meta: JsonElement?): AuthenticateResponse = + AuthenticateResponse() + public suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession public suspend fun loadSession(sessionId: SessionId, sessionParameters: SessionCreationParameters): AgentSession } \ No newline at end of file diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/RemoteClientSessionOperations.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/RemoteClientSessionOperations.kt index 947fb66..0b8625b 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/RemoteClientSessionOperations.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/RemoteClientSessionOperations.kt @@ -6,13 +6,20 @@ import com.agentclientprotocol.protocol.RpcMethodsOperations import com.agentclientprotocol.protocol.invoke import kotlinx.serialization.json.JsonElement -internal class RemoteClientSessionOperations(private val rpc: RpcMethodsOperations, private val sessionId: SessionId, private val clientCapabilities: ClientCapabilities) : ClientSessionOperations { +internal class RemoteClientSessionOperations( + private val rpc: RpcMethodsOperations, + private val sessionId: SessionId, + private val clientCapabilities: ClientCapabilities +) : ClientSessionOperations { override suspend fun requestPermissions( toolCall: SessionUpdate.ToolCallUpdate, permissions: List, _meta: JsonElement?, ): RequestPermissionResponse { - return AcpMethod.ClientMethods.SessionRequestPermission(rpc, RequestPermissionRequest(sessionId, toolCall, permissions, _meta)) + return AcpMethod.ClientMethods.SessionRequestPermission( + rpc, + RequestPermissionRequest(sessionId, toolCall, permissions, _meta) + ) } override suspend fun notify( @@ -50,7 +57,10 @@ internal class RemoteClientSessionOperations(private val rpc: RpcMethodsOperatio _meta: JsonElement?, ): CreateTerminalResponse { if (!clientCapabilities.terminal) error("Client does not support terminal capability") - return AcpMethod.ClientMethods.TerminalCreate(rpc, CreateTerminalRequest(sessionId, command, args, cwd, env, outputByteLimit, _meta)) + return AcpMethod.ClientMethods.TerminalCreate( + rpc, + CreateTerminalRequest(sessionId, command, args, cwd, env, outputByteLimit, _meta) + ) } override suspend fun terminalOutput( @@ -74,7 +84,10 @@ internal class RemoteClientSessionOperations(private val rpc: RpcMethodsOperatio _meta: JsonElement?, ): WaitForTerminalExitResponse { if (!clientCapabilities.terminal) error("Client does not support terminal capability") - return AcpMethod.ClientMethods.TerminalWaitForExit(rpc, WaitForTerminalExitRequest(sessionId, terminalId, _meta)) + return AcpMethod.ClientMethods.TerminalWaitForExit( + rpc, + WaitForTerminalExitRequest(sessionId, terminalId, _meta) + ) } override suspend fun terminalKill( diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt index 0841b58..6c69dd6 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/Client.kt @@ -74,7 +74,14 @@ public class Client( val terminal = session.operations as? TerminalOperations ?: sessionMethodNotFound(AcpMethod.ClientMethods.TerminalCreate) return@setRequestHandler session.executeWithSession { - return@executeWithSession terminal.terminalCreate(params.command, params.args, params.cwd, params.env, params.outputByteLimit, params._meta) + return@executeWithSession terminal.terminalCreate( + params.command, + params.args, + params.cwd, + params.env, + params.outputByteLimit, + params._meta + ) } } @@ -138,8 +145,17 @@ public class Client( public suspend fun initialize(clientInfo: ClientInfo, _meta: JsonElement? = null): AgentInfo { _clientInfo.complete(clientInfo) - val initializeResponse = AcpMethod.AgentMethods.Initialize(protocol, InitializeRequest(clientInfo.protocolVersion, clientInfo.capabilities, clientInfo.implementation, _meta)) - val agentInfo = AgentInfo(initializeResponse.protocolVersion, initializeResponse.agentCapabilities, initializeResponse.authMethods, initializeResponse.agentInfo, initializeResponse._meta) + val initializeResponse = AcpMethod.AgentMethods.Initialize( + protocol, + InitializeRequest(clientInfo.protocolVersion, clientInfo.capabilities, clientInfo.implementation, _meta) + ) + val agentInfo = AgentInfo( + initializeResponse.protocolVersion, + initializeResponse.agentCapabilities, + initializeResponse.authMethods, + initializeResponse.agentInfo, + initializeResponse._meta + ) _agentInfo.complete(agentInfo) return agentInfo } @@ -161,7 +177,10 @@ public class Client( * See [ClientOperationsFactory.createClientOperations] for more details. * @return a [ClientSession] instance for the new session */ - public suspend fun newSession(sessionParameters: SessionCreationParameters, operationsFactory: ClientOperationsFactory): ClientSession { + public suspend fun newSession( + sessionParameters: SessionCreationParameters, + operationsFactory: ClientOperationsFactory + ): ClientSession { return withInitializingSession { val newSessionResponse = AcpMethod.AgentMethods.SessionNew( protocol, @@ -172,7 +191,12 @@ public class Client( ) ) val sessionId = newSessionResponse.sessionId - return@withInitializingSession createSession(sessionId, sessionParameters, newSessionResponse, operationsFactory) + return@withInitializingSession createSession( + sessionId, + sessionParameters, + newSessionResponse, + operationsFactory + ) } } @@ -186,7 +210,11 @@ public class Client( * See [ClientOperationsFactory.createClientOperations] for more details. * @return a [ClientSession] instance for the new session */ - public suspend fun loadSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, operationsFactory: ClientOperationsFactory): ClientSession { + public suspend fun loadSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + operationsFactory: ClientOperationsFactory + ): ClientSession { return withInitializingSession { val loadSessionResponse = AcpMethod.AgentMethods.SessionLoad( protocol, @@ -197,11 +225,21 @@ public class Client( sessionParameters._meta ) ) - return@withInitializingSession createSession(sessionId, sessionParameters, loadSessionResponse, operationsFactory) + return@withInitializingSession createSession( + sessionId, + sessionParameters, + loadSessionResponse, + operationsFactory + ) } } - private suspend fun createSession(sessionId: SessionId, sessionParameters: SessionCreationParameters, sessionResponse: AcpCreatedSessionResponse, factory: ClientOperationsFactory): ClientSession { + private suspend fun createSession( + sessionId: SessionId, + sessionParameters: SessionCreationParameters, + sessionResponse: AcpCreatedSessionResponse, + factory: ClientOperationsFactory + ): ClientSession { val sessionDeferred = CompletableDeferred() return runCatching { _sessions.update { it.put(sessionId, sessionDeferred) } @@ -238,7 +276,7 @@ public class Client( acpFail("Session $sessionId not found") } - private suspend fun withInitializingSession(block: suspend () -> T): T { + private suspend fun withInitializingSession(block: suspend () -> T): T { _currentlyInitializingSessionsCount.update { it + 1 } try { return block() diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientOperationsFactory.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientOperationsFactory.kt index 766b14a..8a9f16f 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientOperationsFactory.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientOperationsFactory.kt @@ -15,5 +15,8 @@ public fun interface ClientOperationsFactory { * * [sessionId] an existing id in the case when the session is being loaded or a new id when the session is newly created (id is returned from the agent) */ - public suspend fun createClientOperations(sessionId: SessionId, sessionResponse: AcpCreatedSessionResponse): ClientSessionOperations + public suspend fun createClientOperations( + sessionId: SessionId, + sessionResponse: AcpCreatedSessionResponse + ): ClientSessionOperations } \ No newline at end of file diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSession.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSession.kt index 0c2564e..0d4877e 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSession.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSession.kt @@ -46,6 +46,7 @@ public interface ClientSession { * @throws IllegalStateException if the mode changing is not supported. */ public val currentMode: StateFlow + /** * Changes the session mode to the specified mode. The real change will be reported by an agent via [currentMode] and [ClientSessionOperations.notify]. */ diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt index bd2135e..ec9e1d6 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/client/ClientSessionImpl.kt @@ -35,6 +35,7 @@ internal class ClientSessionImpl( private class PromptSession( val updateChannel: Channel ) + private val activePrompt = atomic(null) private val _currentMode by lazy { @@ -63,7 +64,8 @@ internal class ClientSessionImpl( } try { logger.trace { "Sending prompt request: $content" } - val promptResponse = AcpMethod.AgentMethods.SessionPrompt(protocol, PromptRequest(sessionId, content, _meta)) + val promptResponse = + AcpMethod.AgentMethods.SessionPrompt(protocol, PromptRequest(sessionId, content, _meta)) logger.trace { "Received prompt response: $promptResponse" } // after receiving prompt response we immediately close the current prompt channel @@ -122,6 +124,7 @@ internal class ClientSessionImpl( block() } } + /** * Routes notification to either active prompt or global notification channel */ @@ -143,17 +146,18 @@ internal class ClientSessionImpl( if (promptSession != null && !promptSession.updateChannel.isClosedForSend) { logger.trace { "Sending update to active prompt: $notification" } promptSession.updateChannel.send(notification) - } - else { + } else { logger.trace { "Notifying globally: $notification" } operations.notify(notification, _meta) } } - internal suspend fun handlePermissionResponse(toolCall: SessionUpdate.ToolCallUpdate, - permissions: List, - _meta: JsonElement?,): RequestPermissionResponse { - return operations.requestPermissions(toolCall, permissions, _meta) + internal suspend fun handlePermissionResponse( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return operations.requestPermissions(toolCall, permissions, _meta) } } diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/common/FileSystemOperations.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/common/FileSystemOperations.kt index 262124a..6a563bf 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/common/FileSystemOperations.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/common/FileSystemOperations.kt @@ -5,15 +5,20 @@ import com.agentclientprotocol.model.WriteTextFileResponse import kotlinx.serialization.json.JsonElement public interface FileSystemOperations { - public suspend fun fsReadTextFile(path: String, - line: UInt? = null, - limit: UInt? = null, - _meta: JsonElement? = null): ReadTextFileResponse { + public suspend fun fsReadTextFile( + path: String, + line: UInt? = null, + limit: UInt? = null, + _meta: JsonElement? = null + ): ReadTextFileResponse { throw NotImplementedError("Must be implemented by client when advertising fs.readTextFile capability") } - public suspend fun fsWriteTextFile(path: String, - content: String, - _meta: JsonElement? = null): WriteTextFileResponse { + + public suspend fun fsWriteTextFile( + path: String, + content: String, + _meta: JsonElement? = null + ): WriteTextFileResponse { throw NotImplementedError("Must be implemented by client when advertising fs.writeTextFile capability") } } \ No newline at end of file diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/common/TerminalOperations.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/common/TerminalOperations.kt index c170456..45e8a96 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/common/TerminalOperations.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/common/TerminalOperations.kt @@ -4,32 +4,42 @@ import com.agentclientprotocol.model.* import kotlinx.serialization.json.JsonElement public interface TerminalOperations { - public suspend fun terminalCreate(command: String, - args: List = emptyList(), - cwd: String? = null, - env: List = emptyList(), - outputByteLimit: ULong? = null, - _meta: JsonElement? = null): CreateTerminalResponse { + public suspend fun terminalCreate( + command: String, + args: List = emptyList(), + cwd: String? = null, + env: List = emptyList(), + outputByteLimit: ULong? = null, + _meta: JsonElement? = null + ): CreateTerminalResponse { throw NotImplementedError("Must be implemented by client when advertising terminal capability") } - public suspend fun terminalOutput(terminalId: String, - _meta: JsonElement? = null): TerminalOutputResponse { + public suspend fun terminalOutput( + terminalId: String, + _meta: JsonElement? = null + ): TerminalOutputResponse { throw NotImplementedError("Must be implemented by client when advertising terminal capability") } - public suspend fun terminalRelease(terminalId: String, - _meta: JsonElement? = null): ReleaseTerminalResponse { + public suspend fun terminalRelease( + terminalId: String, + _meta: JsonElement? = null + ): ReleaseTerminalResponse { throw NotImplementedError("Must be implemented by client when advertising terminal capability") } - public suspend fun terminalWaitForExit(terminalId: String, - _meta: JsonElement? = null): WaitForTerminalExitResponse { + public suspend fun terminalWaitForExit( + terminalId: String, + _meta: JsonElement? = null + ): WaitForTerminalExitResponse { throw NotImplementedError("Must be implemented by client when advertising terminal capability") } - public suspend fun terminalKill(terminalId: String, - _meta: JsonElement? = null): KillTerminalCommandResponse { + public suspend fun terminalKill( + terminalId: String, + _meta: JsonElement? = null + ): KillTerminalCommandResponse { throw NotImplementedError("Must be implemented by client when advertising terminal capability") } } \ No newline at end of file diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.extensions.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.extensions.kt index cb488de..2b1b13e 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.extensions.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.extensions.kt @@ -7,8 +7,6 @@ import com.agentclientprotocol.model.AcpResponse import com.agentclientprotocol.rpc.ACPJson import com.agentclientprotocol.rpc.JsonRpcRequest import kotlinx.serialization.json.JsonNull -import kotlinx.serialization.json.decodeFromJsonElement -import kotlinx.serialization.json.encodeToJsonElement import kotlin.coroutines.AbstractCoroutineContextElement import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext @@ -29,7 +27,7 @@ public suspend fun RpcMethodsOp /** * Send a notification (no response expected). */ -public fun RpcMethodsOperations.sendNotification( +public fun RpcMethodsOperations.sendNotification( method: AcpMethod.AcpNotificationMethod, notification: TNotification? = null, ) { @@ -40,7 +38,7 @@ public fun RpcMethodsOperations.sendNotificatio /** * Register a handler for incoming requests. */ -public fun RpcMethodsOperations.setRequestHandler( +public fun RpcMethodsOperations.setRequestHandler( method: AcpMethod.AcpRequestResponseMethod, additionalContext: CoroutineContext = EmptyCoroutineContext, handler: suspend (TRequest) -> TResponse @@ -51,10 +49,11 @@ public fun RpcMethodsOperations. ACPJson.encodeToJsonElement(method.responseSerializer, responseObject) } } + /** * Register a handler for incoming notifications. */ -public fun RpcMethodsOperations.setNotificationHandler( +public fun RpcMethodsOperations.setNotificationHandler( method: AcpMethod.AcpNotificationMethod, additionalContext: CoroutineContext = EmptyCoroutineContext, handler: suspend (TNotification) -> Unit @@ -65,11 +64,17 @@ public fun RpcMethodsOperations.setNotification } } -public suspend operator fun AcpMethod.AcpRequestResponseMethod.invoke(rpc: RpcMethodsOperations, request: TRequest): TResponse { +public suspend operator fun AcpMethod.AcpRequestResponseMethod.invoke( + rpc: RpcMethodsOperations, + request: TRequest +): TResponse { return rpc.sendRequest(this, request) } -public operator fun AcpMethod.AcpNotificationMethod.invoke(rpc: RpcMethodsOperations, notification: TNotification) { +public operator fun AcpMethod.AcpNotificationMethod.invoke( + rpc: RpcMethodsOperations, + notification: TNotification +) { return rpc.sendNotification(this, notification) } diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt index 4f72566..77ccf49 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.protocol import com.agentclientprotocol.model.AcpMethod diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/BaseTransport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/BaseTransport.kt index 02710c1..18a04ed 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/BaseTransport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/BaseTransport.kt @@ -12,9 +12,9 @@ private val logger = KotlinLogging.logger {} public abstract class BaseTransport : Transport { protected val _state: MutableStateFlow = MutableStateFlow(Transport.State.CREATED) - private val messageHandlers = atomic({}) - private val errorHandlers = atomic({}) - private val closeHandlers = atomic({}) + private val messageHandlers = atomic {} + private val errorHandlers = atomic {} + private val closeHandlers = atomic {} override fun onMessage(handler: MessageListener) { messageHandlers.update { old -> diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt index 5ed59df..69907f0 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt @@ -4,12 +4,10 @@ import com.agentclientprotocol.rpc.ACPJson import com.agentclientprotocol.rpc.JsonRpcMessage import com.agentclientprotocol.rpc.decodeJsonRpcMessage import com.agentclientprotocol.transport.Transport.State -import com.agentclientprotocol.util.checkCancelled import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.getAndUpdate -import kotlinx.coroutines.flow.update import kotlinx.io.* import kotlinx.serialization.encodeToString @@ -28,11 +26,13 @@ public class StdioTransport( private val output: Sink, private val name: String = StdioTransport::class.simpleName!!, ) : BaseTransport() { - private val childScope = CoroutineScope(parentScope.coroutineContext + SupervisorJob(parentScope.coroutineContext[Job]) + CoroutineName(name)) + private val childScope = CoroutineScope( + parentScope.coroutineContext + SupervisorJob(parentScope.coroutineContext[Job]) + CoroutineName(name) + ) private val receiveChannel = Channel(Channel.UNLIMITED) private val sendChannel = Channel(Channel.UNLIMITED) - + override fun start() { if (_state.getAndUpdate { State.STARTING } != State.CREATED) error("Transport is not in ${State.CREATED.name} state") // Start reading messages from input @@ -112,16 +112,13 @@ public class StdioTransport( logger.trace { "Joining read/write jobs..." } if (_state.getAndUpdate { State.STARTED } != State.STARTING) logger.warn { "Transport is not in ${State.STARTING.name} state" } joinAll(readJob, writeJob) - } - catch (ce: CancellationException) { + } catch (ce: CancellationException) { logger.trace(ce) { "Join cancelled" } // don't throw as error - } - catch (e: Exception) { + } catch (e: Exception) { logger.trace(e) { "Exception while waiting read/write jobs" } fireError(e) - } - finally { + } finally { childScope.cancel() if (_state.getAndUpdate { State.CLOSED } != State.CLOSING) logger.warn { "Transport is not in ${State.CLOSING.name} state" } fireClose() @@ -129,7 +126,7 @@ public class StdioTransport( } } } - + override fun send(message: JsonRpcMessage) { logger.trace { "Sending message: $message" } val channelResult = sendChannel.trySend(message) diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/Transport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/Transport.kt index 2b765c1..da8edd6 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/Transport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/Transport.kt @@ -1,5 +1,3 @@ -@file:Suppress("unused") - package com.agentclientprotocol.transport import com.agentclientprotocol.rpc.JsonRpcMessage @@ -18,6 +16,7 @@ public typealias CloseListener = () -> Unit */ public interface Transport : AutoCloseable { public enum class State { CREATED, STARTING, STARTED, CLOSING, CLOSED } + public val state: StateFlow /** diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportTest.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportTest.kt index 4024cce..d8cbc80 100644 --- a/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportTest.kt +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportTest.kt @@ -1,11 +1,6 @@ package com.agentclientprotocol.transport -import com.agentclientprotocol.rpc.JsonRpcMessage -import com.agentclientprotocol.rpc.JsonRpcNotification -import com.agentclientprotocol.rpc.JsonRpcRequest -import com.agentclientprotocol.rpc.JsonRpcResponse -import com.agentclientprotocol.rpc.MethodName -import com.agentclientprotocol.rpc.RequestId +import com.agentclientprotocol.rpc.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.onEach @@ -38,8 +33,7 @@ class StdioTransportTest { .onEach { observed.add(it) } .first { it == state } } - } - catch (_: TimeoutCancellationException) { + } catch (_: TimeoutCancellationException) { fail("Timed out waiting for state $state after $timeout, observed states: ${observed.joinToString { it.name }}, ${message?.let { ": $it" } ?: ""}") } } @@ -168,7 +162,7 @@ class StdioTransportTest { // Closing again should not throw transport.close() - expectState(Transport.State.CLOSED, message = "After 2 close") + expectState(Transport.State.CLOSED, message = "After 2 close") } @Test diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 1df1422..bed0464 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -7,14 +7,12 @@ kotlinx-collections-immutable = "0.3.8" ktor = "3.1.3" kotlin-logging = "7.0.0" slf4j = "2.0.16" -kotest = "6.0.3" atomicfu = "0.25.0" mavenPublish = "0.34.0" [libraries] kotlin-gradle-plugin = { module = "org.jetbrains.kotlin:kotlin-gradle-plugin", version.ref = "kotlin" } kotlin-serialization-plugin = { module = "org.jetbrains.kotlin:kotlin-serialization", version.ref = "kotlin" } -atomicfu-plugin = { module = "org.jetbrains.kotlinx:atomicfu-gradle-plugin", version.ref = "atomicfu" } maven-publish = { module = "com.vanniktech:gradle-maven-publish-plugin", version.ref = "mavenPublish" } kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } @@ -26,14 +24,11 @@ kotlinx-atomicfu = { module = "org.jetbrains.kotlinx:atomicfu", version.ref = "a ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" } ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" } ktor-server-websockets = { module = "io.ktor:ktor-server-websockets", version.ref = "ktor" } -ktor-server-sse = { module = "io.ktor:ktor-server-sse", version.ref = "ktor" } ktor-server-test-host = { module = "io.ktor:ktor-server-test-host", version.ref = "ktor" } ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } ktor-client-websockets = { module = "io.ktor:ktor-client-websockets", version.ref = "ktor" } kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlin-logging" } slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" } -kotest-assertions-json = { module = "io.kotest:kotest-assertions-json-jvm", version.ref = "kotest" } - [plugins] kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version = "0.16.3" } \ No newline at end of file diff --git a/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/GeminiClientApp.kt b/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/GeminiClientApp.kt index c46ab6e..f8fde3b 100644 --- a/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/GeminiClientApp.kt +++ b/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/GeminiClientApp.kt @@ -22,7 +22,7 @@ private val logger = KotlinLogging.logger {} */ suspend fun main() = coroutineScope { logger.info { "Starting Gemini ACP Client App" } - // Create process transport to start Gemini agent - val transport = createProcessStdioTransport(this, "gemini", "--experimental-acp") - runTerminalClient(transport) + // Create process transport to start Gemini agent + val transport = createProcessStdioTransport(this, "gemini", "--experimental-acp") + runTerminalClient(transport) } \ No newline at end of file diff --git a/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/SimpleAgentSupport.kt b/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/SimpleAgentSupport.kt index a77b3a1..020b979 100644 --- a/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/SimpleAgentSupport.kt +++ b/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/SimpleAgentSupport.kt @@ -43,11 +43,13 @@ class SimpleAgentSession( .joinToString(" ") { it.text } }" - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text(responseText) + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text(responseText) + ) ) - )) + ) // Simulate a tool call if client supports file operations if (clientCapabilities.fs?.readTextFile == true) { @@ -85,50 +87,58 @@ class SimpleAgentSession( ) ) - emit(Event.SessionUpdateEvent( - SessionUpdate.PlanUpdate(plan.entries) - )) + emit( + Event.SessionUpdateEvent( + SessionUpdate.PlanUpdate(plan.entries) + ) + ) } private suspend fun FlowCollector.simulateToolCall() { val toolCallId = ToolCallId("tool-${System.currentTimeMillis()}") // Start tool call - emit(Event.SessionUpdateEvent( - SessionUpdate.ToolCallUpdate( - toolCallId = toolCallId, - title = "Reading current directory", - kind = ToolKind.READ, - status = ToolCallStatus.PENDING, - locations = listOf(ToolCallLocation(".")), - content = emptyList() + emit( + Event.SessionUpdateEvent( + SessionUpdate.ToolCallUpdate( + toolCallId = toolCallId, + title = "Reading current directory", + kind = ToolKind.READ, + status = ToolCallStatus.PENDING, + locations = listOf(ToolCallLocation(".")), + content = emptyList() + ) ) - )) + ) delay(500) // Simulate work // Update to in progress - emit(Event.SessionUpdateEvent( - SessionUpdate.ToolCallUpdate( - toolCallId = toolCallId, - status = ToolCallStatus.IN_PROGRESS + emit( + Event.SessionUpdateEvent( + SessionUpdate.ToolCallUpdate( + toolCallId = toolCallId, + status = ToolCallStatus.IN_PROGRESS + ) ) - )) + ) delay(500) // Simulate more work // Complete the tool call - emit(Event.SessionUpdateEvent( - SessionUpdate.ToolCallUpdate( - toolCallId = toolCallId, - status = ToolCallStatus.COMPLETED, - content = listOf( - ToolCallContent.Content( - ContentBlock.Text("Directory listing completed successfully") + emit( + Event.SessionUpdateEvent( + SessionUpdate.ToolCallUpdate( + toolCallId = toolCallId, + status = ToolCallStatus.COMPLETED, + content = listOf( + ToolCallContent.Content( + ContentBlock.Text("Directory listing completed successfully") + ) ) ) ) - )) + ) } private suspend fun FlowCollector.demonstrateFileSystemOperations() { @@ -136,11 +146,13 @@ class SimpleAgentSession( val clientOperation = currentCoroutineContext().client // Example: Write a file - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text("\nDemonstrating file system operations...") + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text("\nDemonstrating file system operations...") + ) ) - )) + ) val testContent = "Hello from ACP agent!" clientOperation.fsWriteTextFile("/tmp/acp_test.txt", testContent) @@ -148,17 +160,21 @@ class SimpleAgentSession( // Example: Read the file back val readResponse = clientOperation.fsReadTextFile("/tmp/acp_test.txt") - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text("\nFile content read: ${readResponse.content}") + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text("\nFile content read: ${readResponse.content}") + ) ) - )) + ) } catch (e: Exception) { - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text("\nFile system operation failed: ${e.message}") + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text("\nFile system operation failed: ${e.message}") + ) ) - )) + ) } } @@ -166,11 +182,13 @@ class SimpleAgentSession( try { val terminalOps = currentCoroutineContext().client - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text("\nDemonstrating terminal operations...") + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text("\nDemonstrating terminal operations...") + ) ) - )) + ) // Example: Execute a simple command val createResponse = terminalOps.terminalCreate("echo", listOf("Hello from terminal!")) @@ -178,17 +196,21 @@ class SimpleAgentSession( val outputResponse = terminalOps.terminalOutput(createResponse.terminalId) terminalOps.terminalRelease(createResponse.terminalId) - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text("\nTerminal output: ${outputResponse.output} (exit code: ${exitResponse.exitCode})") + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text("\nTerminal output: ${outputResponse.output} (exit code: ${exitResponse.exitCode})") + ) ) - )) + ) } catch (e: Exception) { - emit(Event.SessionUpdateEvent( - SessionUpdate.AgentMessageChunk( - ContentBlock.Text("\nTerminal operation failed: ${e.message}") + emit( + Event.SessionUpdateEvent( + SessionUpdate.AgentMessageChunk( + ContentBlock.Text("\nTerminal operation failed: ${e.message}") + ) ) - )) + ) } } } diff --git a/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/TerminalClientSupport.kt b/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/TerminalClientSupport.kt index 8b780c5..7611470 100644 --- a/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/TerminalClientSupport.kt +++ b/samples/kotlin-acp-client-sample/src/main/kotlin/com/agentclientprotocol/samples/TerminalClientSupport.kt @@ -1,11 +1,8 @@ package com.agentclientprotocol.samples -import com.agentclientprotocol.client.* -import com.agentclientprotocol.common.ClientSessionOperations -import com.agentclientprotocol.common.Event -import com.agentclientprotocol.common.FileSystemOperations -import com.agentclientprotocol.common.SessionCreationParameters -import com.agentclientprotocol.common.TerminalOperations +import com.agentclientprotocol.client.Client +import com.agentclientprotocol.client.ClientInfo +import com.agentclientprotocol.common.* import com.agentclientprotocol.model.* import com.agentclientprotocol.protocol.Protocol import com.agentclientprotocol.transport.Transport @@ -13,7 +10,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.CoroutineScope import kotlinx.serialization.json.JsonElement import java.nio.file.Paths -import java.util.UUID +import java.util.* import java.util.concurrent.ConcurrentHashMap import kotlin.io.path.absolutePathString import kotlin.io.path.readText @@ -37,7 +34,10 @@ class TerminalClientSessionOperations : ClientSessionOperations, FileSystemOpera val read = readln() val optionIndex = read.toIntOrNull() if (optionIndex != null && optionIndex in permissions.indices) { - return RequestPermissionResponse(RequestPermissionOutcome.Selected(permissions[optionIndex].optionId), _meta) + return RequestPermissionResponse( + RequestPermissionOutcome.Selected(permissions[optionIndex].optionId), + _meta + ) } println("Invalid option selected. Try again.") }