1
1
"""Panel components for the WebLLM interface."""
2
+ from __future__ import annotations
2
3
3
4
import asyncio
5
+ from collections .abc import Mapping
6
+ from typing import TYPE_CHECKING
4
7
5
8
import panel as pn
6
9
import param
7
10
from panel .chat import ChatFeed
8
11
from panel .chat import ChatInterface
9
12
from panel .custom import JSComponent
13
+ from panel .models import ReactiveESM
10
14
11
15
from .models import ModelParam
12
16
from .settings import MODEL_MAPPING
13
17
18
+ if TYPE_CHECKING :
19
+ from bokeh .model import Model
20
+
14
21
15
22
class WebLLM (JSComponent ):
16
23
"""
@@ -81,56 +88,8 @@ class WebLLM(JSComponent):
81
88
"""
82
89
)
83
90
84
- _esm = """
85
- import * as webllm from "https://esm.run/@mlc-ai/web-llm";
86
-
87
- const engines = new Map()
88
-
89
- export async function render({ model }) {
90
- model.on("msg:custom", async (event) => {
91
- if (event.type === 'load') {
92
- model.loading = true
93
- if (!engines.has(model.model_slug)) {
94
- const initProgressCallback = (load_status) => {
95
- model.load_status = load_status
96
- }
97
- try {
98
- const mlc = await webllm.CreateMLCEngine(
99
- model.model_slug,
100
- { initProgressCallback }
101
- )
102
- engines.set(model.model_slug, mlc)
103
- model.loaded = true
104
- } catch (error) {
105
- model.load_status = {
106
- progress: 0,
107
- text: error.message + " Try again later, or try a different size/quantization.",
108
- };
109
- model.loaded = false;
110
- }
111
- }
112
- model.loading = false
113
- } else if (event.type === 'completion') {
114
- const engine = engines.get(model.model_slug)
115
- if (engine == null) {
116
- model.send_msg({ 'finish_reason': 'error' })
117
- }
118
- const chunks = await engine.chat.completions.create({
119
- messages: event.messages,
120
- temperature: model.temperature,
121
- stream: true,
122
- })
123
- model.running = true
124
- for await (const chunk of chunks) {
125
- if (!model.running) {
126
- break
127
- }
128
- model.send_msg(chunk.choices[0])
129
- }
130
- }
131
- })
132
- }
133
- """
91
+ _esm = "webllm.js"
92
+ _rename = {"loading" : "loading" }
134
93
135
94
def __init__ (self , ** params ):
136
95
"""
@@ -224,17 +183,30 @@ def __init__(self, **params):
224
183
if pn .state .location :
225
184
pn .state .location .sync (self , {"model_slug" : "model_slug" })
226
185
186
+ def _set_on_model (self , msg : Mapping [str , Any ], root : Model , model : Model ) -> None :
187
+ if 'loading' in msg and isinstance (model , ReactiveESM ):
188
+ model .data .loading = msg .pop ('loading' )
189
+ super ()._set_on_model (msg , root , model )
190
+
191
+ def _get_properties (self , doc : Document | None ) -> dict [str , Any ]:
192
+ props = super ()._get_properties (doc )
193
+ props .pop ('loading' , None )
194
+ props ['data' ].loading = self .loading
195
+ return props
196
+
227
197
def _get_model_options (self , model_mapping ):
228
198
"""
229
199
Generates the model options for the nested select widget.
230
200
231
- Args:
232
- model_mapping (dict):
233
- A dictionary mapping model names to parameters and quantizations.
201
+ Parameters
202
+ ----------
203
+ model_mapping : dict
204
+ A dictionary mapping model names to parameters and quantizations.
234
205
235
206
Returns
236
207
-------
237
- dict: A dictionary representing the model options.
208
+ dict
209
+ A dictionary representing the model options.
238
210
"""
239
211
model_options = {
240
212
model_name : {parameters : list (quantizations .keys ()) for parameters , quantizations in model_mapping [model_name ].items ()}
@@ -255,20 +227,14 @@ def _update_model_select(self):
255
227
if self .model_slug :
256
228
model_params = ModelParam .from_model_slug (self .model_slug )
257
229
value = model_params .to_dict (levels )
258
- # TODO: Bug https://github.com/holoviz/panel/issues/7647
259
- # self._model_select.param.update(
260
- # options=options,
261
- # levels=levels,
262
- # value=value,
263
- # )
264
- self ._model_select = pn .widgets .NestedSelect (
230
+ self ._model_select .param .update (
265
231
options = options ,
266
232
levels = levels ,
267
233
value = value ,
268
- layout = self ._model_select .layout ,
269
234
)
270
- self ._model_select_placeholder .object = self ._model_select
271
- self .param ["model_slug" ].objects = sorted (value for models in MODEL_MAPPING .values () for sizes in models .values () for value in sizes .values ())
235
+ self .param ["model_slug" ].objects = sorted (
236
+ value for models in MODEL_MAPPING .values () for sizes in models .values () for value in sizes .values ()
237
+ )
272
238
273
239
def _update_model_slug (self , event ):
274
240
"""
@@ -289,8 +255,9 @@ def _update_nested_select(self):
289
255
@param .depends ("load_model" , watch = True )
290
256
def _load_model (self ):
291
257
"""Loads the model when the load_model event is triggered."""
292
- if self .model_slug in self ._card_header .object :
258
+ if self .model_slug in self ._card_header .object or self . loading :
293
259
return
260
+ self .loading = True
294
261
self .load_status = {
295
262
"progress" : 0 ,
296
263
"text" : f"Preparing to load { self .model_slug } " ,
@@ -333,25 +300,37 @@ def _handle_msg(self, msg):
333
300
if self .running :
334
301
self ._buffer .insert (0 , msg )
335
302
336
- async def create_completion (self , messages ):
303
+ async def create_completion (self , messages , response_format = None , stream = False ):
337
304
"""
338
305
Creates a chat completion with the WebLLM.
339
306
340
- Args:
341
- messages (list):
342
- A list of message dictionaries representing the chat history.
307
+ Parameters
308
+ ----------
309
+ messages : list
310
+ A list of message dictionaries representing the chat history.
311
+ response_format : dict, optional
312
+ The format to return the response in.
313
+ stream : bool, optional
314
+ Whether to stream the response chunks, by default False.
343
315
344
316
Yields
345
- ------
346
- dict: The response chunks from the LLM.
317
+ -------
318
+ dict
319
+ The response chunks from the LLM.
347
320
348
321
Raises
349
- ------
350
- RuntimeError: If the model is not loaded.
322
+ -------
323
+ RuntimeError
324
+ If the model is not loaded.
351
325
"""
352
- self ._send_msg ({"type" : "completion" , "messages" : messages })
326
+ while self .loading :
327
+ await asyncio .sleep (0.1 )
328
+ await asyncio .sleep (0.1 )
329
+ if not self .loaded :
330
+ return
331
+ self ._send_msg ({"type" : "completion" , "messages" : messages , "response_format" : response_format , "stream" : stream })
353
332
while True :
354
- await asyncio .sleep (0.01 )
333
+ await asyncio .sleep (0.05 )
355
334
if not self ._buffer :
356
335
continue
357
336
choice = self ._buffer .pop ()
@@ -401,22 +380,27 @@ async def callback(self, contents: str, user: str, instance: ChatInterface):
401
380
"""
402
381
Callback function for chat completion.
403
382
404
- Args:
405
- contents (str):
406
- The current user message.
407
- user (str):
408
- The username of the user sending the message.
409
- instance (ChatInterface):
410
- The ChatInterface instance.
383
+ Parameters
384
+ ----------
385
+ contents : str
386
+ The current user message.
387
+ user : str
388
+ The username of the user sending the message.
389
+ instance : ChatInterface
390
+ The ChatInterface instance.
411
391
412
392
Yields
413
- ------
414
- dict or str: Yields either the messages as dict or a markdown string
393
+ -------
394
+ dict or str
395
+ Yields either the messages as dict or a markdown string.
415
396
416
397
Raises
417
- ------
418
- RuntimeError: If the model is not loaded
398
+ -------
399
+ RuntimeError
400
+ If the model is not loaded.
419
401
"""
402
+ while self .loading :
403
+ await asyncio .sleep (0.1 )
420
404
if not self .loaded :
421
405
return
422
406
self .running = False
@@ -436,7 +420,7 @@ def menu(self):
436
420
437
421
Returns
438
422
-------
439
- pn.widgets.NestedSelect: The model selection widget.
423
+ pn.widgets.NestedSelect: The model selection widget.
440
424
"""
441
425
return self ._card
442
426
0 commit comments