diff --git a/.env.example b/.env.example index 95288c0..da0ea88 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,7 @@ +# Vereist voor authenticatie (zijbalk met gesprekshistorie) +# Genereer met: python -c "import secrets; print(secrets.token_hex(32))" +CHAINLIT_AUTH_SECRET= + # Model in litellm-formaat: provider/model-naam # Zie https://docs.litellm.ai/docs/providers voor alle opties MODEL=anthropic/claude-sonnet-4-6 diff --git a/.gitignore b/.gitignore index 159775e..9b08717 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,8 @@ chainlit.md socket:* .uv_cache/ + +# Lokale data — niet in repo +data/sessions.db +data/sessions.db-shm +data/sessions.db-wal diff --git a/app.py b/app.py index bc5a874..711424b 100644 --- a/app.py +++ b/app.py @@ -8,9 +8,34 @@ import chainlit as cl +import persistence from agent import run +from data_layer import SQLiteDataLayer, init_db from report import generate_report +init_db() + + +@cl.data_layer +def get_data_layer() -> SQLiteDataLayer: + return SQLiteDataLayer() + + +@cl.header_auth_callback +def auth_callback(headers) -> cl.User: + return cl.User(identifier="local", metadata={"role": "user"}) + + +@cl.on_chat_resume +async def on_resume(thread: cl.types.ThreadDict) -> None: + metadata = thread.get("metadata") or {} + messages = metadata.get("messages", []) + cl.user_session.set("messages", messages) + cl.user_session.set("figures", []) + cl.user_session.set("turns", []) + cl.user_session.set("session_id", thread["id"]) + cl.user_session.set("turn_figures", []) + WELKOM = """Welkom! Ik kan je helpen met vragen over open Nederlandse onderwijsdata. Ik heb toegang tot: @@ -38,6 +63,7 @@ async def set_starters(): @cl.on_chat_start async def on_start(): + cl.user_session.set("session_id", cl.context.session.id) cl.user_session.set("messages", []) cl.user_session.set("figures", []) cl.user_session.set("turns", []) @@ -47,11 +73,54 @@ async def on_start(): await cl.Message(content="⚠️ Geen API key gevonden. Stel een omgevingsvariabele in (bijv. `ANTHROPIC_API_KEY`) en herstart de app.").send() return - await cl.Message(content=WELKOM).send() + await cl.Message(content=WELKOM + "\n\nTip: typ `/history` om eerdere gesprekken te hervatten.").send() + + +@cl.action_callback("resume_session") +async def on_resume_session(action: cl.Action): + session_id = action.payload["session_id"] + messages = persistence.load(session_id) + if not messages: + await cl.Message(content="Gesprek niet meer beschikbaar.").send() + return + + cl.user_session.set("messages", messages) + cl.user_session.set("session_id", session_id) + + user_turns = [(m["content"], messages[i + 1]["content"] if i + 1 < len(messages) else "") + for i, m in enumerate(messages) if m["role"] == "user"] + + summary_lines = "\n".join( + f"- **V:** {q[:80]}{'…' if len(q) > 80 else ''}\n **A:** {a[:120]}{'…' if len(a) > 120 else ''}" + for q, a in user_turns[-5:] + ) + await cl.Message( + content=f"Gesprek herladen ({len(user_turns)} vragen). Laatste uitwisselingen:\n\n{summary_lines}\n\nGa gerust verder." + ).send() + + +async def _show_history() -> None: + sessions = persistence.recent() + if not sessions: + await cl.Message(content="Nog geen opgeslagen gesprekken.").send() + return + actions = [ + cl.Action( + name="resume_session", + label=f"↩ {s['title'][:55]}{'…' if len(s['title']) > 55 else ''}", + payload={"session_id": s["id"]}, + ) + for s in sessions + ] + await cl.Message(content="**Vorige gesprekken:**", actions=actions).send() @cl.on_message async def on_message(message: cl.Message): + if message.content.strip().lower() in ("/history", "/gesprekken"): + await _show_history() + return + messages: list = cl.user_session.get("messages") messages.append({"role": "user", "content": message.content}) @@ -71,6 +140,13 @@ async def on_message(message: cl.Message): turns.append({"question": message.content, "answer": response_text, "figures": turn_figures}) cl.user_session.set("turns", turns) + # Persist LLM messages in thread metadata so on_chat_resume can restore them + thread_id = cl.context.session.thread_id + from chainlit.data import get_data_layer + dl = get_data_layer() + if dl: + await dl.update_thread(thread_id, metadata={"messages": messages}) + @cl.action_callback("download_rapport") async def on_download_rapport(action: cl.Action): diff --git a/data_layer.py b/data_layer.py new file mode 100644 index 0000000..668cec9 --- /dev/null +++ b/data_layer.py @@ -0,0 +1,271 @@ +""" +Minimale SQLite-gebaseerde Chainlit data layer. +Levert: sidebar met gespreksgeschiedenis + hervatten van eerdere chats. +""" + +import json +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional + +from chainlit.data.base import BaseDataLayer +from chainlit.step import StepDict +from chainlit.types import ( + Feedback, + PageInfo, + PaginatedResponse, + Pagination, + ThreadDict, + ThreadFilter, +) +from chainlit.user import PersistedUser, User + +_DB = Path("data/sessions.db") + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _conn() -> sqlite3.Connection: + _DB.parent.mkdir(exist_ok=True) + conn = sqlite3.connect(_DB) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + return conn + + +def init_db() -> None: + with _conn() as conn: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + identifier TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + name TEXT, + created_at TEXT NOT NULL, + metadata TEXT NOT NULL DEFAULT '{}' + ); + CREATE TABLE IF NOT EXISTS steps ( + id TEXT PRIMARY KEY, + thread_id TEXT NOT NULL, + type TEXT NOT NULL DEFAULT 'undefined', + name TEXT, + input TEXT, + output TEXT, + created_at TEXT NOT NULL, + is_error INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_steps_thread ON steps(thread_id); + """) + + +def _row_to_step(row: sqlite3.Row) -> StepDict: + return StepDict( + id=row["id"], + threadId=row["thread_id"], + type=row["type"], + name=row["name"] or "", + input=row["input"] or "", + output=row["output"] or "", + createdAt=row["created_at"], + isError=bool(row["is_error"]), + streaming=False, + metadata={}, + ) + + +def _row_to_thread(row: sqlite3.Row, steps: List[StepDict] | None = None) -> ThreadDict: + return ThreadDict( + id=row["id"], + createdAt=row["created_at"], + name=row["name"], + userId=None, + userIdentifier=None, + tags=None, + metadata=json.loads(row["metadata"] or "{}"), + steps=steps or [], + elements=[], + ) + + +class SQLiteDataLayer(BaseDataLayer): + + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + with _conn() as conn: + row = conn.execute( + "SELECT * FROM users WHERE identifier=?", (identifier,) + ).fetchone() + if row: + return PersistedUser(id=row["id"], identifier=row["identifier"], createdAt=row["created_at"]) + return None + + async def create_user(self, user: User) -> Optional[PersistedUser]: + uid = user.identifier + ts = _now() + with _conn() as conn: + conn.execute( + "INSERT OR IGNORE INTO users(id, identifier, created_at) VALUES(?,?,?)", + (uid, uid, ts), + ) + row = conn.execute( + "SELECT created_at FROM users WHERE identifier=?", (uid,) + ).fetchone() + return PersistedUser(id=uid, identifier=uid, createdAt=row["created_at"]) + + async def delete_feedback(self, feedback_id: str) -> bool: + return True + + async def upsert_feedback(self, feedback: Feedback) -> str: + return "" + + async def get_thread_author(self, thread_id: str) -> str: + return "anonymous" + + async def list_threads( + self, pagination: Pagination, filters: ThreadFilter + ) -> PaginatedResponse[ThreadDict]: + limit = pagination.first or 20 + cursor_ts = pagination.cursor or "9999-99-99" + + search = f"%{filters.search}%" if filters.search else None + + with _conn() as conn: + if search: + rows = conn.execute( + """SELECT * FROM threads + WHERE created_at < ? AND name LIKE ? + ORDER BY created_at DESC LIMIT ?""", + (cursor_ts, search, limit + 1), + ).fetchall() + else: + rows = conn.execute( + """SELECT * FROM threads + WHERE created_at < ? + ORDER BY created_at DESC LIMIT ?""", + (cursor_ts, limit + 1), + ).fetchall() + + has_next = len(rows) > limit + rows = rows[:limit] + threads = [_row_to_thread(r) for r in rows] + + end_cursor = rows[-1]["created_at"] if rows else None + return PaginatedResponse( + pageInfo=PageInfo( + hasNextPage=has_next, + startCursor=rows[0]["created_at"] if rows else None, + endCursor=end_cursor, + ), + data=threads, + ) + + async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: + with _conn() as conn: + t = conn.execute( + "SELECT * FROM threads WHERE id=?", (thread_id,) + ).fetchone() + if not t: + return None + step_rows = conn.execute( + "SELECT * FROM steps WHERE thread_id=? ORDER BY created_at", + (thread_id,), + ).fetchall() + + steps = [_row_to_step(r) for r in step_rows] + return _row_to_thread(t, steps) + + async def update_thread( + self, + thread_id: str, + name: Optional[str] = None, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, + tags: Optional[List[str]] = None, + ) -> None: + with _conn() as conn: + exists = conn.execute( + "SELECT id FROM threads WHERE id=?", (thread_id,) + ).fetchone() + if exists: + if name is not None: + conn.execute( + "UPDATE threads SET name=? WHERE id=?", (name, thread_id) + ) + if metadata is not None: + conn.execute( + "UPDATE threads SET metadata=? WHERE id=?", + (json.dumps(metadata, ensure_ascii=False, default=str), thread_id), + ) + else: + conn.execute( + "INSERT INTO threads(id, name, created_at, metadata) VALUES(?,?,?,?)", + ( + thread_id, + name, + _now(), + json.dumps(metadata or {}, ensure_ascii=False, default=str), + ), + ) + + async def delete_thread(self, thread_id: str) -> bool: + with _conn() as conn: + conn.execute("DELETE FROM threads WHERE id=?", (thread_id,)) + return True + + async def create_step(self, step_dict: StepDict) -> None: + with _conn() as conn: + # Ensure parent thread exists + conn.execute( + "INSERT OR IGNORE INTO threads(id, name, created_at) VALUES(?,?,?)", + (step_dict.get("threadId"), None, _now()), + ) + conn.execute( + """INSERT OR REPLACE INTO steps + (id, thread_id, type, name, input, output, created_at, is_error) + VALUES(?,?,?,?,?,?,?,?)""", + ( + step_dict.get("id"), + step_dict.get("threadId"), + step_dict.get("type", "undefined"), + step_dict.get("name", ""), + step_dict.get("input", ""), + step_dict.get("output", ""), + step_dict.get("createdAt") or _now(), + int(bool(step_dict.get("isError", False))), + ), + ) + + async def update_step(self, step_dict: StepDict) -> None: + await self.create_step(step_dict) + + async def delete_step(self, step_id: str) -> bool: + with _conn() as conn: + conn.execute("DELETE FROM steps WHERE id=?", (step_id,)) + return True + + async def create_element(self, element) -> None: + pass + + async def get_element(self, thread_id: str, element_id: str): + return None + + async def delete_element(self, element_id: str, thread_id: Optional[str] = None) -> bool: + return True + + async def get_favorite_steps(self, thread_id: str) -> List[StepDict]: + return [] + + async def set_step_favorite(self, step_id: str, is_favorite: bool) -> None: + pass + + async def build_debug_url(self) -> str: + return "" + + async def close(self) -> None: + pass diff --git a/persistence.py b/persistence.py new file mode 100644 index 0000000..e8b4d37 --- /dev/null +++ b/persistence.py @@ -0,0 +1,64 @@ +import json +import sqlite3 +from datetime import datetime +from pathlib import Path + +_DB = Path("data/sessions.db") + + +def _connect() -> sqlite3.Connection: + _DB.parent.mkdir(exist_ok=True) + conn = sqlite3.connect(_DB) + conn.row_factory = sqlite3.Row + return conn + + +def init() -> None: + with _connect() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + title TEXT, + saved_at TEXT NOT NULL, + messages TEXT NOT NULL + ) + """) + + +def save(session_id: str, messages: list[dict]) -> None: + user_msgs = [m for m in messages if m["role"] == "user"] + title = user_msgs[0]["content"][:72] if user_msgs else "Gesprek" + serialisable = [ + m for m in messages + if isinstance(m.get("content"), str) or m.get("content") is None + ] + with _connect() as conn: + conn.execute( + """INSERT INTO sessions(id, title, saved_at, messages) VALUES(?,?,?,?) + ON CONFLICT(id) DO UPDATE SET title=excluded.title, + saved_at=excluded.saved_at, messages=excluded.messages""", + (session_id, title, datetime.now().isoformat(), + json.dumps(serialisable, ensure_ascii=False)), + ) + + +def load(session_id: str) -> list[dict] | None: + with _connect() as conn: + row = conn.execute( + "SELECT messages FROM sessions WHERE id=?", (session_id,) + ).fetchone() + return json.loads(row["messages"]) if row else None + + +def recent(limit: int = 4) -> list[dict]: + with _connect() as conn: + rows = conn.execute( + "SELECT id, title, saved_at FROM sessions ORDER BY saved_at DESC LIMIT ?", + (limit,), + ).fetchall() + return [dict(r) for r in rows] + + +def delete(session_id: str) -> None: + with _connect() as conn: + conn.execute("DELETE FROM sessions WHERE id=?", (session_id,))