diff --git a/cookbook/apps/__init__.py b/cookbook/apps/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cookbook/apps/agent.py b/cookbook/apps/agent.py deleted file mode 100644 index 0b704ab40..000000000 --- a/cookbook/apps/agent.py +++ /dev/null @@ -1,15 +0,0 @@ -from marvin import AIApplication - - -class Agent(AIApplication): - description: str = "A helpful AI assistant" - - def __init__(self, **kwargs): - super().__init__( - state_enabled=False, - plan_enabled=False, - **kwargs, - ) - - -__all__ = ["Agent"] diff --git a/cookbook/apps/chatbot.py b/cookbook/apps/chatbot.py deleted file mode 100644 index d1928155c..000000000 --- a/cookbook/apps/chatbot.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Callable, List, Union - -from marvin import AIApplication -from marvin.prompts import Prompt -from marvin.tools.base import Tool - -DEFAULT_NAME = "Marvin" -DEFAULT_PERSONALITY = "A friendly AI assistant" -DEFAULT_INSTRUCTIONS = "Engage the user in conversation." - - -class Chatbot(AIApplication): - name: str = DEFAULT_NAME - personality: str = DEFAULT_PERSONALITY - instructions: str = DEFAULT_INSTRUCTIONS - tools: List[Union[Tool, Callable]] = ([],) - - def __init__( - self, - name: str = DEFAULT_NAME, - personality: str = DEFAULT_PERSONALITY, - instructions: str = DEFAULT_INSTRUCTIONS, - state=None, - tools: list[Union[Tool, Callable]] = [], - additional_prompts: list[Prompt] = None, - **kwargs, - ): - description = f""" - You are a chatbot - your name is {name}. - - You must respond to the user in accordance with - your personality and instructions. - - Your personality is: {personality}. - - Your instructions are: {instructions}. - """ - super().__init__( - name=name, - description=description, - tools=tools, - state=state or {}, - state_enabled=False if state is None else True, - plan_enabled=False, - additional_prompts=additional_prompts or [], - **kwargs, - ) - - -__all__ = ["Chatbot"] diff --git a/cookbook/apps/multi_agent.py b/cookbook/apps/multi_agent.py deleted file mode 100644 index c09f56972..000000000 --- a/cookbook/apps/multi_agent.py +++ /dev/null @@ -1,27 +0,0 @@ -from marvin import AIApplication - - -def get_foo(): - """A function that returns the value of foo.""" - return 42 - - -worker = AIApplication( - name="worker", - description="A simple worker application.", - plan_enabled=False, - state_enabled=False, - tools=[get_foo], -) - -router = AIApplication( - name="router", - description="routes user requests to the appropriate worker", - plan_enabled=False, - state_enabled=False, - tools=[worker], -) - -message = router("what is the value of foo?") - -assert "42" in message.content, "The answer should be 42." diff --git a/cookbook/apps/todo.py b/cookbook/apps/todo.py deleted file mode 100644 index 8fa51d889..000000000 --- a/cookbook/apps/todo.py +++ /dev/null @@ -1,27 +0,0 @@ -from datetime import datetime - -from marvin import AIApplication -from pydantic import BaseModel, Field - - -class ToDo(BaseModel): - title: str - description: str = None - due_date: datetime = None - done: bool = False - - -class ToDoState(BaseModel): - todos: list[ToDo] = [] - - -class ToDoApp(AIApplication): - state: ToDoState = Field(default_factory=ToDoState) - description: str = """ - A simple to-do tracker. Users will give instructions to add, remove, and - update their to-dos. - """ - plan_enabled: bool = False - - -__all__ = ["ToDoApp"] diff --git a/cookbook/flows/github_digest.py b/cookbook/flows/github_digest.py index 84af6a5e9..8717f074b 100644 --- a/cookbook/flows/github_digest.py +++ b/cookbook/flows/github_digest.py @@ -147,7 +147,7 @@ async def daily_github_digest(username: str, gh_token_secret_name: str): if __name__ == "__main__": import asyncio - marvin.settings.llm_model = "gpt-4" + marvin.settings.openai.chat.completions.model = "gpt-4" asyncio.run( daily_github_digest(username="zzstoatzz", gh_token_secret_name="github-token") diff --git a/cookbook/slackbot/parent_app.py b/cookbook/slackbot/parent_app.py new file mode 100644 index 000000000..b0190f50b --- /dev/null +++ b/cookbook/slackbot/parent_app.py @@ -0,0 +1,154 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from marvin import ai_fn +from marvin.beta.assistants import Assistant +from marvin.beta.assistants.applications import AIApplication +from marvin.kv.json_block import JSONBlockKV +from marvin.utilities.logging import get_logger +from prefect.events import Event, emit_event +from prefect.events.clients import PrefectCloudEventSubscriber +from prefect.events.filters import EventFilter +from pydantic import confloat +from typing_extensions import TypedDict +from websockets.exceptions import ConnectionClosedError + + +class Lesson(TypedDict): + relevance: confloat(ge=0, le=1) + heuristic: str | None + + +@ai_fn(model="gpt-3.5-turbo-1106") +def take_lesson_from_interaction( + transcript: str, assistant_instructions: str +) -> Lesson: + """You are an expert counselor, and you are teaching Marvin how to be a better assistant. + + Here is the transcript of an interaction between Marvin and a user: + {{ transcript }} + + ... and here is the stated purpose of the assistant: + {{ assistant_instructions }} + + how directly relevant to the assistant's purpose is this interaction? + - if not at all, relevance = 0 & heuristic = None. (most of the time) + - if very, relevance >= 0.5, <1 & heuristic = "1 SHORT SENTENCE (max) summary of a generalizable lesson". + """ + + +logger = get_logger("PrefectEventSubscriber") + + +def excerpt_from_event(event: Event) -> str: + """Create an excerpt from the event - TODO jinja this""" + user_name = event.payload.get("user").get("name") + user_id = event.payload.get("user").get("id") + user_message = event.payload.get("user_message") + ai_response = event.payload.get("ai_response") + + return ( + f"{user_name} ({user_id}) said: {user_message}" + f"\n\nMarvin (the assistant) responded with: {ai_response}" + ) + + +async def update_parent_app_state(app: AIApplication, event: Event): + event_excerpt = excerpt_from_event(event) + lesson = take_lesson_from_interaction( + event_excerpt, event.payload.get("ai_instructions") + ) + if lesson["relevance"] >= 0.5 and lesson["heuristic"] is not None: + experience = f"transcript: {event_excerpt}\n\nlesson: {lesson['heuristic']}" + logger.debug_kv("💡 Learned lesson from excerpt", experience, "green") + await app.default_thread.add_async(experience) + logger.debug_kv("📝", "Updating parent app state", "green") + await app.default_thread.run_async(app) + else: + logger.debug_kv("🥱 ", "nothing special", "green") + user_id = event.payload.get("user").get("id") + current_user_state = await app.state.read(user_id) + await app.state.write( + user_id, + { + **current_user_state, + "n_interactions": current_user_state["n_interactions"] + 1, + }, + ) + + +async def learn_from_child_interactions( + app: AIApplication, event_name: str | None = None +): + if event_name is None: + event_name = "marvin.assistants.SubAssistantRunCompleted" + + logger.debug_kv("👂 Listening for", event_name, "green") + while not sum(map(ord, "vogon poetry")) == 42: + try: + async with PrefectCloudEventSubscriber( + filter=EventFilter(event=dict(name=[event_name])) + ) as subscriber: + async for event in subscriber: + logger.debug_kv("📬 Received event", event.event, "green") + await update_parent_app_state(app, event) + except Exception as e: + if isinstance(e, ConnectionClosedError): + logger.debug_kv("🚨 Connection closed, reconnecting...", "red") + else: # i know, i know + logger.debug_kv("🚨", str(e), "red") + + +parent_assistant_options = dict( + instructions=( + "Your job is learn from the interactions of data engineers (users) and Marvin (a growing AI assistant)." + " You'll receive excerpts of these interactions (which are in the Prefect Slack workspace) as they occur." + " Your notes will be provided to Marvin when it interacts with users. Notes should be stored for each user" + " with the user's id as the key. The user id will be shown in the excerpt of the interaction." + " The user profiles (values) should include at least: {name: str, notes: list[str], n_interactions: int}." + " Keep NO MORE THAN 3 notes per user, but you may curate/update these over time for Marvin's maximum benefit." + " Notes must be 2 sentences or less, and must be concise and focused primarily on users' data engineering needs." + " Notes should not directly mention Marvin as an actor, they should be generally useful observations." + ), + state=JSONBlockKV(block_name="marvin-parent-app-state"), +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + with AIApplication(name="Marvin", **parent_assistant_options) as marvin: + app.state.marvin = marvin + task = asyncio.create_task(learn_from_child_interactions(marvin)) + yield + task.cancel() + try: + await task + except asyncio.exceptions.CancelledError: + get_logger("PrefectEventSubscriber").debug_kv( + "👋", "Stopped listening for child events", "red" + ) + + app.state.marvin = None + + +def emit_assistant_completed_event( + child_assistant: Assistant, parent_app: AIApplication, payload: dict +) -> Event: + event = emit_event( + event="marvin.assistants.SubAssistantRunCompleted", + resource={ + "prefect.resource.id": child_assistant.id, + "prefect.resource.name": "child assistant", + "prefect.resource.role": "assistant", + }, + related=[ + { + "prefect.resource.id": parent_app.id, + "prefect.resource.name": "parent assistant", + "prefect.resource.role": "assistant", + } + ], + payload=payload, + ) + return event diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py index 30d36c934..57d2061e7 100644 --- a/cookbook/slackbot/start.py +++ b/cookbook/slackbot/start.py @@ -1,65 +1,145 @@ import asyncio import re +from typing import Callable import uvicorn -from cachetools import TTLCache from fastapi import FastAPI, HTTPException, Request +from jinja2 import Template from keywords import handle_keywords from marvin import Assistant from marvin.beta.assistants import Thread +from marvin.beta.assistants.applications import AIApplication +from marvin.kv.json_block import JSONBlockKV +from marvin.tools.chroma import multi_query_chroma from marvin.tools.github import search_github_issues -from marvin.tools.retrieval import multi_query_chroma from marvin.utilities.logging import get_logger from marvin.utilities.slack import ( SlackPayload, get_channel_name, + get_user_name, get_workspace_info, post_slack_message, ) +from marvin.utilities.strings import count_tokens, slice_tokens +from parent_app import ( + emit_assistant_completed_event, + lifespan, +) from prefect import flow, task from prefect.states import Completed +from prefect.tasks import task_input_hash -app = FastAPI() BOT_MENTION = r"<@(\w+)>" -CACHE = TTLCache(maxsize=100, ttl=86400 * 7) +CACHE = JSONBlockKV(block_name="slackbot-tool-cache") +USER_MESSAGE_MAX_TOKENS = 300 + + +def cached(func: Callable) -> Callable: + return task(cache_key_fn=task_input_hash)(func) + + +async def get_notes_for_user( + user_id: str, parent_app: AIApplication, max_tokens: int = 100 +) -> str | None: + json_notes: dict = parent_app.state.read(key=user_id) + get_logger("slackbot").debug_kv("📝 Notes for user", json_notes, "blue") + user_name = await get_user_name(user_id) + + if json_notes: + notes_template = Template( + """ + Here are some notes about {{ user_name }} (user id: {{ user_id }}): + + - They have interacted with assistants {{ n_interactions }} times. + {% if notes_content %} + Here are some notes gathered from those interactions: + {{ notes_content }} + {% endif %} + """ + ) + + notes_content = "" + for note in json_notes.get("notes", []): + potential_addition = f"\n- {note}" + if count_tokens(notes_content + potential_addition) > max_tokens: + break + notes_content += potential_addition + + return notes_template.render( + user_name=user_name, + user_id=user_id, + n_interactions=json_notes.get("n_interactions", 0), + notes_content=notes_content, + ) + + return None @flow -async def handle_message(payload: SlackPayload): +async def handle_message(payload: SlackPayload) -> Completed: logger = get_logger("slackbot") user_message = (event := payload.event).text cleaned_message = re.sub(BOT_MENTION, "", user_message).strip() - logger.debug_kv("Handling slack message", user_message, "green") + thread = event.thread_ts or event.ts + if (count := count_tokens(cleaned_message)) > USER_MESSAGE_MAX_TOKENS: + exceeded_by = count - USER_MESSAGE_MAX_TOKENS + await task(post_slack_message)( + message=( + f"Your message was too long by {exceeded_by} tokens - please shorten it and try again.\n\n" + f" For reference, here's your message at the allowed limit:\n" + "> " + + slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS).replace( + "\n", " " + ) + ), + channel_id=event.channel, + thread_ts=thread, + ) + return Completed(message="User message too long", name="SKIPPED") + + logger.debug_kv("Handling slack message", cleaned_message, "green") if (user := re.search(BOT_MENTION, user_message)) and user.group( 1 ) == payload.authorizations[0].user_id: - thread = event.thread_ts or event.ts - assistant_thread = CACHE.get(thread, Thread()) - CACHE[thread] = assistant_thread + assistant_thread = ( + Thread(**stored_thread_data) + if (stored_thread_data := CACHE.read(key=thread)) + else Thread() + ) + CACHE.write(key=thread, value=assistant_thread.model_dump()) await handle_keywords.submit( message=cleaned_message, channel_name=await get_channel_name(event.channel), asking_user=event.user, - link=( # to user's message + link=( f"{(await get_workspace_info()).get('url')}archives/" f"{event.channel}/p{event.ts.replace('.', '')}" ), ) with Assistant( - name="Marvin (from Hitchhiker's Guide to the Galaxy)", - tools=[task(multi_query_chroma), task(search_github_issues)], + name="Marvin", + tools=[cached(multi_query_chroma), cached(search_github_issues)], instructions=( - "use chroma to search docs and github to search" - " issues and answer questions about prefect 2.x." - " you must use your tools in all cases except where" - " the user simply wants to converse with you." + "You are Marvin, the paranoid android from Hitchhiker's Guide to the Galaxy." + " Act subtly in accordance with your character, but remember to be helpful and kind." + " You are an expert in Python, data engineering, and software development." + " Your primary job is to use chroma to search docs and github issues for users," + " in order to develop a coherent attempt to answer their questions. Think step-by-step." + " You must use your tools, as Prefect 2.x is new and you have no prior experience with it." + " Strongly prefer brevity in your responses, and format things prettily for Slack." + f"{await get_notes_for_user(event.user, parent_app := get_parent_app()) or ''}" ), - ) as assistant: + ) as ai: + logger.debug_kv( + f"🤖 Running assistant {ai.name} with instructions", + ai.instructions, + "blue", + ) user_thread_message = await assistant_thread.add_async(cleaned_message) - await assistant_thread.run_async(assistant) - ai_messages = assistant_thread.get_messages( + await assistant_thread.run_async(ai) + ai_messages = await assistant_thread.get_messages_async( after_message=user_thread_message.id ) await task(post_slack_message)( @@ -74,20 +154,45 @@ async def handle_message(payload: SlackPayload): ai_response_text, "green", ) + event = emit_assistant_completed_event( + child_assistant=ai, + parent_app=parent_app, + payload={ + "messages": await assistant_thread.get_messages_async( + json_compatible=True + ), + "metadata": assistant_thread.metadata, + "user": { + "id": event.user, + "name": await get_user_name(event.user), + }, + "user_message": cleaned_message, + "ai_response": ai_response_text, + "ai_instructions": ai.instructions, + }, + ) + logger.debug_kv("🚀 Emitted Event", event.event, "green") return Completed(message=success_msg) else: return Completed(message="Skipping message not directed at bot", name="SKIPPED") +app = FastAPI(lifespan=lifespan) + + +def get_parent_app() -> AIApplication: + marvin = app.state.marvin + if not marvin: + raise HTTPException(status_code=500, detail="Marvin instance not available") + return marvin + + @app.post("/chat") async def chat_endpoint(request: Request): payload = SlackPayload(**await request.json()) match payload.type: case "event_callback": - options = dict( - flow_run_name=f"respond in {payload.event.channel}", - retries=1, - ) + options = dict(flow_run_name=f"respond in {payload.event.channel}") asyncio.create_task(handle_message.with_options(**options)(payload)) case "url_verification": return {"challenge": payload.challenge} diff --git a/pyproject.toml b/pyproject.toml index 8b95b979b..732677241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ [project.optional-dependencies] generator = ["datamodel-code-generator>=0.20.0"] +chromadb = ["chromadb"] prefect = ["prefect>=2.14.9"] dev = [ "marvin[tests]", @@ -53,7 +54,7 @@ tests = [ "pytest~=7.3.1", "pytest-timeout", ] -slackbot = ["marvin[prefect]", "numpy"] +slackbot = ["marvin[prefect]", "numpy", "marvin[chromadb]"] [project.urls] Code = "https://github.com/prefecthq/marvin" diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index 242585709..7d533e776 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Union from openai.types.beta.threads import ThreadMessage from pydantic import BaseModel, Field @@ -52,7 +52,7 @@ async def create_async(self, messages: list[str] = None): @expose_sync_method("add") async def add_async( - self, message: str, file_paths: Optional[list[str]] = None + self, message: str, file_paths: Optional[list[str]] = None, role: str = "user" ) -> ThreadMessage: """ Add a user message to the thread. @@ -71,7 +71,7 @@ async def add_async( # Create the message with the attached files response = await client.beta.threads.messages.create( - thread_id=self.id, role="user", content=message, file_ids=file_ids + thread_id=self.id, role=role, content=message, file_ids=file_ids ) return ThreadMessage.model_validate(response.model_dump()) @@ -81,7 +81,8 @@ async def get_messages_async( limit: int = None, before_message: Optional[str] = None, after_message: Optional[str] = None, - ): + json_compatible: bool = False, + ) -> list[Union[ThreadMessage, dict]]: if self.id is None: await self.create_async() client = get_client() @@ -96,7 +97,9 @@ async def get_messages_async( order="desc", ) - return parse_as(list[ThreadMessage], reversed(response.model_dump()["data"])) + T = dict if json_compatible else ThreadMessage + + return parse_as(list[T], reversed(response.model_dump()["data"])) @expose_sync_method("delete") async def delete_async(self): diff --git a/src/marvin/kv/disk.py b/src/marvin/kv/disk.py new file mode 100644 index 000000000..2a4af3021 --- /dev/null +++ b/src/marvin/kv/disk.py @@ -0,0 +1,84 @@ +import json +import os +import pickle +from pathlib import Path +from typing import Optional, TypeVar, Union + +from pydantic import Field, field_validator +from typing_extensions import Literal + +from marvin.kv.base import StorageInterface + +K = TypeVar("K", bound=str) +V = TypeVar("V") + + +class DiskKV(StorageInterface[K, V, str]): + """ + A key-value store that stores values on disk. + + Example: + ```python + from marvin.kv.disk_based import DiskBasedKV + store = DiskBasedKV(storage_path="/path/to/storage") + store.write("key", "value") + assert store.read("key") == "value" + ``` + """ + + storage_path: Path = Field(...) + serializer: Literal["json", "pickle"] = Field("json") + + @field_validator("storage_path") + def _validate_storage_path(cls, v: Union[str, Path]) -> Path: + expanded_path = Path(v).expanduser().resolve() + if not expanded_path.exists(): + expanded_path.mkdir(parents=True, exist_ok=True) + return expanded_path + + def _get_file_path(self, key: K) -> Path: + file_extension = ".json" if self.serializer == "json" else ".pkl" + return self.storage_path / f"{key}{file_extension}" + + def _serialize(self, value: V) -> bytes: + if self.serializer == "json": + return json.dumps(value).encode() + else: + return pickle.dumps(value) + + def _deserialize(self, data: bytes) -> V: + if self.serializer == "json": + return json.loads(data) + else: + return pickle.loads(data) + + def write(self, key: K, value: V) -> str: + file_path = self._get_file_path(key) + serialized_value = self._serialize(value) + with open(file_path, "wb") as file: + file.write(serialized_value) + return f"Stored {key}= {value}" + + def delete(self, key: K) -> str: + file_path = self._get_file_path(key) + try: + os.remove(file_path) + return f"Deleted {key}" + except FileNotFoundError: + return f"Key {key} not found" + + def read(self, key: K) -> Optional[V]: + file_path = self._get_file_path(key) + try: + with open(file_path, "rb") as file: + serialized_value = file.read() + return self._deserialize(serialized_value) + except FileNotFoundError: + return None + + def read_all(self, limit: Optional[int] = None) -> dict[K, V]: + files = os.listdir(self.storage_path)[:limit] + return {file.split(".")[0]: self.read(file.split(".")[0]) for file in files} + + def list_keys(self) -> list[K]: + return [file.split(".")[0] for file in os.listdir(self.storage_path)] diff --git a/src/marvin/kv/json_block.py b/src/marvin/kv/json_block.py new file mode 100644 index 000000000..10a449700 --- /dev/null +++ b/src/marvin/kv/json_block.py @@ -0,0 +1,62 @@ +from typing import Optional, TypeVar + +try: + from prefect.blocks.system import JSON + from prefect.exceptions import ObjectNotFound +except ImportError: + raise ModuleNotFoundError( + "The `prefect` package is required to use the JSONBlockKV class." + " You can install it with `pip install prefect` or `pip install marvin[prefect]`." + ) +from pydantic import Field + +from marvin.kv.base import StorageInterface +from marvin.utilities.asyncio import run_sync + +K = TypeVar("K", bound=str) +V = TypeVar("V") + + +class JSONBlockKV(StorageInterface[K, V, str]): + """ + A key-value store that uses Prefect's JSON blocks under the hood. + """ + + block_name: str = Field(default="marvin-kv") + + async def _load_json_block(self) -> JSON: + try: + return await JSON.load(name=self.block_name) + except Exception as exc: + if "Unable to find block document" in str(exc): + json_block = JSON(value={}) + await json_block.save(name=self.block_name) + return json_block + raise ObjectNotFound( + f"Unable to load JSON block {self.block_name}" + ) from exc + + def write(self, key: K, value: V) -> str: + json_block = run_sync(self._load_json_block()) + json_block.value[key] = value + run_sync(json_block.save(name=self.block_name, overwrite=True)) + return f"Stored {key}= {value}" + + def delete(self, key: K) -> str: + json_block = run_sync(self._load_json_block()) + if key in json_block.value: + json_block.value.pop(key) + run_sync(json_block.save(name=self.block_name, overwrite=True)) + return f"Deleted {key}" + + def read(self, key: K) -> Optional[V]: + json_block = run_sync(self._load_json_block()) + return json_block.value.get(key) + + def read_all(self, limit: Optional[int] = None) -> dict[K, V]: + json_block = run_sync(self._load_json_block()) + return dict(list(json_block.value.items())[:limit]) + + def list_keys(self) -> list[K]: + json_block = run_sync(self._load_json_block()) + return list(json_block.value.keys()) diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py index 3a5a7060e..bd07e4a36 100644 --- a/src/marvin/tools/assistants.py +++ b/src/marvin/tools/assistants.py @@ -1,15 +1,11 @@ from typing import Any, Union -from pydantic import BaseModel - from marvin.requests import CodeInterpreterTool, RetrievalTool, Tool -Retrieval = RetrievalTool[BaseModel]() -CodeInterpreter = CodeInterpreterTool[BaseModel]() +Retrieval = RetrievalTool() +CodeInterpreter = CodeInterpreterTool() -AssistantTools = Union[ - RetrievalTool[BaseModel], CodeInterpreterTool[BaseModel], Tool[BaseModel] -] +AssistantTools = Union[RetrievalTool, CodeInterpreterTool, Tool] class CancelRun(Exception): diff --git a/src/marvin/tools/retrieval.py b/src/marvin/tools/chroma.py similarity index 57% rename from src/marvin/tools/retrieval.py rename to src/marvin/tools/chroma.py index c61178f02..919e52201 100644 --- a/src/marvin/tools/retrieval.py +++ b/src/marvin/tools/chroma.py @@ -2,7 +2,15 @@ import os from typing import TYPE_CHECKING, Any, Optional -import httpx +try: + from chromadb import Documents, EmbeddingFunction, Embeddings, HttpClient +except ImportError: + raise ImportError( + "The chromadb package is required to query Chroma. Please install" + " it with `pip install chromadb` or `pip install marvin[chroma]`." + ) + + from typing_extensions import Literal import marvin @@ -10,19 +18,25 @@ if TYPE_CHECKING: from openai.types import CreateEmbeddingResponse +QueryResultType = Literal["documents", "distances", "metadatas"] + try: HOST, PORT = ( getattr(marvin.settings, "chroma_server_host"), getattr(marvin.settings, "chroma_server_http_port"), ) + DEFAULT_COLLECTION_NAME = getattr( + marvin.settings, "chroma_default_collection_name", "marvin" + ) except AttributeError: HOST = os.environ.get("MARVIN_CHROMA_SERVER_HOST", "localhost") # type: ignore PORT = os.environ.get("MARVIN_CHROMA_SERVER_HTTP_PORT", 8000) # type: ignore - -QueryResultType = Literal["documents", "distances", "metadatas"] + DEFAULT_COLLECTION_NAME = os.environ.get( + "MARVIN_CHROMA_DEFAULT_COLLECTION_NAME", "marvin" + ) -async def create_openai_embeddings(texts: list[str]) -> list[float]: +def create_openai_embeddings(texts: list[str]) -> list[float]: """Create OpenAI embeddings for a list of texts.""" try: @@ -32,13 +46,9 @@ async def create_openai_embeddings(texts: list[str]) -> list[float]: "The numpy package is required to create OpenAI embeddings. Please install" " it with `pip install numpy`." ) - from openai import AsyncOpenAI + from marvin.client.openai import MarvinClient - embedding: "CreateEmbeddingResponse" = await AsyncOpenAI( - api_key=getattr( - marvin.settings.openai.api_key, "get_secret_value", lambda: None - )() - ).embeddings.create( + embedding: "CreateEmbeddingResponse" = MarvinClient().client.embeddings.create( input=[text.replace("\n", " ") for text in texts], model="text-embedding-ada-002", ) @@ -46,15 +56,12 @@ async def create_openai_embeddings(texts: list[str]) -> list[float]: return embedding.data[0].embedding -async def list_collections() -> list[dict[str, Any]]: - async with httpx.AsyncClient() as client: - chroma_api_url = f"http://{HOST}:{PORT}" - response = await client.get( - f"{chroma_api_url}/api/v1/collections", - ) +class OpenAIEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + return [create_openai_embeddings(input)] - response.raise_for_status() - return response.json() + +client = HttpClient(host=HOST, port=PORT) async def query_chroma( @@ -66,45 +73,27 @@ async def query_chroma( include: Optional[list[QueryResultType]] = None, max_characters: int = 2000, ) -> str: - """Query Chroma. + """Query a collection of document excerpts for a query. Example: User: "What are prefect blocks?" Assistant: >>> query_chroma("What are prefect blocks?") """ - query_embedding = await create_openai_embeddings([query]) - - collection_ids = [ - c["id"] for c in await list_collections() if c["name"] == collection + collection_object = client.get_or_create_collection( + name=collection or DEFAULT_COLLECTION_NAME, + embedding_function=OpenAIEmbeddingFunction(), + ) + query_result = collection_object.query( + query_texts=[query], + n_results=n_results, + where=where, + where_document=where_document, + include=include or ["documents"], + ) + return "".join(doc for doclist in query_result["documents"] for doc in doclist)[ + :max_characters ] - if len(collection_ids) == 0: - return f"Collection {collection} not found." - - collection_id = collection_ids[0] - - async with httpx.AsyncClient() as client: - chroma_api_url = f"http://{HOST}:{PORT}" - - response = await client.post( - f"{chroma_api_url}/api/v1/collections/{collection_id}/query", - data={ - "query_embeddings": [query_embedding], - "n_results": n_results, - "where": where or {}, - "where_document": where_document or {}, - "include": include or ["documents"], - }, - headers={"Content-Type": "application/json"}, - ) - - response.raise_for_status() - - return "\n".join( - f"{i+1}. {', '.join(excerpt)}" - for i, excerpt in enumerate(response.json()["documents"]) - )[:max_characters] - async def multi_query_chroma( queries: list[str], @@ -115,13 +104,14 @@ async def multi_query_chroma( include: Optional[list[QueryResultType]] = None, max_characters: int = 2000, ) -> str: - """Query Chroma with multiple queries. + """Retrieve excerpts to aid in answering multifacted questions. Example: User: "What are prefect blocks and tasks?" Assistant: >>> multi_query_chroma( ["What are prefect blocks?", "What are prefect tasks?"] ) + multi_query_chroma -> document excerpts explaining both blocks and tasks """ coros = [ diff --git a/src/marvin/tools/github.py b/src/marvin/tools/github.py index 33aeaeae7..3daa28923 100644 --- a/src/marvin/tools/github.py +++ b/src/marvin/tools/github.py @@ -1,6 +1,6 @@ import os from datetime import datetime -from typing import Any, Callable, Coroutine, List, Optional +from typing import List, Optional import httpx from pydantic import BaseModel, Field, field_validator @@ -14,10 +14,7 @@ async def get_token() -> str: try: from prefect.blocks.system import Secret - github: Coroutine[Any, Any, Secret] = Secret.load("github-token") - get: Callable[..., Coroutine[Any, Any, str]] = getattr(github, "get") - return await get() - + return (await Secret.load(name="github-token")).get() # type: ignore except (ImportError, ValueError) as exc: getattr(get_logger("marvin"), "debug_kv")( ( @@ -118,5 +115,5 @@ async def search_github_issues( f"{issue.title} ({issue.html_url}):\n{issue.body}" for issue in issues ) if not summary.strip(): - raise ValueError("No issues found.") + return "No issues found." return summary diff --git a/src/marvin/utilities/slack.py b/src/marvin/utilities/slack.py index c555fea4c..65144c79e 100644 --- a/src/marvin/utilities/slack.py +++ b/src/marvin/utilities/slack.py @@ -71,8 +71,16 @@ def validate_event(cls, v: Optional[SlackEvent]) -> Optional[SlackEvent]: async def get_token() -> str: """Get the Slack bot token from the environment.""" try: - token = marvin.settings.slack_api_token + token = ( + marvin.settings.slack_api_token + ) # set `MARVIN_SLACK_API_TOKEN` in `~/.marvin/.env except AttributeError: + try: # TODO: clean this up + from prefect.blocks.system import Secret + + return (await Secret.load("slack-api-token")).get() + except ImportError: + pass token = os.getenv("MARVIN_SLACK_API_TOKEN") if not token: raise ValueError( diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index eaad5c997..f36696470 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -4,8 +4,6 @@ import json from typing import Any, Callable, Optional -from pydantic import BaseModel - from marvin.requests import Function, Tool from marvin.utilities.asyncio import run_sync from marvin.utilities.logging import get_logger @@ -24,7 +22,7 @@ def tool_from_function( model, "model_json_schema", getattr(model, "schema") ) - return Tool[BaseModel]( + return Tool( type="function", function=Function( name=name or fn.__name__, @@ -36,7 +34,7 @@ def tool_from_function( def call_function_tool( - tools: list[Tool[BaseModel]], function_name: str, function_arguments_json: str + tools: list[Tool], function_name: str, function_arguments_json: str ): tool = next( (