Skip to content

Commit 9d9818b

Browse files
authored
Merge pull request #9 from panel-extensions/improvements
Support streaming and non-streaming and response format
2 parents 2e8b402 + 9602804 commit 9d9818b

File tree

2 files changed

+162
-87
lines changed

2 files changed

+162
-87
lines changed

src/panel_web_llm/main.py

Lines changed: 71 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
"""Panel components for the WebLLM interface."""
2+
from __future__ import annotations
23

34
import asyncio
5+
from collections.abc import Mapping
6+
from typing import TYPE_CHECKING
47

58
import panel as pn
69
import param
710
from panel.chat import ChatFeed
811
from panel.chat import ChatInterface
912
from panel.custom import JSComponent
13+
from panel.models import ReactiveESM
1014

1115
from .models import ModelParam
1216
from .settings import MODEL_MAPPING
1317

18+
if TYPE_CHECKING:
19+
from bokeh.model import Model
20+
1421

1522
class WebLLM(JSComponent):
1623
"""
@@ -81,56 +88,8 @@ class WebLLM(JSComponent):
8188
"""
8289
)
8390

84-
_esm = """
85-
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
86-
87-
const engines = new Map()
88-
89-
export async function render({ model }) {
90-
model.on("msg:custom", async (event) => {
91-
if (event.type === 'load') {
92-
model.loading = true
93-
if (!engines.has(model.model_slug)) {
94-
const initProgressCallback = (load_status) => {
95-
model.load_status = load_status
96-
}
97-
try {
98-
const mlc = await webllm.CreateMLCEngine(
99-
model.model_slug,
100-
{ initProgressCallback }
101-
)
102-
engines.set(model.model_slug, mlc)
103-
model.loaded = true
104-
} catch (error) {
105-
model.load_status = {
106-
progress: 0,
107-
text: error.message + " Try again later, or try a different size/quantization.",
108-
};
109-
model.loaded = false;
110-
}
111-
}
112-
model.loading = false
113-
} else if (event.type === 'completion') {
114-
const engine = engines.get(model.model_slug)
115-
if (engine == null) {
116-
model.send_msg({ 'finish_reason': 'error' })
117-
}
118-
const chunks = await engine.chat.completions.create({
119-
messages: event.messages,
120-
temperature: model.temperature,
121-
stream: true,
122-
})
123-
model.running = true
124-
for await (const chunk of chunks) {
125-
if (!model.running) {
126-
break
127-
}
128-
model.send_msg(chunk.choices[0])
129-
}
130-
}
131-
})
132-
}
133-
"""
91+
_esm = "webllm.js"
92+
_rename = {"loading": "loading"}
13493

