diff --git a/src/panel_web_llm/main.py b/src/panel_web_llm/main.py index f34a72b..49b5f4b 100644 --- a/src/panel_web_llm/main.py +++ b/src/panel_web_llm/main.py @@ -1,16 +1,23 @@ """Panel components for the WebLLM interface.""" +from __future__ import annotations import asyncio +from collections.abc import Mapping +from typing import TYPE_CHECKING import panel as pn import param from panel.chat import ChatFeed from panel.chat import ChatInterface from panel.custom import JSComponent +from panel.models import ReactiveESM from .models import ModelParam from .settings import MODEL_MAPPING +if TYPE_CHECKING: + from bokeh.model import Model + class WebLLM(JSComponent): """ @@ -81,56 +88,8 @@ class WebLLM(JSComponent): """ ) - _esm = """ - import * as webllm from "https://esm.run/@mlc-ai/web-llm"; - - const engines = new Map() - - export async function render({ model }) { - model.on("msg:custom", async (event) => { - if (event.type === 'load') { - model.loading = true - if (!engines.has(model.model_slug)) { - const initProgressCallback = (load_status) => { - model.load_status = load_status - } - try { - const mlc = await webllm.CreateMLCEngine( - model.model_slug, - { initProgressCallback } - ) - engines.set(model.model_slug, mlc) - model.loaded = true - } catch (error) { - model.load_status = { - progress: 0, - text: error.message + " Try again later, or try a different size/quantization.", - }; - model.loaded = false; - } - } - model.loading = false - } else if (event.type === 'completion') { - const engine = engines.get(model.model_slug) - if (engine == null) { - model.send_msg({ 'finish_reason': 'error' }) - } - const chunks = await engine.chat.completions.create({ - messages: event.messages, - temperature: model.temperature, - stream: true, - }) - model.running = true - for await (const chunk of chunks) { - if (!model.running) { - break - } - model.send_msg(chunk.choices[0]) - } - } - }) - } - """ + _esm = "webllm.js" + _rename = {"loading": "loading"} def __init__(self, **params): """ @@ -224,17 +183,30 @@ def __init__(self, **params): if pn.state.location: pn.state.location.sync(self, {"model_slug": "model_slug"}) + def _set_on_model(self, msg: Mapping[str, Any], root: Model, model: Model) -> None: + if 'loading' in msg and isinstance(model, ReactiveESM): + model.data.loading = msg.pop('loading') + super()._set_on_model(msg, root, model) + + def _get_properties(self, doc: Document | None) -> dict[str, Any]: + props = super()._get_properties(doc) + props.pop('loading', None) + props['data'].loading = self.loading + return props + def _get_model_options(self, model_mapping): """ Generates the model options for the nested select widget. - Args: - model_mapping (dict): - A dictionary mapping model names to parameters and quantizations. + Parameters + ---------- + model_mapping : dict + A dictionary mapping model names to parameters and quantizations. Returns ------- - dict: A dictionary representing the model options. + dict + A dictionary representing the model options. """ model_options = { model_name: {parameters: list(quantizations.keys()) for parameters, quantizations in model_mapping[model_name].items()} @@ -255,20 +227,14 @@ def _update_model_select(self): if self.model_slug: model_params = ModelParam.from_model_slug(self.model_slug) value = model_params.to_dict(levels) - # TODO: Bug https://github.com/holoviz/panel/issues/7647 - # self._model_select.param.update( - # options=options, - # levels=levels, - # value=value, - # ) - self._model_select = pn.widgets.NestedSelect( + self._model_select.param.update( options=options, levels=levels, value=value, - layout=self._model_select.layout, ) - self._model_select_placeholder.object = self._model_select - self.param["model_slug"].objects = sorted(value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values()) + self.param["model_slug"].objects = sorted( + value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values() + ) def _update_model_slug(self, event): """ @@ -289,8 +255,9 @@ def _update_nested_select(self): @param.depends("load_model", watch=True) def _load_model(self): """Loads the model when the load_model event is triggered.""" - if self.model_slug in self._card_header.object: + if self.model_slug in self._card_header.object or self.loading: return + self.loading = True self.load_status = { "progress": 0, "text": f"Preparing to load {self.model_slug}", @@ -333,25 +300,37 @@ def _handle_msg(self, msg): if self.running: self._buffer.insert(0, msg) - async def create_completion(self, messages): + async def create_completion(self, messages, response_format=None, stream=False): """ Creates a chat completion with the WebLLM. - Args: - messages (list): - A list of message dictionaries representing the chat history. + Parameters + ---------- + messages : list + A list of message dictionaries representing the chat history. + response_format : dict, optional + The format to return the response in. + stream : bool, optional + Whether to stream the response chunks, by default False. Yields - ------ - dict: The response chunks from the LLM. + ------- + dict + The response chunks from the LLM. Raises - ------ - RuntimeError: If the model is not loaded. + ------- + RuntimeError + If the model is not loaded. """ - self._send_msg({"type": "completion", "messages": messages}) + while self.loading: + await asyncio.sleep(0.1) + await asyncio.sleep(0.1) + if not self.loaded: + return + self._send_msg({"type": "completion", "messages": messages, "response_format": response_format, "stream": stream}) while True: - await asyncio.sleep(0.01) + await asyncio.sleep(0.05) if not self._buffer: continue choice = self._buffer.pop() @@ -401,22 +380,27 @@ async def callback(self, contents: str, user: str, instance: ChatInterface): """ Callback function for chat completion. - Args: - contents (str): - The current user message. - user (str): - The username of the user sending the message. - instance (ChatInterface): - The ChatInterface instance. + Parameters + ---------- + contents : str + The current user message. + user : str + The username of the user sending the message. + instance : ChatInterface + The ChatInterface instance. Yields - ------ - dict or str: Yields either the messages as dict or a markdown string + ------- + dict or str + Yields either the messages as dict or a markdown string. Raises - ------ - RuntimeError: If the model is not loaded + ------- + RuntimeError + If the model is not loaded. """ + while self.loading: + await asyncio.sleep(0.1) if not self.loaded: return self.running = False @@ -436,7 +420,7 @@ def menu(self): Returns ------- - pn.widgets.NestedSelect: The model selection widget. + pn.widgets.NestedSelect: The model selection widget. """ return self._card diff --git a/src/panel_web_llm/webllm.js b/src/panel_web_llm/webllm.js new file mode 100644 index 0000000..2ed04ec --- /dev/null +++ b/src/panel_web_llm/webllm.js @@ -0,0 +1,91 @@ +import * as webllm from "https://esm.run/@mlc-ai/web-llm"; + +const engines = new Map() + +export async function render({ model }) { + model.on("msg:custom", async (event) => { + if (event.type === 'load') { + model.loading = true + if (!engines.has(model.model_slug)) { + const initProgressCallback = (load_status) => { + // Parse progress from cache loading messages like "[43/88]" + const match = load_status.text.match(/\[(\d+)\/(\d+)\]/) + if (match) { + const [_, current, total] = match + load_status.progress = current / total + } + model.load_status = load_status + } + try { + const mlc = await webllm.CreateMLCEngine( + model.model_slug, + { initProgressCallback } + ) + engines.set(model.model_slug, mlc) + model.loaded = true + } catch (error) { + console.warn(error.message) + model.load_status = { + progress: 0, + text: error.message + " Try again later, or try a different size/quantization.", + }; + model.loaded = false + } + } + model.loading = false + } else if (event.type === 'completion') { + const engine = engines.get(model.model_slug) + if (engine == null) { + model.send_msg({'finish_reason': 'error'}) + return + } + model.running = true + const format = event.response_format + const chunks = await engine.chat.completions.create({ + messages: event.messages, + temperature: model.temperature, + response_format: format ? { type: format.type, schema: format.schema ? JSON.stringify(format.schema) : undefined } : undefined, + stream: event.stream, + }) + if (event.stream) { + let buffer = "" + let current = null + let lastChunk = null + let timeout = null + const sendBuffer = () => { + if (buffer) { + console.log(buffer) + model.send_msg({ + delta: { content: buffer, role: current.delta.role }, + index: current.index, + finish_reason: null + }) + buffer = ""; + } + if (lastChunk && lastChunk.finish_reason) { + model.send_msg(lastChunk) + lastChunk = null + } + } + timeout = setInterval(sendBuffer, 200) + for await (const chunk of chunks) { + if (!model.running) { + break + } + const choice = chunk.choices[0] + if (choice.delta.content) { + current = choice + buffer += choice.delta.content; + } + if (choice.finish_reason) { + lastChunk = choice; + } + } + clearTimeout(timeout) + sendBuffer() + } else { + model.send_msg(chunks.choices[0]) + } + } + }) +}