Skip to content

Commit be8fa7a

Browse files
committed
Support streaming and non-streaming and response format
1 parent 2e8b402 commit be8fa7a

File tree

2 files changed

+154
-97
lines changed

2 files changed

+154
-97
lines changed

src/panel_web_llm/main.py

Lines changed: 62 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class WebLLM(JSComponent):
6868
doc="Whether the model is loaded.",
6969
)
7070

71-
loading = param.Boolean(
71+
model_loading = param.Boolean(
7272
default=False,
7373
doc="""
7474
Whether the model is currently loading.""",
@@ -81,56 +81,7 @@ class WebLLM(JSComponent):
8181
"""
8282
)
8383

84-
_esm = """
85-
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
86-
87-
const engines = new Map()
88-
89-
export async function render({ model }) {
90-
model.on("msg:custom", async (event) => {
91-
if (event.type === 'load') {
92-
model.loading = true
93-
if (!engines.has(model.model_slug)) {
94-
const initProgressCallback = (load_status) => {
95-
model.load_status = load_status
96-
}
97-
try {
98-
const mlc = await webllm.CreateMLCEngine(
99-
model.model_slug,
100-
{ initProgressCallback }
101-
)
102-
engines.set(model.model_slug, mlc)
103-
model.loaded = true
104-
} catch (error) {
105-
model.load_status = {
106-
progress: 0,
107-
text: error.message + " Try again later, or try a different size/quantization.",
108-
};
109-
model.loaded = false;
110-
}
111-
}
112-
model.loading = false
113-
} else if (event.type === 'completion') {
114-
const engine = engines.get(model.model_slug)
115-
if (engine == null) {
116-
model.send_msg({ 'finish_reason': 'error' })
117-
}
118-
const chunks = await engine.chat.completions.create({
119-
messages: event.messages,
120-
temperature: model.temperature,
121-
stream: true,
122-
})
123-
model.running = true
124-
for await (const chunk of chunks) {
125-
if (!model.running) {
126-
break
127-
}
128-
model.send_msg(chunk.choices[0])
129-
}
130-
}
131-
})
132-
}
133-
"""
84+
_esm = "webllm.js"
13485

13586
def __init__(self, **params):
13687
"""
@@ -150,12 +101,12 @@ def __init__(self, **params):
150101

151102
self._history_input = pn.widgets.IntSlider.from_param(
152103
self.param.history,
153-
disabled=self.param.loading,
104+
disabled=self.param.model_loading,
154105
sizing_mode="stretch_width",
155106
)
156107
self._temperature_input = pn.widgets.FloatSlider.from_param(
157108
self.param.temperature,
158-
disabled=self.param.loading,
109+
disabled=self.param.model_loading,
159110
sizing_mode="stretch_width",
160111
)
161112
self._refresh_button = pn.widgets.ButtonIcon.from_param(
@@ -171,7 +122,7 @@ def __init__(self, **params):
171122
self._load_button = pn.widgets.Button.from_param(
172123
self.param.load_model,
173124
name=param.rx("Load ") + self.param.model_slug,
174-
loading=self.param.loading,
125+
loading=self.param.model_loading,
175126
align=("start", "end"),
176127
button_type="primary",
177128
description=None, # override default text
@@ -199,7 +150,7 @@ def __init__(self, **params):
199150
load_progress = pn.Column(
200151
pn.indicators.Progress(
201152
value=(load_status["progress"] * 100).rx.pipe(int),
202-
visible=self.param.loading,
153+
visible=self.param.model_loading,
203154
sizing_mode="stretch_width",
204155
margin=(5, 10, -10, 10),
205156
height=30,
@@ -227,14 +178,16 @@ def __init__(self, **params):
227178
def _get_model_options(self, model_mapping):
228179
"""
229180
Generates the model options for the nested select widget.
230-
231-
Args:
232-
model_mapping (dict):
233-
A dictionary mapping model names to parameters and quantizations.
234-
181+
182+
Parameters
183+
----------
184+
model_mapping : dict
185+
A dictionary mapping model names to parameters and quantizations.
186+
235187
Returns
236188
-------
237-
dict: A dictionary representing the model options.
189+
dict
190+
A dictionary representing the model options.
238191
"""
239192
model_options = {
240193
model_name: {parameters: list(quantizations.keys()) for parameters, quantizations in model_mapping[model_name].items()}
@@ -255,20 +208,14 @@ def _update_model_select(self):
255208
if self.model_slug:
256209
model_params = ModelParam.from_model_slug(self.model_slug)
257210
value = model_params.to_dict(levels)
258-
# TODO: Bug https://github.com/holoviz/panel/issues/7647
259-
# self._model_select.param.update(
260-
# options=options,
261-
# levels=levels,
262-
# value=value,
263-
# )
264-
self._model_select = pn.widgets.NestedSelect(
211+
self._model_select.param.update(
265212
options=options,
266213
levels=levels,
267214
value=value,
268-
layout=self._model_select.layout,
269215
)
270-
self._model_select_placeholder.object = self._model_select
271-
self.param["model_slug"].objects = sorted(value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values())
216+
self.param["model_slug"].objects = sorted(
217+
value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values()
218+
)
272219

273220
def _update_model_slug(self, event):
274221
"""
@@ -289,8 +236,9 @@ def _update_nested_select(self):
289236
@param.depends("load_model", watch=True)
290237
def _load_model(self):
291238
"""Loads the model when the load_model event is triggered."""
292-
if self.model_slug in self._card_header.object:
239+
if self.model_slug in self._card_header.object or self.model_loading:
293240
return
241+
self.model_loading = True
294242
self.load_status = {
295243
"progress": 0,
296244
"text": f"Preparing to load {self.model_slug}",
@@ -302,9 +250,9 @@ def _on_multiple_loads(self):
302250
if not self.multiple_loads and self.loaded:
303251
self._card.visible = False
304252

305-
@param.depends("loading", watch=True)
306-
def _on_loading(self):
307-
self._model_select.disabled = self.loading
253+
@param.depends("model_loading", watch=True)
254+
def _on_model_loading(self):
255+
self._model_select.disabled = self.model_loading
308256

309257
@param.depends("loaded", watch=True)
310258
def _on_loaded(self):
@@ -333,25 +281,37 @@ def _handle_msg(self, msg):
333281
if self.running:
334282
self._buffer.insert(0, msg)
335283

336-
async def create_completion(self, messages):
284+
async def create_completion(self, messages, response_format=None, stream=False):
337285
"""
338286
Creates a chat completion with the WebLLM.
339287
340-
Args:
341-
messages (list):
342-
A list of message dictionaries representing the chat history.
288+
Parameters
289+
----------
290+
messages : list
291+
A list of message dictionaries representing the chat history.
292+
response_format : dict, optional
293+
The format to return the response in.
294+
stream : bool, optional
295+
Whether to stream the response chunks, by default False.
343296
344297
Yields
345-
------
346-
dict: The response chunks from the LLM.
298+
-------
299+
dict
300+
The response chunks from the LLM.
347301
348302
Raises
349-
------
350-
RuntimeError: If the model is not loaded.
303+
-------
304+
RuntimeError
305+
If the model is not loaded.
351306
"""
352-
self._send_msg({"type": "completion", "messages": messages})
307+
while self.model_loading:
308+
await asyncio.sleep(0.1)
309+
await asyncio.sleep(0.1)
310+
if not self.loaded:
311+
return
312+
self._send_msg({"type": "completion", "messages": messages, "response_format": response_format, "stream": stream})
353313
while True:
354-
await asyncio.sleep(0.01)
314+
await asyncio.sleep(0.05)
355315
if not self._buffer:
356316
continue
357317
choice = self._buffer.pop()
@@ -401,22 +361,27 @@ async def callback(self, contents: str, user: str, instance: ChatInterface):
401361
"""
402362
Callback function for chat completion.
403363
404-
Args:
405-
contents (str):
406-
The current user message.
407-
user (str):
408-
The username of the user sending the message.
409-
instance (ChatInterface):
410-
The ChatInterface instance.
364+
Parameters
365+
----------
366+
contents : str
367+
The current user message.
368+
user : str
369+
The username of the user sending the message.
370+
instance : ChatInterface
371+
The ChatInterface instance.
411372
412373
Yields
413-
------
414-
dict or str: Yields either the messages as dict or a markdown string
374+
-------
375+
dict or str
376+
Yields either the messages as dict or a markdown string.
415377
416378
Raises
417-
------
418-
RuntimeError: If the model is not loaded
379+
-------
380+
RuntimeError
381+
If the model is not loaded.
419382
"""
383+
while self.model_loading:
384+
await asyncio.sleep(0.1)
420385
if not self.loaded:
421386
return
422387
self.running = False
@@ -436,7 +401,7 @@ def menu(self):
436401
437402
Returns
438403
-------
439-
pn.widgets.NestedSelect: The model selection widget.
404+
pn.widgets.NestedSelect: The model selection widget.
440405
"""
441406
return self._card
442407

src/panel_web_llm/webllm.js

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
2+
3+
const engines = new Map()
4+
5+
export async function render({ model }) {
6+
model.on("msg:custom", async (event) => {
7+
if (event.type === 'load') {
8+
model.model_loading = true
9+
if (!engines.has(model.model_slug)) {
10+
console.log("loading model", model.model_slug)
11+
const initProgressCallback = (load_status) => {
12+
// Parse progress from cache loading messages like "[43/88]"
13+
const match = load_status.text.match(/\[(\d+)\/(\d+)\]/)
14+
if (match) {
15+
const [_, current, total] = match
16+
load_status.progress = current / total
17+
}
18+
model.load_status = load_status
19+
}
20+
try {
21+
const mlc = await webllm.CreateMLCEngine(
22+
model.model_slug,
23+
{ initProgressCallback }
24+
)
25+
engines.set(model.model_slug, mlc)
26+
model.loaded = true
27+
} catch (error) {
28+
console.warn(error.message)
29+
model.load_status = {
30+
progress: 0,
31+
text: error.message + " Try again later, or try a different size/quantization.",
32+
};
33+
model.loaded = false
34+
}
35+
}
36+
model.model_loading = false
37+
} else if (event.type === 'completion') {
38+
const engine = engines.get(model.model_slug)
39+
if (engine == null) {
40+
model.send_msg({'finish_reason': 'error'})
41+
return
42+
}
43+
model.running = true
44+
const format = event.response_format
45+
const chunks = await engine.chat.completions.create({
46+
messages: event.messages,
47+
temperature: model.temperature,
48+
response_format: format ? { type: format.type, schema: format.schema ? JSON.stringify(format.schema) : undefined } : undefined,
49+
stream: event.stream,
50+
})
51+
if (event.stream) {
52+
let buffer = ""
53+
let current = null
54+
let lastChunk = null
55+
let timeout = null
56+
const sendBuffer = () => {
57+
if (buffer) {
58+
console.log(buffer)
59+
model.send_msg({
60+
delta: { content: buffer, role: current.delta.role },
61+
index: current.index,
62+
finish_reason: null
63+
})
64+
buffer = "";
65+
}
66+
if (lastChunk && lastChunk.finish_reason) {
67+
model.send_msg(lastChunk)
68+
lastChunk = null
69+
}
70+
}
71+
timeout = setInterval(sendBuffer, 200)
72+
for await (const chunk of chunks) {
73+
if (!model.running) {
74+
break
75+
}
76+
const choice = chunk.choices[0]
77+
if (choice.delta.content) {
78+
current = choice
79+
buffer += choice.delta.content;
80+
}
81+
if (choice.finish_reason) {
82+
lastChunk = choice;
83+
}
84+
}
85+
clearTimeout(timeout)
86+
sendBuffer()
87+
} else {
88+
model.send_msg(chunks.choices[0])
89+
}
90+
}
91+
})
92+
}

0 commit comments

Comments
 (0)