|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "metadata": { |
| 6 | + "collapsed": true, |
| 7 | + "ExecuteTime": { |
| 8 | + "end_time": "2025-02-10T13:27:05.671253Z", |
| 9 | + "start_time": "2025-02-10T13:27:05.206956Z" |
| 10 | + } |
| 11 | + }, |
| 12 | + "source": [ |
| 13 | + "val kinferencerVersion = \"0.2.26\"\n", |
| 14 | + "val ktorVersion = \"3.0.3\"\n", |
| 15 | + "\n", |
| 16 | + "USE {\n", |
| 17 | + " repositories {\n", |
| 18 | + " mavenCentral()\n", |
| 19 | + " maven(\"https://packages.jetbrains.team/maven/p/ki/maven\")\n", |
| 20 | + " maven(\"https://packages.jetbrains.team/maven/p/grazi/grazie-platform-public\")\n", |
| 21 | + " }\n", |
| 22 | + " dependencies {\n", |
| 23 | + " implementation(\"io.kinference:inference-core-jvm:$kinferencerVersion\")\n", |
| 24 | + " implementation(\"io.kinference:inference-ort-jvm:$kinferencerVersion\")\n", |
| 25 | + " implementation(\"io.kinference:serializer-protobuf-jvm:$kinferencerVersion\")\n", |
| 26 | + " implementation(\"io.kinference:utils-common-jvm:$kinferencerVersion\")\n", |
| 27 | + " implementation(\"io.kinference:ndarray-core-jvm:$kinferencerVersion\")\n", |
| 28 | + "\n", |
| 29 | + " implementation(\"io.ktor:ktor-client-core-jvm:$ktorVersion\")\n", |
| 30 | + " implementation(\"io.ktor:ktor-client-cio-jvm:$ktorVersion\")\n", |
| 31 | + "\n", |
| 32 | + " implementation(\"org.slf4j:slf4j-api:2.0.9\")\n", |
| 33 | + " implementation(\"org.slf4j:slf4j-simple:2.0.9\")\n", |
| 34 | + "\n", |
| 35 | + " implementation(\"ai.djl:api:0.28.0\")\n", |
| 36 | + " implementation(\"ai.djl.huggingface:tokenizers:0.28.0\")\n", |
| 37 | + " }\n", |
| 38 | + "}" |
| 39 | + ], |
| 40 | + "outputs": [], |
| 41 | + "execution_count": 1 |
| 42 | + }, |
| 43 | + { |
| 44 | + "metadata": { |
| 45 | + "ExecuteTime": { |
| 46 | + "end_time": "2025-02-10T13:27:05.732096Z", |
| 47 | + "start_time": "2025-02-10T13:27:05.678435Z" |
| 48 | + } |
| 49 | + }, |
| 50 | + "cell_type": "code", |
| 51 | + "source": [ |
| 52 | + "import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer\n", |
| 53 | + "import io.kinference.core.data.tensor.KITensor\n", |
| 54 | + "import io.kinference.core.data.tensor.asTensor\n", |
| 55 | + "import io.kinference.ndarray.arrays.FloatNDArray\n", |
| 56 | + "import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke\n", |
| 57 | + "import io.kinference.ort.ORTData\n", |
| 58 | + "import io.kinference.ort.ORTEngine\n", |
| 59 | + "import io.kinference.ort.data.tensor.ORTTensor\n", |
| 60 | + "import io.kinference.utils.CommonDataLoader\n", |
| 61 | + "import io.kinference.utils.inlines.InlineInt\n", |
| 62 | + "import io.kinference.utils.toIntArray\n", |
| 63 | + "import okio.Path.Companion.toPath\n", |
| 64 | + "import io.kinference.core.KIONNXData\n", |
| 65 | + "import io.kinference.ndarray.arrays.LongNDArray\n", |
| 66 | + "import io.kinference.ndarray.arrays.NumberNDArrayCore\n", |
| 67 | + "import io.ktor.client.HttpClient\n", |
| 68 | + "import io.ktor.client.plugins.HttpTimeout\n", |
| 69 | + "import io.ktor.client.request.prepareRequest\n", |
| 70 | + "import io.ktor.client.statement.bodyAsChannel\n", |
| 71 | + "import io.ktor.util.cio.writeChannel\n", |
| 72 | + "import io.ktor.utils.io.copyAndClose\n", |
| 73 | + "import java.io.File\n", |
| 74 | + "import kotlinx.coroutines.runBlocking" |
| 75 | + ], |
| 76 | + "outputs": [], |
| 77 | + "execution_count": 2 |
| 78 | + }, |
| 79 | + { |
| 80 | + "metadata": { |
| 81 | + "ExecuteTime": { |
| 82 | + "end_time": "2025-02-10T13:27:06.563564Z", |
| 83 | + "start_time": "2025-02-10T13:27:06.070567Z" |
| 84 | + } |
| 85 | + }, |
| 86 | + "cell_type": "code", |
| 87 | + "source": [ |
| 88 | + "/**\n", |
| 89 | + " * Directory used to store cached files.\n", |
| 90 | + " *\n", |
| 91 | + " * This variable combines the user's current working directory\n", |
| 92 | + " * with a \"cache\" subdirectory to create the path for storing cache files.\n", |
| 93 | + " * It is used in various functions to check for existing files or directories,\n", |
| 94 | + " * create new ones if they do not exist, and manage the caching of downloaded files.\n", |
| 95 | + " */\n", |
| 96 | + "val cacheDirectory = System.getProperty(\"user.dir\") + \"/.cache/\"\n", |
| 97 | + "\n", |
| 98 | + "/**\n", |
| 99 | + " * Downloads a file from the given URL and saves it with the specified file name.\n", |
| 100 | + " *\n", |
| 101 | + " * Checks if the directory specified by `cacheDirectory` exists.\n", |
| 102 | + " * If not, it creates the directory. If the file already exists,\n", |
| 103 | + " * the download is skipped. Otherwise, the file is downloaded\n", |
| 104 | + " * using an HTTP client with a 10-minute timeout setting.\n", |
| 105 | + " *\n", |
| 106 | + " * @param url The URL from which to download the file.\n", |
| 107 | + " * @param fileName The name to use for the downloaded file.\n", |
| 108 | + " * @param timeout Optional timeout duration for the download request, in milliseconds.\n", |
| 109 | + " * Defaults to 600,000 milliseconds (10 minutes).\n", |
| 110 | + " * Increase the timeout if you are not sure that download for the particular model with fit into the default timeout.\n", |
| 111 | + " */\n", |
| 112 | + "suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) {\n", |
| 113 | + " // Ensure the predefined path is treated as a directory\n", |
| 114 | + " val directory = File(cacheDirectory)\n", |
| 115 | + "\n", |
| 116 | + " // Check if the directory exists, if not create it\n", |
| 117 | + " if (!directory.exists()) {\n", |
| 118 | + " println(\"Predefined directory doesn't exist. Creating directory at $cacheDirectory.\")\n", |
| 119 | + " directory.mkdirs() // Create the directory if it doesn't exist\n", |
| 120 | + " }\n", |
| 121 | + "\n", |
| 122 | + " // Check if the file already exists\n", |
| 123 | + " val file = File(directory, fileName)\n", |
| 124 | + " if (file.exists()) {\n", |
| 125 | + " println(\"File already exists at ${file.absolutePath}. Skipping download.\")\n", |
| 126 | + " return // Exit the function if the file exists\n", |
| 127 | + " }\n", |
| 128 | + "\n", |
| 129 | + " // Create an instance of HttpClient with custom timeout settings\n", |
| 130 | + " val client = HttpClient {\n", |
| 131 | + " install(HttpTimeout) {\n", |
| 132 | + " requestTimeoutMillis = timeout\n", |
| 133 | + " }\n", |
| 134 | + " }\n", |
| 135 | + "\n", |
| 136 | + " // Download the file and write to the specified output path\n", |
| 137 | + " client.prepareRequest(url).execute { response ->\n", |
| 138 | + " response.bodyAsChannel().copyAndClose(file.writeChannel())\n", |
| 139 | + " }\n", |
| 140 | + "\n", |
| 141 | + " client.close()\n", |
| 142 | + "}\n", |
| 143 | + "\n", |
| 144 | + "/**\n", |
| 145 | + " * Extracts the token ID with the highest probability from the output tensor.\n", |
| 146 | + " *\n", |
| 147 | + " * @param output A map containing the output tensors identified by their names.\n", |
| 148 | + " * @param tokensSize The number of tokens in the sequence.\n", |
| 149 | + " * @param outputName The name of the tensor containing the logits.\n", |
| 150 | + " * @return The ID of the top token.\n", |
| 151 | + " */\n", |
| 152 | + "suspend fun extractTopToken(output: Map<String, KIONNXData<*>>, tokensSize: Int, outputName: String): Long {\n", |
| 153 | + " val logits = output[outputName]!! as KITensor\n", |
| 154 | + " val sliced = logits.data.slice(\n", |
| 155 | + " starts = intArrayOf(0, 0, tokensSize - 1, 0), // First batch, first element in the second dimension, last token, first vocab entry\n", |
| 156 | + " ends = intArrayOf(1, 1, tokensSize, 50257), // Same batch, same second dimension, one token step, whole vocab (50257)\n", |
| 157 | + " steps = intArrayOf(1, 1, 1, 1) // Step of 1 for each dimension\n", |
| 158 | + " ) as NumberNDArrayCore\n", |
| 159 | + " val softmax = sliced.softmax(axis = -1)\n", |
| 160 | + " val topK = softmax.topK(\n", |
| 161 | + " axis = -1, // Apply top-k along the last dimension (vocabulary size)\n", |
| 162 | + " k = 1, // Retrieve the top 1 element\n", |
| 163 | + " largest = true, // We want the largest probabilities (most probable tokens)\n", |
| 164 | + " sorted = false // Sorting is unnecessary since we are only retrieving the top 1\n", |
| 165 | + " )\n", |
| 166 | + " val tokenId = (topK.second as LongNDArray)[intArrayOf(0, 0, 0, 0)]\n", |
| 167 | + "\n", |
| 168 | + " return tokenId\n", |
| 169 | + "}\n", |
| 170 | + "\n", |
| 171 | + "suspend fun convertToKITensorMap(outputs: Map<String, ORTData<*>>): Map<String, KITensor> {\n", |
| 172 | + " return outputs.map { (name, ortTensor) ->\n", |
| 173 | + " val ortTensor = ortTensor as ORTTensor\n", |
| 174 | + " val data = ortTensor.toFloatArray()\n", |
| 175 | + " val shape = ortTensor.shape.toIntArray()\n", |
| 176 | + " val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] }\n", |
| 177 | + " val kiTensor = ndArray.asTensor(name)\n", |
| 178 | + " return@map name to kiTensor\n", |
| 179 | + " }.toMap()\n", |
| 180 | + "}" |
| 181 | + ], |
| 182 | + "outputs": [], |
| 183 | + "execution_count": 3 |
| 184 | + }, |
| 185 | + { |
| 186 | + "metadata": { |
| 187 | + "ExecuteTime": { |
| 188 | + "end_time": "2025-02-10T13:27:06.737944Z", |
| 189 | + "start_time": "2025-02-10T13:27:06.715791Z" |
| 190 | + } |
| 191 | + }, |
| 192 | + "cell_type": "code", |
| 193 | + "source": [ |
| 194 | + "// Constants for input and output tensor names used in the GPT-2 model\n", |
| 195 | + "val INPUT_TENSOR_NAME = \"input1\"\n", |
| 196 | + "val OUTPUT_TENSOR_NAME = \"output1\" // We use only logits tensor" |
| 197 | + ], |
| 198 | + "outputs": [], |
| 199 | + "execution_count": 4 |
| 200 | + }, |
| 201 | + { |
| 202 | + "metadata": { |
| 203 | + "ExecuteTime": { |
| 204 | + "end_time": "2025-02-10T13:27:07.461104Z", |
| 205 | + "start_time": "2025-02-10T13:27:07.440849Z" |
| 206 | + } |
| 207 | + }, |
| 208 | + "cell_type": "code", |
| 209 | + "source": [ |
| 210 | + "val modelUrl = \"https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx\"\n", |
| 211 | + "val modelName = \"gpt2-lm-head-10\"\n" |
| 212 | + ], |
| 213 | + "outputs": [], |
| 214 | + "execution_count": 5 |
| 215 | + }, |
| 216 | + { |
| 217 | + "metadata": { |
| 218 | + "ExecuteTime": { |
| 219 | + "end_time": "2025-02-10T13:27:38.180925Z", |
| 220 | + "start_time": "2025-02-10T13:27:09.793527Z" |
| 221 | + } |
| 222 | + }, |
| 223 | + "cell_type": "code", |
| 224 | + "source": [ |
| 225 | + "runBlocking {\n", |
| 226 | + " println(\"Downloading model from: $modelUrl\")\n", |
| 227 | + " downloadFile(modelUrl, \"$modelName.onnx\") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed\n", |
| 228 | + "\n", |
| 229 | + " println(\"Loading model...\")\n", |
| 230 | + " val model = ORTEngine.loadModel(\"$cacheDirectory/$modelName.onnx\".toPath())\n", |
| 231 | + "\n", |
| 232 | + " val tokenizer = HuggingFaceTokenizer.newInstance(\"gpt2\", mapOf(\"modelMaxLength\" to \"1024\"))\n", |
| 233 | + " val testString = \"Neurogenesis is most active during embryonic development and is responsible for producing \" +\n", |
| 234 | + " \"all the various types of neurons of the organism, but it continues throughout adult life \" +\n", |
| 235 | + " \"in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will \" +\n", |
| 236 | + " \"live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances.\"\n", |
| 237 | + " val encoded = tokenizer.encode(testString)\n", |
| 238 | + " val tokens = encoded.ids\n", |
| 239 | + " val tokensSize = tokens.size\n", |
| 240 | + "\n", |
| 241 | + " val predictionLength = 34\n", |
| 242 | + " val outputTokens = LongArray(predictionLength) { 0 }\n", |
| 243 | + "\n", |
| 244 | + " val input = ORTTensor(tokens, longArrayOf(1, 1, tokensSize.toLong()))\n", |
| 245 | + " var currentContext = input.clone(INPUT_TENSOR_NAME)\n", |
| 246 | + "\n", |
| 247 | + " print(\"Here goes the test text for generation:\\n$testString\")\n", |
| 248 | + "\n", |
| 249 | + " for (idx in 0 until predictionLength) {\n", |
| 250 | + " val inputTensor = listOf(currentContext)\n", |
| 251 | + " val output = model.predict(inputTensor)\n", |
| 252 | + "\n", |
| 253 | + " outputTokens[idx] = extractTopToken(convertToKITensorMap(output), tokensSize + idx, OUTPUT_TENSOR_NAME)\n", |
| 254 | + "\n", |
| 255 | + " val newTokenArray = tokens + outputTokens.slice(IntRange(0, idx))\n", |
| 256 | + " currentContext = ORTTensor(newTokenArray, longArrayOf(1, 1, tokensSize + idx + 1L), INPUT_TENSOR_NAME)\n", |
| 257 | + " print(tokenizer.decode(longArrayOf(outputTokens[idx])))\n", |
| 258 | + " }\n", |
| 259 | + " println(\"\\n\\nDone\")\n", |
| 260 | + "}" |
| 261 | + ], |
| 262 | + "outputs": [ |
| 263 | + { |
| 264 | + "name": "stdout", |
| 265 | + "output_type": "stream", |
| 266 | + "text": [ |
| 267 | + "Downloading model from: https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx\n", |
| 268 | + "Loading model...\n", |
| 269 | + "Here goes the test text for generation:\n", |
| 270 | + "Neurogenesis is most active during embryonic development and is responsible for producing all the various types of neurons of the organism, but it continues throughout adult life in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances.\n", |
| 271 | + "\n", |
| 272 | + "The most common type of neurogenesis is the development of the hippocampus, which is the area of the brain that contains the hippocampus's electrical and chemical signals.\n", |
| 273 | + "\n", |
| 274 | + "Done\n" |
| 275 | + ] |
| 276 | + } |
| 277 | + ], |
| 278 | + "execution_count": 6 |
| 279 | + } |
| 280 | + ], |
| 281 | + "metadata": { |
| 282 | + "kernelspec": { |
| 283 | + "display_name": "Kotlin", |
| 284 | + "language": "kotlin", |
| 285 | + "name": "kotlin" |
| 286 | + }, |
| 287 | + "language_info": { |
| 288 | + "name": "kotlin", |
| 289 | + "version": "1.9.23", |
| 290 | + "mimetype": "text/x-kotlin", |
| 291 | + "file_extension": ".kt", |
| 292 | + "pygments_lexer": "kotlin", |
| 293 | + "codemirror_mode": "text/x-kotlin", |
| 294 | + "nbconvert_exporter": "" |
| 295 | + }, |
| 296 | + "ktnbPluginMetadata": { |
| 297 | + "projectLibraries": false |
| 298 | + } |
| 299 | + }, |
| 300 | + "nbformat": 4, |
| 301 | + "nbformat_minor": 0 |
| 302 | +} |
0 commit comments