Skip to content

Commit 165fe68

Browse files
committed
Add GPT-2 inference notebook using ORTEngine and Kinference
1 parent 293171f commit 165fe68

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed

notebooks/kinference/ORTGPT2.ipynb

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"metadata": {
6+
"collapsed": true,
7+
"ExecuteTime": {
8+
"end_time": "2025-02-10T13:27:05.671253Z",
9+
"start_time": "2025-02-10T13:27:05.206956Z"
10+
}
11+
},
12+
"source": [
13+
"val kinferencerVersion = \"0.2.26\"\n",
14+
"val ktorVersion = \"3.0.3\"\n",
15+
"\n",
16+
"USE {\n",
17+
" repositories {\n",
18+
" mavenCentral()\n",
19+
" maven(\"https://packages.jetbrains.team/maven/p/ki/maven\")\n",
20+
" maven(\"https://packages.jetbrains.team/maven/p/grazi/grazie-platform-public\")\n",
21+
" }\n",
22+
" dependencies {\n",
23+
" implementation(\"io.kinference:inference-core-jvm:$kinferencerVersion\")\n",
24+
" implementation(\"io.kinference:inference-ort-jvm:$kinferencerVersion\")\n",
25+
" implementation(\"io.kinference:serializer-protobuf-jvm:$kinferencerVersion\")\n",
26+
" implementation(\"io.kinference:utils-common-jvm:$kinferencerVersion\")\n",
27+
" implementation(\"io.kinference:ndarray-core-jvm:$kinferencerVersion\")\n",
28+
"\n",
29+
" implementation(\"io.ktor:ktor-client-core-jvm:$ktorVersion\")\n",
30+
" implementation(\"io.ktor:ktor-client-cio-jvm:$ktorVersion\")\n",
31+
"\n",
32+
" implementation(\"org.slf4j:slf4j-api:2.0.9\")\n",
33+
" implementation(\"org.slf4j:slf4j-simple:2.0.9\")\n",
34+
"\n",
35+
" implementation(\"ai.djl:api:0.28.0\")\n",
36+
" implementation(\"ai.djl.huggingface:tokenizers:0.28.0\")\n",
37+
" }\n",
38+
"}"
39+
],
40+
"outputs": [],
41+
"execution_count": 1
42+
},
43+
{
44+
"metadata": {
45+
"ExecuteTime": {
46+
"end_time": "2025-02-10T13:27:05.732096Z",
47+
"start_time": "2025-02-10T13:27:05.678435Z"
48+
}
49+
},
50+
"cell_type": "code",
51+
"source": [
52+
"import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer\n",
53+
"import io.kinference.core.data.tensor.KITensor\n",
54+
"import io.kinference.core.data.tensor.asTensor\n",
55+
"import io.kinference.ndarray.arrays.FloatNDArray\n",
56+
"import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke\n",
57+
"import io.kinference.ort.ORTData\n",
58+
"import io.kinference.ort.ORTEngine\n",
59+
"import io.kinference.ort.data.tensor.ORTTensor\n",
60+
"import io.kinference.utils.CommonDataLoader\n",
61+
"import io.kinference.utils.inlines.InlineInt\n",
62+
"import io.kinference.utils.toIntArray\n",
63+
"import okio.Path.Companion.toPath\n",
64+
"import io.kinference.core.KIONNXData\n",
65+
"import io.kinference.ndarray.arrays.LongNDArray\n",
66+
"import io.kinference.ndarray.arrays.NumberNDArrayCore\n",
67+
"import io.ktor.client.HttpClient\n",
68+
"import io.ktor.client.plugins.HttpTimeout\n",
69+
"import io.ktor.client.request.prepareRequest\n",
70+
"import io.ktor.client.statement.bodyAsChannel\n",
71+
"import io.ktor.util.cio.writeChannel\n",
72+
"import io.ktor.utils.io.copyAndClose\n",
73+
"import java.io.File\n",
74+
"import kotlinx.coroutines.runBlocking"
75+
],
76+
"outputs": [],
77+
"execution_count": 2
78+
},
79+
{
80+
"metadata": {
81+
"ExecuteTime": {
82+
"end_time": "2025-02-10T13:27:06.563564Z",
83+
"start_time": "2025-02-10T13:27:06.070567Z"
84+
}
85+
},
86+
"cell_type": "code",
87+
"source": [
88+
"/**\n",
89+
" * Directory used to store cached files.\n",
90+
" *\n",
91+
" * This variable combines the user's current working directory\n",
92+
" * with a \"cache\" subdirectory to create the path for storing cache files.\n",
93+
" * It is used in various functions to check for existing files or directories,\n",
94+
" * create new ones if they do not exist, and manage the caching of downloaded files.\n",
95+
" */\n",
96+
"val cacheDirectory = System.getProperty(\"user.dir\") + \"/.cache/\"\n",
97+
"\n",
98+
"/**\n",
99+
" * Downloads a file from the given URL and saves it with the specified file name.\n",
100+
" *\n",
101+
" * Checks if the directory specified by `cacheDirectory` exists.\n",
102+
" * If not, it creates the directory. If the file already exists,\n",
103+
" * the download is skipped. Otherwise, the file is downloaded\n",
104+
" * using an HTTP client with a 10-minute timeout setting.\n",
105+
" *\n",
106+
" * @param url The URL from which to download the file.\n",
107+
" * @param fileName The name to use for the downloaded file.\n",
108+
" * @param timeout Optional timeout duration for the download request, in milliseconds.\n",
109+
" * Defaults to 600,000 milliseconds (10 minutes).\n",
110+
" * Increase the timeout if you are not sure that download for the particular model with fit into the default timeout.\n",
111+
" */\n",
112+
"suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) {\n",
113+
" // Ensure the predefined path is treated as a directory\n",
114+
" val directory = File(cacheDirectory)\n",
115+
"\n",
116+
" // Check if the directory exists, if not create it\n",
117+
" if (!directory.exists()) {\n",
118+
" println(\"Predefined directory doesn't exist. Creating directory at $cacheDirectory.\")\n",
119+
" directory.mkdirs() // Create the directory if it doesn't exist\n",
120+
" }\n",
121+
"\n",
122+
" // Check if the file already exists\n",
123+
" val file = File(directory, fileName)\n",
124+
" if (file.exists()) {\n",
125+
" println(\"File already exists at ${file.absolutePath}. Skipping download.\")\n",
126+
" return // Exit the function if the file exists\n",
127+
" }\n",
128+
"\n",
129+
" // Create an instance of HttpClient with custom timeout settings\n",
130+
" val client = HttpClient {\n",
131+
" install(HttpTimeout) {\n",
132+
" requestTimeoutMillis = timeout\n",
133+
" }\n",
134+
" }\n",
135+
"\n",
136+
" // Download the file and write to the specified output path\n",
137+
" client.prepareRequest(url).execute { response ->\n",
138+
" response.bodyAsChannel().copyAndClose(file.writeChannel())\n",
139+
" }\n",
140+
"\n",
141+
" client.close()\n",
142+
"}\n",
143+
"\n",
144+
"/**\n",
145+
" * Extracts the token ID with the highest probability from the output tensor.\n",
146+
" *\n",
147+
" * @param output A map containing the output tensors identified by their names.\n",
148+
" * @param tokensSize The number of tokens in the sequence.\n",
149+
" * @param outputName The name of the tensor containing the logits.\n",
150+
" * @return The ID of the top token.\n",
151+
" */\n",
152+
"suspend fun extractTopToken(output: Map<String, KIONNXData<*>>, tokensSize: Int, outputName: String): Long {\n",
153+
" val logits = output[outputName]!! as KITensor\n",
154+
" val sliced = logits.data.slice(\n",
155+
" starts = intArrayOf(0, 0, tokensSize - 1, 0), // First batch, first element in the second dimension, last token, first vocab entry\n",
156+
" ends = intArrayOf(1, 1, tokensSize, 50257), // Same batch, same second dimension, one token step, whole vocab (50257)\n",
157+
" steps = intArrayOf(1, 1, 1, 1) // Step of 1 for each dimension\n",
158+
" ) as NumberNDArrayCore\n",
159+
" val softmax = sliced.softmax(axis = -1)\n",
160+
" val topK = softmax.topK(\n",
161+
" axis = -1, // Apply top-k along the last dimension (vocabulary size)\n",
162+
" k = 1, // Retrieve the top 1 element\n",
163+
" largest = true, // We want the largest probabilities (most probable tokens)\n",
164+
" sorted = false // Sorting is unnecessary since we are only retrieving the top 1\n",
165+
" )\n",
166+
" val tokenId = (topK.second as LongNDArray)[intArrayOf(0, 0, 0, 0)]\n",
167+
"\n",
168+
" return tokenId\n",
169+
"}\n",
170+
"\n",
171+
"suspend fun convertToKITensorMap(outputs: Map<String, ORTData<*>>): Map<String, KITensor> {\n",
172+
" return outputs.map { (name, ortTensor) ->\n",
173+
" val ortTensor = ortTensor as ORTTensor\n",
174+
" val data = ortTensor.toFloatArray()\n",
175+
" val shape = ortTensor.shape.toIntArray()\n",
176+
" val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] }\n",
177+
" val kiTensor = ndArray.asTensor(name)\n",
178+
" return@map name to kiTensor\n",
179+
" }.toMap()\n",
180+
"}"
181+
],
182+
"outputs": [],
183+
"execution_count": 3
184+
},
185+
{
186+
"metadata": {
187+
"ExecuteTime": {
188+
"end_time": "2025-02-10T13:27:06.737944Z",
189+
"start_time": "2025-02-10T13:27:06.715791Z"
190+
}
191+
},
192+
"cell_type": "code",
193+
"source": [
194+
"// Constants for input and output tensor names used in the GPT-2 model\n",
195+
"val INPUT_TENSOR_NAME = \"input1\"\n",
196+
"val OUTPUT_TENSOR_NAME = \"output1\" // We use only logits tensor"
197+
],
198+
"outputs": [],
199+
"execution_count": 4
200+
},
201+
{
202+
"metadata": {
203+
"ExecuteTime": {
204+
"end_time": "2025-02-10T13:27:07.461104Z",
205+
"start_time": "2025-02-10T13:27:07.440849Z"
206+
}
207+
},
208+
"cell_type": "code",
209+
"source": [
210+
"val modelUrl = \"https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx\"\n",
211+
"val modelName = \"gpt2-lm-head-10\"\n"
212+
],
213+
"outputs": [],
214+
"execution_count": 5
215+
},
216+
{
217+
"metadata": {
218+
"ExecuteTime": {
219+
"end_time": "2025-02-10T13:27:38.180925Z",
220+
"start_time": "2025-02-10T13:27:09.793527Z"
221+
}
222+
},
223+
"cell_type": "code",
224+
"source": [
225+
"runBlocking {\n",
226+
" println(\"Downloading model from: $modelUrl\")\n",
227+
" downloadFile(modelUrl, \"$modelName.onnx\") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed\n",
228+
"\n",
229+
" println(\"Loading model...\")\n",
230+
" val model = ORTEngine.loadModel(\"$cacheDirectory/$modelName.onnx\".toPath())\n",
231+
"\n",
232+
" val tokenizer = HuggingFaceTokenizer.newInstance(\"gpt2\", mapOf(\"modelMaxLength\" to \"1024\"))\n",
233+
" val testString = \"Neurogenesis is most active during embryonic development and is responsible for producing \" +\n",
234+
" \"all the various types of neurons of the organism, but it continues throughout adult life \" +\n",
235+
" \"in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will \" +\n",
236+
" \"live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances.\"\n",
237+
" val encoded = tokenizer.encode(testString)\n",
238+
" val tokens = encoded.ids\n",
239+
" val tokensSize = tokens.size\n",
240+
"\n",
241+
" val predictionLength = 34\n",
242+
" val outputTokens = LongArray(predictionLength) { 0 }\n",
243+
"\n",
244+
" val input = ORTTensor(tokens, longArrayOf(1, 1, tokensSize.toLong()))\n",
245+
" var currentContext = input.clone(INPUT_TENSOR_NAME)\n",
246+
"\n",
247+
" print(\"Here goes the test text for generation:\\n$testString\")\n",
248+
"\n",
249+
" for (idx in 0 until predictionLength) {\n",
250+
" val inputTensor = listOf(currentContext)\n",
251+
" val output = model.predict(inputTensor)\n",
252+
"\n",
253+
" outputTokens[idx] = extractTopToken(convertToKITensorMap(output), tokensSize + idx, OUTPUT_TENSOR_NAME)\n",
254+
"\n",
255+
" val newTokenArray = tokens + outputTokens.slice(IntRange(0, idx))\n",
256+
" currentContext = ORTTensor(newTokenArray, longArrayOf(1, 1, tokensSize + idx + 1L), INPUT_TENSOR_NAME)\n",
257+
" print(tokenizer.decode(longArrayOf(outputTokens[idx])))\n",
258+
" }\n",
259+
" println(\"\\n\\nDone\")\n",
260+
"}"
261+
],
262+
"outputs": [
263+
{
264+
"name": "stdout",
265+
"output_type": "stream",
266+
"text": [
267+
"Downloading model from: https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx\n",
268+
"Loading model...\n",
269+
"Here goes the test text for generation:\n",
270+
"Neurogenesis is most active during embryonic development and is responsible for producing all the various types of neurons of the organism, but it continues throughout adult life in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances.\n",
271+
"\n",
272+
"The most common type of neurogenesis is the development of the hippocampus, which is the area of the brain that contains the hippocampus's electrical and chemical signals.\n",
273+
"\n",
274+
"Done\n"
275+
]
276+
}
277+
],
278+
"execution_count": 6
279+
}
280+
],
281+
"metadata": {
282+
"kernelspec": {
283+
"display_name": "Kotlin",
284+
"language": "kotlin",
285+
"name": "kotlin"
286+
},
287+
"language_info": {
288+
"name": "kotlin",
289+
"version": "1.9.23",
290+
"mimetype": "text/x-kotlin",
291+
"file_extension": ".kt",
292+
"pygments_lexer": "kotlin",
293+
"codemirror_mode": "text/x-kotlin",
294+
"nbconvert_exporter": ""
295+
},
296+
"ktnbPluginMetadata": {
297+
"projectLibraries": false
298+
}
299+
},
300+
"nbformat": 4,
301+
"nbformat_minor": 0
302+
}

0 commit comments

Comments
 (0)