Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 71 additions & 87 deletions src/panel_web_llm/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""Panel components for the WebLLM interface."""
from __future__ import annotations

import asyncio
from collections.abc import Mapping
from typing import TYPE_CHECKING

import panel as pn
import param
from panel.chat import ChatFeed
from panel.chat import ChatInterface
from panel.custom import JSComponent
from panel.models import ReactiveESM

from .models import ModelParam
from .settings import MODEL_MAPPING

if TYPE_CHECKING:
from bokeh.model import Model


class WebLLM(JSComponent):
"""
Expand Down Expand Up @@ -81,56 +88,8 @@ class WebLLM(JSComponent):
"""
)

_esm = """
import * as webllm from "https://esm.run/@mlc-ai/web-llm";

const engines = new Map()

export async function render({ model }) {
model.on("msg:custom", async (event) => {
if (event.type === 'load') {
model.loading = true
if (!engines.has(model.model_slug)) {
const initProgressCallback = (load_status) => {
model.load_status = load_status
}
try {
const mlc = await webllm.CreateMLCEngine(
model.model_slug,
{ initProgressCallback }
)
engines.set(model.model_slug, mlc)
model.loaded = true
} catch (error) {
model.load_status = {
progress: 0,
text: error.message + " Try again later, or try a different size/quantization.",
};
model.loaded = false;
}
}
model.loading = false
} else if (event.type === 'completion') {
const engine = engines.get(model.model_slug)
if (engine == null) {
model.send_msg({ 'finish_reason': 'error' })
}
const chunks = await engine.chat.completions.create({
messages: event.messages,
temperature: model.temperature,
stream: true,
})
model.running = true
for await (const chunk of chunks) {
if (!model.running) {
break
}
model.send_msg(chunk.choices[0])
}
}
})
}
"""
_esm = "webllm.js"
_rename = {"loading": "loading"}

def __init__(self, **params):
"""
Expand Down Expand Up @@ -224,17 +183,30 @@ def __init__(self, **params):
if pn.state.location:
pn.state.location.sync(self, {"model_slug": "model_slug"})

def _set_on_model(self, msg: Mapping[str, Any], root: Model, model: Model) -> None:
if 'loading' in msg and isinstance(model, ReactiveESM):
model.data.loading = msg.pop('loading')
super()._set_on_model(msg, root, model)

def _get_properties(self, doc: Document | None) -> dict[str, Any]:
props = super()._get_properties(doc)
props.pop('loading', None)
props['data'].loading = self.loading
return props

