Skip to content

Commit

Permalink
Add special type classes in shared-compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolay-egorov committed Sep 14, 2021
1 parent dad4df7 commit f16ffa0
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,48 @@ data class SerializedCompiledScriptsData(
}
}

@Serializable
data class SerializableTypeInfo(val type: Type = Type.Custom, val isPrimitive: Boolean = false, val fullType: String = "") {
companion object {
val ignoreSet = setOf("int", "double", "boolean", "char", "float", "byte", "string", "entry")

val propertyNamesForNullFilter = setOf("data", "size")

fun makeFromSerializedVariablesState(type: String?, isContainer: Boolean?): SerializableTypeInfo {
val fullType = type.orEmpty()
val enumType = fullType.toTypeEnum()
val isPrimitive = !(
if (fullType != "Entry") (isContainer ?: false)
else true
)

return SerializableTypeInfo(enumType, isPrimitive, fullType)
}
}
}

@Serializable
enum class Type {
Map,
Entry,
Array,
List,
Custom
}

fun String.toTypeEnum(): Type {
return when (this) {
"Map" -> Type.Map
"Entry" -> Type.Entry
"Array" -> Type.Array
"List" -> Type.List
else -> Type.Custom
}
}

@Serializable
data class SerializedVariablesState(
val type: String = "",
val type: SerializableTypeInfo = SerializableTypeInfo(),
val value: String? = null,
val isContainer: Boolean = false,
val stateId: String = ""
Expand Down
15 changes: 11 additions & 4 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package org.jetbrains.kotlinx.jupyter

import ch.qos.logback.classic.Level
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.encodeToJsonElement
import kotlinx.serialization.json.jsonObject
import org.jetbrains.annotations.TestOnly
import org.jetbrains.kotlinx.jupyter.LoggingManagement.disableLogging
import org.jetbrains.kotlinx.jupyter.LoggingManagement.mainLoggerLevel
Expand Down Expand Up @@ -81,7 +84,6 @@ class OkResponseWithMessage(
)
)
}

