Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ public open class Client(private val clientInfo: Implementation, options: Client
* @throws IllegalStateException If the server does not support logging.
*/
public suspend fun setLoggingLevel(level: LoggingLevel, options: RequestOptions? = null): EmptyRequestResult =
request<EmptyRequestResult>(SetLevelRequest(level), options)
request(SetLevelRequest(level), options)

/**
* Retrieves a prompt by name from the server.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import io.modelcontextprotocol.kotlin.sdk.InitializedNotification
import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION
import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest
import io.modelcontextprotocol.kotlin.sdk.ListRootsResult
import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification
import io.modelcontextprotocol.kotlin.sdk.Method
import io.modelcontextprotocol.kotlin.sdk.Method.Defined
Expand All @@ -27,6 +28,8 @@ import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CompletableDeferred
import kotlinx.serialization.json.JsonObject

Expand All @@ -43,22 +46,6 @@ public open class ServerSession(
@Suppress("ktlint:standard:backing-property-naming")
private var _onClose: () -> Unit = {}

init {
// Core protocol handlers
setRequestHandler<InitializeRequest>(Method.Defined.Initialize) { request, _ ->
handleInitialize(request)
}
setNotificationHandler<InitializedNotification>(Method.Defined.NotificationsInitialized) {
_onInitialized()
CompletableDeferred(Unit)
}
}

/**
* The capabilities supported by the server, related to the session.
*/
private val serverCapabilities = options.capabilities

/**
* The client's reported capabilities after initialization.
*/
Expand All @@ -71,6 +58,37 @@ public open class ServerSession(
public var clientVersion: Implementation? = null
private set

/**
* The capabilities supported by the server, related to the session.
*/
private val serverCapabilities = options.capabilities

/**
* The current logging level set by the client.
* When null, all messages are sent (no filtering).
*/
private val currentLoggingLevel: AtomicRef<LoggingLevel?> = atomic(null)

init {
// Core protocol handlers
setRequestHandler<InitializeRequest>(Defined.Initialize) { request, _ ->
handleInitialize(request)
}
setNotificationHandler<InitializedNotification>(Defined.NotificationsInitialized) {
_onInitialized()
CompletableDeferred(Unit)
}

// Logging level handler
if (options.capabilities.logging != null) {
setRequestHandler<LoggingMessageNotification.SetLevelRequest>(Defined.LoggingSetLevel) { request, _ ->
currentLoggingLevel.value = request.level
logger.debug { "Logging level set to: ${request.level}" }
EmptyRequestResult()
}
}
}

/**
* Registers a callback to be invoked when the server has completed initialization.
*/
Expand Down Expand Up @@ -160,12 +178,20 @@ public open class ServerSession(

/**
* Sends a logging message notification to the client.
* Messages are filtered based on the current logging level set by the client.
* If no logging level is set, all messages are sent.
*
* @param notification The logging message notification.
*/
public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) {
logger.trace { "Sending logging message: ${notification.params.data}" }
notification(notification)
if (serverCapabilities.logging != null) {
if (isMessageAccepted(notification.params.level)) {
logger.trace { "Sending logging message: ${notification.params.data}" }
notification(notification)
} else {
logger.trace { "Filtering out logging message with level ${notification.params.level}" }
}
}
}

/**
Expand Down Expand Up @@ -318,6 +344,7 @@ public open class ServerSession(

Defined.LoggingSetLevel -> {
if (serverCapabilities.logging == null) {
logger.error { "Server does not support logging (required for $method)" }
throw IllegalStateException("Server does not support logging (required for $method)")
}
}
Expand Down Expand Up @@ -381,4 +408,24 @@ public open class ServerSession(
instructions = instructions,
)
}

/**
* Checks if a message with the given level should be ignored based on the current logging level.
*
* @param level The level of the message to check.
* @return true if the message should be ignored (filtered out), false otherwise.
*/
private fun isMessageIgnored(level: LoggingLevel): Boolean {
val current = currentLoggingLevel.value ?: return false // If no level is set, don't filter

return level.ordinal < current.ordinal
}

/**
* Checks if a message with the given level should be accepted based on the current logging level.
*
* @param level The level of the message to check.
* @return true if the message should be accepted (not filtered out), false otherwise.
*/
private fun isMessageAccepted(level: LoggingLevel): Boolean = !isMessageIgnored(level)
}
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,86 @@ class ClientTest {
)
}

@Test
fun `should handle logging setLevel request`() = runTest {
val server = Server(
Implementation(name = "test server", version = "1.0"),
ServerOptions(
capabilities = ServerCapabilities(
logging = EmptyJsonObject,
),
),
)

val client = Client(
clientInfo = Implementation(name = "test client", version = "1.0"),
options = ClientOptions(
capabilities = ClientCapabilities(),
),
)

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

val receivedMessages = mutableListOf<LoggingMessageNotification>()
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) { notification ->
receivedMessages.add(notification)
CompletableDeferred(Unit)
}

val serverSessionResult = CompletableDeferred<ServerSession>()

listOf(
launch {
client.connect(clientTransport)
println("Client connected")
},
launch {
serverSessionResult.complete(server.connect(serverTransport))
println("Server connected")
},
).joinAll()

val serverSession = serverSessionResult.await()

// Set logging level to warning
val minLevel = LoggingLevel.warning
val result = client.setLoggingLevel(minLevel)
assertEquals(EmptyJsonObject, result._meta)

// Send messages of different levels
val testMessages = listOf(
LoggingLevel.debug to "Debug - should be filtered",
LoggingLevel.info to "Info - should be filtered",
LoggingLevel.warning to "Warning - should pass",
LoggingLevel.error to "Error - should pass",
)

testMessages.forEach { (level, message) ->
serverSession.sendLoggingMessage(
LoggingMessageNotification(
params = LoggingMessageNotification.Params(
level = level,
data = buildJsonObject { put("message", message) },
),
),
)
}

delay(100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awaitility is more correct than delay

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's common code, I used delay


// Only warning and error should be received
assertEquals(2, receivedMessages.size, "Should receive only 2 messages (warning and error)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't awaitility be used here?

It will be needed for sure to verify rate-limiting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's common code, I used delay


// Verify all received messages have severity >= minLevel
receivedMessages.forEach { message ->
val messageSeverity = message.params.level.ordinal
assertTrue(
messageSeverity >= minLevel.ordinal,
"Received message with level ${message.params.level} should have severity >= $minLevel",
)
}
}

@Test
fun `should handle server elicitation`() = runTest {
val client = Client(
Expand Down
Loading