diff --git a/.gitignore b/.gitignore index c0f10dc973..22657eee27 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ env*/ /TODO.md /postgres-data/ .DS_Store -examples/pydantic_ai_examples/.chat_app_messages.sqlite +.chat_app_messages.sqlite .cache/ .vscode/ /question_graph_history.json diff --git a/examples/pydantic_ai_examples/chat_app.py b/examples/pydantic_ai_examples/chat_app.py index f81211111b..39691827af 100644 --- a/examples/pydantic_ai_examples/chat_app.py +++ b/examples/pydantic_ai_examples/chat_app.py @@ -7,215 +7,96 @@ from __future__ import annotations as _annotations -import asyncio -import json -import sqlite3 -from collections.abc import AsyncIterator, Callable -from concurrent.futures.thread import ThreadPoolExecutor from contextlib import asynccontextmanager from dataclasses import dataclass -from datetime import datetime, timezone -from functools import partial from pathlib import Path -from typing import Annotated, Any, Literal, TypeVar import fastapi import logfire -from fastapi import Depends, Request -from fastapi.responses import FileResponse, Response, StreamingResponse -from typing_extensions import LiteralString, ParamSpec, TypedDict - -from pydantic_ai import Agent, UnexpectedModelBehavior -from pydantic_ai.messages import ( - ModelMessage, - ModelMessagesTypeAdapter, - ModelRequest, - ModelResponse, - TextPart, - UserPromptPart, -) +from fastapi import Depends, Request, Response + +from pydantic_ai import Agent, RunContext +from pydantic_ai.vercel_ai_elements.starlette import StarletteChat + +from .sqlite_database import Database # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() -agent = Agent('openai:gpt-4o') THIS_DIR = Path(__file__).parent +sql_schema = """ +create table if not exists memory( + id integer primary key, + user_id integer not null, + value text not null, + unique(user_id, value) +);""" @asynccontextmanager async def lifespan(_app: fastapi.FastAPI): - async with Database.connect() as db: + async with Database.connect(sql_schema) as db: yield {'db': db} -app = fastapi.FastAPI(lifespan=lifespan) -logfire.instrument_fastapi(app) - +@dataclass +class Deps: + conn: Database + user_id: int -@app.get('/') -async def index() -> FileResponse: - return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html') +chat_agent = Agent( + 'openai:gpt-4.1', + deps_type=Deps, + instructions=""" +You are a helpful assistant. -@app.get('/chat_app.ts') -async def main_ts() -> FileResponse: - """Get the raw typescript code, it's compiled in the browser, forgive me.""" - return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain') +Always reply with markdown. ALWAYS use code fences for code examples and lines of code. +""", +) -async def get_db(request: Request) -> Database: - return request.state.db +@chat_agent.tool +async def record_memory(ctx: RunContext[Deps], value: str) -> str: + """Use this tool to store information in memory.""" + await ctx.deps.conn.execute( + 'insert into memory(user_id, value) values(?, ?) on conflict do nothing', + ctx.deps.user_id, + value, + commit=True, + ) + return 'Value added to memory.' -@app.get('/chat/') -async def get_chat(database: Database = Depends(get_db)) -> Response: - msgs = await database.get_messages() - return Response( - b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs), - media_type='text/plain', +@chat_agent.tool +async def retrieve_memories(ctx: RunContext[Deps], memory_contains: str) -> str: + """Get all memories about the user.""" + rows = await ctx.deps.conn.fetchall( + 'select value from memory where user_id = ? and value like ?', + ctx.deps.user_id, + f'%{memory_contains}%', ) + return '\n'.join([row[0] for row in rows]) -class ChatMessage(TypedDict): - """Format of messages sent to the browser.""" - - role: Literal['user', 'model'] - timestamp: str - content: str - - -def to_chat_message(m: ModelMessage) -> ChatMessage: - first_part = m.parts[0] - if isinstance(m, ModelRequest): - if isinstance(first_part, UserPromptPart): - assert isinstance(first_part.content, str) - return { - 'role': 'user', - 'timestamp': first_part.timestamp.isoformat(), - 'content': first_part.content, - } - elif isinstance(m, ModelResponse): - if isinstance(first_part, TextPart): - return { - 'role': 'model', - 'timestamp': m.timestamp.isoformat(), - 'content': first_part.content, - } - raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}') - - -@app.post('/chat/') -async def post_chat( - prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db) -) -> StreamingResponse: - async def stream_messages(): - """Streams new line delimited JSON `Message`s to the client.""" - # stream the user prompt so that can be displayed straight away - yield ( - json.dumps( - { - 'role': 'user', - 'timestamp': datetime.now(tz=timezone.utc).isoformat(), - 'content': prompt, - } - ).encode('utf-8') - + b'\n' - ) - # get the chat history so far to pass as context to the agent - messages = await database.get_messages() - # run the agent with the user prompt and the chat history - async with agent.run_stream(prompt, message_history=messages) as result: - async for text in result.stream_output(debounce_by=0.01): - # text here is a `str` and the frontend wants - # JSON encoded ModelResponse, so we create one - m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp()) - yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n' - - # add new messages (e.g. the user prompt and the agent response in this case) to the database - await database.add_messages(result.new_messages_json()) - - return StreamingResponse(stream_messages(), media_type='text/plain') - - -P = ParamSpec('P') -R = TypeVar('R') +starlette_chat = StarletteChat(chat_agent) +app = fastapi.FastAPI(lifespan=lifespan) +logfire.instrument_fastapi(app) -@dataclass -class Database: - """Rudimentary database to store chat messages in SQLite. - - The SQLite standard library package is synchronous, so we - use a thread pool executor to run queries asynchronously. - """ - - con: sqlite3.Connection - _loop: asyncio.AbstractEventLoop - _executor: ThreadPoolExecutor - - @classmethod - @asynccontextmanager - async def connect( - cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite' - ) -> AsyncIterator[Database]: - with logfire.span('connect to DB'): - loop = asyncio.get_event_loop() - executor = ThreadPoolExecutor(max_workers=1) - con = await loop.run_in_executor(executor, cls._connect, file) - slf = cls(con, loop, executor) - try: - yield slf - finally: - await slf._asyncify(con.close) - - @staticmethod - def _connect(file: Path) -> sqlite3.Connection: - con = sqlite3.connect(str(file)) - con = logfire.instrument_sqlite3(con) - cur = con.cursor() - cur.execute( - 'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);' - ) - con.commit() - return con - - async def add_messages(self, messages: bytes): - await self._asyncify( - self._execute, - 'INSERT INTO messages (message_list) VALUES (?);', - messages, - commit=True, - ) - await self._asyncify(self.con.commit) - - async def get_messages(self) -> list[ModelMessage]: - c = await self._asyncify( - self._execute, 'SELECT message_list FROM messages order by id' - ) - rows = await self._asyncify(c.fetchall) - messages: list[ModelMessage] = [] - for row in rows: - messages.extend(ModelMessagesTypeAdapter.validate_json(row[0])) - return messages - - def _execute( - self, sql: LiteralString, *args: Any, commit: bool = False - ) -> sqlite3.Cursor: - cur = self.con.cursor() - cur.execute(sql, args) - if commit: - self.con.commit() - return cur - - async def _asyncify( - self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs - ) -> R: - return await self._loop.run_in_executor( # type: ignore - self._executor, - partial(func, **kwargs), - *args, # type: ignore - ) +async def get_db(request: Request) -> Database: + return request.state.db + + +@app.options('/api/chat') +def options_chat(): + pass + + +@app.post('/api/chat') +async def get_chat(request: Request, database: Database = Depends(get_db)) -> Response: + return await starlette_chat.dispatch_request(request, deps=Deps(database, 123)) if __name__ == '__main__': diff --git a/examples/pydantic_ai_examples/sqlite_database.py b/examples/pydantic_ai_examples/sqlite_database.py new file mode 100644 index 0000000000..9d470a937a --- /dev/null +++ b/examples/pydantic_ai_examples/sqlite_database.py @@ -0,0 +1,81 @@ +from __future__ import annotations as _annotations + +import asyncio +import sqlite3 +from collections.abc import AsyncIterator, Callable +from concurrent.futures.thread import ThreadPoolExecutor +from contextlib import asynccontextmanager +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Any, LiteralString, ParamSpec, TypeVar + +import logfire + +P = ParamSpec('P') +R = TypeVar('R') + + +@dataclass +class Database: + """Rudimentary database to store chat messages in SQLite. + + The SQLite standard library package is synchronous, so we + use a thread pool executor to run queries asynchronously. + """ + + con: sqlite3.Connection + _loop: asyncio.AbstractEventLoop + _executor: ThreadPoolExecutor + + @classmethod + @asynccontextmanager + async def connect( + cls, schema_sql: str, file: Path = Path('.chat_app_messages.sqlite') + ) -> AsyncIterator[Database]: + with logfire.span('connect to DB'): + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor(max_workers=1) + con = await loop.run_in_executor(executor, cls._connect, schema_sql, file) + slf = cls(con, loop, executor) + try: + yield slf + finally: + await slf._asyncify(con.close) + + @staticmethod + def _connect(schema_sql: str, file: Path) -> sqlite3.Connection: + con = sqlite3.connect(str(file)) + con = logfire.instrument_sqlite3(con) + cur = con.cursor() + cur.execute(schema_sql) + con.commit() + return con + + async def execute(self, sql: LiteralString, *args: Any, commit: bool = False): + await self._asyncify(self._execute, sql, *args, commit=True) + if commit: + await self._asyncify(self.con.commit) + + async def fetchall(self, sql: LiteralString, *args: Any) -> list[tuple[str, ...]]: + c = await self._asyncify(self._execute, sql, *args) + rows = await self._asyncify(c.fetchall) + return [tuple(row) for row in rows] + + def _execute( + self, sql: LiteralString, *args: Any, commit: bool = False + ) -> sqlite3.Cursor: + cur = self.con.cursor() + cur.execute(sql, args) + if commit: + self.con.commit() + return cur + + async def _asyncify( + self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs + ) -> R: + return await self._loop.run_in_executor( # type: ignore + self._executor, + partial(func, **kwargs), + *args, # type: ignore + ) diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 8d6c9ff293..54c6856099 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import asyncio import inspect from abc import ABC, abstractmethod from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator, Mapping, Sequence @@ -7,6 +8,7 @@ from types import FrameType from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload +import anyio from typing_extensions import Self, TypeIs, TypeVar from pydantic_graph import End @@ -24,7 +26,7 @@ from .._tool_manager import ToolManager from ..output import OutputDataT, OutputSpec from ..result import AgentStream, FinalResult, StreamedRunResult -from ..run import AgentRun, AgentRunResult +from ..run import AgentRun, AgentRunResult, AgentRunResultEvent from ..settings import ModelSettings from ..tools import ( AgentDepsT, @@ -543,6 +545,161 @@ async def on_complete() -> None: if not yielded: raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover + @overload + def run_stream_events( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... + + @overload + def run_stream_events( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... + + def run_stream_events( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + """Run the agent with a user prompt in async mode and stream events from the run. + + This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] and + uses the `event_stream_handler` kwarg to get a stream of events from the run. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + async for event in agent.run_stream_events('What is the capital of France?'): + print(event) + ``` + + Arguments are the same as for [`self.run`][pydantic_ai.agent.AbstractAgent.run], + except that `event_stream_handler` is now allowed. + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + deferred_tool_results: Optional results for deferred tool calls in the message history. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + An async iterable of stream events `AgentStreamEvent` and finally a `AgentRunResultEvent` with the final + run result. + """ + # unfortunately this hack of returning a generator rather than defining it right here is + # required to allow overloads of this method to work in python's typing system, or at least with pyright + # or at least I couldn't make it work without + return self._run_stream_events( + user_prompt, + output_type=output_type, + message_history=message_history, + deferred_tool_results=deferred_tool_results, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) + + async def _run_stream_events( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + send_stream, receive_stream = anyio.create_memory_object_stream[ + _messages.AgentStreamEvent | AgentRunResultEvent[Any] + ]() + + async def event_stream_handler( + _: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent] + ) -> None: + async for event in events: + await send_stream.send(event) + + async def run_agent() -> AgentRunResult[Any]: + try: + return await self.run( + user_prompt, + output_type=output_type, + message_history=message_history, + deferred_tool_results=deferred_tool_results, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + event_stream_handler=event_stream_handler, + ) + finally: + send_stream.close() + + task = asyncio.create_task(run_agent()) + + async for message in receive_stream: + yield message + + result = await task + yield AgentRunResultEvent(result) + @overload def iter( self, diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 0cc9481043..39a3d9080a 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -353,3 +353,16 @@ def timestamp(self) -> datetime: model_response = self.all_messages()[-1] assert isinstance(model_response, _messages.ModelResponse) return model_response.timestamp + + +@dataclasses.dataclass +class AgentRunResultEvent(Generic[OutputDataT]): + """An event indicating the agent run ended and containing the final result of the agent run.""" + + result: AgentRunResult[OutputDataT] + """The result of the run.""" + + _: dataclasses.KW_ONLY + + event_kind: Literal['agent_run_result'] = 'agent_run_result' + """Event type identifier, used as a discriminator.""" diff --git a/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/__init__.py b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/_utils.py b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/_utils.py new file mode 100644 index 0000000000..6ef877a235 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/_utils.py @@ -0,0 +1,18 @@ +from abc import ABC +from typing import Any + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + +__all__ = 'ProviderMetadata', 'CamelBaseModel' + +# technically this is recursive union of JSON types +# for to simplify validation, we call it Any +JSONValue = Any + +# Provider metadata types +ProviderMetadata = dict[str, dict[str, JSONValue]] + + +class CamelBaseModel(BaseModel, ABC): + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, extra='forbid') diff --git a/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/request_types.py b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/request_types.py new file mode 100644 index 0000000000..d7dd60acb6 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/request_types.py @@ -0,0 +1,273 @@ +"""Convert to Python from. + +https://github.com/vercel/ai/blob/ai%405.0.34/packages/ai/src/ui/ui-messages.ts + +Mostly with Claude. +""" + +from typing import Annotated, Any, Literal + +from pydantic import Discriminator, TypeAdapter + +from ._utils import CamelBaseModel, ProviderMetadata + + +class TextUIPart(CamelBaseModel): + """A text part of a message.""" + + type: Literal['text'] = 'text' + + text: str + """The text content.""" + + state: Literal['streaming', 'done'] | None = None + """The state of the text part.""" + + provider_metadata: ProviderMetadata | None = None + """The provider metadata.""" + + +class ReasoningUIPart(CamelBaseModel): + """A reasoning part of a message.""" + + type: Literal['reasoning'] = 'reasoning' + + text: str + """The reasoning text.""" + + state: Literal['streaming', 'done'] | None = None + """The state of the reasoning part.""" + + provider_metadata: ProviderMetadata | None = None + """The provider metadata.""" + + +class SourceUrlUIPart(CamelBaseModel): + """A source part of a message.""" + + type: Literal['source-url'] = 'source-url' + source_id: str + url: str + title: str | None = None + provider_metadata: ProviderMetadata | None = None + + +class SourceDocumentUIPart(CamelBaseModel): + """A document source part of a message.""" + + type: Literal['source-document'] = 'source-document' + source_id: str + media_type: str + title: str + filename: str | None = None + provider_metadata: ProviderMetadata | None = None + + +class FileUIPart(CamelBaseModel): + """A file part of a message.""" + + type: Literal['file'] = 'file' + + media_type: str + """ + IANA media type of the file. + + @see https://www.iana.org/assignments/media-types/media-types.xhtml + """ + + filename: str | None = None + """Optional filename of the file.""" + + url: str + """ + The URL of the file. + It can either be a URL to a hosted file or a [Data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs). + """ + + provider_metadata: ProviderMetadata | None = None + """The provider metadata.""" + + +class StepStartUIPart(CamelBaseModel): + """A step boundary part of a message.""" + + type: Literal['step-start'] = 'step-start' + + +class DataUIPart(CamelBaseModel): + """Data part with dynamic type based on data name.""" + + type: str # Will be f"data-{NAME}" + id: str | None = None + data: Any + + +# Tool part states as separate models +class ToolInputStreamingPart(CamelBaseModel): + """Tool part in input-streaming state.""" + + type: str # Will be f"tool-{NAME}" + tool_call_id: str + state: Literal['input-streaming'] = 'input-streaming' + input: Any | None = None + provider_executed: bool | None = None + + +class ToolInputAvailablePart(CamelBaseModel): + """Tool part in input-available state.""" + + type: str # Will be f"tool-{NAME}" + tool_call_id: str + state: Literal['input-available'] = 'input-available' + input: Any + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + + +class ToolOutputAvailablePart(CamelBaseModel): + """Tool part in output-available state.""" + + type: str # Will be f"tool-{NAME}" + tool_call_id: str + state: Literal['output-available'] = 'output-available' + input: Any + output: Any + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + preliminary: bool | None = None + + +class ToolOutputErrorPart(CamelBaseModel): + """Tool part in output-error state.""" + + type: str # Will be f"tool-{NAME}" + tool_call_id: str + state: Literal['output-error'] = 'output-error' + input: Any | None = None + raw_input: Any | None = None + error_text: str + provider_executed: bool | None = None + call_provider_metadata: ProviderMetadata | None = None + + +# Union of all tool part states +ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart + + +# Dynamic tool part states as separate models +class DynamicToolInputStreamingPart(CamelBaseModel): + """Dynamic tool part in input-streaming state.""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['input-streaming'] = 'input-streaming' + input: Any | None = None + + +class DynamicToolInputAvailablePart(CamelBaseModel): + """Dynamic tool part in input-available state.""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['input-available'] = 'input-available' + input: Any + call_provider_metadata: ProviderMetadata | None = None + + +class DynamicToolOutputAvailablePart(CamelBaseModel): + """Dynamic tool part in output-available state.""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['output-available'] = 'output-available' + input: Any + output: Any + call_provider_metadata: ProviderMetadata | None = None + preliminary: bool | None = None + + +class DynamicToolOutputErrorPart(CamelBaseModel): + """Dynamic tool part in output-error state.""" + + type: Literal['dynamic-tool'] = 'dynamic-tool' + tool_name: str + tool_call_id: str + state: Literal['output-error'] = 'output-error' + input: Any + error_text: str + call_provider_metadata: ProviderMetadata | None = None + + +# Union of all dynamic tool part states +DynamicToolUIPart = ( + DynamicToolInputStreamingPart + | DynamicToolInputAvailablePart + | DynamicToolOutputAvailablePart + | DynamicToolOutputErrorPart +) + + +UIMessagePart = ( + TextUIPart + | ReasoningUIPart + | ToolUIPart + | DynamicToolUIPart + | SourceUrlUIPart + | SourceDocumentUIPart + | FileUIPart + | DataUIPart + | StepStartUIPart +) +"""Union of all message part types.""" + + +class UIMessage(CamelBaseModel): + """A message as displayed in the UI by Vercel AI Elements.""" + + id: str + """A unique identifier for the message.""" + + role: Literal['system', 'user', 'assistant'] + """The role of the message.""" + + metadata: Any | None = None + """The metadata of the message.""" + + parts: list[UIMessagePart] + """ + The parts of the message. Use this for rendering the message in the UI. + + System messages should be avoided (set the system prompt on the server instead). + They can have text parts. + + User messages can have text parts and file parts. + + Assistant messages can have text, reasoning, tool invocation, and file parts. + """ + + +class SubmitMessage(CamelBaseModel): + """Submit a message to the agent.""" + + trigger: Literal['submit-message'] + id: str + messages: list[UIMessage] + + model: str + web_search: bool + + +class RegenerateMessage(CamelBaseModel): + """Ask the agent to regenerate a message.""" + + trigger: Literal['regenerate-message'] + id: str + messages: list[UIMessage] + message_id: str + + +RequestData = SubmitMessage | RegenerateMessage +request_data_schema: TypeAdapter[RequestData] = TypeAdapter(Annotated[RequestData, Discriminator('trigger')]) diff --git a/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_stream.py b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_stream.py new file mode 100644 index 0000000000..82da300014 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_stream.py @@ -0,0 +1,141 @@ +from __future__ import annotations as _annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Any +from uuid import uuid4 + +from pydantic_core import to_json + +from .. import messages +from ..agent import Agent +from ..run import AgentRunResultEvent +from ..tools import AgentDepsT +from . import response_types as _t + +__all__ = 'sse_stream', 'VERCEL_AI_ELEMENTS_HEADERS', 'EventStreamer' +# no idea if this is important, but vercel sends it, therefore so am I +VERCEL_AI_ELEMENTS_HEADERS = {'x-vercel-ai-ui-message-stream': 'v1'} + + +async def sse_stream(agent: Agent[AgentDepsT], user_prompt: str, deps: Any) -> AsyncIterator[str]: + """Stream events from an agent run as Vercel AI Elements events. + + Args: + agent: The agent to run. + user_prompt: The user prompt to run the agent with. + deps: The dependencies to pass to the agent. + + Yields: + An async iterator text lines to stream over SSE. + """ + event_streamer = EventStreamer() + async for event in agent.run_stream_events(user_prompt, deps=deps): + if not isinstance(event, AgentRunResultEvent): + async for chunk in event_streamer.event_to_chunks(event): + yield chunk.sse() + async for chunk in event_streamer.finish(): + yield chunk.sse() + + +@dataclass +class EventStreamer: + """Logic for mapping pydantic-ai events to Vercel AI Elements events which can be streamed to a client over SSE.""" + + message_id: str = field(default_factory=lambda: uuid4().hex) + _final_result_tool_id: str | None = field(default=None, init=False) + + async def event_to_chunks(self, event: messages.AgentStreamEvent) -> AsyncIterator[_t.AbstractSSEChunk]: # noqa C901 + """Convert pydantic-ai events to Vercel AI Elements events which can be streamed to a client over SSE. + + Args: + event: The pydantic-ai event to convert. + + Yields: + An async iterator of Vercel AI Elements events. + """ + match event: + case messages.PartStartEvent(part=part): + match part: + case messages.TextPart(content=content): + yield _t.TextStartChunk(id=self.message_id) + yield _t.TextDeltaChunk(id=self.message_id, delta=content) + case ( + messages.ToolCallPart(tool_name=tool_name, tool_call_id=tool_call_id, args=args) + | messages.BuiltinToolCallPart(tool_name=tool_name, tool_call_id=tool_call_id, args=args) + ): + yield _t.ToolInputStartChunk(tool_call_id=tool_call_id, tool_name=tool_name) + if isinstance(args, str): + yield _t.ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=args) + elif args is not None: + yield ( + _t.ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=_json_dumps(args)) + ) + + case messages.BuiltinToolReturnPart( + tool_name=tool_name, tool_call_id=tool_call_id, content=content + ): + yield _t.ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=content) + + case messages.ThinkingPart(content=content): + yield _t.ReasoningStartChunk(id=self.message_id) + yield _t.ReasoningDeltaChunk(id=self.message_id, delta=content) + + case messages.PartDeltaEvent(delta=delta): + match delta: + case messages.TextPartDelta(content_delta=content_delta): + yield _t.TextDeltaChunk(id=self.message_id, delta=content_delta) + case messages.ThinkingPartDelta(content_delta=content_delta): + if content_delta: + yield _t.ReasoningDeltaChunk(id=self.message_id, delta=content_delta) + case messages.ToolCallPartDelta(args_delta=args, tool_call_id=tool_call_id): + tool_call_id = tool_call_id or '' + if isinstance(args, str): + yield _t.ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=args) + elif args is not None: + yield ( + _t.ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=_json_dumps(args)) + ) + case messages.FinalResultEvent(tool_name=tool_name, tool_call_id=tool_call_id): + if tool_call_id and tool_name: + self._final_result_tool_id = tool_call_id + yield _t.ToolInputStartChunk(tool_call_id=tool_call_id, tool_name=tool_name) + case messages.FunctionToolCallEvent(): + pass + # print(f'TODO FunctionToolCallEvent {part}') + case messages.FunctionToolResultEvent(result=result): + match result: + case messages.ToolReturnPart(tool_name=tool_name, tool_call_id=tool_call_id, content=content): + yield _t.ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=content) + case messages.RetryPromptPart(tool_name=tool_name, tool_call_id=tool_call_id, content=content): + yield _t.ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=content) + case messages.BuiltinToolCallEvent(part=part): + tool_call_id = part.tool_call_id + tool_name = part.tool_name + args = part.args + yield _t.ToolInputStartChunk(tool_call_id=tool_call_id, tool_name=tool_name) + if isinstance(args, str): + yield _t.ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=args) + elif args is not None: + yield _t.ToolInputDeltaChunk(tool_call_id=tool_call_id, input_text_delta=_json_dumps(args)) + case messages.BuiltinToolResultEvent(result=result): + yield _t.ToolOutputAvailableChunk(tool_call_id=result.tool_call_id, output=result.content) + + async def finish(self) -> AsyncIterator[_t.AbstractSSEChunk | DoneChunk]: + """Send extra messages required to close off the stream.""" + if tool_call_id := self._final_result_tool_id: + yield _t.ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=None) + yield _t.FinishChunk() + yield DoneChunk() + + +class DoneChunk: + def sse(self) -> str: + return '[DONE]' + + def __str__(self) -> str: + return 'DoneChunk' + + +def _json_dumps(obj: Any) -> str: + return to_json(obj).decode('utf-8') diff --git a/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_types.py b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_types.py new file mode 100644 index 0000000000..8538dd9680 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_types.py @@ -0,0 +1,216 @@ +"""Convert to Python from. + +https://github.com/vercel/ai/blob/ai%405.0.34/packages/ai/src/ui/ui-messages.ts + +Mostly with Claude. +""" + +from typing import Any, Literal + +from ._utils import CamelBaseModel, ProviderMetadata + + +class AbstractSSEChunk(CamelBaseModel): + """Abstract base class for response SSE even.""" + + def sse(self) -> str: + return self.model_dump_json(exclude_none=True, by_alias=True) + + +class TextStartChunk(AbstractSSEChunk): + """Text start chunk.""" + + type: Literal['text-start'] = 'text-start' + id: str + provider_metadata: ProviderMetadata | None = None + + +class TextDeltaChunk(AbstractSSEChunk): + """Text delta chunk.""" + + type: Literal['text-delta'] = 'text-delta' + delta: str + id: str + provider_metadata: ProviderMetadata | None = None + + +class TextEndChunk(AbstractSSEChunk): + """Text end chunk.""" + + type: Literal['text-end'] = 'text-end' + id: str + provider_metadata: ProviderMetadata | None = None + + +class ReasoningStartChunk(AbstractSSEChunk): + """Reasoning start chunk.""" + + type: Literal['reasoning-start'] = 'reasoning-start' + id: str + provider_metadata: ProviderMetadata | None = None + + +class ReasoningDeltaChunk(AbstractSSEChunk): + """Reasoning delta chunk.""" + + type: Literal['reasoning-delta'] = 'reasoning-delta' + id: str + delta: str + provider_metadata: ProviderMetadata | None = None + + +class ReasoningEndChunk(AbstractSSEChunk): + """Reasoning end chunk.""" + + type: Literal['reasoning-end'] = 'reasoning-end' + id: str + provider_metadata: ProviderMetadata | None = None + + +class ErrorChunk(AbstractSSEChunk): + """Error chunk.""" + + type: Literal['error'] = 'error' + error_text: str + + +class ToolInputAvailableChunk(AbstractSSEChunk): + """Tool input available chunk.""" + + type: Literal['tool-input-available'] = 'tool-input-available' + tool_call_id: str + tool_name: str + input: Any + provider_executed: bool | None = None + provider_metadata: ProviderMetadata | None = None + dynamic: bool | None = None + + +class ToolInputErrorChunk(AbstractSSEChunk): + """Tool input error chunk.""" + + type: Literal['tool-input-error'] = 'tool-input-error' + tool_call_id: str + tool_name: str + input: Any + provider_executed: bool | None = None + provider_metadata: ProviderMetadata | None = None + dynamic: bool | None = None + error_text: str + + +class ToolOutputAvailableChunk(AbstractSSEChunk): + """Tool output available chunk.""" + + type: Literal['tool-output-available'] = 'tool-output-available' + tool_call_id: str + output: Any + provider_executed: bool | None = None + dynamic: bool | None = None + preliminary: bool | None = None + + +class ToolOutputErrorChunk(AbstractSSEChunk): + """Tool output error chunk.""" + + type: Literal['tool-output-error'] = 'tool-output-error' + tool_call_id: str + error_text: str + provider_executed: bool | None = None + dynamic: bool | None = None + + +class ToolInputStartChunk(AbstractSSEChunk): + """Tool input start chunk.""" + + type: Literal['tool-input-start'] = 'tool-input-start' + tool_call_id: str + tool_name: str + provider_executed: bool | None = None + dynamic: bool | None = None + + +class ToolInputDeltaChunk(AbstractSSEChunk): + """Tool input delta chunk.""" + + type: Literal['tool-input-delta'] = 'tool-input-delta' + tool_call_id: str + input_text_delta: str + + +# Source chunk types +class SourceUrlChunk(AbstractSSEChunk): + """Source URL chunk.""" + + type: Literal['source-url'] = 'source-url' + source_id: str + url: str + title: str | None = None + provider_metadata: ProviderMetadata | None = None + + +class SourceDocumentChunk(AbstractSSEChunk): + """Source document chunk.""" + + type: Literal['source-document'] = 'source-document' + source_id: str + media_type: str + title: str + filename: str | None = None + provider_metadata: ProviderMetadata | None = None + + +class FileChunk(AbstractSSEChunk): + """File chunk.""" + + type: Literal['file'] = 'file' + url: str + media_type: str + + +class DataUIMessageChunk(AbstractSSEChunk): + """Data UI message chunk with dynamic type.""" + + type: str # Will be f"data-{NAME}" + data: Any + + +class StartStepChunk(AbstractSSEChunk): + """Start step chunk.""" + + type: Literal['start-step'] = 'start-step' + + +class FinishStepChunk(AbstractSSEChunk): + """Finish step chunk.""" + + type: Literal['finish-step'] = 'finish-step' + + +# Message lifecycle chunk types +class StartChunk(AbstractSSEChunk): + """Start chunk.""" + + type: Literal['start'] = 'start' + message_id: str | None = None + message_metadata: Any | None = None + + +class FinishChunk(AbstractSSEChunk): + """Finish chunk.""" + + type: Literal['finish'] = 'finish' + message_metadata: Any | None = None + + +class AbortChunk(AbstractSSEChunk): + """Abort chunk.""" + + type: Literal['abort'] = 'abort' + + +class MessageMetadataChunk(AbstractSSEChunk): + """Message metadata chunk.""" + + type: Literal['message-metadata'] = 'message-metadata' + message_metadata: Any diff --git a/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/starlette.py b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/starlette.py new file mode 100644 index 0000000000..59868afeef --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/vercel_ai_elements/starlette.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Generic + +from pydantic import ValidationError + +from ..agent import Agent +from ..tools import AgentDepsT +from .request_types import RequestData, TextUIPart, request_data_schema +from .response_stream import VERCEL_AI_ELEMENTS_HEADERS, sse_stream + +try: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response +except ImportError as e: + raise ImportError('To use Vercel AI Elements, please install starlette and sse_starlette') from e + + +@dataclass +class StarletteChat(Generic[AgentDepsT]): + """Starlette support for Pydantic AI's Vercel AI Elements integration. + + This can be used with either FastAPI or Starlette apps. + """ + + agent: Agent[AgentDepsT] + + async def dispatch_request(self, request: Request, deps: AgentDepsT) -> Response: + """Handle a request and return a streamed SSE response. + + Args: + request: The incoming Starlette/FastAPI request. + deps: The dependencies for the agent. + + Returns: + A streamed SSE response. + """ + body = await request.body() + try: + data = request_data_schema.validate_json(body) + except ValidationError as e: + return JSONResponse({'errors': e.errors()}, status_code=422) + else: + return await self.handle_request_data(data, deps) + + async def handle_request_data(self, data: RequestData, deps: AgentDepsT) -> Response: + """Handle request data that has already been validated and return a streamed SSE response. + + Args: + data: The validated request data. + deps: The dependencies for the agent. + + Returns: + A streamed SSE response. + """ + if not data.messages: + return JSONResponse({'errors': 'no messages provided'}) + + message = data.messages[-1] + prompt: list[str] = [] + for part in message.parts: + if isinstance(part, TextUIPart): + prompt.append(part.text) + else: + return JSONResponse({'errors': 'only text parts are supported yet'}) + + return EventSourceResponse( + sse_stream(self.agent, '\n'.join(prompt), deps=deps), headers=VERCEL_AI_ELEMENTS_HEADERS + )