socket.send(
makeReplyMessage(
requestMsg,
Expand All @@ -91,7 +93,7 @@ class OkResponseWithMessage(
"engine" to Json.encodeToJsonElement(requestMsg.data.header?.session),
"status" to Json.encodeToJsonElement("ok"),
"started" to Json.encodeToJsonElement(startedTime),
"eval_metadata" to Json.encodeToJsonElement(metadata),
"eval_metadata" to Json.encodeToJsonElement(metadata.convertToNullIfEmpty()),
),
content = ExecuteReply(
MessageStatus.OK,
Expand Down Expand Up @@ -316,7 +318,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
if (data.isEmpty()) return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
log.debug("Message data: $data")
val messageContent = getVariablesDescriptorsFromJson(data)
GlobalScope.launch(Dispatchers.Default) {
connection.launchJob {
repl.serializeVariables(
messageContent.topLevelDescriptorName,
messageContent.descriptorsState,
Expand All @@ -342,7 +344,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
}
}
is SerializationRequest -> {
GlobalScope.launch(Dispatchers.Default) {
connection.launchJob {
if (content.topLevelDescriptorName.isNotEmpty()) {
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, commID = content.commId, content.pathToDescriptor) { result ->
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
Expand Down Expand Up @@ -539,3 +541,8 @@ fun JupyterConnection.evalWithIO(repl: ReplForJupyter, srcMessage: Message, body
KernelStreams.setStreams(false, out, err)
}
}

fun EvaluatedSnippetMetadata?.convertToNullIfEmpty(): JsonElement? {
val jsonNode = Json.encodeToJsonElement(this)
return if (jsonNode is JsonNull || jsonNode?.jsonObject.isEmpty()) null else jsonNode
}
7 changes: 4 additions & 3 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,10 @@ class ReplForJupyterImpl(
val newImports: List<String>
val oldDeclarations: MutableMap<String, Int> = mutableMapOf()
oldDeclarations.putAll(internalEvaluator.getVariablesDeclarationInfo())
val jupyterId = evalData.jupyterId
val result = try {
log.debug("Current cell id: ${evalData.jupyterId}")
executor.execute(evalData.code, evalData.displayHandler, currentCellId = evalData.jupyterId - 1) { internalId, codeToExecute ->
log.debug("Current cell id: $jupyterId")
executor.execute(evalData.code, evalData.displayHandler, currentCellId = jupyterId - 1) { internalId, codeToExecute ->
if (evalData.storeHistory) {
cell = notebook.addCell(internalId, codeToExecute, EvalData(evalData))
}
Expand Down Expand Up @@ -444,7 +445,7 @@ class ReplForJupyterImpl(
// printVars()
// printUsagesInfo(jupyterId, cellVariables[jupyterId - 1])
val variablesCells: Map<String, Int> = notebook.variablesState.mapValues { internalEvaluator.findVariableCell(it.key) }
val serializedData = variablesSerializer.serializeVariables(jupyterId - 1, notebook.variablesState, oldDeclarations, variablesCells, notebook.unchangedVariables())
val serializedData = variablesSerializer.serializeVariables(jupyterId - 1, notebook.variablesState, oldDeclarations, variablesCells, notebook.unchangedVariables)

GlobalScope.launch(Dispatchers.Default) {
variablesSerializer.tryValidateCache(jupyterId - 1, notebook.cellVariables)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import org.jetbrains.kotlinx.jupyter.api.VariableState
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializableTypeInfo
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
import java.lang.reflect.Field
import kotlin.contracts.ExperimentalContracts
Expand Down Expand Up @@ -32,14 +33,14 @@ enum class PropertiesType {
}

@Serializable
data class SerializedCommMessageContent(
data class VariablesStateCommMessageContent(
val topLevelDescriptorName: String,
val descriptorsState: Map<String, SerializedVariablesState>,
val pathToDescriptor: List<String> = emptyList()
)

fun getVariablesDescriptorsFromJson(json: JsonObject): SerializedCommMessageContent {
return Json.decodeFromJsonElement<SerializedCommMessageContent>(json)
fun getVariablesDescriptorsFromJson(json: JsonObject): VariablesStateCommMessageContent {
return Json.decodeFromJsonElement<VariablesStateCommMessageContent>(json)
}

class ProcessedSerializedVarsState(
Expand Down Expand Up @@ -216,7 +217,12 @@ class VariablesSerializer(
} else {
""
}
val serializedVersion = SerializedVariablesState(simpleTypeName, stringedValue, true, varID)
val serializedVersion = SerializedVariablesState(
SerializableTypeInfo.makeFromSerializedVariablesState(simpleTypeName, true),
stringedValue,
true,
varID
)
val descriptors = serializedVersion.fieldDescriptor

// only for set case
Expand Down Expand Up @@ -700,7 +706,12 @@ class VariablesSerializer(
""
}

val serializedVariablesState = SerializedVariablesState(type, getProperString(value), isContainer, finalID)
val serializedVariablesState = SerializedVariablesState(
SerializableTypeInfo.makeFromSerializedVariablesState(simpleTypeName, isContainer),
getProperString(value),
isContainer,
finalID
)

return ProcessedSerializedVarsState(serializedVariablesState, membersProperties?.toTypedArray())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import kotlin.test.assertEquals
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNotEquals
import kotlin.test.assertNotNull
import kotlin.test.fail

class ReplTests : AbstractSingleReplTest() {
Expand Down Expand Up @@ -834,14 +834,13 @@ class ReplVarsTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 1
)
var state = repl.notebook.unchangedVariables()
val res = eval(
eval(
"""
l += 11111
""".trimIndent(),
jupyterId = 2
).metadata.evaluatedVariablesState
state = repl.notebook.unchangedVariables()
val state: Set<String> = repl.notebook.unchangedVariables
assertEquals(1, state.size)
assertTrue(state.contains("m"))
}
Expand Down Expand Up @@ -911,7 +910,7 @@ class ReplVarsTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 1
)
var state = repl.notebook.unchangedVariables()
var state = repl.notebook.unchangedVariables
assertEquals(3, state.size)

eval(
Expand All @@ -922,7 +921,7 @@ class ReplVarsTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 2
)
state = repl.notebook.unchangedVariables()
state = repl.notebook.unchangedVariables
assertEquals(0, state.size)

eval(
Expand All @@ -931,7 +930,7 @@ class ReplVarsTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 3
)
state = repl.notebook.unchangedVariables()
state = repl.notebook.unchangedVariables
assertEquals(1, state.size)
}
}
Expand Down Expand Up @@ -967,7 +966,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
assertEquals(listOf(1, 2, 3, 4).toString().substring(1, actualContainer.value!!.length + 1), actualContainer.value)

val serializer = repl.variablesSerializer
val newData = serializer.doIncrementalSerialization(0, "x", "data", actualContainer)
serializer.doIncrementalSerialization(0, "x", "data", actualContainer)
}

@Test
Expand All @@ -983,7 +982,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
assertEquals(2, varsData.size)
assertTrue(varsData.containsKey("x"))
assertTrue(varsData.containsKey("f"))
var unchangedVariables = repl.notebook.unchangedVariables()
var unchangedVariables = repl.notebook.unchangedVariables
assertTrue(unchangedVariables.isNotEmpty())

eval(
Expand All @@ -992,7 +991,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 1
)
unchangedVariables = repl.notebook.unchangedVariables()
unchangedVariables = repl.notebook.unchangedVariables
assertTrue(unchangedVariables.contains("x"))
assertTrue(unchangedVariables.contains("f"))
}
Expand Down Expand Up @@ -1056,7 +1055,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {

val serializer = repl.variablesSerializer

val newData = serializer.doIncrementalSerialization(0, "c", "i", descriptor["i"]!!)
serializer.doIncrementalSerialization(0, "c", "i", descriptor["i"]!!)
}

@Test
Expand Down Expand Up @@ -1345,7 +1344,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 1
)
val state = repl.notebook.unchangedVariables()
val state = repl.notebook.unchangedVariables
val setOfCell = setOf("x", "f", "z")
assertTrue(state.isNotEmpty())
assertEquals(setOfCell, state)
Expand All @@ -1372,7 +1371,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 1
)
var state = repl.notebook.unchangedVariables()
var state = repl.notebook.unchangedVariables
val setOfCell = setOf("x", "f", "z")
assertTrue(state.isNotEmpty())
assertEquals(setOfCell, state)
Expand All @@ -1396,9 +1395,9 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 3
)
state = repl.notebook.unchangedVariables()
// assertTrue(state.isNotEmpty())
// assertEquals(state, setOfPrevCell)
state = repl.notebook.unchangedVariables
assertTrue(state.isEmpty())
// assertEquals(state, setOfPrevCell)

eval(
"""
Expand All @@ -1408,20 +1407,20 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
""".trimIndent(),
jupyterId = 4
)
state = repl.notebook.unchangedVariables()
state = repl.notebook.unchangedVariables
assertTrue(state.isEmpty())
}

@Test
fun testSerializationClearInfo() {
var res = eval(
eval(
"""
val x = listOf(1, 2, 3, 4)
""".trimIndent(),
jupyterId = 1
).metadata.evaluatedVariablesState
var state = repl.notebook.unchangedVariables()
res = eval(
repl.notebook.unchangedVariables
eval(
"""
val x = listOf(1, 2, 3, 4)
""".trimIndent(),
Expand Down

0 comments on commit f16ffa0

Please sign in to comment.