Skip to content

Commit 1162179

Browse files
devcrocodkpavlov
authored andcommitted
Add atomicref for logingLevel and refactor test
1 parent ae09c0b commit 1162179

File tree

2 files changed

+40
-50
lines changed
  • kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server
  • kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client

2 files changed

+40
-50
lines changed

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
2828
import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification
2929
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
3030
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
31+
import kotlinx.atomicfu.AtomicRef
32+
import kotlinx.atomicfu.atomic
3133
import kotlinx.coroutines.CompletableDeferred
3234
import kotlinx.serialization.json.JsonObject
3335

@@ -65,14 +67,7 @@ public open class ServerSession(
6567
* The current logging level set by the client.
6668
* When null, all messages are sent (no filtering).
6769
*/
68-
private var currentLoggingLevel: LoggingLevel? = null
69-
70-
/**
71-
* Map of LoggingLevel to severity index for comparison.
72-
* Higher index means higher severity.
73-
*/
74-
private val loggingLevelSeverity: Map<LoggingLevel, Int> = LoggingLevel.entries.withIndex()
75-
.associate { (index, level) -> level to index }
70+
private val currentLoggingLevel: AtomicRef<LoggingLevel?> = atomic(null)
7671

7772
init {
7873
// Core protocol handlers
@@ -87,7 +82,7 @@ public open class ServerSession(
8782
// Logging level handler
8883
if (options.capabilities.logging != null) {
8984
setRequestHandler<LoggingMessageNotification.SetLevelRequest>(Defined.LoggingSetLevel) { request, _ ->
90-
currentLoggingLevel = request.level
85+
currentLoggingLevel.value = request.level
9186
logger.debug { "Logging level set to: ${request.level}" }
9287
EmptyRequestResult()
9388
}
@@ -190,7 +185,7 @@ public open class ServerSession(
190185
*/
191186
public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) {
192187
if (serverCapabilities.logging != null) {
193-
if (!isMessageIgnored(notification.params.level)) {
188+
if (isMessageAccepted(notification.params.level)) {
194189
logger.trace { "Sending logging message: ${notification.params.data}" }
195190
notification(notification)
196191
} else {
@@ -421,11 +416,16 @@ public open class ServerSession(
421416
* @return true if the message should be ignored (filtered out), false otherwise.
422417
*/
423418
private fun isMessageIgnored(level: LoggingLevel): Boolean {
424-
val current = currentLoggingLevel ?: return false // If no level is set, don't filter
419+
val current = currentLoggingLevel.value ?: return false // If no level is set, don't filter
425420

426-
val messageSeverity = loggingLevelSeverity[level] ?: return false
427-
val currentSeverity = loggingLevelSeverity[current] ?: return false
428-
429-
return messageSeverity < currentSeverity
421+
return level.ordinal < current.ordinal
430422
}
423+
424+
/**
425+
* Checks if a message with the given level should be accepted based on the current logging level.
426+
*
427+
* @param level The level of the message to check.
428+
* @return true if the message should be accepted (not filtered out), false otherwise.
429+
*/
430+
private fun isMessageAccepted(level: LoggingLevel): Boolean = !isMessageIgnored(level)
431431
}

kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -928,52 +928,42 @@ class ClientTest {
928928
val serverSession = serverSessionResult.await()
929929

930930
// Set logging level to warning
931-
val result = client.setLoggingLevel(LoggingLevel.warning)
931+
val minLevel = LoggingLevel.warning
932+
val result = client.setLoggingLevel(minLevel)
932933
assertEquals(EmptyJsonObject, result._meta)
933934

934935
// Send messages of different levels
935-
serverSession.sendLoggingMessage(
936-
LoggingMessageNotification(
937-
params = LoggingMessageNotification.Params(
938-
level = LoggingLevel.debug,
939-
data = buildJsonObject { put("message", "Debug - should be filtered") },
940-
),
941-
),
936+
val testMessages = listOf(
937+
LoggingLevel.debug to "Debug - should be filtered",
938+
LoggingLevel.info to "Info - should be filtered",
939+
LoggingLevel.warning to "Warning - should pass",
940+
LoggingLevel.error to "Error - should pass",
942941
)
943942

944-
serverSession.sendLoggingMessage(
945-
LoggingMessageNotification(
946-
params = LoggingMessageNotification.Params(
947-
level = LoggingLevel.info,
948-
data = buildJsonObject { put("message", "Info - should be filtered") },
949-
),
950-
),
951-
)
952-
953-
serverSession.sendLoggingMessage(
954-
LoggingMessageNotification(
955-
params = LoggingMessageNotification.Params(
956-
level = LoggingLevel.warning,
957-
data = buildJsonObject { put("message", "Warning - should pass") },
943+
testMessages.forEach { (level, message) ->
944+
serverSession.sendLoggingMessage(
945+
LoggingMessageNotification(
946+
params = LoggingMessageNotification.Params(
947+
level = level,
948+
data = buildJsonObject { put("message", message) },
949+
),
958950
),
959-
),
960-
)
951+
)
952+
}
961953

962-
serverSession.sendLoggingMessage(
963-
LoggingMessageNotification(
964-
params = LoggingMessageNotification.Params(
965-
level = LoggingLevel.error,
966-
data = buildJsonObject { put("message", "Error - should pass") },
967-
),
968-
),
969-
)
954+
delay(100)
970955

971956
// Only warning and error should be received
972957
assertEquals(2, receivedMessages.size, "Should receive only 2 messages (warning and error)")
973958

974-
val levels = receivedMessages.map { it.params.level }
975-
assertTrue(levels.contains(LoggingLevel.warning), "Should receive warning message")
976-
assertTrue(levels.contains(LoggingLevel.error), "Should receive error message")
959+
// Verify all received messages have severity >= minLevel
960+
receivedMessages.forEach { message ->
961+
val messageSeverity = message.params.level.ordinal
962+
assertTrue(
963+
messageSeverity >= minLevel.ordinal,
964+
"Received message with level ${message.params.level} should have severity >= $minLevel",
965+
)
966+
}
977967
}
978968

979969
@Test

0 commit comments

Comments
 (0)