def _get_model_options(self, model_mapping):
"""
Generates the model options for the nested select widget.

Args:
model_mapping (dict):
A dictionary mapping model names to parameters and quantizations.
Parameters
----------
model_mapping : dict
A dictionary mapping model names to parameters and quantizations.

Returns
-------
dict: A dictionary representing the model options.
dict
A dictionary representing the model options.
"""
model_options = {
model_name: {parameters: list(quantizations.keys()) for parameters, quantizations in model_mapping[model_name].items()}
Expand All @@ -255,20 +227,14 @@ def _update_model_select(self):
if self.model_slug:
model_params = ModelParam.from_model_slug(self.model_slug)
value = model_params.to_dict(levels)
# TODO: Bug https://github.com/holoviz/panel/issues/7647
# self._model_select.param.update(
# options=options,
# levels=levels,
# value=value,
# )
self._model_select = pn.widgets.NestedSelect(
self._model_select.param.update(
options=options,
levels=levels,
value=value,
layout=self._model_select.layout,
)
self._model_select_placeholder.object = self._model_select
self.param["model_slug"].objects = sorted(value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values())
self.param["model_slug"].objects = sorted(
value for models in MODEL_MAPPING.values() for sizes in models.values() for value in sizes.values()
)

def _update_model_slug(self, event):
"""
Expand All @@ -289,8 +255,9 @@ def _update_nested_select(self):
@param.depends("load_model", watch=True)
def _load_model(self):
"""Loads the model when the load_model event is triggered."""
if self.model_slug in self._card_header.object:
if self.model_slug in self._card_header.object or self.loading:
return
self.loading = True
self.load_status = {
"progress": 0,
"text": f"Preparing to load {self.model_slug}",
Expand Down Expand Up @@ -333,25 +300,37 @@ def _handle_msg(self, msg):
if self.running:
self._buffer.insert(0, msg)

async def create_completion(self, messages):
async def create_completion(self, messages, response_format=None, stream=False):
"""
Creates a chat completion with the WebLLM.

Args:
messages (list):
A list of message dictionaries representing the chat history.
Parameters
----------
messages : list
A list of message dictionaries representing the chat history.
response_format : dict, optional
The format to return the response in.
stream : bool, optional
Whether to stream the response chunks, by default False.

Yields
------
dict: The response chunks from the LLM.
-------
dict
The response chunks from the LLM.

Raises
------
RuntimeError: If the model is not loaded.
-------
RuntimeError
If the model is not loaded.
"""
self._send_msg({"type": "completion", "messages": messages})
while self.loading:
await asyncio.sleep(0.1)
await asyncio.sleep(0.1)
if not self.loaded:
return
self._send_msg({"type": "completion", "messages": messages, "response_format": response_format, "stream": stream})
while True:
await asyncio.sleep(0.01)
await asyncio.sleep(0.05)
if not self._buffer:
continue
choice = self._buffer.pop()
Expand Down Expand Up @@ -401,22 +380,27 @@ async def callback(self, contents: str, user: str, instance: ChatInterface):
"""
Callback function for chat completion.

Args:
contents (str):
The current user message.
user (str):
The username of the user sending the message.
instance (ChatInterface):
The ChatInterface instance.
Parameters
----------
contents : str
The current user message.
user : str
The username of the user sending the message.
instance : ChatInterface
The ChatInterface instance.

Yields
------
dict or str: Yields either the messages as dict or a markdown string
-------
dict or str
Yields either the messages as dict or a markdown string.

Raises
------
RuntimeError: If the model is not loaded
-------
RuntimeError
If the model is not loaded.
"""
while self.loading:
await asyncio.sleep(0.1)
if not self.loaded:
return
self.running = False
Expand All @@ -436,7 +420,7 @@ def menu(self):

Returns
-------
pn.widgets.NestedSelect: The model selection widget.
pn.widgets.NestedSelect: The model selection widget.
"""
return self._card

Expand Down
91 changes: 91 additions & 0 deletions src/panel_web_llm/webllm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import * as webllm from "https://esm.run/@mlc-ai/web-llm";

const engines = new Map()

export async function render({ model }) {
model.on("msg:custom", async (event) => {
if (event.type === 'load') {
model.loading = true
if (!engines.has(model.model_slug)) {
const initProgressCallback = (load_status) => {
// Parse progress from cache loading messages like "[43/88]"
const match = load_status.text.match(/\[(\d+)\/(\d+)\]/)
if (match) {
const [_, current, total] = match
load_status.progress = current / total
}
model.load_status = load_status
}
try {
const mlc = await webllm.CreateMLCEngine(
model.model_slug,
{ initProgressCallback }
)
engines.set(model.model_slug, mlc)
model.loaded = true
} catch (error) {
console.warn(error.message)
model.load_status = {
progress: 0,
text: error.message + " Try again later, or try a different size/quantization.",
};
model.loaded = false
}
}
model.loading = false
} else if (event.type === 'completion') {
const engine = engines.get(model.model_slug)
if (engine == null) {
model.send_msg({'finish_reason': 'error'})
return
}
model.running = true
const format = event.response_format
const chunks = await engine.chat.completions.create({
messages: event.messages,
temperature: model.temperature,
response_format: format ? { type: format.type, schema: format.schema ? JSON.stringify(format.schema) : undefined } : undefined,
stream: event.stream,
})
if (event.stream) {
let buffer = ""
let current = null
let lastChunk = null
let timeout = null
const sendBuffer = () => {
if (buffer) {
console.log(buffer)
model.send_msg({
delta: { content: buffer, role: current.delta.role },
index: current.index,
finish_reason: null
})
buffer = "";
}
if (lastChunk && lastChunk.finish_reason) {
model.send_msg(lastChunk)
lastChunk = null
}
}
timeout = setInterval(sendBuffer, 200)
for await (const chunk of chunks) {
if (!model.running) {
break
}
const choice = chunk.choices[0]
if (choice.delta.content) {
current = choice
buffer += choice.delta.content;
}
if (choice.finish_reason) {
lastChunk = choice;
}
}
clearTimeout(timeout)
sendBuffer()
} else {
model.send_msg(chunks.choices[0])
}
}
})
}
Loading