Skip to content

Commit

Permalink
Make serialization contain comm_id to respect jupyter comm handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolay-egorov committed Sep 13, 2021
1 parent 64114f0 commit dad4df7
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ data class SerializedVariablesState(

@Serializable
class SerializationReply(
val cellId: Int = 1,
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
val cell_id: Int = 1,
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
val comm_id: String = ""
)

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,15 @@ class SerializationRequest(
val cellId: Int,
val descriptorsState: Map<String, SerializedVariablesState>,
val topLevelDescriptorName: String = "",
val pathToDescriptor: List<String> = emptyList()
val pathToDescriptor: List<String> = emptyList(),
val commId: String = ""
) : MessageContent()

@Serializable
class SerializationReply(
val cellId: Int = 1,
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
val cell_id: Int = 1,
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap(),
val comm_id: String = ""
) : MessageContent()

@Serializable(MessageDataSerializer::class)
Expand Down
10 changes: 6 additions & 4 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -307,21 +307,23 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_INFO_REPLY, content = CommInfoReply(mapOf())))
}
is CommOpen -> {
if (!content.commId.equals(MessageType.SERIALIZATION_REQUEST.name, ignoreCase = true)) {
if (!content.targetName.equals("kotlin_serialization", ignoreCase = true)) {
send(makeReplyMessage(msg, MessageType.NONE))
return
}
log.debug("Message type in CommOpen: $msg, ${msg.type}")
val data = content.data ?: return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))

if (data.isEmpty()) return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
log.debug("Message data: $data")
val messageContent = getVariablesDescriptorsFromJson(data)
GlobalScope.launch(Dispatchers.Default) {
repl.serializeVariables(
messageContent.topLevelDescriptorName,
messageContent.descriptorsState,
content.commId,
messageContent.pathToDescriptor
) { result ->
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_OPEN, content = result))
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_MSG, content = result))
}
}
}
Expand All @@ -342,7 +344,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
is SerializationRequest -> {
GlobalScope.launch(Dispatchers.Default) {
if (content.topLevelDescriptorName.isNotEmpty()) {
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, content.pathToDescriptor) { result ->
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, commID = content.commId, content.pathToDescriptor) { result ->
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
}
} else {
Expand Down
10 changes: 5 additions & 5 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ interface ReplForJupyter {

suspend fun serializeVariables(cellId: Int, topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)

suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String> = emptyList(),
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String = "", pathToDescriptor: List<String> = emptyList(),
callback: (SerializationReply) -> Unit)

val homeDir: File?
Expand Down Expand Up @@ -552,9 +552,8 @@ class ReplForJupyterImpl(
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
}

override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String>,
callback: (SerializationReply) -> Unit) {
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, commID: String, pathToDescriptor: List<String>, callback: (SerializationReply) -> Unit) {
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, comm_id = commID ,pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables)
}

private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
Expand All @@ -569,7 +568,7 @@ class ReplForJupyterImpl(
}
log.debug("Serialization cellID: $cellId")
log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}")
return SerializationReply(cellId, resultMap)
return SerializationReply(cellId, resultMap, args.comm_id)
}


Expand Down Expand Up @@ -610,6 +609,7 @@ class ReplForJupyterImpl(
var cellId: Int = -1,
val topLevelVarName: String = "",
val pathToDescriptor: List<String> = emptyList(),
val comm_id: String = "",
override val callback: (SerializationReply) -> Unit
) : LockQueueArgs<SerializationReply>

Expand Down

0 comments on commit dad4df7

Please sign in to comment.