-
Notifications
You must be signed in to change notification settings - Fork 175
Add Server Streamable Http Transport #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging | |
| import io.ktor.http.HttpStatusCode | ||
| import io.ktor.server.application.Application | ||
| import io.ktor.server.application.install | ||
| import io.ktor.server.request.header | ||
| import io.ktor.server.response.respond | ||
| import io.ktor.server.routing.Routing | ||
| import io.ktor.server.routing.RoutingContext | ||
|
|
@@ -19,16 +20,20 @@ import kotlinx.atomicfu.atomic | |
| import kotlinx.atomicfu.update | ||
| import kotlinx.collections.immutable.PersistentMap | ||
| import kotlinx.collections.immutable.toPersistentMap | ||
| import io.modelcontextprotocol.kotlin.sdk.ErrorCode | ||
| import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport | ||
|
|
||
| private val logger = KotlinLogging.logger {} | ||
|
|
||
| internal class SseTransportManager(transports: Map<String, SseServerTransport> = emptyMap()) { | ||
| private val transports: AtomicRef<PersistentMap<String, SseServerTransport>> = atomic(transports.toPersistentMap()) | ||
| internal class TransportManager(transports: Map<String, AbstractTransport> = emptyMap()) { | ||
| private val transports: AtomicRef<PersistentMap<String, AbstractTransport>> = atomic(transports.toPersistentMap()) | ||
|
|
||
| fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId] | ||
| fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId) | ||
|
|
||
| fun addTransport(transport: SseServerTransport) { | ||
| transports.update { it.put(transport.sessionId, transport) } | ||
| fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId] | ||
|
|
||
| fun addTransport(sessionId: String, transport: AbstractTransport) { | ||
| transports.update { it.put(sessionId, transport) } | ||
| } | ||
|
|
||
| fun removeTransport(sessionId: String) { | ||
|
|
@@ -48,14 +53,14 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) { | |
| */ | ||
| @KtorDsl | ||
| public fun Routing.mcp(block: ServerSSESession.() -> Server) { | ||
| val sseTransportManager = SseTransportManager() | ||
| val transportManager = TransportManager() | ||
|
|
||
| sse { | ||
| mcpSseEndpoint("", sseTransportManager, block) | ||
| mcpSseEndpoint("", transportManager, block) | ||
| } | ||
|
|
||
| post { | ||
| mcpPostEndpoint(sseTransportManager) | ||
| mcpPostEndpoint(transportManager) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -74,18 +79,71 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) { | |
| } | ||
| } | ||
|
|
||
| internal suspend fun ServerSSESession.mcpSseEndpoint( | ||
| /* | ||
| * Configures the Ktor Application to handle Model Context Protocol (MCP) over Streamable Http. | ||
| * It currently only works with JSON response. | ||
| */ | ||
| @KtorDsl | ||
| public fun Application.mcpStreamableHttp( | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| val transportManager = TransportManager() | ||
|
|
||
| routing { | ||
| post("/mcp") { | ||
| mcpStreamableHttpEndpoint( | ||
| transportManager, | ||
| enableDnsRebindingProtection, | ||
| allowedHosts, | ||
| allowedOrigins, | ||
| eventStore, | ||
| block, | ||
| ) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /* | ||
| * Configures the Ktor Application to handle Model Context Protocol (MCP) over stateless Streamable Http. | ||
| * It currently only works with JSON response. | ||
| */ | ||
| @KtorDsl | ||
| public fun Application.mcpStatelessStreamableHttp( | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| routing { | ||
| post("/mcp") { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't implement because it's not needed for json response. Once we add SSE back, we can do it then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http The server MUST either return Content-Type: text/event-stream in response to this HTTP GET, or else return HTTP 405 Method Not Allowed, indicating that the server does not offer an SSE stream at this endpoint Without processing (which responds with code 405), it seems the inspector was spamming errors. I replaced |
||
| mcpStatelessStreamableHttpEndpoint( | ||
| enableDnsRebindingProtection, | ||
| allowedHosts, | ||
| allowedOrigins, | ||
| eventStore, | ||
| block, | ||
| ) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private suspend fun ServerSSESession.mcpSseEndpoint( | ||
| postEndpoint: String, | ||
| sseTransportManager: SseTransportManager, | ||
| transportManager: TransportManager, | ||
| block: ServerSSESession.() -> Server, | ||
| ) { | ||
| val transport = mcpSseTransport(postEndpoint, sseTransportManager) | ||
| val transport = mcpSseTransport(postEndpoint, transportManager) | ||
|
|
||
| val server = block() | ||
|
|
||
| server.onClose { | ||
| logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } | ||
| sseTransportManager.removeTransport(transport.sessionId) | ||
| transportManager.removeTransport(transport.sessionId) | ||
| } | ||
|
|
||
| server.connect(transport) | ||
|
|
@@ -95,24 +153,106 @@ internal suspend fun ServerSSESession.mcpSseEndpoint( | |
|
|
||
| internal fun ServerSSESession.mcpSseTransport( | ||
| postEndpoint: String, | ||
| sseTransportManager: SseTransportManager, | ||
| transportManager: TransportManager, | ||
| ): SseServerTransport { | ||
| val transport = SseServerTransport(postEndpoint, this) | ||
| sseTransportManager.addTransport(transport) | ||
| transportManager.addTransport(transport.sessionId, transport) | ||
| logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } | ||
|
|
||
| return transport | ||
| } | ||
|
|
||
| internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTransportManager) { | ||
| internal suspend fun RoutingContext.mcpStreamableHttpEndpoint( | ||
| transportManager: TransportManager, | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| val sessionId = this.call.request.header(MCP_SESSION_ID_HEADER) | ||
| val transport = if (sessionId != null && transportManager.hasTransport(sessionId)) { | ||
| transportManager.getTransport(sessionId) | ||
| } else if (sessionId == null) { | ||
| val transport = StreamableHttpServerTransport( | ||
| enableDnsRebindingProtection = enableDnsRebindingProtection, | ||
| allowedHosts = allowedHosts, | ||
| allowedOrigins = allowedOrigins, | ||
| eventStore = eventStore, | ||
| enableJsonResponse = true, | ||
| ) | ||
|
|
||
| transport.setOnSessionInitialized { sessionId -> | ||
| transportManager.addTransport(sessionId, transport) | ||
|
|
||
| logger.info { "New StreamableHttp connection established and stored with sessionId: $sessionId" } | ||
| } | ||
|
|
||
| val server = block() | ||
| server.onClose { | ||
| logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } | ||
| } | ||
|
|
||
| server.connect(transport) | ||
|
|
||
| transport | ||
| } else { | ||
| null | ||
| } | ||
|
|
||
| if (transport == null) { | ||
| this.call.reject( | ||
| HttpStatusCode.BadRequest, | ||
| ErrorCode.Unknown(-32000), | ||
| "Bad Request: No valid session ID provided", | ||
| ) | ||
| return | ||
| } | ||
|
|
||
| (transport as StreamableHttpServerTransport).handleRequest(null, this.call) | ||
| logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } | ||
| } | ||
|
|
||
| internal suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| val transport = StreamableHttpServerTransport( | ||
| enableDnsRebindingProtection = enableDnsRebindingProtection, | ||
| allowedHosts = allowedHosts, | ||
| allowedOrigins = allowedOrigins, | ||
| eventStore = eventStore, | ||
| enableJsonResponse = true, | ||
| ) | ||
| transport.setSessionIdGenerator(null) | ||
|
|
||
| logger.info { "New stateless StreamableHttp connection established without sessionId" } | ||
|
|
||
| val server = block() | ||
|
|
||
| server.onClose { | ||
| logger.info { "Server connection closed without sessionId" } | ||
| } | ||
|
|
||
| server.connect(transport) | ||
|
|
||
| transport.handleRequest(null, this.call) | ||
|
|
||
| logger.debug { "Server connected to transport without sessionId" } | ||
| } | ||
|
|
||
| internal suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) { | ||
| val sessionId: String = call.request.queryParameters["sessionId"] ?: run { | ||
| call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") | ||
| return | ||
| } | ||
|
|
||
| logger.debug { "Received message for sessionId: $sessionId" } | ||
|
|
||
| val transport = sseTransportManager.getTransport(sessionId) | ||
| val transport = transportManager.getTransport(sessionId) as SseServerTransport? | ||
| if (transport == null) { | ||
| logger.warn { "Session not found for sessionId: $sessionId" } | ||
| call.respond(HttpStatusCode.NotFound, "Session not found") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.