Skip to content

Add Streamable Http Transport #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions api/kotlin-sdk.api
Original file line number Diff line number Diff line change
Expand Up @@ -1485,8 +1485,9 @@ public final class io/modelcontextprotocol/kotlin/sdk/PingRequest$Companion {

public class io/modelcontextprotocol/kotlin/sdk/Progress : io/modelcontextprotocol/kotlin/sdk/ProgressBase {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/Progress$Companion;
public synthetic fun <init> (IILjava/lang/Double;Lkotlinx/serialization/internal/SerializationConstructorMarker;)V
public fun <init> (ILjava/lang/Double;)V
public synthetic fun <init> (IILjava/lang/Double;Ljava/lang/String;Lkotlinx/serialization/internal/SerializationConstructorMarker;)V
public fun <init> (ILjava/lang/Double;Ljava/lang/String;)V
public fun getMessage ()Ljava/lang/String;
public fun getProgress ()I
public fun getTotal ()Ljava/lang/Double;
public static final synthetic fun write$Self (Lio/modelcontextprotocol/kotlin/sdk/Progress;Lkotlinx/serialization/encoding/CompositeEncoder;Lkotlinx/serialization/descriptors/SerialDescriptor;)V
Expand All @@ -1509,6 +1510,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/Progress$Companion {

public abstract interface class io/modelcontextprotocol/kotlin/sdk/ProgressBase {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ProgressBase$Companion;
public abstract fun getMessage ()Ljava/lang/String;
public abstract fun getProgress ()I
public abstract fun getTotal ()Ljava/lang/Double;
}
Expand All @@ -1519,15 +1521,17 @@ public final class io/modelcontextprotocol/kotlin/sdk/ProgressBase$Companion {

public final class io/modelcontextprotocol/kotlin/sdk/ProgressNotification : io/modelcontextprotocol/kotlin/sdk/ClientNotification, io/modelcontextprotocol/kotlin/sdk/ProgressBase, io/modelcontextprotocol/kotlin/sdk/ServerNotification {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification$Companion;
public fun <init> (ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;)V
public synthetic fun <init> (ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;Ljava/lang/String;)V
public synthetic fun <init> (ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()I
public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/RequestId;
public final fun component3 ()Lkotlinx/serialization/json/JsonObject;
public final fun component4 ()Ljava/lang/Double;
public final fun copy (ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;)Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification;
public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification;ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification;
public final fun component5 ()Ljava/lang/String;
public final fun copy (ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification;
public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification;ILio/modelcontextprotocol/kotlin/sdk/RequestId;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Double;Ljava/lang/String;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/ProgressNotification;
public fun equals (Ljava/lang/Object;)Z
public fun getMessage ()Ljava/lang/String;
public fun getMethod ()Lio/modelcontextprotocol/kotlin/sdk/Method;
public fun getProgress ()I
public final fun getProgressToken ()Lio/modelcontextprotocol/kotlin/sdk/RequestId;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
package io.modelcontextprotocol.kotlin.sdk.server

import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.sse.*
import io.modelcontextprotocol.kotlin.sdk.*
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlin.collections.HashMap
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid

/**
* Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages.
*
* Creates a new StreamableHttp server transport.
*/
@OptIn(ExperimentalAtomicApi::class)
public class StreamableHttpServerTransport(
private val isStateful: Boolean = false,
private val enableJSONResponse: Boolean = false,
): AbstractTransport() {
private val standalone = "standalone"
private val streamMapping: HashMap<String, ServerSSESession> = hashMapOf()
private val requestToStreamMapping: HashMap<RequestId, String> = hashMapOf()
private val requestResponseMapping: HashMap<RequestId, JSONRPCMessage> = hashMapOf()
private val callMapping: HashMap<String, ApplicationCall> = hashMapOf()
private val started: AtomicBoolean = AtomicBoolean(false)
private val initialized: AtomicBoolean = AtomicBoolean(false)

public var sessionId: String? = null
private set

override suspend fun start() {
if (!started.compareAndSet(false, true)) {
error("StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically.")
}
}

override suspend fun send(message: JSONRPCMessage) {
var requestId: RequestId? = null

if (message is JSONRPCResponse) {
requestId = message.id
}

if (requestId == null) {
val standaloneSSE = streamMapping[standalone] ?: return

standaloneSSE.send(
event = "message",
data = McpJson.encodeToString(message),
)
return
}

val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId")
val correspondingStream =
streamMapping[streamId] ?: error("No connection established for request id $requestId")
val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId")

if (!enableJSONResponse) {
correspondingStream.send(
event = "message",
data = McpJson.encodeToString(message),
)
}

requestResponseMapping[requestId] = message
val relatedIds =
requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null }

if (!allResponsesReady) return

if (enableJSONResponse) {
correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString())
correspondingCall.response.status(HttpStatusCode.OK)
if (sessionId != null) {
correspondingCall.response.header(MCP_SESSION_ID, sessionId!!)
}
val responses = relatedIds.map { requestResponseMapping[it] }
if (responses.size == 1) {
correspondingCall.respond(responses[0]!!)
} else {
correspondingCall.respond(responses)
}
callMapping.remove(streamId)
} else {
correspondingStream.close()
streamMapping.remove(streamId)
}

for (id in relatedIds) {
requestToStreamMapping.remove(id)
requestResponseMapping.remove(id)
}

}

override suspend fun close() {
streamMapping.values.forEach {
it.close()
}
streamMapping.clear()
requestToStreamMapping.clear()
requestResponseMapping.clear()
// TODO Check if we need to clear the callMapping or if call timeout after awhile
_onClose.invoke()
Comment on lines +118 to +119
Copy link
Preview

Copilot AI Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review whether callMapping should be cleared in the close() method to free resources and avoid potential memory leaks.

Suggested change
// TODO Check if we need to clear the callMapping or if call timeout after awhile
_onClose.invoke()
callMapping.clear() // Clear callMapping to free resources and avoid potential memory leaks

Copilot uses AI. Check for mistakes.

}

@OptIn(ExperimentalUuidApi::class)
public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) {
try {
if (!validateHeaders(call)) return

val messages = parseBody(call)

if (messages.isEmpty()) return

val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method.Defined.Initialize.value }
if (hasInitializationRequest) {
if (initialized.load() && sessionId != null) {
call.response.status(HttpStatusCode.BadRequest)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Defined.InvalidRequest,
message = "Invalid Request: Server already initialized"
)
)
)
return
}

if (messages.size > 1) {
call.response.status(HttpStatusCode.BadRequest)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Defined.InvalidRequest,
message = "Invalid Request: Only one initialization request is allowed"
)
)
)
return
}

if (isStateful) {
sessionId = Uuid.random().toString()
}
initialized.store(true)
}

