From 3165e72f4e2adab2b0ce05c4e78dd97e4e178ff5 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 20 Aug 2025 15:25:07 -0700 Subject: [PATCH 01/13] add 'litellm_utils' module --- .../jupyter_ai/litellm_utils/__init__.py | 2 + .../litellm_utils/test_toolcall_list.py | 52 ++++++++ .../jupyter_ai/litellm_utils/toolcall_list.py | 121 ++++++++++++++++++ .../litellm_utils/toolcall_types.py | 57 +++++++++ 4 files changed, 232 insertions(+) create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py new file mode 100644 index 000000000..cd95e2b2d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py @@ -0,0 +1,2 @@ +from .toolcall_list import ToolCallList +from .toolcall_types import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py new file mode 100644 index 000000000..9069eb481 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py @@ -0,0 +1,52 @@ +from litellm.utils import ChatCompletionDeltaToolCall, Function +from .toolcall_list import ToolCallList + +class TestToolCallList(): + + def test_single_tool_stream(self): + """ + Asserts this class works against a sample response from Claude running a + single tool. + """ + # Setup test + ID = "toolu_01TzXi4nFJErYThcdhnixn7e" + toolcall_list = ToolCallList() + toolcall_list += [ChatCompletionDeltaToolCall(id=ID, function=Function(arguments='', name='ls'), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='": "."}', name=None), type='function', index=0)] + + # Verify the resolved list of calls + resolved_toolcalls = toolcall_list.resolve() + assert len(resolved_toolcalls) == 1 + assert resolved_toolcalls[0] + + def test_two_tool_stream(self): + """ + Asserts this class works against a sample response from Claude running a + two tools in parallel. + """ + # Setup test + ID_0 = 'toolu_0141FrNfT2LJg6odqbrdmLM6' + ID_1 = 'toolu_01DKqnaXVcyp1v1ABxhHC5Sg' + toolcall_list = ToolCallList() + toolcall_list += [ChatCompletionDeltaToolCall(id=ID_0, function=Function(arguments='', name='ls'), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path": ', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='"."}', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=ID_1, function=Function(arguments='', name='bash'), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"com', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='mand": "ech', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='o \'hello\'"}', name=None), type='function', index=1)] + + # Verify the resolved list of calls + resolved_toolcalls = toolcall_list.resolve() + assert len(resolved_toolcalls) == 2 + assert resolved_toolcalls[0].id == ID_0 + assert resolved_toolcalls[0].function.name == "ls" + assert resolved_toolcalls[0].function.arguments == { "path": "." } + assert resolved_toolcalls[1].id == ID_1 + assert resolved_toolcalls[1].function.name == "bash" + assert resolved_toolcalls[1].function.arguments == { "command": "echo \'hello\'" } + diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py new file mode 100644 index 000000000..1e3effd3a --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py @@ -0,0 +1,121 @@ +from litellm.utils import ChatCompletionDeltaToolCall, Function +import json + +from .toolcall_types import ResolvedToolCall, ResolvedFunction + +class ToolCallList(): + """ + A helper object that defines a custom `__iadd__()` method which accepts a + `tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class + is used to aggregate the tool call deltas yielded from a LiteLLM response + stream and produce a list of tool calls. + + After all tool call deltas are added, the `process()` method may be called + to return a list of resolved tool calls. + + Example usage: + + ```py + tool_call_list = ToolCallList() + reply_stream = await litellm.acompletion(..., stream=True) + + async for chunk in reply_stream: + tool_call_delta = chunk.choices[0].delta.tool_calls + tool_call_list += tool_call_delta + + tool_call_list.resolve() + ``` + """ + + _aggregate: list[ChatCompletionDeltaToolCall] + + def __init__(self): + self.size = None + + # Initialize `_aggregate` + self._aggregate = [] + + + def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': + """ + Adds a list of tool call deltas to this instance. + + NOTE: This assumes the 'index' attribute on each entry in this list to + be accurate. If this assumption doesn't hold, we will need to rework the + logic here. + """ + if other is None: + return self + + # Iterate through each delta + for delta in other: + # Ensure `self._aggregate` is at least of size `delta.index + 1` + for i in range(len(self._aggregate), delta.index + 1): + self._aggregate.append(ChatCompletionDeltaToolCall( + function=Function(arguments=""), + index=i, + )) + + # Find the corresponding target in the `self._aggregate` and add the + # delta on top of it. In most cases, the value of aggregate + # attribute is set as soon as any delta sets it to a non-`None` + # value. However, `delta.function.arguments` is a string that should + # be appended to the aggregate value of that attribute. + target = self._aggregate[delta.index] + if delta.type: + target.type = delta.type + if delta.id: + target.id = delta.id + if delta.function.name: + target.function.name = delta.function.name + if delta.function.arguments: + target.function.arguments += delta.function.arguments + + return self + + + def __add__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': + """ + Alias for `__iadd__()`. + """ + return self.__iadd__(other) + + + def resolve(self) -> list[ResolvedToolCall]: + """ + Resolve the aggregated tool call delta lists into a list of tool calls. + """ + resolved_toolcalls: list[ResolvedToolCall] = [] + for i, raw_toolcall in enumerate(self._aggregate): + # Verify entries are at the correct index in the aggregated list + assert raw_toolcall.index == i + + # Verify each tool call specifies the name of the tool to run. + # + # TODO: Check if this may cause a runtime error. The docstring on + # `litellm.utils.Function` implies that `name` may be `None`. + assert raw_toolcall.function.name + + # Verify each tool call defines the type of tool it is calling. + assert raw_toolcall.type is not None + + # Parse the function argument string into a dictionary + resolved_fn_args = json.loads(raw_toolcall.function.arguments) + + # Add to the returned list + resolved_fn = ResolvedFunction( + name=raw_toolcall.function.name, + arguments=resolved_fn_args + ) + resolved_toolcall = ResolvedToolCall( + id=raw_toolcall.id, + type=raw_toolcall.type, + index=i, + function=resolved_fn + ) + resolved_toolcalls.append(resolved_toolcall) + + return resolved_toolcalls + + + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py new file mode 100644 index 000000000..9426439f0 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py @@ -0,0 +1,57 @@ +from __future__ import annotations +from pydantic import BaseModel +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +class ResolvedFunction(BaseModel): + """ + A type-safe, parsed representation of `litellm.utils.Function`. + """ + + name: str + """ + Name of the tool function to be called. + + TODO: Check if this attribute is defined for non-function tools, e.g. tools + provided by a MCP server. The docstring on `litellm.utils.Function` implies + that `name` may be `None`. + """ + + arguments: dict + """ + Arguments to the tool function, as a dictionary. + """ + +class ResolvedToolCall(BaseModel): + """ + A type-safe, parsed representation of + `litellm.utils.ChatCompletionDeltaToolCall`. + """ + + id: str | None + """ + The ID of the tool call. This should always be provided by LiteLLM, this + type is left optional as we do not use this attribute. + """ + + type: str + """ + The 'type' of tool call. Usually 'function'. + + TODO: Make this a union of string literals to ensure we are handling every + potential type of tool call. + """ + + function: ResolvedFunction + """ + The resolved function. See `ResolvedFunction` for more info. + """ + + index: int + """ + The index of this tool call. + + This is usually 0 unless the LLM supports parallel tool calling. + """ From d504b348c9ab641a0280b9dca1c70793c28e0f3d Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Sat, 23 Aug 2025 11:59:27 -0700 Subject: [PATCH 02/13] WIP: first working copy of Jupyternaut as an agent --- .../jupyter_ai/personas/base_persona.py | 211 +++++++++++++++--- .../personas/jupyternaut/jupyternaut.py | 53 +++-- .../jupyter-ai/jupyter_ai/tools/models.py | 15 +- 3 files changed, 222 insertions(+), 57 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 310901b1c..b74c778a7 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -1,29 +1,39 @@ +from __future__ import annotations import asyncio import os from abc import ABC, ABCMeta, abstractmethod from dataclasses import asdict from logging import Logger from time import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Tuple from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User from jupyterlab_chat.ychat import YChat +from litellm import ModelResponseStream, supports_function_calling +from litellm.utils import function_to_dict from pydantic import BaseModel from traitlets import MetaHasTraits from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness +from ..litellm_utils import ToolCallList, ResolvedToolCall + +# Import toolkits +from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit +from jupyter_ai_tools.toolkits.code_execution import toolkit as codeexec_toolkit +from jupyter_ai_tools.toolkits.git import toolkit as git_toolkit -# prevents a circular import -# types imported under this block have to be surrounded in single quotes on use if TYPE_CHECKING: from collections.abc import AsyncIterator - - from litellm import ModelResponseStream - from .persona_manager import PersonaManager + from ..tools import Toolkit +DEFAULT_TOOLKITS: dict[str, Toolkit] = { + "fs": fs_toolkit, + "codeexec": codeexec_toolkit, + "git": git_toolkit, +} class PersonaDefaults(BaseModel): """ @@ -237,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]: async def stream_message( self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> None: + ) -> Tuple[ResolvedToolCall, ToolCallList]: """ Takes an async iterator, dubbed the 'reply stream', and streams it to a new message by this persona in the YChat. The async iterator may yield @@ -247,21 +257,36 @@ async def stream_message( stream, then continuously updates it until the stream is closed. - Automatically manages its awareness state to show writing status. + + Returns a list of `ResolvedToolCall` objects. If this list is not empty, + the persona should run these tools. """ stream_id: Optional[str] = None stream_interrupted = False try: self.awareness.set_local_state_field("isWriting", True) - async for chunk in reply_stream: - # Coerce LiteLLM stream chunk to a string delta - if not isinstance(chunk, str): - chunk = chunk.choices[0].delta.content + toolcall_list = ToolCallList() + resolved_toolcalls: list[ResolvedToolCall] = [] - # LiteLLM streams always terminate with an empty chunk, so we - # ignore and continue when this occurs. - if not chunk: + async for chunk in reply_stream: + # Compute `content_delta` and `tool_calls_delta` based on the + # type of object yielded by `reply_stream`. + if isinstance(chunk, ModelResponseStream): + delta = chunk.choices[0].delta + content_delta = delta.content + toolcalls_delta = delta.tool_calls + elif isinstance(chunk, str): + content_delta = chunk + toolcalls_delta = None + else: + raise Exception(f"Unrecognized type in stream_message(): {type(chunk)}") + + # LiteLLM streams always terminate with an empty chunk, so + # continue in this case. + if not (content_delta or toolcalls_delta): continue + # Terminate the stream if the user requested it. if ( stream_id and stream_id in self.message_interrupted.keys() @@ -280,34 +305,46 @@ async def stream_message( stream_interrupted = True break - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) + # Append `content_delta` to the existing message. + if content_delta: + # Start the stream with an empty message on the initial reply. + # Bind the new message ID to `stream_id`. + if not stream_id: + stream_id = self.ychat.add_message( + NewMessage(body="", sender=self.id) + ) + self.message_interrupted[stream_id] = asyncio.Event() + self.awareness.set_local_state_field("isWriting", stream_id) + assert stream_id + + self.ychat.update_message( + Message( + id=stream_id, + body=content_delta, + time=time(), + sender=self.id, + raw_time=False, + ), + append=True, ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - - assert stream_id - self.ychat.update_message( - Message( - id=stream_id, - body=chunk, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) + if toolcalls_delta: + toolcall_list += toolcalls_delta + + # After the reply stream is complete, resolve the list of tool calls. + resolved_toolcalls = toolcall_list.resolve() except Exception as e: self.log.error( f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." ) self.log.exception(e) finally: + # Reset local state self.awareness.set_local_state_field("isWriting", False) - if stream_id: - # if stream was interrupted, add a tombstone - if stream_interrupted: + self.message_interrupted.pop(stream_id, None) + + # If stream was interrupted, add a tombstone and return `[]`, + # indicating that no tools should be run afterwards. + if stream_id and stream_interrupted: stream_tombstone = "\n\n(AI response stopped by user)" self.ychat.update_message( Message( @@ -319,8 +356,15 @@ async def stream_message( ), append=True, ) - if stream_id in self.message_interrupted.keys(): - del self.message_interrupted[stream_id] + return None + + # Otherwise return the resolved list. + if len(resolved_toolcalls): + count = len(resolved_toolcalls) + names = sorted([tc.function.name for tc in resolved_toolcalls]) + self.log.info(f"AI response triggered {count} tool calls: {names}") + return resolved_toolcalls, toolcall_list + def send_message(self, body: str) -> None: """ @@ -361,7 +405,7 @@ def get_mcp_config(self) -> dict[str, Any]: Returns the MCP config for the current chat. """ return self.parent.get_mcp_config() - + def process_attachments(self, message: Message) -> Optional[str]: """ Process file attachments in the message and return their content as a string. @@ -431,6 +475,99 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]: self.log.error(f"Failed to resolve attachment {attachment_id}: {e}") return None + def get_tools(self, model_id: str) -> list[dict]: + """ + Returns the `tools` parameter which should be passed to + `litellm.acompletion()` for a given LiteLLM model ID. + + If the model does not support tool-calling, this method returns an empty + list. Otherwise, it returns the list of tools available in the current + environment. These may include: + + - The default set of tool functions in Jupyter AI, defined in the + `jupyter_ai_tools` package. + + - (TODO) Tools provided by MCP server configuration, if any. + + - (TODO) Web search. + + - (TODO) File search using vector store IDs. + + TODO: cache this + + TODO: Implement some permissions system so users can control what tools + are allowable. + + NOTE: The returned list is expected by LiteLLM to conform to the `tools` + parameter defintiion defined by the OpenAI API: + https://platform.openai.com/docs/guides/tools#available-tools + + NOTE: This API is a WIP and is very likely to change. + """ + # Return early if the model does not support tool calling + if not supports_function_calling(model=model_id): + return [] + + tool_descriptions = [] + + # Get all tools from `jupyter_ai_tools` and store their object descriptions + for toolkit_name, toolkit in DEFAULT_TOOLKITS.items(): + # TODO: make these tool permissions configurable. + for tool in toolkit.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + + # Prepend the toolkit name to each function name, hopefully + # ensuring every tool function has a unique name. + # e.g. 'git_add' => 'git__git_add' + # + # TODO: Actually ensure this instead of hoping. + desc['function']['name'] = f"{toolkit_name}__{desc['function']['name']}" + tool_descriptions.append(desc) + + # Finally, return the tool descriptions + return tool_descriptions + + + async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: + """ + Runs the tools specified in the list of tool calls returned by + `self.stream_message()`. Returns a list of dictionaries + `toolcall_outputs: list[dict]`, which should be appended directly to the + message history on the next invocation of the LLM. + """ + if not len(tools): + return [] + + tool_outputs: list[dict] = [] + for tool_call in tools: + # Get tool definition from the correct toolkit + toolkit_name, tool_name = tool_call.function.name.split("__") + assert toolkit_name in DEFAULT_TOOLKITS + tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name) + + # Run tool and store its output + output = await tool_defn.callable(**tool_call.function.arguments) + + # Store the tool output in a dictionary accepted by LiteLLM + output_dict = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": output, + } + tool_outputs.append(output_dict) + + self.log.info(f"Ran {len(tools)} tool functions.") + return tool_outputs + + + def shutdown(self) -> None: """ Shuts the persona down. This method should: diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 05ec403e4..66f1805e2 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -9,6 +9,7 @@ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) +from ...litellm_utils import ResolvedToolCall class JupyternautPersona(BasePersona): @@ -37,22 +38,35 @@ async def process_message(self, message: Message) -> None: return model_id = self.config_manager.chat_model - model_args = self.config_manager.chat_model_args - context_as_messages = self.get_context_as_messages(model_id, message) - response_aiter = await acompletion( - **model_args, - model=model_id, - messages=[ - *context_as_messages, - { - "role": "user", - "content": message.body, - }, - ], - stream=True, - ) - await self.stream_message(response_aiter) + # `True` on the first LLM invocation, `False` on all invocations after. + initial_invocation = True + # List of tool calls requested by the LLM in the previous invocaiton. + tool_calls: list[ResolvedToolCall] = [] + tool_call_list = None + # List of tool call outputs computed in the previous invocation. + tool_call_outputs: list[dict] = [] + + # Loop until the AI is complete running all its tools. + while initial_invocation or len(tool_call_outputs): + messages = self.get_context_as_messages(model_id, message) + + # TODO: Find a better way to track tool calls + if not initial_invocation and tool_calls: + self.log.error(messages[-1]) + messages[-1]['tool_calls'] = tool_call_list._aggregate + messages.extend(tool_call_outputs) + + self.log.error(messages) + response_aiter = await acompletion( + model=model_id, + messages=messages, + tools=self.get_tools(model_id), + stream=True, + ) + tool_calls, tool_call_list = await self.stream_message(response_aiter) + initial_invocation = False + tool_call_outputs = await self.run_tools(tool_calls) def get_context_as_messages( self, model_id: str, message: Message @@ -79,16 +93,17 @@ def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]] """ Returns the current history as a list of messages accepted by `litellm.acompletion()`. + + NOTE: You should usually call the public `get_context_as_messages()` + method instead. """ # TODO: consider bounding history based on message size (e.g. total # char/token count) instead of message count. all_messages = self.ychat.get_messages() # gather last k * 2 messages and return - # we exclude the last message since that is the human message just - # submitted by a user. - start_idx = 0 if k is None else -2 * k - 1 - recent_messages: list[Message] = all_messages[start_idx:-1] + start_idx = 0 if k is None else -2 * k + recent_messages: list[Message] = all_messages[start_idx:] history: list[dict[str, Any]] = [] for msg in recent_messages: diff --git a/packages/jupyter-ai/jupyter_ai/tools/models.py b/packages/jupyter-ai/jupyter_ai/tools/models.py index 5b95b6174..e547f0c15 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/models.py +++ b/packages/jupyter-ai/jupyter_ai/tools/models.py @@ -135,7 +135,7 @@ class Toolkit(BaseModel): name: str description: Optional[str] = None - tools: set = Field(default_factory=set) + tools: set[Tool] = Field(default_factory=set) model_config = ConfigDict(arbitrary_types_allowed=True) def add_tool(self, tool: Tool): @@ -146,6 +146,19 @@ def add_tool(self, tool: Tool): """ self.tools.add(tool) + def get_tool_unsafe(self, tool_name: str) -> Tool: + """ + (WIP) Gets a tool by its name. This is just a temporary method which is + used to make Jupyternaut agentic before we implement the + read/write/execute/delete permissions. + """ + for tool in self.tools: + if tool_name == tool.name: + return tool + + raise Exception(f"Tool not found: {tool_name}") + + def get_tools( self, read: Optional[bool] = None, From 5aa46bf9a4a93d89403ab9f5c9fe3dfc76c13832 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Sat, 23 Aug 2025 19:05:28 -0700 Subject: [PATCH 03/13] clean up tool calling flow & show in chat --- .../jupyter_ai/litellm_utils/__init__.py | 4 +- .../litellm_utils/streaming_utils.py | 13 ++++ .../jupyter_ai/litellm_utils/toolcall_list.py | 78 +++++++++++++++---- .../litellm_utils/toolcall_types.py | 57 -------------- .../jupyter_ai/personas/base_persona.py | 46 ++++++----- .../personas/jupyternaut/jupyternaut.py | 65 +++++++++++----- 6 files changed, 151 insertions(+), 112 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py delete mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py index cd95e2b2d..787493764 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py @@ -1,2 +1,2 @@ -from .toolcall_list import ToolCallList -from .toolcall_types import * +from .toolcall_list import * +from .streaming_utils import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py new file mode 100644 index 000000000..febe3f7f2 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel +from .toolcall_list import ToolCallList + +class StreamResult(BaseModel): + id: str + """ + ID of the new message. + """ + + tool_calls: ToolCallList + """ + Tool calls requested by the LLM in its streamed response. + """ diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py index 1e3effd3a..654939ebb 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py @@ -1,9 +1,61 @@ from litellm.utils import ChatCompletionDeltaToolCall, Function import json +from pydantic import BaseModel +from typing import Any -from .toolcall_types import ResolvedToolCall, ResolvedFunction +class ResolvedFunction(BaseModel): + """ + A type-safe, parsed representation of `litellm.utils.Function`. + """ + + name: str + """ + Name of the tool function to be called. + + TODO: Check if this attribute is defined for non-function tools, e.g. tools + provided by a MCP server. The docstring on `litellm.utils.Function` implies + that `name` may be `None`. + """ + + arguments: dict[str, Any] + """ + Arguments to the tool function, as a dictionary. + """ + + +class ResolvedToolCall(BaseModel): + """ + A type-safe, parsed representation of + `litellm.utils.ChatCompletionDeltaToolCall`. + """ + + id: str | None + """ + The ID of the tool call. This should always be provided by LiteLLM, this + type is left optional as we do not use this attribute. + """ + + type: str + """ + The 'type' of tool call. Usually 'function'. -class ToolCallList(): + TODO: Make this a union of string literals to ensure we are handling every + potential type of tool call. + """ + + function: ResolvedFunction + """ + The resolved function. See `ResolvedFunction` for more info. + """ + + index: int + """ + The index of this tool call. + + This is usually 0 unless the LLM supports parallel tool calling. + """ + +class ToolCallList(BaseModel): """ A helper object that defines a custom `__iadd__()` method which accepts a `tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class @@ -27,14 +79,7 @@ class ToolCallList(): ``` """ - _aggregate: list[ChatCompletionDeltaToolCall] - - def __init__(self): - self.size = None - - # Initialize `_aggregate` - self._aggregate = [] - + _aggregate: list[ChatCompletionDeltaToolCall] = [] def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': """ @@ -116,6 +161,13 @@ def resolve(self) -> list[ResolvedToolCall]: resolved_toolcalls.append(resolved_toolcall) return resolved_toolcalls - - - \ No newline at end of file + + def to_json(self) -> list[dict[str, Any]]: + """ + Returns the list of tool calls as a Python dictionary that can be + JSON-serialized. + """ + return [ + model.model_dump() for model in self._aggregate + ] + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py deleted file mode 100644 index 9426439f0..000000000 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations -from pydantic import BaseModel -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Any - -class ResolvedFunction(BaseModel): - """ - A type-safe, parsed representation of `litellm.utils.Function`. - """ - - name: str - """ - Name of the tool function to be called. - - TODO: Check if this attribute is defined for non-function tools, e.g. tools - provided by a MCP server. The docstring on `litellm.utils.Function` implies - that `name` may be `None`. - """ - - arguments: dict - """ - Arguments to the tool function, as a dictionary. - """ - -class ResolvedToolCall(BaseModel): - """ - A type-safe, parsed representation of - `litellm.utils.ChatCompletionDeltaToolCall`. - """ - - id: str | None - """ - The ID of the tool call. This should always be provided by LiteLLM, this - type is left optional as we do not use this attribute. - """ - - type: str - """ - The 'type' of tool call. Usually 'function'. - - TODO: Make this a union of string literals to ensure we are handling every - potential type of tool call. - """ - - function: ResolvedFunction - """ - The resolved function. See `ResolvedFunction` for more info. - """ - - index: int - """ - The index of this tool call. - - This is usually 0 unless the LLM supports parallel tool calling. - """ diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index b74c778a7..21718dab7 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -5,7 +5,7 @@ from dataclasses import asdict from logging import Logger from time import time -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User @@ -17,7 +17,7 @@ from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_utils import ToolCallList, ResolvedToolCall +from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall # Import toolkits from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit @@ -247,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]: async def stream_message( self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> Tuple[ResolvedToolCall, ToolCallList]: + ) -> StreamResult: """ Takes an async iterator, dubbed the 'reply stream', and streams it to a new message by this persona in the YChat. The async iterator may yield @@ -263,12 +263,21 @@ async def stream_message( """ stream_id: Optional[str] = None stream_interrupted = False + tool_calls = ToolCallList() try: self.awareness.set_local_state_field("isWriting", True) - toolcall_list = ToolCallList() - resolved_toolcalls: list[ResolvedToolCall] = [] async for chunk in reply_stream: + # Start the stream with an empty message on the initial reply. + # Bind the new message ID to `stream_id`. + if not stream_id: + stream_id = self.ychat.add_message( + NewMessage(body="", sender=self.id) + ) + self.message_interrupted[stream_id] = asyncio.Event() + self.awareness.set_local_state_field("isWriting", stream_id) + assert stream_id + # Compute `content_delta` and `tool_calls_delta` based on the # type of object yielded by `reply_stream`. if isinstance(chunk, ModelResponseStream): @@ -307,16 +316,6 @@ async def stream_message( # Append `content_delta` to the existing message. if content_delta: - # Start the stream with an empty message on the initial reply. - # Bind the new message ID to `stream_id`. - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) - ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - assert stream_id - self.ychat.update_message( Message( id=stream_id, @@ -328,10 +327,8 @@ async def stream_message( append=True, ) if toolcalls_delta: - toolcall_list += toolcalls_delta + tool_calls += toolcalls_delta - # After the reply stream is complete, resolve the list of tool calls. - resolved_toolcalls = toolcall_list.resolve() except Exception as e: self.log.error( f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." @@ -358,12 +355,17 @@ async def stream_message( ) return None - # Otherwise return the resolved list. + # TODO: determine where this should live + resolved_toolcalls = tool_calls.resolve() if len(resolved_toolcalls): count = len(resolved_toolcalls) names = sorted([tc.function.name for tc in resolved_toolcalls]) self.log.info(f"AI response triggered {count} tool calls: {names}") - return resolved_toolcalls, toolcall_list + + return StreamResult( + id=stream_id, + tool_calls=tool_calls + ) def send_message(self, body: str) -> None: @@ -552,7 +554,9 @@ async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name) # Run tool and store its output - output = await tool_defn.callable(**tool_call.function.arguments) + output = tool_defn.callable(**tool_call.function.arguments) + if asyncio.iscoroutine(output): + output = await output # Store the tool output in a dictionary accepted by LiteLLM output_dict = { diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 66f1805e2..3c350c1f4 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,4 +1,6 @@ from typing import Any, Optional +import time +import json from jupyterlab_chat.models import Message from litellm import acompletion @@ -9,7 +11,6 @@ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) -from ...litellm_utils import ResolvedToolCall class JupyternautPersona(BasePersona): @@ -39,34 +40,60 @@ async def process_message(self, message: Message) -> None: model_id = self.config_manager.chat_model - # `True` on the first LLM invocation, `False` on all invocations after. - initial_invocation = True - # List of tool calls requested by the LLM in the previous invocaiton. - tool_calls: list[ResolvedToolCall] = [] - tool_call_list = None + # `True` before the first LLM response is sent, `False` afterwards. + initial_response = True # List of tool call outputs computed in the previous invocation. tool_call_outputs: list[dict] = [] - # Loop until the AI is complete running all its tools. - while initial_invocation or len(tool_call_outputs): - messages = self.get_context_as_messages(model_id, message) - - # TODO: Find a better way to track tool calls - if not initial_invocation and tool_calls: - self.log.error(messages[-1]) - messages[-1]['tool_calls'] = tool_call_list._aggregate - messages.extend(tool_call_outputs) + # Initialize list of messages, including history and context + messages: list[dict] = self.get_context_as_messages(model_id, message) - self.log.error(messages) + # Loop until the AI is complete running all its tools. + while initial_response or len(tool_call_outputs): + # Stream message to the chat response_aiter = await acompletion( model=model_id, messages=messages, tools=self.get_tools(model_id), stream=True, ) - tool_calls, tool_call_list = await self.stream_message(response_aiter) - initial_invocation = False - tool_call_outputs = await self.run_tools(tool_calls) + result = await self.stream_message(response_aiter) + initial_response = False + + # Append new reply to `messages` + reply = self.ychat.get_message(result.id) + tool_calls_json = result.tool_calls.to_json() + messages.append({ + "role": "assistant", + "content": reply.body, + "tool_calls": tool_calls_json + }) + + # Show tool call requests to YChat (not synced with `messages`) + if len(tool_calls_json): + self.ychat.update_message(Message( + id=result.id, + body=f"\n\n```\n{json.dumps(tool_calls_json, indent=2)}\n```\n", + sender=self.id, + time=time.time(), + raw_time=False + ), append=True) + + # Run tools and append outputs to `messages` + tool_call_outputs = await self.run_tools(result.tool_calls.resolve()) + messages.extend(tool_call_outputs) + + # Add tool call outputs to YChat (not synced with `messages`) + if tool_call_outputs: + self.ychat.update_message(Message( + id=result.id, + body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", + sender=self.id, + time=time.time(), + raw_time=False + ), append=True) + + def get_context_as_messages( self, model_id: str, message: Message From 7ba285dc5aaf7ef8780ffa02b6bf21f6a2f7dcf4 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 06:16:51 +0200 Subject: [PATCH 04/13] add temporary default toolkit --- .../jupyter_ai/personas/base_persona.py | 46 +- .../jupyter-ai/jupyter_ai/tools/__init__.py | 3 +- .../jupyter_ai/tools/default_toolkit.py | 255 ++++++++ .../jupyter_ai/tools/test_default_toolkit.py | 595 ++++++++++++++++++ 4 files changed, 867 insertions(+), 32 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py create mode 100644 packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 21718dab7..4e550cfbd 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -20,21 +20,13 @@ from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall # Import toolkits -from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit -from jupyter_ai_tools.toolkits.code_execution import toolkit as codeexec_toolkit -from jupyter_ai_tools.toolkits.git import toolkit as git_toolkit +from ..tools.default_toolkit import DEFAULT_TOOLKIT if TYPE_CHECKING: from collections.abc import AsyncIterator from .persona_manager import PersonaManager from ..tools import Toolkit -DEFAULT_TOOLKITS: dict[str, Toolkit] = { - "fs": fs_toolkit, - "codeexec": codeexec_toolkit, - "git": git_toolkit, -} - class PersonaDefaults(BaseModel): """ Data structure that represents the default settings of a persona. Each persona @@ -512,27 +504,19 @@ def get_tools(self, model_id: str) -> list[dict]: tool_descriptions = [] - # Get all tools from `jupyter_ai_tools` and store their object descriptions - for toolkit_name, toolkit in DEFAULT_TOOLKITS.items(): - # TODO: make these tool permissions configurable. - for tool in toolkit.get_tools(): - # Here, we are using a util function from LiteLLM to coerce - # each `Tool` struct into a tool description dictionary expected - # by LiteLLM. - desc = { - "type": "function", - "function": function_to_dict(tool.callable), - } - - # Prepend the toolkit name to each function name, hopefully - # ensuring every tool function has a unique name. - # e.g. 'git_add' => 'git__git_add' - # - # TODO: Actually ensure this instead of hoping. - desc['function']['name'] = f"{toolkit_name}__{desc['function']['name']}" - tool_descriptions.append(desc) + # Get all tools from the default toolkit and store their object descriptions + for tool in DEFAULT_TOOLKIT.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + tool_descriptions.append(desc) # Finally, return the tool descriptions + self.log.info(tool_descriptions) return tool_descriptions @@ -549,9 +533,9 @@ async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: tool_outputs: list[dict] = [] for tool_call in tools: # Get tool definition from the correct toolkit - toolkit_name, tool_name = tool_call.function.name.split("__") - assert toolkit_name in DEFAULT_TOOLKITS - tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name) + # TODO: validation? + tool_name = tool_call.function.name + tool_defn = DEFAULT_TOOLKIT.get_tool_unsafe(tool_name) # Run tool and store its output output = tool_defn.callable(**tool_call.function.arguments) diff --git a/packages/jupyter-ai/jupyter_ai/tools/__init__.py b/packages/jupyter-ai/jupyter_ai/tools/__init__.py index 0252ac1a9..1f8e3afa3 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/tools/__init__.py @@ -1,5 +1,6 @@ """Tools package for Jupyter AI.""" from .models import Tool, Toolkit +from .default_toolkit import DEFAULT_TOOLKIT -__all__ = ["Tool", "Toolkit"] +__all__ = ["Tool", "Toolkit", "DEFAULT_TOOLKIT"] diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py new file mode 100644 index 000000000..79c8a7675 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -0,0 +1,255 @@ +from .models import Tool, Toolkit +from jupyter_ai_tools.toolkits.code_execution import bash + +import pathlib + + +def read(file_path: str, offset: int, limit: int) -> str: + """ + Read a subset of lines from a text file. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be read. + offset : int + The line number at which to start reading (1-based indexing). + limit : int + Number of lines to read starting from *offset*. + If *offset + limit* exceeds the number of lines in the file, + all available lines after *offset* are returned. + + Returns + ------- + List[str] + List of lines (including line-ending characters) that were read. + + Examples + -------- + >>> # Suppose ``/tmp/example.txt`` contains 10 lines + >>> read('/tmp/example.txt', offset=3, limit=4) + ['third line\n', 'fourth line\n', 'fifth line\n', 'sixth line\n'] + """ + path = pathlib.Path(file_path) + if not path.is_file(): + raise FileNotFoundError(f"File not found: {file_path}") + + # Normalize arguments + offset = max(1, int(offset)) + limit = max(0, int(limit)) + lines: list[str] = [] + + with path.open(encoding='utf-8', errors='replace') as f: + # Skip to offset + line_no = 0 + # Loop invariant: line_no := last read line + # After the loop exits, line_no == offset - 1, meaning the + # next line starts at `offset` + while line_no < offset - 1: + line = f.readline() + # Return early if offset exceeds number of lines in file + if line == "": + return "" + line_no += 1 + + # Append lines until limit is reached + while len(lines) < limit: + line = f.readline() + if line == "": + break + lines.append(line) + + return "".join(lines) + + +def edit( + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, +) -> None: + """ + Replace occurrences of a substring in a file. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be edited. + old_string : str + Text that should be replaced. + new_string : str + Text that will replace *old_string*. + replace_all : bool, optional + If ``True`` all occurrences of *old_string* are replaced. + If ``False`` (default), only the first occurrence in the file is replaced. + + Returns + ------- + None + + Raises + ------ + FileNotFoundError + If *file_path* does not exist. + ValueError + If *old_string* is empty (replacing an empty string is ambiguous). + + Notes + ----- + The file is overwritten atomically: it is first read into memory, + the substitution is performed, and the file is written back. + This keeps the operation safe for short to medium-sized files. + + Examples + -------- + >>> # Replace only the first occurrence + >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=False) + >>> # Replace all occurrences + >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=True) + """ + path = pathlib.Path(file_path) + if not path.is_file(): + raise FileNotFoundError(f"File not found: {file_path}") + + if old_string == "": + raise ValueError("old_string must not be empty") + + # Read the entire file + content = path.read_text(encoding="utf-8", errors="replace") + + # Perform replacement + if replace_all: + new_content = content.replace(old_string, new_string) + else: + new_content = content.replace(old_string, new_string, 1) + + # Write back + path.write_text(new_content, encoding="utf-8") + + +def write(file_path: str, content: str) -> None: + """ + Write content to a file, creating it if it doesn't exist. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be written. + content : str + Content to write to the file. + + Returns + ------- + None + + Raises + ------ + OSError + If the file cannot be written (e.g., permission denied, invalid path). + + Notes + ----- + This function will overwrite the file if it already exists. + The parent directory must exist; this function does not create directories. + + Examples + -------- + >>> write('/tmp/example.txt', 'Hello, world!') + >>> write('/tmp/data.json', '{"key": "value"}') + """ + path = pathlib.Path(file_path) + + # Write the content to the file + path.write_text(content, encoding="utf-8") + + +async def search_grep(pattern: str, include: str = "*") -> str: + """ + Search for text patterns in files using ripgrep. + + This function uses ripgrep (rg) to perform fast regex-based text searching + across files, with optional file filtering based on glob patterns. + + Parameters + ---------- + pattern : str + A regular expression pattern to search for. Ripgrep uses Rust regex + syntax which supports: + - Basic regex features: ., *, +, ?, ^, $, [], (), | + - Character classes: \w, \d, \s, \W, \D, \S + - Unicode categories: \p{L}, \p{N}, \p{P}, etc. + - Word boundaries: \b, \B + - Anchors: ^, $, \A, \z + - Quantifiers: {n}, {n,}, {n,m} + - Groups: (pattern), (?:pattern), (?Ppattern) + - Lookahead/lookbehind: (?=pattern), (?!pattern), (?<=pattern), (?>> search_grep(r"def\s+\w+", "*.py") + 'file.py:10:def my_function():' + + >>> search_grep(r"TODO|FIXME", "**/*.{py,js}") + 'app.py:25:# TODO: implement this + script.js:15:// FIXME: handle edge case' + + >>> search_grep(r"class\s+(\w+)", "src/**/*.py") + 'src/models.py:1:class User:' + """ + # Use bash tool to execute ripgrep + cmd_parts = ["rg", "--color=never", "--line-number", "--with-filename"] + + # Add glob pattern if specified + if include != "*": + cmd_parts.extend(["-g", include]) + + # Add the pattern (always quote it to handle special characters) + cmd_parts.append(pattern) + + # Join command with proper shell escaping + command = " ".join(f'"{part}"' if " " in part or any(c in part for c in "!*?[]{}()") else part for part in cmd_parts) + + try: + result = await bash(command) + return result + except Exception as e: + raise RuntimeError(f"Ripgrep search failed: {str(e)}") from e + + +DEFAULT_TOOLKIT = Toolkit(name="jupyter-ai-default-toolkit") +DEFAULT_TOOLKIT.add_tool(Tool(callable=bash)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=read)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=edit)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=write)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=search_grep)) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py new file mode 100644 index 000000000..9db82dc41 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py @@ -0,0 +1,595 @@ +"""Tests for default_toolkit.py functions and toolkit configuration.""" + +import pathlib +import tempfile +import pytest +from unittest.mock import patch, mock_open + +from .default_toolkit import read, edit, write, search_grep, DEFAULT_TOOLKIT +from .models import Tool, Toolkit + + +class TestReadFunction: + """Test the read function.""" + + def test_read_valid_file(self): + """Test reading lines from a valid file.""" + # Create a temporary file with known content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\nline 3\nline 4\nline 5\n") + temp_path = f.name + + try: + # Test reading from offset 2, limit 3 + result = read(temp_path, offset=2, limit=3) + assert result == "line 2\nline 3\nline 4\n" + + # Test reading from offset 1, limit 2 + result = read(temp_path, offset=1, limit=2) + assert result == "line 1\nline 2\n" + + # Test reading all lines from beginning + result = read(temp_path, offset=1, limit=10) + assert result == "line 1\nline 2\nline 3\nline 4\nline 5\n" + + # Test reading from middle to end + result = read(temp_path, offset=4, limit=10) + assert result == "line 4\nline 5\n" + + finally: + # Clean up + pathlib.Path(temp_path).unlink() + + def test_read_empty_file(self): + """Test reading from an empty file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=5) + assert result == "" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_single_line_file(self): + """Test reading from a file with one line.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("single line\n") + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=1) + assert result == "single line\n" + + result = read(temp_path, offset=1, limit=5) + assert result == "single line\n" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_file_not_found(self): + """Test reading from a non-existent file.""" + with pytest.raises(FileNotFoundError, match="File not found: /nonexistent/path"): + read("/nonexistent/path", offset=1, limit=5) + + def test_read_offset_beyond_file_length(self): + """Test reading with offset beyond file length.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\n") + temp_path = f.name + + try: + # Offset beyond file length should return empty string + result = read(temp_path, offset=10, limit=5) + assert result == "" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_negative_and_zero_values(self): + """Test read function with negative and zero offset/limit values.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\nline 3\n") + temp_path = f.name + + try: + # Negative offset should be normalized to 1 + result = read(temp_path, offset=-5, limit=2) + assert result == "line 1\nline 2\n" + + # Zero offset should be normalized to 1 + result = read(temp_path, offset=0, limit=2) + assert result == "line 1\nline 2\n" + + # Zero limit should return empty string + result = read(temp_path, offset=1, limit=0) + assert result == "" + + # Negative limit should return empty string + result = read(temp_path, offset=1, limit=-5) + assert result == "" + + finally: + pathlib.Path(temp_path).unlink() + + def test_read_unicode_content(self): + """Test reading file with unicode content.""" + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as f: + f.write("línea 1 🚀\nlínea 2 ❤️\nlínea 3 🎉\n") + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=2) + assert result == "línea 1 🚀\nlínea 2 ❤️\n" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_with_encoding_errors(self): + """Test reading file with encoding issues using replace errors handling.""" + # This test ensures the 'replace' error handling works properly + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("valid line\n") + temp_path = f.name + + try: + # The function should handle encoding errors gracefully + result = read(temp_path, offset=1, limit=1) + assert result == "valid line\n" + finally: + pathlib.Path(temp_path).unlink() + + +class TestEditFunction: + """Test the edit function.""" + + def test_edit_replace_first_occurrence(self): + """Test replacing the first occurrence of a string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("foo bar foo baz foo") + temp_path = f.name + + try: + edit(temp_path, "foo", "qux", replace_all=False) + + # Read the file to verify the change + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "qux bar foo baz foo" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_replace_all_occurrences(self): + """Test replacing all occurrences of a string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("foo bar foo baz foo") + temp_path = f.name + + try: + edit(temp_path, "foo", "qux", replace_all=True) + + # Read the file to verify the change + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "qux bar qux baz qux" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_multiline_content(self): + """Test editing multiline content.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nold content\nline 3\nold content\nline 5") + temp_path = f.name + + try: + edit(temp_path, "old content", "new content", replace_all=True) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "line 1\nnew content\nline 3\nnew content\nline 5" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_string_not_found(self): + """Test editing when the target string is not found.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("hello world") + temp_path = f.name + + try: + # This should not raise an error, just leave the file unchanged + edit(temp_path, "nonexistent", "replacement", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "hello world" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_file_not_found(self): + """Test editing a non-existent file.""" + with pytest.raises(FileNotFoundError, match="File not found: /nonexistent/path"): + edit("/nonexistent/path", "old", "new") + + def test_edit_empty_old_string(self): + """Test editing with an empty old_string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("hello world") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="old_string must not be empty"): + edit(temp_path, "", "replacement") + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_unicode_content(self): + """Test editing file with unicode content.""" + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as f: + f.write("hola 🌟 mundo 🌟 adiós") + temp_path = f.name + + try: + edit(temp_path, "🌟", "⭐", replace_all=True) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "hola ⭐ mundo ⭐ adiós" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_newline_characters(self): + """Test editing with newline characters.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line1\nold\nline3") + temp_path = f.name + + try: + edit(temp_path, "\nold\n", "\nnew\n", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "line1\nnew\nline3" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_replace_with_empty_string(self): + """Test replacing content with empty string (deletion).""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("keep this DELETE_ME keep this too") + temp_path = f.name + + try: + edit(temp_path, "DELETE_ME ", "", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "keep this keep this too" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_atomicity(self): + """Test that edit operation is atomic (file is either fully updated or unchanged).""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + original_content = "original content" + f.write(original_content) + temp_path = f.name + + try: + # Mock pathlib.Path.write_text to raise an exception + with patch.object(pathlib.Path, 'write_text', side_effect=IOError("Disk full")): + with pytest.raises(IOError): + edit(temp_path, "original", "modified") + + # File should remain unchanged due to the error + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == original_content + + finally: + pathlib.Path(temp_path).unlink() + + +class TestWriteFunction: + """Test the write function.""" + + def test_write_new_file(self): + """Test writing content to a new file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "new_file.txt" + test_content = "Hello, world!\nThis is a test." + + write(str(temp_path), test_content) + + # Verify the file was created and contains the correct content + assert temp_path.exists() + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_overwrite_existing_file(self): + """Test overwriting an existing file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("original content") + temp_path = f.name + + try: + new_content = "new content that replaces the old" + write(temp_path, new_content) + + # Verify the file was overwritten + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == new_content + finally: + pathlib.Path(temp_path).unlink() + + def test_write_empty_content(self): + """Test writing empty content to a file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "empty_file.txt" + + write(str(temp_path), "") + + # Verify the file exists and is empty + assert temp_path.exists() + content = temp_path.read_text(encoding='utf-8') + assert content == "" + + def test_write_multiline_content(self): + """Test writing multiline content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "multiline.txt" + test_content = "Line 1\nLine 2\nLine 3\n" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_unicode_content(self): + """Test writing unicode content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "unicode.txt" + test_content = "Hello 世界! 🌍 Café naïve résumé" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_large_content(self): + """Test writing large content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "large.txt" + # Create content with 10000 lines + test_content = "\n".join([f"Line {i}" for i in range(10000)]) + + write(str(temp_path), test_content) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == test_content + + @pytest.mark.skip("Fix this test for CRLF newlines (Windows problem)") + def test_write_special_characters(self): + """Test writing content with special characters.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "special.txt" + test_content = 'Content with "quotes", \ttabs, and \nnewlines\r\n' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_invalid_directory(self): + """Test writing to a non-existent directory.""" + invalid_path = "/nonexistent/directory/file.txt" + + with pytest.raises(OSError): + write(invalid_path, "test content") + + def test_write_permission_denied(self): + """Test writing to a file without write permissions.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("original") + temp_path = f.name + + try: + # Make file read-only + pathlib.Path(temp_path).chmod(0o444) + + with pytest.raises(OSError): + write(temp_path, "new content") + + finally: + # Restore write permissions and clean up + pathlib.Path(temp_path).chmod(0o644) + pathlib.Path(temp_path).unlink() + + def test_write_binary_like_content(self): + """Test writing content that looks like binary data.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "binary_like.txt" + # Content with null bytes and other control characters + test_content = "Normal text\x00null byte\x01control char\xff" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_json_content(self): + """Test writing JSON-like content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "data.json" + test_content = '{"name": "test", "value": 42, "nested": {"key": "value"}}' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_code_content(self): + """Test writing code content with proper indentation.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "code.py" + test_content = '''def hello(): + """Say hello.""" + print("Hello, world!") + + if True: + return "success" +''' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + @pytest.mark.skip("Fix this test for CRLF newlines (Windows problem)") + def test_write_preserves_line_endings(self): + """Test that write preserves different line endings.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "line_endings.txt" + test_content = "Unix\nWindows\r\nMac\rMixed\r\n" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + +class TestSearchGrepFunction: + """Test the search_grep function.""" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_bash_integration(self, mock_bash): + """Test that search_grep correctly calls bash with proper arguments.""" + mock_bash.return_value = "test.py:1:def test():" + + result = await search_grep("def", "*.py") + + # Verify bash was called + mock_bash.assert_called_once() + call_args = mock_bash.call_args[0][0] + + # Check that the command contains expected parts + assert "rg" in call_args + assert "--color=never" in call_args + assert "--line-number" in call_args + assert "--with-filename" in call_args + assert "-g" in call_args + assert "*.py" in call_args + assert "def" in call_args + + assert result == "test.py:1:def test():" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_default_include(self, mock_bash): + """Test search_grep with default include pattern.""" + mock_bash.return_value = "" + + await search_grep("pattern") + + call_args = mock_bash.call_args[0][0] + # Should not contain -g flag when using default "*" pattern + assert "-g" not in call_args or "\"*\"" not in call_args + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_bash_exception(self, mock_bash): + """Test search_grep handling of bash execution errors.""" + mock_bash.side_effect = Exception("Command failed") + + with pytest.raises(RuntimeError, match="Ripgrep search failed: Command failed"): + await search_grep("pattern", "*.txt") + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_basic_pattern(self, mock_bash): + """Test basic pattern searching.""" + mock_bash.return_value = "test1.py:1:def hello_world():\ntest2.py:2: def method(self):" + + result = await search_grep(r"def\s+\w+", "*.py") + + # Should find function definitions in both files + assert "test1.py" in result + assert "test2.py" in result + assert "def hello_world" in result + assert "def method" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_no_matches(self, mock_bash): + """Test search with no matches.""" + mock_bash.return_value = "" + + result = await search_grep("nonexistent_pattern", "*.txt") + assert result == "" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_with_include_pattern(self, mock_bash): + """Test search with file include pattern.""" + mock_bash.return_value = "script.py:1:import os" + + result = await search_grep("import", "*.py") + assert "script.py" in result + assert "readme.txt" not in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_special_characters(self, mock_bash): + """Test searching for patterns with special regex characters.""" + # Mock different return values for different calls + mock_bash.side_effect = [ + "special.txt:2:email: user@domain.com", + "special.txt:1:price: $10.99" + ] + + # Search for email pattern + result = await search_grep(r"\w+@\w+\.\w+", "*.txt") + assert "user@domain.com" in result + + # Search for price pattern + result = await search_grep(r"\$\d+\.\d+", "*.txt") + assert "$10.99" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_unicode_content(self, mock_bash): + """Test searching in files with unicode content.""" + mock_bash.return_value = "unicode.txt:1:Hello 世界" + + result = await search_grep("世界", "*.txt") + assert "世界" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_line_anchors(self, mock_bash): + """Test line anchor patterns (^ and $).""" + mock_bash.side_effect = [ + "anchors.txt:1:start of line", + "anchors.txt:3:line with end" + ] + + # Search for lines starting with specific text + result = await search_grep("^start", "*.txt") + assert "start of line" in result + + # Search for lines ending with specific text + result = await search_grep("end$", "*.txt") + assert "line with end" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_case_insensitive_pattern(self, mock_bash): + """Test case insensitive regex patterns.""" + mock_bash.return_value = "mixed_case.txt:1:TODO: fix this\nmixed_case.txt:2:todo: also this\nmixed_case.txt:3:ToDo: and this" + + # Case insensitive search + result = await search_grep("(?i)todo", "*.txt") + lines = result.strip().split('\n') if result.strip() else [] + assert len(lines) == 3 # Should match all three variants + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_complex_glob_patterns(self, mock_bash): + """Test various complex glob patterns.""" + mock_bash.return_value = "src/main.py:1:import sys\nsrc/utils.py:1:import os" + + # Test recursive search in src directory + result = await search_grep("import", "src/**/*.py") + assert "src/main.py" in result + assert "src/utils.py" in result + assert "test_main.py" not in result \ No newline at end of file From 4059489d7af208e74cbd023b5db8ccd2b31c81a6 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 06:22:04 +0200 Subject: [PATCH 05/13] update tool calling APIs --- .../jupyter_ai/litellm_utils/__init__.py | 1 + .../jupyter_ai/litellm_utils/run_tools.py | 50 ++++++++++++++++++ .../litellm_utils/streaming_utils.py | 2 +- .../jupyter_ai/litellm_utils/toolcall_list.py | 4 ++ .../jupyter_ai/personas/base_persona.py | 51 +++++-------------- .../personas/jupyternaut/jupyternaut.py | 4 +- 6 files changed, 70 insertions(+), 42 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py index 787493764..ff1c7d8c3 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py @@ -1,2 +1,3 @@ from .toolcall_list import * from .streaming_utils import * +from .run_tools import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py new file mode 100644 index 000000000..2b148b505 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py @@ -0,0 +1,50 @@ +import asyncio +from pydantic import BaseModel +from .toolcall_list import ToolCallList +from ..tools import Toolkit + + +class ToolCallOutput(BaseModel): + tool_call_id: str + role: str = "tool" + name: str + content: str + + +async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]: + """ + Runs the tools specified in the list of tool calls returned by + `self.stream_message()`. + + Returns `list[ToolCallOutput]`. The outputs should be appended directly to + the message history on the next request made to the LLM. + """ + tool_calls = tool_call_list.resolve() + if not len(tool_calls): + return [] + + tool_outputs: list[dict] = [] + for tool_call in tool_calls: + # Get tool definition from the correct toolkit + # TODO: validation? + tool_name = tool_call.function.name + tool_defn = toolkit.get_tool_unsafe(tool_name) + + # Run tool and store its output + try: + output = tool_defn.callable(**tool_call.function.arguments) + if asyncio.iscoroutine(output): + output = await output + except Exception as e: + output = str(e) + + # Store the tool output in a dictionary accepted by LiteLLM + output_dict = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": output, + } + tool_outputs.append(output_dict) + + return tool_outputs diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py index febe3f7f2..7251c88ed 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py @@ -7,7 +7,7 @@ class StreamResult(BaseModel): ID of the new message. """ - tool_calls: ToolCallList + tool_call_list: ToolCallList """ Tool calls requested by the LLM in its streamed response. """ diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py index 654939ebb..e7094e4f9 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py @@ -170,4 +170,8 @@ def to_json(self) -> list[dict[str, Any]]: return [ model.model_dump() for model in self._aggregate ] + + + def __len__(self) -> int: + return len(self._aggregate) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 4e550cfbd..a33ad69f3 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -17,7 +17,7 @@ from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall +from ..litellm_utils import ToolCallList, StreamResult, run_tools, ToolCallOutput # Import toolkits from ..tools.default_toolkit import DEFAULT_TOOLKIT @@ -255,7 +255,7 @@ async def stream_message( """ stream_id: Optional[str] = None stream_interrupted = False - tool_calls = ToolCallList() + tool_call_list = ToolCallList() try: self.awareness.set_local_state_field("isWriting", True) @@ -319,7 +319,7 @@ async def stream_message( append=True, ) if toolcalls_delta: - tool_calls += toolcalls_delta + tool_call_list += toolcalls_delta except Exception as e: self.log.error( @@ -348,15 +348,13 @@ async def stream_message( return None # TODO: determine where this should live - resolved_toolcalls = tool_calls.resolve() - if len(resolved_toolcalls): - count = len(resolved_toolcalls) - names = sorted([tc.function.name for tc in resolved_toolcalls]) - self.log.info(f"AI response triggered {count} tool calls: {names}") + count = len(tool_call_list) + if count > 0: + self.log.info(f"AI response triggered {count} tool calls.") return StreamResult( id=stream_id, - tool_calls=tool_calls + tool_call_list=tool_call_list ) @@ -520,38 +518,13 @@ def get_tools(self, model_id: str) -> list[dict]: return tool_descriptions - async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: + async def run_tools(self, tool_call_list: ToolCallList) -> list[ToolCallOutput]: """ - Runs the tools specified in the list of tool calls returned by - `self.stream_message()`. Returns a list of dictionaries - `toolcall_outputs: list[dict]`, which should be appended directly to the - message history on the next invocation of the LLM. + Runs the tools specified in a given tool call list using the default + toolkit. """ - if not len(tools): - return [] - - tool_outputs: list[dict] = [] - for tool_call in tools: - # Get tool definition from the correct toolkit - # TODO: validation? - tool_name = tool_call.function.name - tool_defn = DEFAULT_TOOLKIT.get_tool_unsafe(tool_name) - - # Run tool and store its output - output = tool_defn.callable(**tool_call.function.arguments) - if asyncio.iscoroutine(output): - output = await output - - # Store the tool output in a dictionary accepted by LiteLLM - output_dict = { - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_call.function.name, - "content": output, - } - tool_outputs.append(output_dict) - - self.log.info(f"Ran {len(tools)} tool functions.") + tool_outputs = await run_tools(tool_call_list, toolkit=DEFAULT_TOOLKIT) + self.log.info(f"Ran {len(tool_outputs)} tool functions.") return tool_outputs diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 3c350c1f4..914372ddf 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -62,7 +62,7 @@ async def process_message(self, message: Message) -> None: # Append new reply to `messages` reply = self.ychat.get_message(result.id) - tool_calls_json = result.tool_calls.to_json() + tool_calls_json = result.tool_call_list.to_json() messages.append({ "role": "assistant", "content": reply.body, @@ -80,7 +80,7 @@ async def process_message(self, message: Message) -> None: ), append=True) # Run tools and append outputs to `messages` - tool_call_outputs = await self.run_tools(result.tool_calls.resolve()) + tool_call_outputs = await self.run_tools(result.tool_call_list) messages.extend(tool_call_outputs) # Add tool call outputs to YChat (not synced with `messages`) From 056676eeb9f13b406e62ceabb3b54074bbb8c7c1 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 06:50:30 +0200 Subject: [PATCH 06/13] improve bash tool reliability, drop jupyter_ai_tools for now --- .../jupyter_ai/personas/base_persona.py | 2 +- .../jupyter_ai/tools/default_toolkit.py | 54 +++++++++++++++++-- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index a33ad69f3..37271e1ba 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -477,7 +477,7 @@ def get_tools(self, model_id: str) -> list[dict]: environment. These may include: - The default set of tool functions in Jupyter AI, defined in the - `jupyter_ai_tools` package. + the default toolkit from `jupyter_ai.tools`. - (TODO) Tools provided by MCP server configuration, if any. diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py index 79c8a7675..d850775af 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -1,7 +1,9 @@ -from .models import Tool, Toolkit -from jupyter_ai_tools.toolkits.code_execution import bash - +import asyncio import pathlib +import shlex +from typing import Optional + +from .models import Tool, Toolkit def read(file_path: str, offset: int, limit: int) -> str: @@ -247,6 +249,52 @@ async def search_grep(pattern: str, include: str = "*") -> str: raise RuntimeError(f"Ripgrep search failed: {str(e)}") from e +async def bash(command: str, timeout: Optional[int] = None) -> str: + """Executes a bash command and returns the result + + Args: + command: The bash command to execute + timeout: Optional timeout in seconds + + Returns: + The command output (stdout and stderr combined) + """ + # coerce `timeout` to the correct type. sometimes LLMs pass this as a string + if isinstance(timeout, str): + timeout = int(timeout) + + proc = await asyncio.create_subprocess_exec( + *shlex.split(command), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout) + stdout = stdout.decode("utf-8") + stderr = stderr.decode("utf-8") + + if proc.returncode != 0: + info = f"Command returned non-zero exit code {proc.returncode}. This usually indicates an error." + info += "\n\n" + fr"Original command: {command}" + if not (stdout or stderr): + info += "\n\nNo further information was given in stdout or stderr." + return info + if stdout: + info += f"stdout:\n\n```\n{stdout}\n```\n\n" + if stderr: + info += f"stderr:\n\n```\n{stderr}\n```\n\n" + return info + + if stdout: + return stdout + return "Command executed successfully with exit code 0. No stdout/stderr was returned." + + except asyncio.TimeoutError: + proc.kill() + return f"Command timed out after {timeout} seconds" + + DEFAULT_TOOLKIT = Toolkit(name="jupyter-ai-default-toolkit") DEFAULT_TOOLKIT.add_tool(Tool(callable=bash)) DEFAULT_TOOLKIT.add_tool(Tool(callable=read)) From 1311989613d03f982a0f1f0fe167339d0a5c76f3 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 14:58:00 +0200 Subject: [PATCH 07/13] add jai-tool-call web component to show tool calls & outputs --- .../jupyter_ai/personas/base_persona.py | 2 +- .../personas/jupyternaut/jupyternaut.py | 83 +++++++++--- packages/jupyter-ai/package.json | 1 + packages/jupyter-ai/src/index.ts | 2 + .../jupyter-ai/src/web-components/index.ts | 2 + .../src/web-components/jai-tool-call.tsx | 121 ++++++++++++++++++ .../jupyter-ai/src/web-components/plugin.ts | 65 ++++++++++ yarn.lock | 20 +++ 8 files changed, 278 insertions(+), 18 deletions(-) create mode 100644 packages/jupyter-ai/src/web-components/index.ts create mode 100644 packages/jupyter-ai/src/web-components/jai-tool-call.tsx create mode 100644 packages/jupyter-ai/src/web-components/plugin.ts diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 37271e1ba..a6abef496 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -518,7 +518,7 @@ def get_tools(self, model_id: str) -> list[dict]: return tool_descriptions - async def run_tools(self, tool_call_list: ToolCallList) -> list[ToolCallOutput]: + async def run_tools(self, tool_call_list: ToolCallList) -> list[dict]: """ Runs the tools specified in a given tool call list using the default toolkit. diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 914372ddf..1e9669c6d 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -5,6 +5,7 @@ from jupyterlab_chat.models import Message from litellm import acompletion +from ...litellm_utils import StreamResult, ToolCallOutput from ..base_persona import BasePersona, PersonaDefaults from ..persona_manager import SYSTEM_USERNAME from .prompt_template import ( @@ -69,30 +70,78 @@ async def process_message(self, message: Message) -> None: "tool_calls": tool_calls_json }) - # Show tool call requests to YChat (not synced with `messages`) - if len(tool_calls_json): - self.ychat.update_message(Message( - id=result.id, - body=f"\n\n```\n{json.dumps(tool_calls_json, indent=2)}\n```\n", - sender=self.id, - time=time.time(), - raw_time=False - ), append=True) + # Render tool calls in new message + if len(result.tool_call_list): + self.render_tool_calls(result) # Run tools and append outputs to `messages` tool_call_outputs = await self.run_tools(result.tool_call_list) messages.extend(tool_call_outputs) - # Add tool call outputs to YChat (not synced with `messages`) + # Render tool call outputs in new message if tool_call_outputs: - self.ychat.update_message(Message( - id=result.id, - body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", - sender=self.id, - time=time.time(), - raw_time=False - ), append=True) + self.render_tool_call_outputs( + message_id=result.id, + tool_call_outputs=tool_call_outputs + ) + def render_tool_calls(self, stream_result: StreamResult): + """ + Renders tool calls by appending the tool calls to a message. + """ + message_id = stream_result.id + tool_call_list = stream_result.tool_call_list + + for tool_call in tool_call_list.resolve(): + id = tool_call.id + index = tool_call.index + type_val = tool_call.type + function = tool_call.function.model_dump_json() + # We have to HTML-escape double quotes in the JSON string. + function = function.replace('"', """) + + self.ychat.update_message(Message( + id=message_id, + body=f'\n\n\n', + sender=self.id, + time=time.time(), + raw_time=False + ), append=True) + + + def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]): + # TODO + # self.ychat.update_message(Message( + # id=message_id, + # body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", + # sender=self.id, + # time=time.time(), + # raw_time=False + # ), append=True) + + # Updates the content of the last message directly + message = self.ychat.get_message(message_id) + body = message.body + for output in tool_call_outputs: + if not output['content']: + output['content'] = "" + output = ToolCallOutput(**output) + tool_id = output.tool_call_id + tool_output = output.model_dump_json() + tool_output = tool_output.replace('"', '"') + body = body.replace( + f'; + }; + index: number; + output?: { + tool_call_id: string; + role: string; + name: string; + content: string | null; + }; +}; + +export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { + const [expanded, setExpanded] = useState(false); + console.log({ + output: props.output + }); + const toolComplete = !!(props.output && Object.keys(props.output).length > 0); + const hasOutput = !!(toolComplete && props.output?.content?.length); + + const handleExpandClick = () => { + setExpanded(!expanded); + }; + + const statusIcon: JSX.Element = toolComplete ? ( + + ) : ( + + ); + + const statusText: JSX.Element = ( + + {toolComplete ? 'Ran' : 'Running'}{' '} + + {props.function.name} + {' '} + tool + {toolComplete ? '.' : '...'} + + ); + + const toolArgsJson = useMemo( + () => JSON.stringify(props.function.arguments, null, 2), + [props.function.arguments] + ); + + const toolArgsSection: JSX.Element | null = + toolArgsJson === '{}' ? null : ( + + + Tool arguments + +
+          {toolArgsJson}
+        
+
+ ); + + const toolOutputSection: JSX.Element | null = hasOutput ? ( + + + Tool output + +
{props.output?.content}
+
+ ) : null; + + if (!props.id || !props.type || !props.function) { + return null; + } + + return ( + + + {statusIcon} + {statusText} + + + + + + + + + {toolArgsSection} + {toolOutputSection} + + + + ); +} diff --git a/packages/jupyter-ai/src/web-components/plugin.ts b/packages/jupyter-ai/src/web-components/plugin.ts new file mode 100644 index 000000000..ac0b0674a --- /dev/null +++ b/packages/jupyter-ai/src/web-components/plugin.ts @@ -0,0 +1,65 @@ +import { + JupyterFrontEnd, + JupyterFrontEndPlugin +} from '@jupyterlab/application'; +import r2wc from '@r2wc/react-to-web-component'; + +import { JaiToolCall } from './jai-tool-call'; +import { ISanitizer, Sanitizer } from '@jupyterlab/apputils'; +import { IRenderMime } from '@jupyterlab/rendermime'; + +/** + * Plugin that registers custom web components for usage in AI responses. + */ +export const webComponentsPlugin: JupyterFrontEndPlugin = + { + id: '@jupyter-ai/core:web-components', + autoStart: true, + provides: ISanitizer, + activate: (app: JupyterFrontEnd) => { + // Define the JaiToolCall web component + // ['id', 'type', 'function', 'index', 'output'] + const JaiToolCallWebComponent = r2wc(JaiToolCall, { + props: { + id: 'string', + type: 'string', + function: 'json', + index: 'number', + output: 'json' + } + }); + + // Register the web component + customElements.define('jai-tool-call', JaiToolCallWebComponent); + console.log("Registered custom 'jai-tool-call' web component."); + + // Finally, override the default Rendermime sanitizer to allow custom web + // components in the output. + class CustomSanitizer + extends Sanitizer + implements IRenderMime.ISanitizer + { + sanitize( + dirty: string, + customOptions: IRenderMime.ISanitizerOptions + ): string { + const options: IRenderMime.ISanitizerOptions = { + // default sanitizer options + ...(this as any)._options, + // custom sanitizer options (variable per call) + ...customOptions + }; + + return super.sanitize(dirty, { + ...options, + allowedTags: [...(options?.allowedTags ?? []), 'jai-tool-call'], + allowedAttributes: { + ...options?.allowedAttributes, + 'jai-tool-call': ['id', 'type', 'function', 'index', 'output'] + } + }); + } + } + return new CustomSanitizer(); + } + }; diff --git a/yarn.lock b/yarn.lock index 173eef6b3..d2ff65f5a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2256,6 +2256,7 @@ __metadata: "@lumino/widgets": ^2.3.2 "@mui/icons-material": ^5.11.0 "@mui/material": ^5.11.0 + "@r2wc/react-to-web-component": ^2.0.4 "@stylistic/eslint-plugin": ^3.0.1 "@types/jest": ^29 "@types/react-dom": ^18.2.0 @@ -4826,6 +4827,25 @@ __metadata: languageName: node linkType: hard +"@r2wc/core@npm:^1.0.0": + version: 1.2.0 + resolution: "@r2wc/core@npm:1.2.0" + checksum: e0dc23e8fd1f0d96193b67f5eb04b74b25b9f4609778e6ea2427c565eb590f458553cad307a2fdb3fc4614f6a576d7701b9bacf11775958bc560cc3b3b5aaae7 + languageName: node + linkType: hard + +"@r2wc/react-to-web-component@npm:^2.0.4": + version: 2.0.4 + resolution: "@r2wc/react-to-web-component@npm:2.0.4" + dependencies: + "@r2wc/core": ^1.0.0 + peerDependencies: + react: ^18.0.0 || ^19.0.0 + react-dom: ^18.0.0 || ^19.0.0 + checksum: 7b140ffd612173a30d74717d18efcf554774ef0ed0fe72f207ec21df707685ef5f4c34521e6840041665550c6461171dc32f12835f35beb1788ccac0c66c0e5c + languageName: node + linkType: hard + "@rjsf/core@npm:^5.13.4": version: 5.17.0 resolution: "@rjsf/core@npm:5.17.0" From 5223fa644b4c6d7eab78f341b9443cada9e36273 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Sat, 30 Aug 2025 15:23:27 +0200 Subject: [PATCH 08/13] move litellm_utils => litellm_lib --- .../{litellm_utils => litellm_lib}/__init__.py | 4 ++-- .../{litellm_utils => litellm_lib}/run_tools.py | 8 -------- .../test_toolcall_list.py | 0 .../toolcall_list.py | 15 +++++++++------ .../streaming_utils.py => litellm_lib/types.py} | 8 ++++++++ .../jupyter_ai/personas/base_persona.py | 5 +---- .../personas/jupyternaut/jupyternaut.py | 13 +------------ 7 files changed, 21 insertions(+), 32 deletions(-) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/__init__.py (63%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/run_tools.py (90%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/test_toolcall_list.py (100%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/toolcall_list.py (92%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils/streaming_utils.py => litellm_lib/types.py} (64%) diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py similarity index 63% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py index ff1c7d8c3..edc2e51cc 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py @@ -1,3 +1,3 @@ -from .toolcall_list import * -from .streaming_utils import * from .run_tools import * +from .toolcall_list import * +from .types import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py similarity index 90% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py index 2b148b505..0bb815d20 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py @@ -1,16 +1,8 @@ import asyncio -from pydantic import BaseModel from .toolcall_list import ToolCallList from ..tools import Toolkit -class ToolCallOutput(BaseModel): - tool_call_id: str - role: str = "tool" - name: str - content: str - - async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]: """ Runs the tools specified in the list of tool calls returned by diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py similarity index 100% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py similarity index 92% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py index e7094e4f9..311e20a6a 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py @@ -29,10 +29,9 @@ class ResolvedToolCall(BaseModel): `litellm.utils.ChatCompletionDeltaToolCall`. """ - id: str | None + id: str """ - The ID of the tool call. This should always be provided by LiteLLM, this - type is left optional as we do not use this attribute. + The ID of the tool call. """ type: str @@ -62,7 +61,7 @@ class ToolCallList(BaseModel): is used to aggregate the tool call deltas yielded from a LiteLLM response stream and produce a list of tool calls. - After all tool call deltas are added, the `process()` method may be called + After all tool call deltas are added, the `resolve()` method may be called to return a list of resolved tool calls. Example usage: @@ -75,7 +74,7 @@ class ToolCallList(BaseModel): tool_call_delta = chunk.choices[0].delta.tool_calls tool_call_list += tool_call_delta - tool_call_list.resolve() + tool_calls = tool_call_list.resolve() ``` """ @@ -128,7 +127,11 @@ def __add__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallL def resolve(self) -> list[ResolvedToolCall]: """ - Resolve the aggregated tool call delta lists into a list of tool calls. + Returns the aggregated tool calls as `list[ResolvedToolCall]`. + + Raises an exception if any function arguments could not be parsed from + JSON into a dictionary. This method should only be called after the + stream completed without errors. """ resolved_toolcalls: list[ResolvedToolCall] = [] for i, raw_toolcall in enumerate(self._aggregate): diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py similarity index 64% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/types.py index 7251c88ed..b901711c3 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pydantic import BaseModel from .toolcall_list import ToolCallList @@ -11,3 +12,10 @@ class StreamResult(BaseModel): """ Tool calls requested by the LLM in its streamed response. """ + +class ToolCallOutput(BaseModel): + tool_call_id: str + role: str = "tool" + name: str + content: str + diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index a6abef496..99763f409 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -17,15 +17,12 @@ from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_utils import ToolCallList, StreamResult, run_tools, ToolCallOutput - -# Import toolkits +from ..litellm_lib import ToolCallList, StreamResult, run_tools from ..tools.default_toolkit import DEFAULT_TOOLKIT if TYPE_CHECKING: from collections.abc import AsyncIterator from .persona_manager import PersonaManager - from ..tools import Toolkit class PersonaDefaults(BaseModel): """ diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 1e9669c6d..7a6261a67 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,11 +1,10 @@ from typing import Any, Optional import time -import json from jupyterlab_chat.models import Message from litellm import acompletion -from ...litellm_utils import StreamResult, ToolCallOutput +from ...litellm_lib import StreamResult, ToolCallOutput from ..base_persona import BasePersona, PersonaDefaults from ..persona_manager import SYSTEM_USERNAME from .prompt_template import ( @@ -110,15 +109,6 @@ def render_tool_calls(self, stream_result: StreamResult): def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]): - # TODO - # self.ychat.update_message(Message( - # id=message_id, - # body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", - # sender=self.id, - # time=time.time(), - # raw_time=False - # ), append=True) - # Updates the content of the last message directly message = self.ychat.get_message(message_id) body = message.body @@ -134,7 +124,6 @@ def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict f' Date: Thu, 18 Sep 2025 10:22:44 -0700 Subject: [PATCH 09/13] migrate all agent logic to module using pocketflow --- .../jupyter_ai/default_flow/__init__.py | 1 + .../jupyter_ai/default_flow/default_flow.py | 335 ++++++++++++++++++ .../jupyter_ai/litellm_lib/run_tools.py | 22 +- .../jupyter_ai/litellm_lib/toolcall_list.py | 81 ++++- .../jupyter_ai/litellm_lib/types.py | 42 ++- .../jupyter_ai/personas/__init__.py | 4 +- .../jupyter_ai/personas/base_persona.py | 189 +--------- .../personas/jupyternaut/jupyternaut.py | 170 ++------- .../jupyter-ai/jupyter_ai/tools/models.py | 22 ++ 9 files changed, 509 insertions(+), 357 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/default_flow/__init__.py create mode 100644 packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py b/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py new file mode 100644 index 000000000..849952ada --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py @@ -0,0 +1 @@ +from .default_flow import * \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py new file mode 100644 index 000000000..9fbdda70d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -0,0 +1,335 @@ +from pocketflow import AsyncNode, AsyncFlow +from jupyterlab_chat.models import Message, NewMessage +from jupyterlab_chat.ychat import YChat +from typing import Any, Optional, Tuple, TypedDict +from jinja2 import Template +from litellm import acompletion, ModelResponseStream +import time +import logging + +from ..litellm_lib import ToolCallList, run_tools, LitellmToolCallOutput +from ..tools import Toolkit +from ..personas import SYSTEM_USERNAME, PersonaAwareness + +DEFAULT_RESPONSE_TEMPLATE = """ +{{ content }} +{{ tool_call_ui_elements }} +""".strip() + +class DefaultFlowParams(TypedDict): + """ + Parameters expected by the default flow provided by Jupyter AI. + """ + + model_id: str + + ychat: YChat + + awareness: PersonaAwareness + + persona_id: str + + logger: logging.Logger + + model_args: dict[str, Any] | None + """ + Custom keyword arguments forwarded to `litellm.acompletion()`. Defaults to + `{}` if unset. + """ + + system_prompt: Optional[str] + """ + System prompt that will be used as the first message in the list of messages + sent to the language model. Unused if unset. + """ + + response_template: Template | None + """ + Jinja2 template used to template the response. If one is not given, + `DEFAULT_RESPONSE_TEMPLATE` is used. + + It should take `content: str` and `tool_call_ui_elements: str` as format arguments. + """ + + toolkit: Toolkit | None + """ + Toolkit of tools. Unused if unset. + """ + + history_size: int | None + """ + Number of messages preceding the message triggering this flow to include + in the prompt as context. Defaults to 2 if unset. + """ + +class JaiAsyncNode(AsyncNode): + """ + An AsyncNode with custom properties & helper methods used exclusively in the + Jupyter AI extension. + """ + + @property + def model_id(self) -> str: + return self.params["model_id"] + + @property + def ychat(self) -> YChat: + return self.params["ychat"] + + @property + def awareness(self) -> PersonaAwareness: + return self.params["awareness"] + + @property + def persona_id(self) -> str: + return self.params["persona_id"] + + @property + def model_args(self) -> dict[str, Any]: + return self.params.get("model_args", {}) + + @property + def system_prompt(self) -> Optional[str]: + return self.params.get("system_prompt") + + @property + def response_template(self) -> Template: + template = self.params.get("response_template") + # If response template was unspecified, use the default response + # template. + if not template: + template = Template(DEFAULT_RESPONSE_TEMPLATE) + + return template + + @property + def toolkit(self) -> Optional[Toolkit]: + return self.params.get("toolkit") + + @property + def history_size(self) -> int: + return self.params.get("history_size", 2) + + @property + def log(self) -> logging.Logger: + return self.params.get("logger") + + +class RootNode(JaiAsyncNode): + """ + The root node of the default flow provided by Jupyter AI. + """ + + async def prep_async(self, shared): + self.log.info("Running RootNode.prep_async()") + # Initialize `shared.litellm_messages` using the YChat message history + # if it is unset. + if not ('litellm_messages' in shared and isinstance(shared['litellm_messages'], list) and len(shared['litellm_messages']) > 0): + shared['litellm_messages'] = self._init_litellm_messages() + + # Return `shared.litellm_messages`. This is passed as the `prep_res` + # argument to `exec_async()`. + return shared['litellm_messages'] + + + def _init_litellm_messages(self) -> list[dict]: + # Store the invoking message & the previous `params.history_size` messages + # as `ychat_messages`. + # TODO: ensure the invoking message is in this list + all_messages = self.ychat.get_messages() + ychat_messages: list[Message] = all_messages[-self.history_size - 1:] + + # Coerce each `Message` in `ychat_messages` to a dictionary following + # the OpenAI spec, and store it as `litellm_messages`. + litellm_messages: list[dict[str, Any]] = [] + for msg in ychat_messages: + role = ( + "assistant" + if msg.sender.startswith("jupyter-ai-personas::") + else "system" if msg.sender == SYSTEM_USERNAME else "user" + ) + litellm_messages.append({"role": role, "content": msg.body}) + + # Insert system message as a dictionary if present. + if self.system_prompt: + system_litellm_message = { + "role": "system", + "content": self.system_prompt + } + litellm_messages = [system_litellm_message, *litellm_messages] + + # Return `litellm_messages` + return litellm_messages + + + async def exec_async(self, prep_res: list[dict]): + self.log.info("Running RootNode.exec_async()") + # Gather arguments and start a reply stream via LiteLLM + reply_stream = await acompletion( + **self.model_args, + model=self.model_id, + messages=prep_res, + tools=self.toolkit.to_json(), + stream=True, + ) + + # Iterate over reply stream + content = "" + tool_calls = ToolCallList() + stream_id: str | None = None + async for chunk in reply_stream: + assert isinstance(chunk, ModelResponseStream) + delta = chunk.choices[0].delta + content_delta = delta.content + toolcalls_delta = delta.tool_calls + + # Continue early if an empty chunk was emitted. + # This sometimes happens with LiteLLM. + if not (content_delta or toolcalls_delta): + continue + + # Aggregate the content and tool calls from the deltas + if content_delta: + content += content_delta + if toolcalls_delta: + tool_calls += toolcalls_delta + + # Create a new message if one does not yet exist + if not stream_id: + stream_id = self.ychat.add_message(NewMessage( + sender=self.persona_id, + body="" + )) + assert stream_id + + # Update the reply + message_body = self.response_template.render({ + "content": content, + "tool_call_ui_elements": tool_calls.render() + }) + self.log.error(message_body) + self.ychat.update_message( + Message( + id=stream_id, + body=message_body, + time=time.time(), + sender=self.persona_id, + raw_time=False, + ) + ) + + # Return message_id, content, and tool calls + return stream_id, content, tool_calls + + async def post_async(self, shared, prep_res, exec_res: Tuple[str, str, ToolCallList]): + self.log.info("Running RootNode.post_async()") + # Assert that `shared['litellm_messages']` is of the correct type, and + # that any tool calls returned are complete. + message_id, content, tool_calls = exec_res + assert 'litellm_messages' in shared and isinstance(shared['litellm_messages'], list) + assert tool_calls.complete + + # Add AI response to `shared['litellm_messages']`, including tool calls + new_litellm_message = { + "role": "assistant", + "content": content + } + if len(tool_calls): + new_litellm_message['tool_calls'] = tool_calls.as_litellm_tool_calls() + shared['litellm_messages'].append(new_litellm_message) + + # Add message ID to `shared['prev_message_id']` + shared['prev_message_id'] = message_id + + # Add message content to `shared['prev_message_content]` + shared['prev_message_content'] = content + + # Add tool calls to `shared['next_tool_calls']` + shared['next_tool_calls'] = tool_calls + + # Trigger `ToolExecutorNode` if tools were called. + if len(tool_calls): + return "execute-tools" + return 'finish' + +class ToolExecutorNode(JaiAsyncNode): + """ + Node responsible for executing tool calls in the default flow. + """ + + + async def prep_async(self, shared): + self.log.info("Running ToolExecutorNode.prep_async()") + # Extract `shared['next_tool_calls']` and the ID of the last message + assert 'next_tool_calls' in shared and isinstance(shared['next_tool_calls'], ToolCallList) + assert 'prev_message_id' in shared and isinstance(shared['prev_message_id'], str) + + # Return list of tool calls as a list of dictionaries + return shared['prev_message_id'], shared['next_tool_calls'] + + async def exec_async(self, prep_res: Tuple[str, ToolCallList]) -> list[LitellmToolCallOutput]: + self.log.info("Running ToolExecutorNode.exec_async()") + message_id, tool_calls = prep_res + + # TODO: Run 1 tool at a time? + outputs = await run_tools(tool_calls, self.toolkit) + + for output in outputs: + self.log.error(output) + return outputs + + async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: list[LitellmToolCallOutput]): + self.log.info("Running ToolExecutorNode.post_async()") + + # Update last message to include outputs + prev_message_id = shared['prev_message_id'] + prev_message_content = shared['prev_message_content'] + tool_calls: ToolCallList = shared['next_tool_calls'] + message_body = self.response_template.render({ + "content": prev_message_content, + "tool_call_ui_elements": tool_calls.render( + outputs=exec_res + ) + }) + self.ychat.update_message( + Message( + id=prev_message_id, + body=message_body, + time=time.time(), + sender=self.persona_id, + raw_time=False, + ) + ) + self.log.error(message_body) + + # Add tool outputs to `shared['litellm_messages']` + shared['litellm_messages'].extend(exec_res) + for msg in shared['litellm_messages']: + self.log.error(msg) + + # Delete shared state that is now stale + del shared['prev_message_id'] + del shared['prev_message_content'] + del shared['next_tool_calls'] + # This node will automatically return to `RootNode` after execution. + +async def run_default_flow(params: DefaultFlowParams): + # Initialize nodes + root_node = RootNode() + tool_executor_node = ToolExecutorNode() + + # Define state transitions + ## Flow to ToolExecutorNode if tool calls were dispatched + root_node - "execute-tools" >> tool_executor_node + ## Always flow back to RootNode after running tools + tool_executor_node >> root_node + ## End the flow if no tool calls were dispatched + root_node - "finish" >> AsyncNode() + + # Initialize flow and set its parameters + flow = AsyncFlow(start=root_node) + flow.set_params(params) + + # Finally, run the async node + await flow.run_async({}) + diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py index 0bb815d20..6ccb22d71 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py @@ -1,21 +1,29 @@ +from __future__ import annotations +from typing import TYPE_CHECKING import asyncio -from .toolcall_list import ToolCallList -from ..tools import Toolkit +if TYPE_CHECKING: + from ..tools import Toolkit + from .toolcall_list import ToolCallList + from .types import LitellmToolCallOutput -async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]: + +async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[LitellmToolCallOutput]: """ Runs the tools specified in the list of tool calls returned by `self.stream_message()`. - Returns `list[ToolCallOutput]`. The outputs should be appended directly to - the message history on the next request made to the LLM. + Returns `list[LitellmToolCallOutput]`, a list of output dictionaries of the + type expected by LiteLLM. + + Each output in the list should be appended directly to the message history + on the next request made to the LLM. """ tool_calls = tool_call_list.resolve() if not len(tool_calls): return [] - tool_outputs: list[dict] = [] + tool_outputs: list[LitellmToolCallOutput] = [] for tool_call in tool_calls: # Get tool definition from the correct toolkit # TODO: validation? @@ -31,7 +39,7 @@ async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict output = str(e) # Store the tool output in a dictionary accepted by LiteLLM - output_dict = { + output_dict: LitellmToolCallOutput = { "tool_call_id": tool_call.id, "role": "tool", "name": tool_call.function.name, diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py index 311e20a6a..563f06746 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py @@ -2,6 +2,8 @@ import json from pydantic import BaseModel from typing import Any +from .types import LitellmToolCall, LitellmToolCallOutput, JaiToolCallProps +from jinja2 import Template class ResolvedFunction(BaseModel): """ @@ -54,6 +56,13 @@ class ResolvedToolCall(BaseModel): This is usually 0 unless the LLM supports parallel tool calling. """ +JAI_TOOL_CALL_TEMPLATE = Template(""" +{% for props in props_list %} + + +{% endfor %} +""".strip()) + class ToolCallList(BaseModel): """ A helper object that defines a custom `__iadd__()` method which accepts a @@ -165,16 +174,80 @@ def resolve(self) -> list[ResolvedToolCall]: return resolved_toolcalls - def to_json(self) -> list[dict[str, Any]]: + @property + def complete(self) -> bool: + for i, tool_call in enumerate(self._aggregate): + if tool_call.index != i: + return False + if not tool_call.function: + return False + if not tool_call.function.name: + return False + if not tool_call.type: + return False + if not tool_call.function.arguments: + return False + try: + json.loads(tool_call.function.arguments) + except Exception: + return False + + return True + + def as_litellm_tool_calls(self) -> list[LitellmToolCall]: """ - Returns the list of tool calls as a Python dictionary that can be - JSON-serialized. + Returns the current list of tool calls as a list of dictionaries. + + This should be set in the `tool_calls` key in the dictionary of the + LiteLLM assistant message responsible for dispatching these tool calls. """ return [ model.model_dump() for model in self._aggregate ] - + def render(self, outputs: list[LitellmToolCallOutput] | None = None) -> str: + """ + Renders this tool call list as a list of `` elements to + be shown in the chat. + """ + # Initialize list of props to render into tool call UI elements + props_list: list[JaiToolCallProps] = [] + + # Index all outputs if passed + outputs_by_id: dict[str, LitellmToolCallOutput] | None = None + if outputs: + outputs_by_id = {} + for output in outputs: + outputs_by_id[output['tool_call_id']] = output + + for tool_call in self._aggregate: + # Build the props for each tool call UI element + props: JaiToolCallProps = { + 'id': tool_call.id, + 'index': tool_call.index, + 'type': tool_call.type, + 'function_name': tool_call.function.name, + 'function_args': tool_call.function.arguments, + } + + # Add the output if present + if outputs_by_id and tool_call.id in outputs_by_id: + output = outputs_by_id[tool_call.id] + # Make sure to manually convert the dictionary to a JSON string + # first. Without doing this, Jinja2 will convert a dictionary to + # JSON using single quotes instead of double quotes, which + # cannot be parsed by the frontend. + output = json.dumps(output) + props['output'] = output + + props_list.append(props) + + # Render the tool call UI elements using the Jinja2 template and return + return JAI_TOOL_CALL_TEMPLATE.render({ + "props_list": props_list + }) + + def __len__(self) -> int: return len(self._aggregate) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py index b901711c3..314c54990 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py @@ -1,21 +1,39 @@ from __future__ import annotations -from pydantic import BaseModel -from .toolcall_list import ToolCallList +from typing import TypedDict, Literal, Optional -class StreamResult(BaseModel): + +class LitellmToolCall(TypedDict): id: str - """ - ID of the new message. - """ + type: Literal['function'] + function: str + index: int - tool_call_list: ToolCallList - """ - Tool calls requested by the LLM in its streamed response. - """ +class LitellmMessage(TypedDict): + role: Literal['assistant', 'user', 'system'] + content: str + tool_calls: Optional[list[LitellmToolCall]] -class ToolCallOutput(BaseModel): +class LitellmToolCallOutput(TypedDict): tool_call_id: str - role: str = "tool" + role: Literal['tool'] name: str content: str +class JaiToolCallProps(TypedDict): + id: str | None + + type: Literal['function'] | None + + index: int | None + + function_name: str | None + + function_args: str | None + """ + The arguments to the function as a dictionary converted to a JSON string. + """ + + output: str | None + """ + The `LitellmToolCallOutput` as a JSON string. + """ \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/personas/__init__.py b/packages/jupyter-ai/jupyter_ai/personas/__init__.py index 6c0704f52..fb8dd1bfe 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/personas/__init__.py @@ -1,2 +1,2 @@ -from .base_persona import BasePersona, PersonaDefaults -from .persona_manager import PersonaManager +from .base_persona import * +from .persona_manager import * diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 99763f409..40fdb14db 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -4,24 +4,22 @@ from abc import ABC, ABCMeta, abstractmethod from dataclasses import asdict from logging import Logger -from time import time from typing import TYPE_CHECKING, Any, Optional from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User from jupyterlab_chat.ychat import YChat -from litellm import ModelResponseStream, supports_function_calling +from litellm import supports_function_calling from litellm.utils import function_to_dict from pydantic import BaseModel from traitlets import MetaHasTraits from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_lib import ToolCallList, StreamResult, run_tools +from ..litellm_lib import ToolCallList, run_tools from ..tools.default_toolkit import DEFAULT_TOOLKIT if TYPE_CHECKING: - from collections.abc import AsyncIterator from .persona_manager import PersonaManager class PersonaDefaults(BaseModel): @@ -234,127 +232,6 @@ def as_user_dict(self) -> dict[str, Any]: user = self.as_user() return asdict(user) - async def stream_message( - self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> StreamResult: - """ - Takes an async iterator, dubbed the 'reply stream', and streams it to a - new message by this persona in the YChat. The async iterator may yield - either strings or `litellm.ModelResponseStream` objects. Details: - - - Creates a new message upon receiving the first chunk from the reply - stream, then continuously updates it until the stream is closed. - - - Automatically manages its awareness state to show writing status. - - Returns a list of `ResolvedToolCall` objects. If this list is not empty, - the persona should run these tools. - """ - stream_id: Optional[str] = None - stream_interrupted = False - tool_call_list = ToolCallList() - try: - self.awareness.set_local_state_field("isWriting", True) - - async for chunk in reply_stream: - # Start the stream with an empty message on the initial reply. - # Bind the new message ID to `stream_id`. - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) - ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - assert stream_id - - # Compute `content_delta` and `tool_calls_delta` based on the - # type of object yielded by `reply_stream`. - if isinstance(chunk, ModelResponseStream): - delta = chunk.choices[0].delta - content_delta = delta.content - toolcalls_delta = delta.tool_calls - elif isinstance(chunk, str): - content_delta = chunk - toolcalls_delta = None - else: - raise Exception(f"Unrecognized type in stream_message(): {type(chunk)}") - - # LiteLLM streams always terminate with an empty chunk, so - # continue in this case. - if not (content_delta or toolcalls_delta): - continue - - # Terminate the stream if the user requested it. - if ( - stream_id - and stream_id in self.message_interrupted.keys() - and self.message_interrupted[stream_id].is_set() - ): - try: - # notify the model provider that streaming was interrupted - # (this is essential to allow the model to stop generating) - await reply_stream.athrow( # type:ignore[attr-defined] - GenerationInterrupted() - ) - except GenerationInterrupted: - # do not let the exception bubble up in case if - # the provider did not handle it - pass - stream_interrupted = True - break - - # Append `content_delta` to the existing message. - if content_delta: - self.ychat.update_message( - Message( - id=stream_id, - body=content_delta, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) - if toolcalls_delta: - tool_call_list += toolcalls_delta - - except Exception as e: - self.log.error( - f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." - ) - self.log.exception(e) - finally: - # Reset local state - self.awareness.set_local_state_field("isWriting", False) - self.message_interrupted.pop(stream_id, None) - - # If stream was interrupted, add a tombstone and return `[]`, - # indicating that no tools should be run afterwards. - if stream_id and stream_interrupted: - stream_tombstone = "\n\n(AI response stopped by user)" - self.ychat.update_message( - Message( - id=stream_id, - body=stream_tombstone, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) - return None - - # TODO: determine where this should live - count = len(tool_call_list) - if count > 0: - self.log.info(f"AI response triggered {count} tool calls.") - - return StreamResult( - id=stream_id, - tool_call_list=tool_call_list - ) - - def send_message(self, body: str) -> None: """ Sends a new message to the chat from this persona. @@ -464,68 +341,6 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]: self.log.error(f"Failed to resolve attachment {attachment_id}: {e}") return None - def get_tools(self, model_id: str) -> list[dict]: - """ - Returns the `tools` parameter which should be passed to - `litellm.acompletion()` for a given LiteLLM model ID. - - If the model does not support tool-calling, this method returns an empty - list. Otherwise, it returns the list of tools available in the current - environment. These may include: - - - The default set of tool functions in Jupyter AI, defined in the - the default toolkit from `jupyter_ai.tools`. - - - (TODO) Tools provided by MCP server configuration, if any. - - - (TODO) Web search. - - - (TODO) File search using vector store IDs. - - TODO: cache this - - TODO: Implement some permissions system so users can control what tools - are allowable. - - NOTE: The returned list is expected by LiteLLM to conform to the `tools` - parameter defintiion defined by the OpenAI API: - https://platform.openai.com/docs/guides/tools#available-tools - - NOTE: This API is a WIP and is very likely to change. - """ - # Return early if the model does not support tool calling - if not supports_function_calling(model=model_id): - return [] - - tool_descriptions = [] - - # Get all tools from the default toolkit and store their object descriptions - for tool in DEFAULT_TOOLKIT.get_tools(): - # Here, we are using a util function from LiteLLM to coerce - # each `Tool` struct into a tool description dictionary expected - # by LiteLLM. - desc = { - "type": "function", - "function": function_to_dict(tool.callable), - } - tool_descriptions.append(desc) - - # Finally, return the tool descriptions - self.log.info(tool_descriptions) - return tool_descriptions - - - async def run_tools(self, tool_call_list: ToolCallList) -> list[dict]: - """ - Runs the tools specified in a given tool call list using the default - toolkit. - """ - tool_outputs = await run_tools(tool_call_list, toolkit=DEFAULT_TOOLKIT) - self.log.info(f"Ran {len(tool_outputs)} tool functions.") - return tool_outputs - - - def shutdown(self) -> None: """ Shuts the persona down. This method should: diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 7a6261a67..c89934a51 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,16 +1,12 @@ -from typing import Any, Optional -import time - from jupyterlab_chat.models import Message -from litellm import acompletion -from ...litellm_lib import StreamResult, ToolCallOutput from ..base_persona import BasePersona, PersonaDefaults -from ..persona_manager import SYSTEM_USERNAME +from ...default_flow import run_default_flow, DefaultFlowParams from .prompt_template import ( JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) +from ...tools import DEFAULT_TOOLKIT class JupyternautPersona(BasePersona): @@ -31,6 +27,7 @@ def defaults(self): ) async def process_message(self, message: Message) -> None: + # Return early if no chat model is configured if not self.config_manager.chat_model: self.send_message( "No chat model is configured.\n\n" @@ -38,145 +35,28 @@ async def process_message(self, message: Message) -> None: ) return - model_id = self.config_manager.chat_model - - # `True` before the first LLM response is sent, `False` afterwards. - initial_response = True - # List of tool call outputs computed in the previous invocation. - tool_call_outputs: list[dict] = [] - - # Initialize list of messages, including history and context - messages: list[dict] = self.get_context_as_messages(model_id, message) - - # Loop until the AI is complete running all its tools. - while initial_response or len(tool_call_outputs): - # Stream message to the chat - response_aiter = await acompletion( - model=model_id, - messages=messages, - tools=self.get_tools(model_id), - stream=True, - ) - result = await self.stream_message(response_aiter) - initial_response = False - - # Append new reply to `messages` - reply = self.ychat.get_message(result.id) - tool_calls_json = result.tool_call_list.to_json() - messages.append({ - "role": "assistant", - "content": reply.body, - "tool_calls": tool_calls_json - }) - - # Render tool calls in new message - if len(result.tool_call_list): - self.render_tool_calls(result) - - # Run tools and append outputs to `messages` - tool_call_outputs = await self.run_tools(result.tool_call_list) - messages.extend(tool_call_outputs) - - # Render tool call outputs in new message - if tool_call_outputs: - self.render_tool_call_outputs( - message_id=result.id, - tool_call_outputs=tool_call_outputs - ) - - def render_tool_calls(self, stream_result: StreamResult): - """ - Renders tool calls by appending the tool calls to a message. - """ - message_id = stream_result.id - tool_call_list = stream_result.tool_call_list - - for tool_call in tool_call_list.resolve(): - id = tool_call.id - index = tool_call.index - type_val = tool_call.type - function = tool_call.function.model_dump_json() - # We have to HTML-escape double quotes in the JSON string. - function = function.replace('"', """) - - self.ychat.update_message(Message( - id=message_id, - body=f'\n\n\n', - sender=self.id, - time=time.time(), - raw_time=False - ), append=True) - - - def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]): - # Updates the content of the last message directly - message = self.ychat.get_message(message_id) - body = message.body - for output in tool_call_outputs: - if not output['content']: - output['content'] = "" - output = ToolCallOutput(**output) - tool_id = output.tool_call_id - tool_output = output.model_dump_json() - tool_output = tool_output.replace('"', '"') - body = body.replace( - f' list[dict[str, Any]]: - """ - Returns the current context, including attachments and recent messages, - as a list of messages accepted by `litellm.acompletion()`. - """ - system_msg_args = JupyternautSystemPromptArgs( - model_id=model_id, - persona_name=self.name, - context=self.process_attachments(message), - ).model_dump() - - system_msg = { - "role": "system", - "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), + # Build default flow params + system_prompt = self._build_system_prompt(message) + flow_params: DefaultFlowParams = { + "persona_id": self.id, + "model_id": self.config_manager.chat_model, + "model_args": self.config_manager.chat_model_args, + "ychat": self.ychat, + "awareness": self.awareness, + "system_prompt": system_prompt, + "toolkit": DEFAULT_TOOLKIT, + "logger": self.log, } - context_as_messages = [system_msg, *self._get_history_as_messages()] - return context_as_messages - - def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: - """ - Returns the current history as a list of messages accepted by - `litellm.acompletion()`. + # Run default agent flow + await run_default_flow(flow_params) - NOTE: You should usually call the public `get_context_as_messages()` - method instead. - """ - # TODO: consider bounding history based on message size (e.g. total - # char/token count) instead of message count. - all_messages = self.ychat.get_messages() - - # gather last k * 2 messages and return - start_idx = 0 if k is None else -2 * k - recent_messages: list[Message] = all_messages[start_idx:] - - history: list[dict[str, Any]] = [] - for msg in recent_messages: - role = ( - "assistant" - if msg.sender.startswith("jupyter-ai-personas::") - else "system" if msg.sender == SYSTEM_USERNAME else "user" - ) - history.append({"role": role, "content": msg.body}) - - return history + def _build_system_prompt(self, message: Message) -> str: + context = self.process_attachments(message) + format_args = JupyternautSystemPromptArgs( + persona_name=self.name, + model_id=self.config_manager.chat_model, + context=context, + ) + system_prompt = JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(format_args.model_dump()) + return system_prompt \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/models.py b/packages/jupyter-ai/jupyter_ai/tools/models.py index e547f0c15..72b9c69ef 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/models.py +++ b/packages/jupyter-ai/jupyter_ai/tools/models.py @@ -1,5 +1,6 @@ import re from typing import Callable, Optional +from litellm.utils import function_to_dict from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -215,3 +216,24 @@ def get_tools( toolset.add(tool) return toolset + + def to_json(self) -> list[dict]: + """ + Returns a list of tool descriptions in the type expected by LiteLLM. + """ + tool_descriptions = [] + + # Get all tools from the default toolkit and store their object descriptions + for tool in self.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + tool_descriptions.append(desc) + + # Finally, return the tool descriptions + return tool_descriptions + From fd950bd2a91bebae28697805cbdce65e3d11fd99 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:23:47 -0700 Subject: [PATCH 10/13] add pocketflow as a dependency --- packages/jupyter-ai/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index d82914103..cdd5b4dd2 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "litellm>=1.73,<2", "jinja2>=3.0,<4", "python_dotenv>=1,<2", + "pocketflow==0.0.3", ] dynamic = ["version", "description", "authors", "urls", "keywords"] From 793fb4e11eba703490518552ffbb089b2417dc55 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:24:08 -0700 Subject: [PATCH 11/13] update jai-tool-call web component API to do less JSON parsing --- .../jupyter-ai/src/web-components/index.ts | 2 +- .../src/web-components/jai-tool-call.tsx | 52 +++++++++---------- .../{plugin.ts => web-components-plugin.ts} | 14 ++++- 3 files changed, 38 insertions(+), 30 deletions(-) rename packages/jupyter-ai/src/web-components/{plugin.ts => web-components-plugin.ts} (83%) diff --git a/packages/jupyter-ai/src/web-components/index.ts b/packages/jupyter-ai/src/web-components/index.ts index aea66be0d..5f5e58107 100644 --- a/packages/jupyter-ai/src/web-components/index.ts +++ b/packages/jupyter-ai/src/web-components/index.ts @@ -1,2 +1,2 @@ -export * from './plugin'; +export * from './web-components-plugin'; export * from './jai-tool-call'; diff --git a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx index dc6a52cc5..d9530c620 100644 --- a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx +++ b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo } from 'react'; +import React, { useState } from 'react'; import { Box, Typography, @@ -10,13 +10,11 @@ import ExpandMore from '@mui/icons-material/ExpandMore'; import CheckCircle from '@mui/icons-material/CheckCircle'; type JaiToolCallProps = { - id: string; - type: string; - function: { - name: string; - arguments: Record; - }; - index: number; + id?: string; + type?: string; + function_name?: string; + function_args?: string; + index?: number; output?: { tool_call_id: string; role: string; @@ -26,10 +24,10 @@ type JaiToolCallProps = { }; export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { - const [expanded, setExpanded] = useState(false); console.log({ - output: props.output + props }); + const [expanded, setExpanded] = useState(false); const toolComplete = !!(props.output && Object.keys(props.output).length > 0); const hasOutput = !!(toolComplete && props.output?.content?.length); @@ -47,29 +45,28 @@ export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { {toolComplete ? 'Ran' : 'Running'}{' '} - {props.function.name} + {props.function_name} {' '} tool {toolComplete ? '.' : '...'} ); - const toolArgsJson = useMemo( - () => JSON.stringify(props.function.arguments, null, 2), - [props.function.arguments] - ); + // const toolArgsJson = useMemo( + // () => JSON.stringify(props?.function_args ?? {}, null, 2), + // [props.function_args] + // ); - const toolArgsSection: JSX.Element | null = - toolArgsJson === '{}' ? null : ( - - - Tool arguments - -
-          {toolArgsJson}
-        
-
- ); + const toolArgsSection: JSX.Element | null = props.function_args ? ( + + + Tool arguments + +
+        {props.function_args}
+      
+
+ ) : null; const toolOutputSection: JSX.Element | null = hasOutput ? ( @@ -80,12 +77,13 @@ export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { ) : null; - if (!props.id || !props.type || !props.function) { + if (!props.id || !props.type || !props.function_name) { return null; } return ( props: { id: 'string', type: 'string', - function: 'json', + function_name: 'string', + // this is deliberately not 'json' since `function_args` may be a + // partial JSON string. + function_args: 'string', index: 'number', output: 'json' } @@ -55,7 +58,14 @@ export const webComponentsPlugin: JupyterFrontEndPlugin allowedTags: [...(options?.allowedTags ?? []), 'jai-tool-call'], allowedAttributes: { ...options?.allowedAttributes, - 'jai-tool-call': ['id', 'type', 'function', 'index', 'output'] + 'jai-tool-call': [ + 'id', + 'type', + 'function_name', + 'function_args', + 'index', + 'output' + ] } }); } From 29f075d4d3b97718470abb40385965e9ada836f2 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:41:02 -0700 Subject: [PATCH 12/13] remove debug logs --- packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py index 9fbdda70d..b93980f74 100644 --- a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -207,7 +207,6 @@ async def exec_async(self, prep_res: list[dict]): "content": content, "tool_call_ui_elements": tool_calls.render() }) - self.log.error(message_body) self.ychat.update_message( Message( id=stream_id, @@ -274,8 +273,6 @@ async def exec_async(self, prep_res: Tuple[str, ToolCallList]) -> list[LitellmTo # TODO: Run 1 tool at a time? outputs = await run_tools(tool_calls, self.toolkit) - for output in outputs: - self.log.error(output) return outputs async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: list[LitellmToolCallOutput]): @@ -300,12 +297,9 @@ async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: raw_time=False, ) ) - self.log.error(message_body) # Add tool outputs to `shared['litellm_messages']` shared['litellm_messages'].extend(exec_res) - for msg in shared['litellm_messages']: - self.log.error(msg) # Delete shared state that is now stale del shared['prev_message_id'] From 95d2615d45788649e8f484ee93df11fcfdb3124a Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:48:46 -0700 Subject: [PATCH 13/13] show writing indicator while processing request --- .../jupyter-ai/jupyter_ai/default_flow/default_flow.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py index b93980f74..f885a6fbf 100644 --- a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -121,7 +121,6 @@ class RootNode(JaiAsyncNode): """ async def prep_async(self, shared): - self.log.info("Running RootNode.prep_async()") # Initialize `shared.litellm_messages` using the YChat message history # if it is unset. if not ('litellm_messages' in shared and isinstance(shared['litellm_messages'], list) and len(shared['litellm_messages']) > 0): @@ -325,5 +324,12 @@ async def run_default_flow(params: DefaultFlowParams): flow.set_params(params) # Finally, run the async node - await flow.run_async({}) + try: + params['awareness'].set_local_state_field("isWriting", True) + await flow.run_async({}) + except Exception as e: + # TODO: implement error handling + params['logger'].exception("Exception occurred while running default agent flow:") + finally: + params['awareness'].set_local_state_field("isWriting", False)