From 961a320ab2664d13dc35cb4e792a2646ac6f1223 Mon Sep 17 00:00:00 2001 From: nikolay-egorov Date: Mon, 9 Aug 2021 11:38:30 +0300 Subject: [PATCH] Add suspended cache validation; Add possibility to remove old variables from cache; some improvements --- .../kotlinx/jupyter/message_types.kt | 3 +- .../org/jetbrains/kotlinx/jupyter/protocol.kt | 8 +- .../org/jetbrains/kotlinx/jupyter/repl.kt | 39 ++-- .../repl/impl/InternalEvaluatorImpl.kt | 18 +- .../kotlinx/jupyter/serializationUtils.kt | 182 ++++++++++++++---- .../org/jetbrains/kotlinx/jupyter/util.kt | 22 ++- .../jupyter/test/repl/AbstractReplTest.kt | 5 - .../kotlinx/jupyter/test/repl/ReplTests.kt | 169 ++++++++++++++-- 8 files changed, 375 insertions(+), 71 deletions(-) diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt index 48c3c68ff..81ac465cf 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt @@ -578,7 +578,8 @@ class ListErrorsReply( class SerializationRequest( val cellId: Int, val descriptorsState: Map, - val topLevelDescriptorName: String = "" + val topLevelDescriptorName: String = "", + val pathToDescriptor: List = emptyList() ) : MessageContent() @Serializable diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt index 8f0d23db7..86336f958 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt @@ -315,7 +315,11 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup val messageContent = getVariablesDescriptorsFromJson(data) GlobalScope.launch(Dispatchers.Default) { - repl.serializeVariables(messageContent.topLevelDescriptorName, messageContent.descriptorsState) { result -> + repl.serializeVariables( + messageContent.topLevelDescriptorName, + messageContent.descriptorsState, + messageContent.pathToDescriptor + ) { result -> sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_OPEN, content = result)) } } @@ -337,7 +341,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup is SerializationRequest -> { GlobalScope.launch(Dispatchers.Default) { if (content.topLevelDescriptorName.isNotEmpty()) { - repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState) { result -> + repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState, content.pathToDescriptor) { result -> sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result)) } } else { diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt index 63fedbebb..6d4b41620 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt @@ -5,6 +5,10 @@ import jupyter.kotlin.DependsOn import jupyter.kotlin.KotlinContext import jupyter.kotlin.KotlinKernelHostProvider import jupyter.kotlin.Repository +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.launch +import org.jetbrains.annotations.TestOnly import org.jetbrains.kotlin.config.KotlinCompilerVersion import org.jetbrains.kotlinx.jupyter.api.Code import org.jetbrains.kotlinx.jupyter.api.ExecutionCallback @@ -48,6 +52,7 @@ import org.jetbrains.kotlinx.jupyter.repl.CellExecutor import org.jetbrains.kotlinx.jupyter.repl.CompletionResult import org.jetbrains.kotlinx.jupyter.repl.ContextUpdater import org.jetbrains.kotlinx.jupyter.repl.EvalResult +import org.jetbrains.kotlinx.jupyter.repl.EvalResultEx import org.jetbrains.kotlinx.jupyter.repl.InternalEvaluator import org.jetbrains.kotlinx.jupyter.repl.KotlinCompleter import org.jetbrains.kotlinx.jupyter.repl.ListErrorsResult @@ -120,7 +125,8 @@ interface ReplForJupyter { suspend fun serializeVariables(cellId: Int, descriptorsState: Map, callback: (SerializationReply) -> Unit) - suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map, callback: (SerializationReply) -> Unit) + suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map, pathToDescriptor: List = emptyList(), + callback: (SerializationReply) -> Unit) val homeDir: File? @@ -191,7 +197,7 @@ class ReplForJupyterImpl( override val variablesSerializer = VariablesSerializer() - private val librariesScanner = LibrariesScanner(notebook) + val librariesScanner = LibrariesScanner(notebook) private val resourcesProcessor = LibraryResourcesProcessorImpl() override var outputConfig @@ -347,7 +353,7 @@ class ReplForJupyterImpl( ) private var evalContextEnabled = false - private fun withEvalContext(action: () -> EvalResult): EvalResult { + private fun withEvalContext(action: () -> T): T { return synchronized(this) { evalContextEnabled = true try { @@ -365,14 +371,14 @@ class ReplForJupyterImpl( else context.compilationConfiguration.asSuccess() } - /** - * Used for debug purposes. - * @see ReplCommand - */ + @TestOnly + @Suppress("unused") private fun printVariables(isHtmlFormat: Boolean = false) = log.debug( if (isHtmlFormat) notebook.variablesReportAsHTML() else notebook.variablesReport() ) + @TestOnly + @Suppress("unused") private fun printUsagesInfo(cellId: Int, usedVariables: Set?) { log.debug(buildString { if (usedVariables == null || usedVariables.isEmpty()) { @@ -386,7 +392,7 @@ class ReplForJupyterImpl( }) } - fun evalEx(code: Code, displayHandler: DisplayHandler?, jupyterId: Int): EvalResult { + fun evalEx(code: Code, displayHandler: DisplayHandler?, jupyterId: Int): EvalResultEx { return withEvalContext { rethrowAsLibraryException(LibraryProblemPart.BEFORE_CELL_CALLBACKS) { beforeCellExecution.forEach { executor.execute(it) } @@ -426,8 +432,15 @@ class ReplForJupyterImpl( // printUsagesInfo(jupyterId, cellVariables[jupyterId - 1]) val serializedData = variablesSerializer.serializeVariables(jupyterId - 1, notebook.variablesState, notebook.unchangedVariables()) - EvalResult( + GlobalScope.launch(Dispatchers.Default) { + variablesSerializer.tryValidateCache(jupyterId - 1, notebook.cellVariables) + } + + EvalResultEx( result.result.value, + rendered, + result.scriptInstance, + result.result.name, EvaluatedSnippetMetadata(newClasspath, compiledData, newImports, serializedData), ) } @@ -525,8 +538,9 @@ class ReplForJupyterImpl( doWithLock(SerializationArgs(descriptorsState, cellId = cellId, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables) } - override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map, callback: (SerializationReply) -> Unit) { - doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(), ::doSerializeVariables) + override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map, pathToDescriptor: List, + callback: (SerializationReply) -> Unit) { + doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback, pathToDescriptor = pathToDescriptor), serializationQueue, SerializationReply(), ::doSerializeVariables) } private fun doSerializeVariables(args: SerializationArgs): SerializationReply { @@ -537,7 +551,7 @@ class ReplForJupyterImpl( finalAns } args.descriptorsState.forEach { (name, state) -> - resultMap[name] = variablesSerializer.doIncrementalSerialization(cellId - 1, name, state) + resultMap[name] = variablesSerializer.doIncrementalSerialization(cellId - 1, name, state, args.pathToDescriptor) } log.debug("Serialization cellID: $cellId") log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}") @@ -581,6 +595,7 @@ class ReplForJupyterImpl( val descriptorsState: Map, var cellId: Int = -1, val topLevelVarName: String = "", + val pathToDescriptor: List = emptyList(), override val callback: (SerializationReply) -> Unit ) : LockQueueArgs diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/InternalEvaluatorImpl.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/InternalEvaluatorImpl.kt index 976cb63ed..383bf6246 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/InternalEvaluatorImpl.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/InternalEvaluatorImpl.kt @@ -151,7 +151,6 @@ internal class InternalEvaluatorImpl( private fun updateVariablesState(cellId: Int) { variablesWatcher.removeOldUsages(cellId) - variablesHolder.forEach { val state = it.value as VariableStateImpl val oldValue = state.stringValue @@ -173,11 +172,23 @@ internal class InternalEvaluatorImpl( it.name }.toHashSet() val ans = mutableMapOf() + // maybe remove known declarations + val addedDeclarations = mutableSetOf() + fields.forEach { property -> if (!memberKPropertiesNames.contains(property.name)) return@forEach val state = VariableStateImpl(property, cellClassInstance) variablesWatcher.addDeclaration(cellId, property.name) + addedDeclarations.add(property.name) + + // try check values + if (variablesHolder.containsKey(property.name)) { + val seenState = variablesHolder[property.name] + if (seenState?.value?.equals(state.value) == true) { + addedDeclarations.remove(property.name) + } + } // it was val, now it's var if (isValField(property)) { @@ -189,6 +200,9 @@ internal class InternalEvaluatorImpl( ans[property.name] = state } + // remove old + variablesWatcher.removeOldDeclarations(cellId, addedDeclarations) + return ans } @@ -199,7 +213,7 @@ internal class InternalEvaluatorImpl( private fun updateDataAfterExecution(lastExecutionCellId: Int, resultValue: ResultValue) { variablesWatcher.ensureStorageCreation(lastExecutionCellId) variablesHolder += getVisibleVariables(resultValue, lastExecutionCellId) - + // remove unreached variables updateVariablesState(lastExecutionCellId) } } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/serializationUtils.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/serializationUtils.kt index 44592c5f8..1f2098a0e 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/serializationUtils.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/serializationUtils.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState import java.lang.reflect.Field import kotlin.contracts.ExperimentalContracts import kotlin.contracts.contract +import kotlin.math.abs import kotlin.reflect.KClass import kotlin.reflect.KProperty import kotlin.reflect.KProperty1 @@ -32,7 +33,8 @@ enum class PropertiesType { @Serializable data class SerializedCommMessageContent( val topLevelDescriptorName: String, - val descriptorsState: Map + val descriptorsState: Map, + val pathToDescriptor: List = emptyList() ) fun getVariablesDescriptorsFromJson(json: JsonObject): SerializedCommMessageContent { @@ -73,14 +75,42 @@ data class RuntimeObjectWrapper( fun Any?.toObjectWrapper(): RuntimeObjectWrapper = RuntimeObjectWrapper(this) -class VariablesSerializer(private val serializationDepth: Int = 2, private val serializationLimit: Int = 10000) { +/** + * Provides contract for using threshold-based removal heuristic. + * Every serialization-related info in [T] would be removed once [isShouldRemove] == true. + * Default: T = Int, cellID + */ +interface ClearableSerializer { + fun isShouldRemove(currentState: T): Boolean + + suspend fun clearStateInfo(currentState: T) +} + +class VariablesSerializer( + private val serializationDepth: Int = 2, + private val serializationLimit: Int = 10000, + private val cellCountRemovalThreshold: Int = 5, + // let's make this flag customizable from Jupyter config menu + val shouldRemoveOldVariablesFromCache: Boolean = true +) : ClearableSerializer { fun MutableMap.addDescriptor(value: Any?, name: String = value.toString()) { + val typeName = if (value != null) value::class.simpleName else "null" this[name] = createSerializeVariableState( name, - if (value != null) value::class.simpleName else "null", + typeName, value ).serializedVariablesState + if (typeName != null && typeName == "Entry") { + val descriptor = this[name] + value as Map.Entry<*, *> + val valueType = if (value.value != null) value.value!!::class.simpleName else "null" + descriptor!!.fieldDescriptor[value.key.toString()] = createSerializeVariableState( + value.key.toString(), + valueType, + value.value + ).serializedVariablesState + } } /** @@ -94,7 +124,8 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s "Array", "Map", "Set", - "Collection" + "Collection", + "LinkedValues" ) fun isStandardType(type: String): Boolean = containersTypes.contains(type) @@ -114,16 +145,18 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s if (value != null) value::class.declaredMemberProperties else { null } - } catch (ex: Exception) {null} + } catch (ex: Exception) { null } val serializedVersion = SerializedVariablesState(simpleTypeName, getProperString(value), true) val descriptors = serializedVersion.fieldDescriptor // only for set case - if (simpleTypeName == "Set" && kProperties == null) { + if (simpleTypeName == "Set" && kProperties == null && value != null) { value as Set<*> val size = value.size descriptors["size"] = createSerializeVariableState( - "size", "Int", size + "size", + "Int", + size ).serializedVariablesState descriptors.addDescriptor(value, "data") } @@ -233,7 +266,59 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s */ private val serializedVariablesCache: MutableMap = mutableMapOf() + private val removedFromSightVariables: MutableSet = mutableSetOf() + + private suspend fun clearOldData(currentCellId: Int, cellVariables: Map>) { + fun removeFromCache(cellId: Int) { + val oldDeclarations = cellVariables[cellId] + oldDeclarations?.let { oldSet -> + oldSet.forEach { varName -> + serializedVariablesCache.remove(varName) + removedFromSightVariables.add(varName) + } + } + } + + val setToRemove = mutableSetOf() + computedDescriptorsPerCell.forEach { (cellNumber, _) -> + if (abs(currentCellId - cellNumber) >= cellCountRemovalThreshold) { + setToRemove.add(cellNumber) + } + } + log.debug("Removing old info about cells: $setToRemove") + setToRemove.forEach { + clearStateInfo(it) + if (shouldRemoveOldVariablesFromCache) { + removeFromCache(it) + } + } + } + + override fun isShouldRemove(currentState: Int): Boolean { + return computedDescriptorsPerCell.size >= cellCountRemovalThreshold + } + + override suspend fun clearStateInfo(currentState: Int) { + computedDescriptorsPerCell.remove(currentState) + seenObjectsPerCell.remove(currentState) + } + + suspend fun tryValidateCache(currentCellId: Int, cellVariables: Map>) { + if (!isShouldRemove(currentCellId)) return + clearOldData(currentCellId, cellVariables) + } + fun serializeVariables(cellId: Int, variablesState: Map, unchangedVariables: Set): Map { + fun removeNonExistingEntries() { + val toRemoveSet = mutableSetOf() + serializedVariablesCache.forEach { (name, _) -> + if (!variablesState.containsKey(name)) { + toRemoveSet.add(name) + } + } + toRemoveSet.forEach { serializedVariablesCache.remove(it) } + } + if (!isSerializationActive) return emptyMap() if (seenObjectsPerCell.containsKey(cellId)) { @@ -243,15 +328,35 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s return emptyMap() } currentSerializeCount = 0 + val neededEntries = variablesState.filterKeys { + val wasRedeclared = !unchangedVariables.contains(it) + if (wasRedeclared) { + removedFromSightVariables.remove(it) + } + (unchangedVariables.contains(it) || serializedVariablesCache[it]?.value != variablesState[it]?.stringValue) && + !removedFromSightVariables.contains(it) + } + log.debug("Variables state as is: $variablesState") + log.debug("Serializing variables after filter: $neededEntries") + log.debug("Unchanged variables: $unchangedVariables") - val neededEntries = variablesState.filterKeys { unchangedVariables.contains(it) } - + // remove previous data + computedDescriptorsPerCell[cellId]?.instancesPerState?.clear() val serializedData = neededEntries.mapValues { serializeVariableState(cellId, it.key, it.value) } + serializedVariablesCache.putAll(serializedData) + removeNonExistingEntries() + log.debug(serializedVariablesCache.entries.toString()) + return serializedVariablesCache } - fun doIncrementalSerialization(cellId: Int, propertyName: String, serializedVariablesState: SerializedVariablesState): SerializedVariablesState { + fun doIncrementalSerialization( + cellId: Int, + propertyName: String, + serializedVariablesState: SerializedVariablesState, + pathToDescriptor: List = emptyList() + ): SerializedVariablesState { if (!isSerializationActive) return serializedVariablesState val cellDescriptors = computedDescriptorsPerCell[cellId] ?: return serializedVariablesState @@ -301,9 +406,14 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s seenObjectsPerCell.putIfAbsent(cellId, mutableMapOf()) if (isOverride) { + val instances = computedDescriptorsPerCell[cellId]?.instancesPerState computedDescriptorsPerCell[cellId] = ProcessedDescriptorsState() + if (instances != null) { + computedDescriptorsPerCell[cellId]!!.instancesPerState += instances + } } val currentCellDescriptors = computedDescriptorsPerCell[cellId] + // TODO should we stack? currentCellDescriptors!!.processedSerializedVarsToJavaProperties[serializedVersion] = processedData.propertiesData currentCellDescriptors.processedSerializedVarsToKTProperties[serializedVersion] = processedData.kPropertiesData @@ -383,8 +493,11 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s val isArrayType = checkForPossibleArray(callInstance) computedDescriptorsPerCell[cellId]!!.instancesPerState += instancesPerState - if (descriptor.size == 2 && descriptor.containsKey("data")) { - val listData = descriptor["data"]?.fieldDescriptor ?: return + if (descriptor.size == 2 && (descriptor.containsKey("data") || descriptor.containsKey("element"))) { + val singleElemMode = descriptor.containsKey("element") + val listData = if (!singleElemMode) descriptor["data"]?.fieldDescriptor else { + descriptor["element"]?.fieldDescriptor + } ?: return if (descriptor.containsKey("size") && descriptor["size"]?.value == "null") { descriptor.remove("size") descriptor.remove("data") @@ -420,19 +533,6 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s ) } } - /* - if (descriptor.size == 2 && descriptor.containsKey("data")) { - val listData = descriptor["data"]?.fieldDescriptor ?: return - if (callInstance is Collection<*>) { - callInstance.forEach { - listData.addDescriptor(it) - } - } else if (callInstance is Array<*>) { - callInstance.forEach { - listData.addDescriptor(it) - } - } - }*/ } /** @@ -488,7 +588,11 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s val returnType = property.type returnType.simpleName } else { - value?.toString() + if (value != null) { + value::class.simpleName + } else { + value?.toString() + } } } @@ -520,14 +624,14 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s !(it.name.startsWith("script$") || it.name.startsWith("serialVersionUID")) } - val isContainer = if (membersProperties != null) ( - !primitiveWrappersSet.contains(javaClass) && membersProperties.isNotEmpty() || value is Set<*> || value::class.java.isArray || javaClass.isMemberClass - ) else false val type = if (value != null && value::class.java.isArray) { "Array" } else { simpleTypeName.toString() } + val isContainer = if (membersProperties != null) ( + !primitiveWrappersSet.contains(javaClass) && type != "Entry" && membersProperties.isNotEmpty() || value is Set<*> || value::class.java.isArray || (javaClass.isMemberClass && type != "Entry") + ) else false if (value != null && standardContainersUtilizer.isStandardType(type)) { return standardContainersUtilizer.serializeContainer(type, value) @@ -595,11 +699,21 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s } fun getProperString(value: Any?): String { - fun print(builder: StringBuilder, containerSize: Int, index: Int, value: Any?) { + fun print(builder: StringBuilder, containerSize: Int, index: Int, value: Any?, mapMode: Boolean = false) { if (index != containerSize - 1) { - builder.append(value, ", ") + if (mapMode) { + value as Map.Entry<*, *> + builder.append(value.key, '=', value.value, "\n") + } else { + builder.append(value, ", ") + } } else { - builder.append(value) + if (mapMode) { + value as Map.Entry<*, *> + builder.append(value.key, '=', value.value) + } else { + builder.append(value) + } } } @@ -637,9 +751,11 @@ fun getProperString(value: Any?): String { val isMap = kClass.isMap() if (isMap) { value as Map<*, *> + val size = value.size + var ind = 0 return buildString { value.forEach { - append(it.key, '=', it.value, "\n") + print(this, size, ind++, it, true) } } } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/util.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/util.kt index ef99bd8d1..85af1230a 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/util.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/util.kt @@ -84,7 +84,25 @@ class VariablesUsagesPerCellWatcher { private val variablesDeclarationInfo: MutableMap = mutableMapOf() private val unchangedVariables: MutableSet = mutableSetOf() -// private val unchangedVariables: MutableSet = mutableSetOf() + + fun removeOldDeclarations(address: K, newDeclarations: Set) { + // removeIf? + cellVariables[address]?.forEach { + val predicate = newDeclarations.contains(it) && variablesDeclarationInfo[it] != address + if (predicate) { + variablesDeclarationInfo.remove(it) + unchangedVariables.remove(it) + } +// predicate + } + + // add old declarations as unchanged + variablesDeclarationInfo.forEach { (name, _) -> + if (!newDeclarations.contains(name)) { + unchangedVariables.add(name) + } + } + } fun addDeclaration(address: K, variableRef: V) { ensureStorageCreation(address) @@ -114,7 +132,7 @@ class VariablesUsagesPerCellWatcher { // remove known modifying usages in this cell cellVariables[newAddress]?.removeIf { val predicate = variablesDeclarationInfo[it] != newAddress - if (predicate) { + if (predicate && variablesDeclarationInfo.containsKey(it)) { unchangedVariables.add(it) } predicate diff --git a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/AbstractReplTest.kt b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/AbstractReplTest.kt index 02db7b3f7..527fced36 100644 --- a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/AbstractReplTest.kt +++ b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/AbstractReplTest.kt @@ -3,7 +3,6 @@ package org.jetbrains.kotlinx.jupyter.test.repl import kotlinx.coroutines.runBlocking import org.jetbrains.kotlinx.jupyter.ReplForJupyter import org.jetbrains.kotlinx.jupyter.ReplForJupyterImpl -import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState import org.jetbrains.kotlinx.jupyter.dependencies.ResolverConfig import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider import org.jetbrains.kotlinx.jupyter.libraries.KERNEL_LIBRARIES @@ -67,7 +66,3 @@ abstract class AbstractReplTest { protected val homeDir = File("") } } - -fun Map.mapValuesToStrings(): Map { - return this.mapValues { it.value.value } -} diff --git a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt index 09afbc61a..5e1e13856 100644 --- a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt +++ b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt @@ -464,7 +464,7 @@ class ReplVarsTest : AbstractSingleReplTest() { "y" to "0", "z" to "47" ) - assertEquals(res.metadata.evaluatedVariablesState.mapValuesToStrings(), varsUpdate) + assertEquals(res.metadata.evaluatedVariablesState.mapValues { it.value.value }, varsUpdate) assertFalse(repl.notebook.variablesState.isEmpty()) val varsState = repl.notebook.variablesState assertEquals("1", varsState.getStringValue("x")) @@ -680,12 +680,12 @@ class ReplVarsTest : AbstractSingleReplTest() { private var z = 1 z += x """.trimIndent(), - jupyterId = 1 + jupyterId = 2 ) assertTrue(state.isNotEmpty()) - // TODO discuss if we really want this - val setOfCell = setOf("z", "f", "x") + // TODO discuss if we really want "z", "f", "x" + val setOfCell = setOf("z") assertTrue(state.containsValue(setOfCell)) } @@ -780,6 +780,51 @@ class ReplVarsTest : AbstractSingleReplTest() { assertEquals(state[0], setOfPrevCell) assertEquals(state[1], setOfNextCell) } + + @Test + fun unchangedVariablesGapedRedefinition() { + eval( + """ + private val x = "abcd" + var f = 47 + internal val z = 47 + """.trimIndent(), + jupyterId = 1 + ) + var state = repl.notebook.unchangedVariables() + assertEquals(3, state.size) + + eval( + """ + private val x = "abcd" + internal val z = 47 + """.trimIndent(), + jupyterId = 2 + ) + state = repl.notebook.unchangedVariables() + assertEquals(1, state.size) + assertTrue(state.contains("f")) + + eval( + """ + private val x = "abcd" + var f = 47 + internal val z = 47 + """.trimIndent(), + jupyterId = 2 + ) + state = repl.notebook.unchangedVariables() + assertEquals(0, state.size) + + eval( + """ + var f = 47 + """.trimIndent(), + jupyterId = 3 + ) + state = repl.notebook.unchangedVariables() + assertEquals(2, state.size) + } } class ReplVarsSerializationTest : AbstractSingleReplTest() { @@ -814,7 +859,33 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { val serializer = repl.variablesSerializer val newData = serializer.doIncrementalSerialization(0, "data", actualContainer) - val a = 1 + } + + @Test + fun testUnchangedVarsRedefinition() { + val res = eval( + """ + val x = listOf(1, 2, 3, 4) + var f = 47 + """.trimIndent(), + jupyterId = 1 + ) + val varsData = res.metadata.evaluatedVariablesState + assertEquals(2, varsData.size) + assertTrue(varsData.containsKey("x")) + assertTrue(varsData.containsKey("f")) + var unchangedVariables = repl.notebook.unchangedVariables() + assertTrue(unchangedVariables.isNotEmpty()) + + eval( + """ + val x = listOf(1, 2, 3, 4) + """.trimIndent(), + jupyterId = 1 + ) + unchangedVariables = repl.notebook.unchangedVariables() + assertTrue(unchangedVariables.contains("x")) + assertTrue(unchangedVariables.contains("f")) } @Test @@ -877,7 +948,6 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { val serializer = repl.variablesSerializer val newData = serializer.doIncrementalSerialization(0, "i", descriptor["i"]!!) - val a = 1 } @Test @@ -902,7 +972,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { assertEquals(4, receivedDescriptor.size) var values = 1 - receivedDescriptor.forEach { (name, state) -> + receivedDescriptor.forEach { (_, state) -> val fieldDescriptor = state!!.fieldDescriptor assertEquals(0, fieldDescriptor.size) assertTrue(state.isContainer) @@ -913,16 +983,45 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { val serializationAns = serializer.doIncrementalSerialization(0, depthMostNode.key, depthMostNode.value!!) } + @Test + fun incrementalUpdateTestWithPath() { + val res = eval( + """ + val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4)) + """.trimIndent(), + jupyterId = 1 + ) + val varsData = res.metadata.evaluatedVariablesState + val listData = varsData["x"]!! + assertEquals(2, listData.fieldDescriptor.size) + val actualContainer = listData.fieldDescriptor.entries.first().value!! + val serializer = repl.variablesSerializer + val path = listOf("x", "a") + + val newData = serializer.doIncrementalSerialization(0, listData.fieldDescriptor.entries.first().key, actualContainer, path) + val receivedDescriptor = newData.fieldDescriptor + assertEquals(4, receivedDescriptor.size) + + var values = 1 + receivedDescriptor.forEach { (_, state) -> + val fieldDescriptor = state!!.fieldDescriptor + assertEquals(0, fieldDescriptor.size) + assertTrue(state.isContainer) + assertEquals("${values++}", state.value) + } + } + @Test fun testMapContainer() { val res = eval( """ val x = mapOf(1 to "a", 2 to "b", 3 to "c", 4 to "c") + val m = mapOf(1 to "a") """.trimIndent(), jupyterId = 1 ) val varsData = res.metadata.evaluatedVariablesState - assertEquals(1, varsData.size) + assertEquals(2, varsData.size) assertTrue(varsData.containsKey("x")) val mapData = varsData["x"]!! @@ -940,8 +1039,8 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { val serializer = repl.variablesSerializer - val newData = serializer.doIncrementalSerialization(0, "values", valuesDescriptor) - val newDescriptor = newData.fieldDescriptor + var newData = serializer.doIncrementalSerialization(0, "values", valuesDescriptor) + var newDescriptor = newData.fieldDescriptor assertEquals("4", newDescriptor["size"]!!.value) assertEquals(3, newDescriptor["data"]!!.fieldDescriptor.size) val ansSet = mutableSetOf("a", "b", "c") @@ -951,8 +1050,26 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { ansSet.remove(state.value) } assertTrue(ansSet.isEmpty()) - } + val entriesDescriptor = listDescriptors["entries"]!! + assertEquals("4", valuesDescriptor.fieldDescriptor["size"]!!.value) + assertTrue(valuesDescriptor.fieldDescriptor["data"]!!.isContainer) + newData = serializer.doIncrementalSerialization(0, "entries", entriesDescriptor) + newDescriptor = newData.fieldDescriptor + assertEquals("4", newDescriptor["size"]!!.value) + assertEquals(4, newDescriptor["data"]!!.fieldDescriptor.size) + ansSet.add("1=a") + ansSet.add("2=b") + ansSet.add("3=c") + ansSet.add("4=c") + + newDescriptor["data"]!!.fieldDescriptor.forEach { (_, state) -> + assertFalse(state!!.isContainer) + assertTrue(ansSet.contains(state.value)) + ansSet.remove(state.value) + } + assertTrue(ansSet.isEmpty()) + } @Test fun testSetContainer() { @@ -1045,7 +1162,6 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { assertTrue(innerList.isContainer) val receivedDescriptor = innerList.fieldDescriptor - assertEquals(4, receivedDescriptor.size) var values = 1 receivedDescriptor.forEach { (_, state) -> @@ -1058,6 +1174,32 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { } } + @Test + fun testUnchangedVariablesSameCell() { + eval( + """ + private val x = "abcd" + var f = 47 + internal val z = 47 + """.trimIndent(), + jupyterId = 1 + ) + val state = repl.notebook.unchangedVariables() + val setOfCell = setOf("x", "f", "z") + assertTrue(state.isNotEmpty()) + assertEquals(setOfCell, state) + + eval( + """ + private val x = "44" + var f = 47 + """.trimIndent(), + jupyterId = 1 + ) + assertTrue(state.isNotEmpty()) + // it's ok that there's more info, cache's data would filter out + assertEquals(setOf("f", "x", "z"), state) + } @Test fun testUnchangedVariables() { @@ -1096,7 +1238,6 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { assertTrue(state.isNotEmpty()) assertEquals(state, setOfPrevCell) - eval( """ private val x = 341 @@ -1104,7 +1245,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() { """.trimIndent(), jupyterId = 1 ) - assertTrue(state.isEmpty()) + assertTrue(state.contains("f")) eval( """