if (!validateSession(call)) {
return
}

val hasRequests = messages.any { it is JSONRPCRequest }
val streamId = Uuid.random().toString()

if (!hasRequests) {
call.respondNullable(HttpStatusCode.Accepted)
} else {
if (!enableJSONResponse) {
call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString())

if (sessionId != null) {
call.response.header(MCP_SESSION_ID, sessionId!!)
}
}

for (message in messages) {
if (message is JSONRPCRequest) {
streamMapping[streamId] = session
callMapping[streamId] = call
requestToStreamMapping[message.id] = streamId
}
}
}
for (message in messages) {
_onMessage.invoke(message)
}
} catch (e: Exception) {
call.response.status(HttpStatusCode.BadRequest)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Unknown(-32000),
message = e.message ?: "Parse error"
)
)
)
_onError.invoke(e)
}
}

public suspend fun handleGetRequest(call: ApplicationCall, session: ServerSSESession) {
val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf()
if (!acceptHeader.contains("text/event-stream")) {
call.response.status(HttpStatusCode.NotAcceptable)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Unknown(-32000),
message = "Not Acceptable: Client must accept text/event-stream"
)
)
)
}

if (!validateSession(call)) {
return
}

if (sessionId != null) {
call.response.header(MCP_SESSION_ID, sessionId!!)
}

if (streamMapping[standalone] != null) {
call.response.status(HttpStatusCode.Conflict)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Unknown(-32000),
message = "Conflict: Only one SSE stream is allowed per session"
)
)
)
session.close()
return
}

// TODO: Equivalent of typescript res.writeHead(200, headers).flushHeaders();
streamMapping[standalone] = session
}

public suspend fun handleDeleteRequest(call: ApplicationCall) {
if (!validateSession(call)) {
return
}
close()
call.respondNullable(HttpStatusCode.OK)
}

private suspend fun validateSession(call: ApplicationCall): Boolean {
if (sessionId == null) {
return true
}

if (!initialized.load()) {
call.response.status(HttpStatusCode.BadRequest)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Unknown(-32000),
message = "Bad Request: Server not initialized"
)
)
)
return false
}
return true
}

private suspend fun validateHeaders(call: ApplicationCall): Boolean {
val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf()

if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) {
call.response.status(HttpStatusCode.NotAcceptable)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Unknown(-32000),
message = "Not Acceptable: Client must accept both application/json and text/event-stream"
)
)
)
return false
}

val contentType = call.request.contentType()
if (contentType != ContentType.Application.Json) {
call.response.status(HttpStatusCode.UnsupportedMediaType)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Unknown(-32000),
message = "Unsupported Media Type: Content-Type must be application/json"
)
)
)
return false
}

return true
}

private suspend fun parseBody(
call: ApplicationCall,
): List<JSONRPCMessage> {
val messages = mutableListOf<JSONRPCMessage>()
when (val body = call.receive<JsonElement>()) {
is JsonObject -> messages.add(McpJson.decodeFromJsonElement(body))
is JsonArray -> messages.addAll(McpJson.decodeFromJsonElement<List<JSONRPCMessage>>(body))
else -> {
call.response.status(HttpStatusCode.BadRequest)
call.respond(
JSONRPCResponse(
id = null,
error = JSONRPCError(
code = ErrorCode.Defined.InvalidRequest,
message = "Invalid Request: Server already initialized"
)
)
)
return listOf()
}
}
return messages
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package io.modelcontextprotocol.kotlin.sdk.shared

internal const val MCP_SESSION_ID = "mcp-session-id"
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public data class JSONRPCNotification(
*/
@Serializable
public class JSONRPCResponse(
public val id: RequestId,
public val id: RequestId?,
public val jsonrpc: String = JSONRPC_VERSION,
public val result: RequestResult? = null,
public val error: JSONRPCError? = null,
Expand Down
Loading