13594
def __init__(self, **params):
13695
"""
@@ -224,17 +183,30 @@ def __init__(self, **params):
224183
if pn.state.location:
225184
pn.state.location.sync(self, {"model_slug": "model_slug"})
226185

186+
def _set_on_model(self, msg: Mapping[str, Any], root: Model, model: Model) -> None:
187+
if 'loading' in msg and isinstance(model, ReactiveESM):
188+
model.data.loading = msg.pop('loading')
189+
super()._set_on_model(msg, root, model)
190+
191+
def _get_properties(self, doc: Document | None) -> dict[str, Any]:
192+
props = super()._get_properties(doc)
193+
props.pop('loading', None)
194+
props['data'].loading = self.loading
195+
return props
196+
227197
def _get_model_options(self, model_mapping):
228198
"""
229199
Generates the model options for the nested select widget.
230200
231-
Args:
232-
model_mapping (dict):
233-
A dictionary mapping model names to parameters and quantizations.
201+
Parameters
202+
----------
203+
model_mapping : dict
204+
A dictionary mapping model names to parameters and quantizations.
234205
235206
Returns
236207
-------
237-
dict: A dictionary representing the model options.
208+
dict
209+
A dictionary representing the model options.
238210
"""
239211
model_options = {
240212
model_name: {parameters: list(quantizations.keys()) for parameters, quantizations in model_mapping[model_name].items()}
@@ -255,20 +227,14 @@ def _update_model_select(self):
255227
if self.model_slug:
256228
model_params = ModelParam.from_model_slug(self.model_slug)
257229
value = model_params.to_dict(levels)
258-
# TODO: Bug https://github.com/holoviz/panel/issues/7647
259-
# self._model_select.param.update(
260-
# options=options,
261-
# levels=levels,
262-
# value=value,
263-
# )
264-
self._model_select = pn.widgets.NestedSelect(
230+
self._model_select.param.update(
265231
options=options,
266232
levels=levels,
267233
value=value,
268-
layout=self._model_select.layout,
269234
)
270-
self._model_select_placeholder.object = self._model_select
271-
self.param["model_slug"].objects = sorted(value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values())
235+
self.param["model_slug"].objects = sorted(
236+
value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values()
237+
)
272238

273239
def _update_model_slug(self, event):
274240
"""
@@ -289,8 +255,9 @@ def _update_nested_select(self):
289255
@param.depends("load_model", watch=True)
290256
def _load_model(self):
291257
"""Loads the model when the load_model event is triggered."""
292-
if self.model_slug in self._card_header.object:
258+
if self.model_slug in self._card_header.object or self.loading:
293259
return
260+
self.loading = True
294261
self.load_status = {
295262
"progress": 0,
296263
"text": f"Preparing to load {self.model_slug}",
@@ -333,25 +300,37 @@ def _handle_msg(self, msg):
333300
if self.running:
334301
self._buffer.insert(0, msg)
335302

336-
async def create_completion(self, messages):
303+
async def create_completion(self, messages, response_format=None, stream=False):
337304
"""
338305
Creates a chat completion with the WebLLM.
339306
340-
Args:
341-
messages (list):
342-
A list of message dictionaries representing the chat history.
307+
Parameters
308+
----------
309+
messages : list
310+
A list of message dictionaries representing the chat history.
311+
response_format : dict, optional
312+
The format to return the response in.
313+
stream : bool, optional
314+
Whether to stream the response chunks, by default False.
343315
344316
Yields
345-
------
346-
dict: The response chunks from the LLM.
317+
-------
318+
dict
319+
The response chunks from the LLM.
347320
348321
Raises
349-
------
350-
RuntimeError: If the model is not loaded.
322+
-------
323+
RuntimeError
324+
If the model is not loaded.
351325
"""
352-
self._send_msg({"type": "completion", "messages": messages})
326+
while self.loading:
327+
await asyncio.sleep(0.1)
328+
await asyncio.sleep(0.1)
329+
if not self.loaded:
330+
return
331+
self._send_msg({"type": "completion", "messages": messages, "response_format": response_format, "stream": stream})
353332
while True:
354-
await asyncio.sleep(0.01)
333+
await asyncio.sleep(0.05)
355334
if not self._buffer:
356335
continue
357336
choice = self._buffer.pop()
@@ -401,22 +380,27 @@ async def callback(self, contents: str, user: str, instance: ChatInterface):
401380
"""
402381
Callback function for chat completion.
403382
404-
Args:
405-
contents (str):
406-
The current user message.
407-
user (str):
408-
The username of the user sending the message.
409-
instance (ChatInterface):
410-
The ChatInterface instance.
383+
Parameters
384+
----------
385+
contents : str
386+
The current user message.
387+
user : str
388+
The username of the user sending the message.
389+
instance : ChatInterface
390+
The ChatInterface instance.
411391
412392
Yields
413-
------
414-
dict or str: Yields either the messages as dict or a markdown string
393+
-------
394+
dict or str
395+
Yields either the messages as dict or a markdown string.
415396
416397
Raises
417-
------
418-
RuntimeError: If the model is not loaded
398+
-------
399+
RuntimeError
400+
If the model is not loaded.
419401
"""
402+
while self.loading:
403+
await asyncio.sleep(0.1)
420404
if not self.loaded:
421405
return
422406
self.running = False
@@ -436,7 +420,7 @@ def menu(self):
436420
437421
Returns
438422
-------
439-
pn.widgets.NestedSelect: The model selection widget.
423+
pn.widgets.NestedSelect: The model selection widget.
440424
"""
441425
return self._card
442426

src/panel_web_llm/webllm.js

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
2+
3+
const engines = new Map()
4+
5+
export async function render({ model }) {
6+
model.on("msg:custom", async (event) => {
7+
if (event.type === 'load') {
8+
model.loading = true
9+
if (!engines.has(model.model_slug)) {
10+
const initProgressCallback = (load_status) => {
11+
// Parse progress from cache loading messages like "[43/88]"
12+
const match = load_status.text.match(/\[(\d+)\/(\d+)\]/)
13+
if (match) {
14+
const [_, current, total] = match
15+
load_status.progress = current / total
16+
}
17+
model.load_status = load_status
18+
}
19+
try {
20+
const mlc = await webllm.CreateMLCEngine(
21+
model.model_slug,
22+
{ initProgressCallback }
23+
)
24+
engines.set(model.model_slug, mlc)
25+
model.loaded = true
26+
} catch (error) {
27+
console.warn(error.message)
28+
model.load_status = {
29+
progress: 0,
30+
text: error.message + " Try again later, or try a different size/quantization.",
31+
};
32+
model.loaded = false
33+
}
34+
}
35+
model.loading = false
36+
} else if (event.type === 'completion') {
37+
const engine = engines.get(model.model_slug)
38+
if (engine == null) {
39+
model.send_msg({'finish_reason': 'error'})
40+
return
41+
}
42+
model.running = true
43+
const format = event.response_format
44+
const chunks = await engine.chat.completions.create({
45+
messages: event.messages,
46+
temperature: model.temperature,
47+
response_format: format ? { type: format.type, schema: format.schema ? JSON.stringify(format.schema) : undefined } : undefined,
48+
stream: event.stream,
49+
})
50+
if (event.stream) {
51+
let buffer = ""
52+
let current = null
53+
let lastChunk = null
54+
let timeout = null
55+
const sendBuffer = () => {
56+
if (buffer) {
57+
console.log(buffer)
58+
model.send_msg({
59+
delta: { content: buffer, role: current.delta.role },
60+
index: current.index,
61+
finish_reason: null
62+
})
63+
buffer = "";
64+
}
65+
if (lastChunk && lastChunk.finish_reason) {
66+
model.send_msg(lastChunk)
67+
lastChunk = null
68+
}
69+
}
70+
timeout = setInterval(sendBuffer, 200)
71+
for await (const chunk of chunks) {
72+
if (!model.running) {
73+
break
74+
}
75+
const choice = chunk.choices[0]
76+
if (choice.delta.content) {
77+
current = choice
78+
buffer += choice.delta.content;
79+
}
80+
if (choice.finish_reason) {
81+
lastChunk = choice;
82+
}
83+
}
84+
clearTimeout(timeout)
85+
sendBuffer()
86+
} else {
87+
model.send_msg(chunks.choices[0])
88+
}
89+
}
90+
})
91+
}

0 commit comments

Comments
 (0)