@@ -68,7 +68,7 @@ class WebLLM(JSComponent):
68
68
doc = "Whether the model is loaded." ,
69
69
)
70
70
71
- loading = param .Boolean (
71
+ model_loading = param .Boolean (
72
72
default = False ,
73
73
doc = """
74
74
Whether the model is currently loading.""" ,
@@ -81,56 +81,7 @@ class WebLLM(JSComponent):
81
81
"""
82
82
)
83
83
84
- _esm = """
85
- import * as webllm from "https://esm.run/@mlc-ai/web-llm";
86
-
87
- const engines = new Map()
88
-
89
- export async function render({ model }) {
90
- model.on("msg:custom", async (event) => {
91
- if (event.type === 'load') {
92
- model.loading = true
93
- if (!engines.has(model.model_slug)) {
94
- const initProgressCallback = (load_status) => {
95
- model.load_status = load_status
96
- }
97
- try {
98
- const mlc = await webllm.CreateMLCEngine(
99
- model.model_slug,
100
- { initProgressCallback }
101
- )
102
- engines.set(model.model_slug, mlc)
103
- model.loaded = true
104
- } catch (error) {
105
- model.load_status = {
106
- progress: 0,
107
- text: error.message + " Try again later, or try a different size/quantization.",
108
- };
109
- model.loaded = false;
110
- }
111
- }
112
- model.loading = false
113
- } else if (event.type === 'completion') {
114
- const engine = engines.get(model.model_slug)
115
- if (engine == null) {
116
- model.send_msg({ 'finish_reason': 'error' })
117
- }
118
- const chunks = await engine.chat.completions.create({
119
- messages: event.messages,
120
- temperature: model.temperature,
121
- stream: true,
122
- })
123
- model.running = true
124
- for await (const chunk of chunks) {
125
- if (!model.running) {
126
- break
127
- }
128
- model.send_msg(chunk.choices[0])
129
- }
130
- }
131
- })
132
- }
133
- """
84
+ _esm = "webllm.js"
134
85
135
86
def __init__ (self , ** params ):
136
87
"""
@@ -150,12 +101,12 @@ def __init__(self, **params):
150
101
151
102
self ._history_input = pn .widgets .IntSlider .from_param (
152
103
self .param .history ,
153
- disabled = self .param .loading ,
104
+ disabled = self .param .model_loading ,
154
105
sizing_mode = "stretch_width" ,
155
106
)
156
107
self ._temperature_input = pn .widgets .FloatSlider .from_param (
157
108
self .param .temperature ,
158
- disabled = self .param .loading ,
109
+ disabled = self .param .model_loading ,
159
110
sizing_mode = "stretch_width" ,
160
111
)
161
112
self ._refresh_button = pn .widgets .ButtonIcon .from_param (
@@ -171,7 +122,7 @@ def __init__(self, **params):
171
122
self ._load_button = pn .widgets .Button .from_param (
172
123
self .param .load_model ,
173
124
name = param .rx ("Load " ) + self .param .model_slug ,
174
- loading = self .param .loading ,
125
+ loading = self .param .model_loading ,
175
126
align = ("start" , "end" ),
176
127
button_type = "primary" ,
177
128
description = None , # override default text
@@ -199,7 +150,7 @@ def __init__(self, **params):
199
150
load_progress = pn .Column (
200
151
pn .indicators .Progress (
201
152
value = (load_status ["progress" ] * 100 ).rx .pipe (int ),
202
- visible = self .param .loading ,
153
+ visible = self .param .model_loading ,
203
154
sizing_mode = "stretch_width" ,
204
155
margin = (5 , 10 , - 10 , 10 ),
205
156
height = 30 ,
@@ -227,14 +178,16 @@ def __init__(self, **params):
227
178
def _get_model_options (self , model_mapping ):
228
179
"""
229
180
Generates the model options for the nested select widget.
230
-
231
- Args:
232
- model_mapping (dict):
233
- A dictionary mapping model names to parameters and quantizations.
234
-
181
+
182
+ Parameters
183
+ ----------
184
+ model_mapping : dict
185
+ A dictionary mapping model names to parameters and quantizations.
186
+
235
187
Returns
236
188
-------
237
- dict: A dictionary representing the model options.
189
+ dict
190
+ A dictionary representing the model options.
238
191
"""
239
192
model_options = {
240
193
model_name : {parameters : list (quantizations .keys ()) for parameters , quantizations in model_mapping [model_name ].items ()}
@@ -255,20 +208,14 @@ def _update_model_select(self):
255
208
if self .model_slug :
256
209
model_params = ModelParam .from_model_slug (self .model_slug )
257
210
value = model_params .to_dict (levels )
258
- # TODO: Bug https://github.com/holoviz/panel/issues/7647
259
- # self._model_select.param.update(
260
- # options=options,
261
- # levels=levels,
262
- # value=value,
263
- # )
264
- self ._model_select = pn .widgets .NestedSelect (
211
+ self ._model_select .param .update (
265
212
options = options ,
266
213
levels = levels ,
267
214
value = value ,
268
- layout = self ._model_select .layout ,
269
215
)
270
- self ._model_select_placeholder .object = self ._model_select
271
- self .param ["model_slug" ].objects = sorted (value for models in MODEL_MAPPING .values () for sizes in models .values () for value in sizes .values ())
216
+ self .param ["model_slug" ].objects = sorted (
217
+ value for models in MODEL_MAPPING .values () for sizes in models .values () for value in sizes .values ()
218
+ )
272
219
273
220
def _update_model_slug (self , event ):
274
221
"""
@@ -289,8 +236,9 @@ def _update_nested_select(self):
289
236
@param .depends ("load_model" , watch = True )
290
237
def _load_model (self ):
291
238
"""Loads the model when the load_model event is triggered."""
292
- if self .model_slug in self ._card_header .object :
239
+ if self .model_slug in self ._card_header .object or self . model_loading :
293
240
return
241
+ self .model_loading = True
294
242
self .load_status = {
295
243
"progress" : 0 ,
296
244
"text" : f"Preparing to load { self .model_slug } " ,
@@ -302,9 +250,9 @@ def _on_multiple_loads(self):
302
250
if not self .multiple_loads and self .loaded :
303
251
self ._card .visible = False
304
252
305
- @param .depends ("loading " , watch = True )
306
- def _on_loading (self ):
307
- self ._model_select .disabled = self .loading
253
+ @param .depends ("model_loading " , watch = True )
254
+ def _on_model_loading (self ):
255
+ self ._model_select .disabled = self .model_loading
308
256
309
257
@param .depends ("loaded" , watch = True )
310
258
def _on_loaded (self ):
@@ -333,25 +281,37 @@ def _handle_msg(self, msg):
333
281
if self .running :
334
282
self ._buffer .insert (0 , msg )
335
283
336
- async def create_completion (self , messages ):
284
+ async def create_completion (self , messages , response_format = None , stream = False ):
337
285
"""
338
286
Creates a chat completion with the WebLLM.
339
287
340
- Args:
341
- messages (list):
342
- A list of message dictionaries representing the chat history.
288
+ Parameters
289
+ ----------
290
+ messages : list
291
+ A list of message dictionaries representing the chat history.
292
+ response_format : dict, optional
293
+ The format to return the response in.
294
+ stream : bool, optional
295
+ Whether to stream the response chunks, by default False.
343
296
344
297
Yields
345
- ------
346
- dict: The response chunks from the LLM.
298
+ -------
299
+ dict
300
+ The response chunks from the LLM.
347
301
348
302
Raises
349
- ------
350
- RuntimeError: If the model is not loaded.
303
+ -------
304
+ RuntimeError
305
+ If the model is not loaded.
351
306
"""
352
- self ._send_msg ({"type" : "completion" , "messages" : messages })
307
+ while self .model_loading :
308
+ await asyncio .sleep (0.1 )
309
+ await asyncio .sleep (0.1 )
310
+ if not self .loaded :
311
+ return
312
+ self ._send_msg ({"type" : "completion" , "messages" : messages , "response_format" : response_format , "stream" : stream })
353
313
while True :
354
- await asyncio .sleep (0.01 )
314
+ await asyncio .sleep (0.05 )
355
315
if not self ._buffer :
356
316
continue
357
317
choice = self ._buffer .pop ()
@@ -401,22 +361,27 @@ async def callback(self, contents: str, user: str, instance: ChatInterface):
401
361
"""
402
362
Callback function for chat completion.
403
363
404
- Args:
405
- contents (str):
406
- The current user message.
407
- user (str):
408
- The username of the user sending the message.
409
- instance (ChatInterface):
410
- The ChatInterface instance.
364
+ Parameters
365
+ ----------
366
+ contents : str
367
+ The current user message.
368
+ user : str
369
+ The username of the user sending the message.
370
+ instance : ChatInterface
371
+ The ChatInterface instance.
411
372
412
373
Yields
413
- ------
414
- dict or str: Yields either the messages as dict or a markdown string
374
+ -------
375
+ dict or str
376
+ Yields either the messages as dict or a markdown string.
415
377
416
378
Raises
417
- ------
418
- RuntimeError: If the model is not loaded
379
+ -------
380
+ RuntimeError
381
+ If the model is not loaded.
419
382
"""
383
+ while self .model_loading :
384
+ await asyncio .sleep (0.1 )
420
385
if not self .loaded :
421
386
return
422
387
self .running = False
@@ -436,7 +401,7 @@ def menu(self):
436
401
437
402
Returns
438
403
-------
439
- pn.widgets.NestedSelect: The model selection widget.
404
+ pn.widgets.NestedSelect: The model selection widget.
440
405
"""
441
406
return self ._card
442
407
0 commit comments