diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py b/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py new file mode 100644 index 000000000..849952ada --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py @@ -0,0 +1 @@ +from .default_flow import * \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py new file mode 100644 index 000000000..f885a6fbf --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -0,0 +1,335 @@ +from pocketflow import AsyncNode, AsyncFlow +from jupyterlab_chat.models import Message, NewMessage +from jupyterlab_chat.ychat import YChat +from typing import Any, Optional, Tuple, TypedDict +from jinja2 import Template +from litellm import acompletion, ModelResponseStream +import time +import logging + +from ..litellm_lib import ToolCallList, run_tools, LitellmToolCallOutput +from ..tools import Toolkit +from ..personas import SYSTEM_USERNAME, PersonaAwareness + +DEFAULT_RESPONSE_TEMPLATE = """ +{{ content }} +{{ tool_call_ui_elements }} +""".strip() + +class DefaultFlowParams(TypedDict): + """ + Parameters expected by the default flow provided by Jupyter AI. + """ + + model_id: str + + ychat: YChat + + awareness: PersonaAwareness + + persona_id: str + + logger: logging.Logger + + model_args: dict[str, Any] | None + """ + Custom keyword arguments forwarded to `litellm.acompletion()`. Defaults to + `{}` if unset. + """ + + system_prompt: Optional[str] + """ + System prompt that will be used as the first message in the list of messages + sent to the language model. Unused if unset. + """ + + response_template: Template | None + """ + Jinja2 template used to template the response. If one is not given, + `DEFAULT_RESPONSE_TEMPLATE` is used. + + It should take `content: str` and `tool_call_ui_elements: str` as format arguments. + """ + + toolkit: Toolkit | None + """ + Toolkit of tools. Unused if unset. + """ + + history_size: int | None + """ + Number of messages preceding the message triggering this flow to include + in the prompt as context. Defaults to 2 if unset. + """ + +class JaiAsyncNode(AsyncNode): + """ + An AsyncNode with custom properties & helper methods used exclusively in the + Jupyter AI extension. + """ + + @property + def model_id(self) -> str: + return self.params["model_id"] + + @property + def ychat(self) -> YChat: + return self.params["ychat"] + + @property + def awareness(self) -> PersonaAwareness: + return self.params["awareness"] + + @property + def persona_id(self) -> str: + return self.params["persona_id"] + + @property + def model_args(self) -> dict[str, Any]: + return self.params.get("model_args", {}) + + @property + def system_prompt(self) -> Optional[str]: + return self.params.get("system_prompt") + + @property + def response_template(self) -> Template: + template = self.params.get("response_template") + # If response template was unspecified, use the default response + # template. + if not template: + template = Template(DEFAULT_RESPONSE_TEMPLATE) + + return template + + @property + def toolkit(self) -> Optional[Toolkit]: + return self.params.get("toolkit") + + @property + def history_size(self) -> int: + return self.params.get("history_size", 2) + + @property + def log(self) -> logging.Logger: + return self.params.get("logger") + + +class RootNode(JaiAsyncNode): + """ + The root node of the default flow provided by Jupyter AI. + """ + + async def prep_async(self, shared): + # Initialize `shared.litellm_messages` using the YChat message history + # if it is unset. + if not ('litellm_messages' in shared and isinstance(shared['litellm_messages'], list) and len(shared['litellm_messages']) > 0): + shared['litellm_messages'] = self._init_litellm_messages() + + # Return `shared.litellm_messages`. This is passed as the `prep_res` + # argument to `exec_async()`. + return shared['litellm_messages'] + + + def _init_litellm_messages(self) -> list[dict]: + # Store the invoking message & the previous `params.history_size` messages + # as `ychat_messages`. + # TODO: ensure the invoking message is in this list + all_messages = self.ychat.get_messages() + ychat_messages: list[Message] = all_messages[-self.history_size - 1:] + + # Coerce each `Message` in `ychat_messages` to a dictionary following + # the OpenAI spec, and store it as `litellm_messages`. + litellm_messages: list[dict[str, Any]] = [] + for msg in ychat_messages: + role = ( + "assistant" + if msg.sender.startswith("jupyter-ai-personas::") + else "system" if msg.sender == SYSTEM_USERNAME else "user" + ) + litellm_messages.append({"role": role, "content": msg.body}) + + # Insert system message as a dictionary if present. + if self.system_prompt: + system_litellm_message = { + "role": "system", + "content": self.system_prompt + } + litellm_messages = [system_litellm_message, *litellm_messages] + + # Return `litellm_messages` + return litellm_messages + + + async def exec_async(self, prep_res: list[dict]): + self.log.info("Running RootNode.exec_async()") + # Gather arguments and start a reply stream via LiteLLM + reply_stream = await acompletion( + **self.model_args, + model=self.model_id, + messages=prep_res, + tools=self.toolkit.to_json(), + stream=True, + ) + + # Iterate over reply stream + content = "" + tool_calls = ToolCallList() + stream_id: str | None = None + async for chunk in reply_stream: + assert isinstance(chunk, ModelResponseStream) + delta = chunk.choices[0].delta + content_delta = delta.content + toolcalls_delta = delta.tool_calls + + # Continue early if an empty chunk was emitted. + # This sometimes happens with LiteLLM. + if not (content_delta or toolcalls_delta): + continue + + # Aggregate the content and tool calls from the deltas + if content_delta: + content += content_delta + if toolcalls_delta: + tool_calls += toolcalls_delta + + # Create a new message if one does not yet exist + if not stream_id: + stream_id = self.ychat.add_message(NewMessage( + sender=self.persona_id, + body="" + )) + assert stream_id + + # Update the reply + message_body = self.response_template.render({ + "content": content, + "tool_call_ui_elements": tool_calls.render() + }) + self.ychat.update_message( + Message( + id=stream_id, + body=message_body, + time=time.time(), + sender=self.persona_id, + raw_time=False, + ) + ) + + # Return message_id, content, and tool calls + return stream_id, content, tool_calls + + async def post_async(self, shared, prep_res, exec_res: Tuple[str, str, ToolCallList]): + self.log.info("Running RootNode.post_async()") + # Assert that `shared['litellm_messages']` is of the correct type, and + # that any tool calls returned are complete. + message_id, content, tool_calls = exec_res + assert 'litellm_messages' in shared and isinstance(shared['litellm_messages'], list) + assert tool_calls.complete + + # Add AI response to `shared['litellm_messages']`, including tool calls + new_litellm_message = { + "role": "assistant", + "content": content + } + if len(tool_calls): + new_litellm_message['tool_calls'] = tool_calls.as_litellm_tool_calls() + shared['litellm_messages'].append(new_litellm_message) + + # Add message ID to `shared['prev_message_id']` + shared['prev_message_id'] = message_id + + # Add message content to `shared['prev_message_content]` + shared['prev_message_content'] = content + + # Add tool calls to `shared['next_tool_calls']` + shared['next_tool_calls'] = tool_calls + + # Trigger `ToolExecutorNode` if tools were called. + if len(tool_calls): + return "execute-tools" + return 'finish' + +class ToolExecutorNode(JaiAsyncNode): + """ + Node responsible for executing tool calls in the default flow. + """ + + + async def prep_async(self, shared): + self.log.info("Running ToolExecutorNode.prep_async()") + # Extract `shared['next_tool_calls']` and the ID of the last message + assert 'next_tool_calls' in shared and isinstance(shared['next_tool_calls'], ToolCallList) + assert 'prev_message_id' in shared and isinstance(shared['prev_message_id'], str) + + # Return list of tool calls as a list of dictionaries + return shared['prev_message_id'], shared['next_tool_calls'] + + async def exec_async(self, prep_res: Tuple[str, ToolCallList]) -> list[LitellmToolCallOutput]: + self.log.info("Running ToolExecutorNode.exec_async()") + message_id, tool_calls = prep_res + + # TODO: Run 1 tool at a time? + outputs = await run_tools(tool_calls, self.toolkit) + + return outputs + + async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: list[LitellmToolCallOutput]): + self.log.info("Running ToolExecutorNode.post_async()") + + # Update last message to include outputs + prev_message_id = shared['prev_message_id'] + prev_message_content = shared['prev_message_content'] + tool_calls: ToolCallList = shared['next_tool_calls'] + message_body = self.response_template.render({ + "content": prev_message_content, + "tool_call_ui_elements": tool_calls.render( + outputs=exec_res + ) + }) + self.ychat.update_message( + Message( + id=prev_message_id, + body=message_body, + time=time.time(), + sender=self.persona_id, + raw_time=False, + ) + ) + + # Add tool outputs to `shared['litellm_messages']` + shared['litellm_messages'].extend(exec_res) + + # Delete shared state that is now stale + del shared['prev_message_id'] + del shared['prev_message_content'] + del shared['next_tool_calls'] + # This node will automatically return to `RootNode` after execution. + +async def run_default_flow(params: DefaultFlowParams): + # Initialize nodes + root_node = RootNode() + tool_executor_node = ToolExecutorNode() + + # Define state transitions + ## Flow to ToolExecutorNode if tool calls were dispatched + root_node - "execute-tools" >> tool_executor_node + ## Always flow back to RootNode after running tools + tool_executor_node >> root_node + ## End the flow if no tool calls were dispatched + root_node - "finish" >> AsyncNode() + + # Initialize flow and set its parameters + flow = AsyncFlow(start=root_node) + flow.set_params(params) + + # Finally, run the async node + try: + params['awareness'].set_local_state_field("isWriting", True) + await flow.run_async({}) + except Exception as e: + # TODO: implement error handling + params['logger'].exception("Exception occurred while running default agent flow:") + finally: + params['awareness'].set_local_state_field("isWriting", False) + diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py new file mode 100644 index 000000000..edc2e51cc --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py @@ -0,0 +1,3 @@ +from .run_tools import * +from .toolcall_list import * +from .types import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py new file mode 100644 index 000000000..6ccb22d71 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py @@ -0,0 +1,50 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +import asyncio + +if TYPE_CHECKING: + from ..tools import Toolkit + from .toolcall_list import ToolCallList + from .types import LitellmToolCallOutput + + +async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[LitellmToolCallOutput]: + """ + Runs the tools specified in the list of tool calls returned by + `self.stream_message()`. + + Returns `list[LitellmToolCallOutput]`, a list of output dictionaries of the + type expected by LiteLLM. + + Each output in the list should be appended directly to the message history + on the next request made to the LLM. + """ + tool_calls = tool_call_list.resolve() + if not len(tool_calls): + return [] + + tool_outputs: list[LitellmToolCallOutput] = [] + for tool_call in tool_calls: + # Get tool definition from the correct toolkit + # TODO: validation? + tool_name = tool_call.function.name + tool_defn = toolkit.get_tool_unsafe(tool_name) + + # Run tool and store its output + try: + output = tool_defn.callable(**tool_call.function.arguments) + if asyncio.iscoroutine(output): + output = await output + except Exception as e: + output = str(e) + + # Store the tool output in a dictionary accepted by LiteLLM + output_dict: LitellmToolCallOutput = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": output, + } + tool_outputs.append(output_dict) + + return tool_outputs diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py new file mode 100644 index 000000000..9069eb481 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py @@ -0,0 +1,52 @@ +from litellm.utils import ChatCompletionDeltaToolCall, Function +from .toolcall_list import ToolCallList + +class TestToolCallList(): + + def test_single_tool_stream(self): + """ + Asserts this class works against a sample response from Claude running a + single tool. + """ + # Setup test + ID = "toolu_01TzXi4nFJErYThcdhnixn7e" + toolcall_list = ToolCallList() + toolcall_list += [ChatCompletionDeltaToolCall(id=ID, function=Function(arguments='', name='ls'), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='": "."}', name=None), type='function', index=0)] + + # Verify the resolved list of calls + resolved_toolcalls = toolcall_list.resolve() + assert len(resolved_toolcalls) == 1 + assert resolved_toolcalls[0] + + def test_two_tool_stream(self): + """ + Asserts this class works against a sample response from Claude running a + two tools in parallel. + """ + # Setup test + ID_0 = 'toolu_0141FrNfT2LJg6odqbrdmLM6' + ID_1 = 'toolu_01DKqnaXVcyp1v1ABxhHC5Sg' + toolcall_list = ToolCallList() + toolcall_list += [ChatCompletionDeltaToolCall(id=ID_0, function=Function(arguments='', name='ls'), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path": ', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='"."}', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=ID_1, function=Function(arguments='', name='bash'), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"com', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='mand": "ech', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='o \'hello\'"}', name=None), type='function', index=1)] + + # Verify the resolved list of calls + resolved_toolcalls = toolcall_list.resolve() + assert len(resolved_toolcalls) == 2 + assert resolved_toolcalls[0].id == ID_0 + assert resolved_toolcalls[0].function.name == "ls" + assert resolved_toolcalls[0].function.arguments == { "path": "." } + assert resolved_toolcalls[1].id == ID_1 + assert resolved_toolcalls[1].function.name == "bash" + assert resolved_toolcalls[1].function.arguments == { "command": "echo \'hello\'" } + diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py new file mode 100644 index 000000000..563f06746 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py @@ -0,0 +1,253 @@ +from litellm.utils import ChatCompletionDeltaToolCall, Function +import json +from pydantic import BaseModel +from typing import Any +from .types import LitellmToolCall, LitellmToolCallOutput, JaiToolCallProps +from jinja2 import Template + +class ResolvedFunction(BaseModel): + """ + A type-safe, parsed representation of `litellm.utils.Function`. + """ + + name: str + """ + Name of the tool function to be called. + + TODO: Check if this attribute is defined for non-function tools, e.g. tools + provided by a MCP server. The docstring on `litellm.utils.Function` implies + that `name` may be `None`. + """ + + arguments: dict[str, Any] + """ + Arguments to the tool function, as a dictionary. + """ + + +class ResolvedToolCall(BaseModel): + """ + A type-safe, parsed representation of + `litellm.utils.ChatCompletionDeltaToolCall`. + """ + + id: str + """ + The ID of the tool call. + """ + + type: str + """ + The 'type' of tool call. Usually 'function'. + + TODO: Make this a union of string literals to ensure we are handling every + potential type of tool call. + """ + + function: ResolvedFunction + """ + The resolved function. See `ResolvedFunction` for more info. + """ + + index: int + """ + The index of this tool call. + + This is usually 0 unless the LLM supports parallel tool calling. + """ + +JAI_TOOL_CALL_TEMPLATE = Template(""" +{% for props in props_list %} + + +{% endfor %} +""".strip()) + +class ToolCallList(BaseModel): + """ + A helper object that defines a custom `__iadd__()` method which accepts a + `tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class + is used to aggregate the tool call deltas yielded from a LiteLLM response + stream and produce a list of tool calls. + + After all tool call deltas are added, the `resolve()` method may be called + to return a list of resolved tool calls. + + Example usage: + + ```py + tool_call_list = ToolCallList() + reply_stream = await litellm.acompletion(..., stream=True) + + async for chunk in reply_stream: + tool_call_delta = chunk.choices[0].delta.tool_calls + tool_call_list += tool_call_delta + + tool_calls = tool_call_list.resolve() + ``` + """ + + _aggregate: list[ChatCompletionDeltaToolCall] = [] + + def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': + """ + Adds a list of tool call deltas to this instance. + + NOTE: This assumes the 'index' attribute on each entry in this list to + be accurate. If this assumption doesn't hold, we will need to rework the + logic here. + """ + if other is None: + return self + + # Iterate through each delta + for delta in other: + # Ensure `self._aggregate` is at least of size `delta.index + 1` + for i in range(len(self._aggregate), delta.index + 1): + self._aggregate.append(ChatCompletionDeltaToolCall( + function=Function(arguments=""), + index=i, + )) + + # Find the corresponding target in the `self._aggregate` and add the + # delta on top of it. In most cases, the value of aggregate + # attribute is set as soon as any delta sets it to a non-`None` + # value. However, `delta.function.arguments` is a string that should + # be appended to the aggregate value of that attribute. + target = self._aggregate[delta.index] + if delta.type: + target.type = delta.type + if delta.id: + target.id = delta.id + if delta.function.name: + target.function.name = delta.function.name + if delta.function.arguments: + target.function.arguments += delta.function.arguments + + return self + + + def __add__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': + """ + Alias for `__iadd__()`. + """ + return self.__iadd__(other) + + + def resolve(self) -> list[ResolvedToolCall]: + """ + Returns the aggregated tool calls as `list[ResolvedToolCall]`. + + Raises an exception if any function arguments could not be parsed from + JSON into a dictionary. This method should only be called after the + stream completed without errors. + """ + resolved_toolcalls: list[ResolvedToolCall] = [] + for i, raw_toolcall in enumerate(self._aggregate): + # Verify entries are at the correct index in the aggregated list + assert raw_toolcall.index == i + + # Verify each tool call specifies the name of the tool to run. + # + # TODO: Check if this may cause a runtime error. The docstring on + # `litellm.utils.Function` implies that `name` may be `None`. + assert raw_toolcall.function.name + + # Verify each tool call defines the type of tool it is calling. + assert raw_toolcall.type is not None + + # Parse the function argument string into a dictionary + resolved_fn_args = json.loads(raw_toolcall.function.arguments) + + # Add to the returned list + resolved_fn = ResolvedFunction( + name=raw_toolcall.function.name, + arguments=resolved_fn_args + ) + resolved_toolcall = ResolvedToolCall( + id=raw_toolcall.id, + type=raw_toolcall.type, + index=i, + function=resolved_fn + ) + resolved_toolcalls.append(resolved_toolcall) + + return resolved_toolcalls + + @property + def complete(self) -> bool: + for i, tool_call in enumerate(self._aggregate): + if tool_call.index != i: + return False + if not tool_call.function: + return False + if not tool_call.function.name: + return False + if not tool_call.type: + return False + if not tool_call.function.arguments: + return False + try: + json.loads(tool_call.function.arguments) + except Exception: + return False + + return True + + def as_litellm_tool_calls(self) -> list[LitellmToolCall]: + """ + Returns the current list of tool calls as a list of dictionaries. + + This should be set in the `tool_calls` key in the dictionary of the + LiteLLM assistant message responsible for dispatching these tool calls. + """ + return [ + model.model_dump() for model in self._aggregate + ] + + def render(self, outputs: list[LitellmToolCallOutput] | None = None) -> str: + """ + Renders this tool call list as a list of `` elements to + be shown in the chat. + """ + # Initialize list of props to render into tool call UI elements + props_list: list[JaiToolCallProps] = [] + + # Index all outputs if passed + outputs_by_id: dict[str, LitellmToolCallOutput] | None = None + if outputs: + outputs_by_id = {} + for output in outputs: + outputs_by_id[output['tool_call_id']] = output + + for tool_call in self._aggregate: + # Build the props for each tool call UI element + props: JaiToolCallProps = { + 'id': tool_call.id, + 'index': tool_call.index, + 'type': tool_call.type, + 'function_name': tool_call.function.name, + 'function_args': tool_call.function.arguments, + } + + # Add the output if present + if outputs_by_id and tool_call.id in outputs_by_id: + output = outputs_by_id[tool_call.id] + # Make sure to manually convert the dictionary to a JSON string + # first. Without doing this, Jinja2 will convert a dictionary to + # JSON using single quotes instead of double quotes, which + # cannot be parsed by the frontend. + output = json.dumps(output) + props['output'] = output + + props_list.append(props) + + # Render the tool call UI elements using the Jinja2 template and return + return JAI_TOOL_CALL_TEMPLATE.render({ + "props_list": props_list + }) + + + def __len__(self) -> int: + return len(self._aggregate) + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py new file mode 100644 index 000000000..314c54990 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py @@ -0,0 +1,39 @@ +from __future__ import annotations +from typing import TypedDict, Literal, Optional + + +class LitellmToolCall(TypedDict): + id: str + type: Literal['function'] + function: str + index: int + +class LitellmMessage(TypedDict): + role: Literal['assistant', 'user', 'system'] + content: str + tool_calls: Optional[list[LitellmToolCall]] + +class LitellmToolCallOutput(TypedDict): + tool_call_id: str + role: Literal['tool'] + name: str + content: str + +class JaiToolCallProps(TypedDict): + id: str | None + + type: Literal['function'] | None + + index: int | None + + function_name: str | None + + function_args: str | None + """ + The arguments to the function as a dictionary converted to a JSON string. + """ + + output: str | None + """ + The `LitellmToolCallOutput` as a JSON string. + """ \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/personas/__init__.py b/packages/jupyter-ai/jupyter_ai/personas/__init__.py index 6c0704f52..fb8dd1bfe 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/personas/__init__.py @@ -1,2 +1,2 @@ -from .base_persona import BasePersona, PersonaDefaults -from .persona_manager import PersonaManager +from .base_persona import * +from .persona_manager import * diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 310901b1c..40fdb14db 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -1,30 +1,27 @@ +from __future__ import annotations import asyncio import os from abc import ABC, ABCMeta, abstractmethod from dataclasses import asdict from logging import Logger -from time import time from typing import TYPE_CHECKING, Any, Optional from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User from jupyterlab_chat.ychat import YChat +from litellm import supports_function_calling +from litellm.utils import function_to_dict from pydantic import BaseModel from traitlets import MetaHasTraits from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness +from ..litellm_lib import ToolCallList, run_tools +from ..tools.default_toolkit import DEFAULT_TOOLKIT -# prevents a circular import -# types imported under this block have to be surrounded in single quotes on use if TYPE_CHECKING: - from collections.abc import AsyncIterator - - from litellm import ModelResponseStream - from .persona_manager import PersonaManager - class PersonaDefaults(BaseModel): """ Data structure that represents the default settings of a persona. Each persona @@ -235,93 +232,6 @@ def as_user_dict(self) -> dict[str, Any]: user = self.as_user() return asdict(user) - async def stream_message( - self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> None: - """ - Takes an async iterator, dubbed the 'reply stream', and streams it to a - new message by this persona in the YChat. The async iterator may yield - either strings or `litellm.ModelResponseStream` objects. Details: - - - Creates a new message upon receiving the first chunk from the reply - stream, then continuously updates it until the stream is closed. - - - Automatically manages its awareness state to show writing status. - """ - stream_id: Optional[str] = None - stream_interrupted = False - try: - self.awareness.set_local_state_field("isWriting", True) - async for chunk in reply_stream: - # Coerce LiteLLM stream chunk to a string delta - if not isinstance(chunk, str): - chunk = chunk.choices[0].delta.content - - # LiteLLM streams always terminate with an empty chunk, so we - # ignore and continue when this occurs. - if not chunk: - continue - - if ( - stream_id - and stream_id in self.message_interrupted.keys() - and self.message_interrupted[stream_id].is_set() - ): - try: - # notify the model provider that streaming was interrupted - # (this is essential to allow the model to stop generating) - await reply_stream.athrow( # type:ignore[attr-defined] - GenerationInterrupted() - ) - except GenerationInterrupted: - # do not let the exception bubble up in case if - # the provider did not handle it - pass - stream_interrupted = True - break - - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) - ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - - assert stream_id - self.ychat.update_message( - Message( - id=stream_id, - body=chunk, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) - except Exception as e: - self.log.error( - f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." - ) - self.log.exception(e) - finally: - self.awareness.set_local_state_field("isWriting", False) - if stream_id: - # if stream was interrupted, add a tombstone - if stream_interrupted: - stream_tombstone = "\n\n(AI response stopped by user)" - self.ychat.update_message( - Message( - id=stream_id, - body=stream_tombstone, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) - if stream_id in self.message_interrupted.keys(): - del self.message_interrupted[stream_id] - def send_message(self, body: str) -> None: """ Sends a new message to the chat from this persona. @@ -361,7 +271,7 @@ def get_mcp_config(self) -> dict[str, Any]: Returns the MCP config for the current chat. """ return self.parent.get_mcp_config() - + def process_attachments(self, message: Message) -> Optional[str]: """ Process file attachments in the message and return their content as a string. diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 05ec403e4..c89934a51 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,14 +1,12 @@ -from typing import Any, Optional - from jupyterlab_chat.models import Message -from litellm import acompletion from ..base_persona import BasePersona, PersonaDefaults -from ..persona_manager import SYSTEM_USERNAME +from ...default_flow import run_default_flow, DefaultFlowParams from .prompt_template import ( JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) +from ...tools import DEFAULT_TOOLKIT class JupyternautPersona(BasePersona): @@ -29,6 +27,7 @@ def defaults(self): ) async def process_message(self, message: Message) -> None: + # Return early if no chat model is configured if not self.config_manager.chat_model: self.send_message( "No chat model is configured.\n\n" @@ -36,67 +35,28 @@ async def process_message(self, message: Message) -> None: ) return - model_id = self.config_manager.chat_model - model_args = self.config_manager.chat_model_args - context_as_messages = self.get_context_as_messages(model_id, message) - response_aiter = await acompletion( - **model_args, - model=model_id, - messages=[ - *context_as_messages, - { - "role": "user", - "content": message.body, - }, - ], - stream=True, - ) - - await self.stream_message(response_aiter) - - def get_context_as_messages( - self, model_id: str, message: Message - ) -> list[dict[str, Any]]: - """ - Returns the current context, including attachments and recent messages, - as a list of messages accepted by `litellm.acompletion()`. - """ - system_msg_args = JupyternautSystemPromptArgs( - model_id=model_id, - persona_name=self.name, - context=self.process_attachments(message), - ).model_dump() - - system_msg = { - "role": "system", - "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), + # Build default flow params + system_prompt = self._build_system_prompt(message) + flow_params: DefaultFlowParams = { + "persona_id": self.id, + "model_id": self.config_manager.chat_model, + "model_args": self.config_manager.chat_model_args, + "ychat": self.ychat, + "awareness": self.awareness, + "system_prompt": system_prompt, + "toolkit": DEFAULT_TOOLKIT, + "logger": self.log, } - context_as_messages = [system_msg, *self._get_history_as_messages()] - return context_as_messages + # Run default agent flow + await run_default_flow(flow_params) - def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: - """ - Returns the current history as a list of messages accepted by - `litellm.acompletion()`. - """ - # TODO: consider bounding history based on message size (e.g. total - # char/token count) instead of message count. - all_messages = self.ychat.get_messages() - - # gather last k * 2 messages and return - # we exclude the last message since that is the human message just - # submitted by a user. - start_idx = 0 if k is None else -2 * k - 1 - recent_messages: list[Message] = all_messages[start_idx:-1] - - history: list[dict[str, Any]] = [] - for msg in recent_messages: - role = ( - "assistant" - if msg.sender.startswith("jupyter-ai-personas::") - else "system" if msg.sender == SYSTEM_USERNAME else "user" - ) - history.append({"role": role, "content": msg.body}) - - return history + def _build_system_prompt(self, message: Message) -> str: + context = self.process_attachments(message) + format_args = JupyternautSystemPromptArgs( + persona_name=self.name, + model_id=self.config_manager.chat_model, + context=context, + ) + system_prompt = JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(format_args.model_dump()) + return system_prompt \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/__init__.py b/packages/jupyter-ai/jupyter_ai/tools/__init__.py index 0252ac1a9..1f8e3afa3 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/tools/__init__.py @@ -1,5 +1,6 @@ """Tools package for Jupyter AI.""" from .models import Tool, Toolkit +from .default_toolkit import DEFAULT_TOOLKIT -__all__ = ["Tool", "Toolkit"] +__all__ = ["Tool", "Toolkit", "DEFAULT_TOOLKIT"] diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py new file mode 100644 index 000000000..d850775af --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -0,0 +1,303 @@ +import asyncio +import pathlib +import shlex +from typing import Optional + +from .models import Tool, Toolkit + + +def read(file_path: str, offset: int, limit: int) -> str: + """ + Read a subset of lines from a text file. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be read. + offset : int + The line number at which to start reading (1-based indexing). + limit : int + Number of lines to read starting from *offset*. + If *offset + limit* exceeds the number of lines in the file, + all available lines after *offset* are returned. + + Returns + ------- + List[str] + List of lines (including line-ending characters) that were read. + + Examples + -------- + >>> # Suppose ``/tmp/example.txt`` contains 10 lines + >>> read('/tmp/example.txt', offset=3, limit=4) + ['third line\n', 'fourth line\n', 'fifth line\n', 'sixth line\n'] + """ + path = pathlib.Path(file_path) + if not path.is_file(): + raise FileNotFoundError(f"File not found: {file_path}") + + # Normalize arguments + offset = max(1, int(offset)) + limit = max(0, int(limit)) + lines: list[str] = [] + + with path.open(encoding='utf-8', errors='replace') as f: + # Skip to offset + line_no = 0 + # Loop invariant: line_no := last read line + # After the loop exits, line_no == offset - 1, meaning the + # next line starts at `offset` + while line_no < offset - 1: + line = f.readline() + # Return early if offset exceeds number of lines in file + if line == "": + return "" + line_no += 1 + + # Append lines until limit is reached + while len(lines) < limit: + line = f.readline() + if line == "": + break + lines.append(line) + + return "".join(lines) + + +def edit( + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, +) -> None: + """ + Replace occurrences of a substring in a file. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be edited. + old_string : str + Text that should be replaced. + new_string : str + Text that will replace *old_string*. + replace_all : bool, optional + If ``True`` all occurrences of *old_string* are replaced. + If ``False`` (default), only the first occurrence in the file is replaced. + + Returns + ------- + None + + Raises + ------ + FileNotFoundError + If *file_path* does not exist. + ValueError + If *old_string* is empty (replacing an empty string is ambiguous). + + Notes + ----- + The file is overwritten atomically: it is first read into memory, + the substitution is performed, and the file is written back. + This keeps the operation safe for short to medium-sized files. + + Examples + -------- + >>> # Replace only the first occurrence + >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=False) + >>> # Replace all occurrences + >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=True) + """ + path = pathlib.Path(file_path) + if not path.is_file(): + raise FileNotFoundError(f"File not found: {file_path}") + + if old_string == "": + raise ValueError("old_string must not be empty") + + # Read the entire file + content = path.read_text(encoding="utf-8", errors="replace") + + # Perform replacement + if replace_all: + new_content = content.replace(old_string, new_string) + else: + new_content = content.replace(old_string, new_string, 1) + + # Write back + path.write_text(new_content, encoding="utf-8") + + +def write(file_path: str, content: str) -> None: + """ + Write content to a file, creating it if it doesn't exist. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be written. + content : str + Content to write to the file. + + Returns + ------- + None + + Raises + ------ + OSError + If the file cannot be written (e.g., permission denied, invalid path). + + Notes + ----- + This function will overwrite the file if it already exists. + The parent directory must exist; this function does not create directories. + + Examples + -------- + >>> write('/tmp/example.txt', 'Hello, world!') + >>> write('/tmp/data.json', '{"key": "value"}') + """ + path = pathlib.Path(file_path) + + # Write the content to the file + path.write_text(content, encoding="utf-8") + + +async def search_grep(pattern: str, include: str = "*") -> str: + """ + Search for text patterns in files using ripgrep. + + This function uses ripgrep (rg) to perform fast regex-based text searching + across files, with optional file filtering based on glob patterns. + + Parameters + ---------- + pattern : str + A regular expression pattern to search for. Ripgrep uses Rust regex + syntax which supports: + - Basic regex features: ., *, +, ?, ^, $, [], (), | + - Character classes: \w, \d, \s, \W, \D, \S + - Unicode categories: \p{L}, \p{N}, \p{P}, etc. + - Word boundaries: \b, \B + - Anchors: ^, $, \A, \z + - Quantifiers: {n}, {n,}, {n,m} + - Groups: (pattern), (?:pattern), (?Ppattern) + - Lookahead/lookbehind: (?=pattern), (?!pattern), (?<=pattern), (?>> search_grep(r"def\s+\w+", "*.py") + 'file.py:10:def my_function():' + + >>> search_grep(r"TODO|FIXME", "**/*.{py,js}") + 'app.py:25:# TODO: implement this + script.js:15:// FIXME: handle edge case' + + >>> search_grep(r"class\s+(\w+)", "src/**/*.py") + 'src/models.py:1:class User:' + """ + # Use bash tool to execute ripgrep + cmd_parts = ["rg", "--color=never", "--line-number", "--with-filename"] + + # Add glob pattern if specified + if include != "*": + cmd_parts.extend(["-g", include]) + + # Add the pattern (always quote it to handle special characters) + cmd_parts.append(pattern) + + # Join command with proper shell escaping + command = " ".join(f'"{part}"' if " " in part or any(c in part for c in "!*?[]{}()") else part for part in cmd_parts) + + try: + result = await bash(command) + return result + except Exception as e: + raise RuntimeError(f"Ripgrep search failed: {str(e)}") from e + + +async def bash(command: str, timeout: Optional[int] = None) -> str: + """Executes a bash command and returns the result + + Args: + command: The bash command to execute + timeout: Optional timeout in seconds + + Returns: + The command output (stdout and stderr combined) + """ + # coerce `timeout` to the correct type. sometimes LLMs pass this as a string + if isinstance(timeout, str): + timeout = int(timeout) + + proc = await asyncio.create_subprocess_exec( + *shlex.split(command), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout) + stdout = stdout.decode("utf-8") + stderr = stderr.decode("utf-8") + + if proc.returncode != 0: + info = f"Command returned non-zero exit code {proc.returncode}. This usually indicates an error." + info += "\n\n" + fr"Original command: {command}" + if not (stdout or stderr): + info += "\n\nNo further information was given in stdout or stderr." + return info + if stdout: + info += f"stdout:\n\n```\n{stdout}\n```\n\n" + if stderr: + info += f"stderr:\n\n```\n{stderr}\n```\n\n" + return info + + if stdout: + return stdout + return "Command executed successfully with exit code 0. No stdout/stderr was returned." + + except asyncio.TimeoutError: + proc.kill() + return f"Command timed out after {timeout} seconds" + + +DEFAULT_TOOLKIT = Toolkit(name="jupyter-ai-default-toolkit") +DEFAULT_TOOLKIT.add_tool(Tool(callable=bash)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=read)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=edit)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=write)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=search_grep)) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/models.py b/packages/jupyter-ai/jupyter_ai/tools/models.py index 5b95b6174..72b9c69ef 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/models.py +++ b/packages/jupyter-ai/jupyter_ai/tools/models.py @@ -1,5 +1,6 @@ import re from typing import Callable, Optional +from litellm.utils import function_to_dict from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -135,7 +136,7 @@ class Toolkit(BaseModel): name: str description: Optional[str] = None - tools: set = Field(default_factory=set) + tools: set[Tool] = Field(default_factory=set) model_config = ConfigDict(arbitrary_types_allowed=True) def add_tool(self, tool: Tool): @@ -146,6 +147,19 @@ def add_tool(self, tool: Tool): """ self.tools.add(tool) + def get_tool_unsafe(self, tool_name: str) -> Tool: + """ + (WIP) Gets a tool by its name. This is just a temporary method which is + used to make Jupyternaut agentic before we implement the + read/write/execute/delete permissions. + """ + for tool in self.tools: + if tool_name == tool.name: + return tool + + raise Exception(f"Tool not found: {tool_name}") + + def get_tools( self, read: Optional[bool] = None, @@ -202,3 +216,24 @@ def get_tools( toolset.add(tool) return toolset + + def to_json(self) -> list[dict]: + """ + Returns a list of tool descriptions in the type expected by LiteLLM. + """ + tool_descriptions = [] + + # Get all tools from the default toolkit and store their object descriptions + for tool in self.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + tool_descriptions.append(desc) + + # Finally, return the tool descriptions + return tool_descriptions + diff --git a/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py new file mode 100644 index 000000000..9db82dc41 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py @@ -0,0 +1,595 @@ +"""Tests for default_toolkit.py functions and toolkit configuration.""" + +import pathlib +import tempfile +import pytest +from unittest.mock import patch, mock_open + +from .default_toolkit import read, edit, write, search_grep, DEFAULT_TOOLKIT +from .models import Tool, Toolkit + + +class TestReadFunction: + """Test the read function.""" + + def test_read_valid_file(self): + """Test reading lines from a valid file.""" + # Create a temporary file with known content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\nline 3\nline 4\nline 5\n") + temp_path = f.name + + try: + # Test reading from offset 2, limit 3 + result = read(temp_path, offset=2, limit=3) + assert result == "line 2\nline 3\nline 4\n" + + # Test reading from offset 1, limit 2 + result = read(temp_path, offset=1, limit=2) + assert result == "line 1\nline 2\n" + + # Test reading all lines from beginning + result = read(temp_path, offset=1, limit=10) + assert result == "line 1\nline 2\nline 3\nline 4\nline 5\n" + + # Test reading from middle to end + result = read(temp_path, offset=4, limit=10) + assert result == "line 4\nline 5\n" + + finally: + # Clean up + pathlib.Path(temp_path).unlink() + + def test_read_empty_file(self): + """Test reading from an empty file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=5) + assert result == "" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_single_line_file(self): + """Test reading from a file with one line.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("single line\n") + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=1) + assert result == "single line\n" + + result = read(temp_path, offset=1, limit=5) + assert result == "single line\n" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_file_not_found(self): + """Test reading from a non-existent file.""" + with pytest.raises(FileNotFoundError, match="File not found: /nonexistent/path"): + read("/nonexistent/path", offset=1, limit=5) + + def test_read_offset_beyond_file_length(self): + """Test reading with offset beyond file length.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\n") + temp_path = f.name + + try: + # Offset beyond file length should return empty string + result = read(temp_path, offset=10, limit=5) + assert result == "" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_negative_and_zero_values(self): + """Test read function with negative and zero offset/limit values.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\nline 3\n") + temp_path = f.name + + try: + # Negative offset should be normalized to 1 + result = read(temp_path, offset=-5, limit=2) + assert result == "line 1\nline 2\n" + + # Zero offset should be normalized to 1 + result = read(temp_path, offset=0, limit=2) + assert result == "line 1\nline 2\n" + + # Zero limit should return empty string + result = read(temp_path, offset=1, limit=0) + assert result == "" + + # Negative limit should return empty string + result = read(temp_path, offset=1, limit=-5) + assert result == "" + + finally: + pathlib.Path(temp_path).unlink() + + def test_read_unicode_content(self): + """Test reading file with unicode content.""" + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as f: + f.write("línea 1 🚀\nlínea 2 ❤️\nlínea 3 🎉\n") + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=2) + assert result == "línea 1 🚀\nlínea 2 ❤️\n" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_with_encoding_errors(self): + """Test reading file with encoding issues using replace errors handling.""" + # This test ensures the 'replace' error handling works properly + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("valid line\n") + temp_path = f.name + + try: + # The function should handle encoding errors gracefully + result = read(temp_path, offset=1, limit=1) + assert result == "valid line\n" + finally: + pathlib.Path(temp_path).unlink() + + +class TestEditFunction: + """Test the edit function.""" + + def test_edit_replace_first_occurrence(self): + """Test replacing the first occurrence of a string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("foo bar foo baz foo") + temp_path = f.name + + try: + edit(temp_path, "foo", "qux", replace_all=False) + + # Read the file to verify the change + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "qux bar foo baz foo" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_replace_all_occurrences(self): + """Test replacing all occurrences of a string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("foo bar foo baz foo") + temp_path = f.name + + try: + edit(temp_path, "foo", "qux", replace_all=True) + + # Read the file to verify the change + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "qux bar qux baz qux" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_multiline_content(self): + """Test editing multiline content.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nold content\nline 3\nold content\nline 5") + temp_path = f.name + + try: + edit(temp_path, "old content", "new content", replace_all=True) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "line 1\nnew content\nline 3\nnew content\nline 5" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_string_not_found(self): + """Test editing when the target string is not found.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("hello world") + temp_path = f.name + + try: + # This should not raise an error, just leave the file unchanged + edit(temp_path, "nonexistent", "replacement", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "hello world" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_file_not_found(self): + """Test editing a non-existent file.""" + with pytest.raises(FileNotFoundError, match="File not found: /nonexistent/path"): + edit("/nonexistent/path", "old", "new") + + def test_edit_empty_old_string(self): + """Test editing with an empty old_string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("hello world") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="old_string must not be empty"): + edit(temp_path, "", "replacement") + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_unicode_content(self): + """Test editing file with unicode content.""" + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as f: + f.write("hola 🌟 mundo 🌟 adiós") + temp_path = f.name + + try: + edit(temp_path, "🌟", "⭐", replace_all=True) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "hola ⭐ mundo ⭐ adiós" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_newline_characters(self): + """Test editing with newline characters.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line1\nold\nline3") + temp_path = f.name + + try: + edit(temp_path, "\nold\n", "\nnew\n", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "line1\nnew\nline3" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_replace_with_empty_string(self): + """Test replacing content with empty string (deletion).""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("keep this DELETE_ME keep this too") + temp_path = f.name + + try: + edit(temp_path, "DELETE_ME ", "", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "keep this keep this too" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_atomicity(self): + """Test that edit operation is atomic (file is either fully updated or unchanged).""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + original_content = "original content" + f.write(original_content) + temp_path = f.name + + try: + # Mock pathlib.Path.write_text to raise an exception + with patch.object(pathlib.Path, 'write_text', side_effect=IOError("Disk full")): + with pytest.raises(IOError): + edit(temp_path, "original", "modified") + + # File should remain unchanged due to the error + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == original_content + + finally: + pathlib.Path(temp_path).unlink() + + +class TestWriteFunction: + """Test the write function.""" + + def test_write_new_file(self): + """Test writing content to a new file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "new_file.txt" + test_content = "Hello, world!\nThis is a test." + + write(str(temp_path), test_content) + + # Verify the file was created and contains the correct content + assert temp_path.exists() + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_overwrite_existing_file(self): + """Test overwriting an existing file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("original content") + temp_path = f.name + + try: + new_content = "new content that replaces the old" + write(temp_path, new_content) + + # Verify the file was overwritten + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == new_content + finally: + pathlib.Path(temp_path).unlink() + + def test_write_empty_content(self): + """Test writing empty content to a file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "empty_file.txt" + + write(str(temp_path), "") + + # Verify the file exists and is empty + assert temp_path.exists() + content = temp_path.read_text(encoding='utf-8') + assert content == "" + + def test_write_multiline_content(self): + """Test writing multiline content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "multiline.txt" + test_content = "Line 1\nLine 2\nLine 3\n" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_unicode_content(self): + """Test writing unicode content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "unicode.txt" + test_content = "Hello 世界! 🌍 Café naïve résumé" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_large_content(self): + """Test writing large content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "large.txt" + # Create content with 10000 lines + test_content = "\n".join([f"Line {i}" for i in range(10000)]) + + write(str(temp_path), test_content) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == test_content + + @pytest.mark.skip("Fix this test for CRLF newlines (Windows problem)") + def test_write_special_characters(self): + """Test writing content with special characters.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "special.txt" + test_content = 'Content with "quotes", \ttabs, and \nnewlines\r\n' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_invalid_directory(self): + """Test writing to a non-existent directory.""" + invalid_path = "/nonexistent/directory/file.txt" + + with pytest.raises(OSError): + write(invalid_path, "test content") + + def test_write_permission_denied(self): + """Test writing to a file without write permissions.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("original") + temp_path = f.name + + try: + # Make file read-only + pathlib.Path(temp_path).chmod(0o444) + + with pytest.raises(OSError): + write(temp_path, "new content") + + finally: + # Restore write permissions and clean up + pathlib.Path(temp_path).chmod(0o644) + pathlib.Path(temp_path).unlink() + + def test_write_binary_like_content(self): + """Test writing content that looks like binary data.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "binary_like.txt" + # Content with null bytes and other control characters + test_content = "Normal text\x00null byte\x01control char\xff" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_json_content(self): + """Test writing JSON-like content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "data.json" + test_content = '{"name": "test", "value": 42, "nested": {"key": "value"}}' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_code_content(self): + """Test writing code content with proper indentation.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "code.py" + test_content = '''def hello(): + """Say hello.""" + print("Hello, world!") + + if True: + return "success" +''' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + @pytest.mark.skip("Fix this test for CRLF newlines (Windows problem)") + def test_write_preserves_line_endings(self): + """Test that write preserves different line endings.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "line_endings.txt" + test_content = "Unix\nWindows\r\nMac\rMixed\r\n" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + +class TestSearchGrepFunction: + """Test the search_grep function.""" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_bash_integration(self, mock_bash): + """Test that search_grep correctly calls bash with proper arguments.""" + mock_bash.return_value = "test.py:1:def test():" + + result = await search_grep("def", "*.py") + + # Verify bash was called + mock_bash.assert_called_once() + call_args = mock_bash.call_args[0][0] + + # Check that the command contains expected parts + assert "rg" in call_args + assert "--color=never" in call_args + assert "--line-number" in call_args + assert "--with-filename" in call_args + assert "-g" in call_args + assert "*.py" in call_args + assert "def" in call_args + + assert result == "test.py:1:def test():" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_default_include(self, mock_bash): + """Test search_grep with default include pattern.""" + mock_bash.return_value = "" + + await search_grep("pattern") + + call_args = mock_bash.call_args[0][0] + # Should not contain -g flag when using default "*" pattern + assert "-g" not in call_args or "\"*\"" not in call_args + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_bash_exception(self, mock_bash): + """Test search_grep handling of bash execution errors.""" + mock_bash.side_effect = Exception("Command failed") + + with pytest.raises(RuntimeError, match="Ripgrep search failed: Command failed"): + await search_grep("pattern", "*.txt") + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_basic_pattern(self, mock_bash): + """Test basic pattern searching.""" + mock_bash.return_value = "test1.py:1:def hello_world():\ntest2.py:2: def method(self):" + + result = await search_grep(r"def\s+\w+", "*.py") + + # Should find function definitions in both files + assert "test1.py" in result + assert "test2.py" in result + assert "def hello_world" in result + assert "def method" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_no_matches(self, mock_bash): + """Test search with no matches.""" + mock_bash.return_value = "" + + result = await search_grep("nonexistent_pattern", "*.txt") + assert result == "" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_with_include_pattern(self, mock_bash): + """Test search with file include pattern.""" + mock_bash.return_value = "script.py:1:import os" + + result = await search_grep("import", "*.py") + assert "script.py" in result + assert "readme.txt" not in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_special_characters(self, mock_bash): + """Test searching for patterns with special regex characters.""" + # Mock different return values for different calls + mock_bash.side_effect = [ + "special.txt:2:email: user@domain.com", + "special.txt:1:price: $10.99" + ] + + # Search for email pattern + result = await search_grep(r"\w+@\w+\.\w+", "*.txt") + assert "user@domain.com" in result + + # Search for price pattern + result = await search_grep(r"\$\d+\.\d+", "*.txt") + assert "$10.99" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_unicode_content(self, mock_bash): + """Test searching in files with unicode content.""" + mock_bash.return_value = "unicode.txt:1:Hello 世界" + + result = await search_grep("世界", "*.txt") + assert "世界" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_line_anchors(self, mock_bash): + """Test line anchor patterns (^ and $).""" + mock_bash.side_effect = [ + "anchors.txt:1:start of line", + "anchors.txt:3:line with end" + ] + + # Search for lines starting with specific text + result = await search_grep("^start", "*.txt") + assert "start of line" in result + + # Search for lines ending with specific text + result = await search_grep("end$", "*.txt") + assert "line with end" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_case_insensitive_pattern(self, mock_bash): + """Test case insensitive regex patterns.""" + mock_bash.return_value = "mixed_case.txt:1:TODO: fix this\nmixed_case.txt:2:todo: also this\nmixed_case.txt:3:ToDo: and this" + + # Case insensitive search + result = await search_grep("(?i)todo", "*.txt") + lines = result.strip().split('\n') if result.strip() else [] + assert len(lines) == 3 # Should match all three variants + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_complex_glob_patterns(self, mock_bash): + """Test various complex glob patterns.""" + mock_bash.return_value = "src/main.py:1:import sys\nsrc/utils.py:1:import os" + + # Test recursive search in src directory + result = await search_grep("import", "src/**/*.py") + assert "src/main.py" in result + assert "src/utils.py" in result + assert "test_main.py" not in result \ No newline at end of file diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index 19e54bbcc..c5e5fa523 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -79,6 +79,7 @@ "@lumino/widgets": "^2.3.2", "@mui/icons-material": "^5.11.0", "@mui/material": "^5.11.0", + "@r2wc/react-to-web-component": "^2.0.4", "react": "^18.2.0", "react-dom": "^18.2.0" }, diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index d82914103..cdd5b4dd2 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "litellm>=1.73,<2", "jinja2>=3.0,<4", "python_dotenv>=1,<2", + "pocketflow==0.0.3", ] dynamic = ["version", "description", "authors", "urls", "keywords"] diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index 2e2c5c5a4..eaea55c48 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -20,6 +20,7 @@ import { completionPlugin } from './completions'; import { StopButton } from './components/message-footer/stop-button'; import { statusItemPlugin } from './status'; import { IJaiCompletionProvider } from './tokens'; +import { webComponentsPlugin } from './web-components'; import { buildErrorWidget } from './widgets/chat-error'; import { buildAiSettings } from './widgets/settings-widget'; @@ -125,6 +126,7 @@ export default [ plugin, statusItemPlugin, completionPlugin, + webComponentsPlugin, stopStreaming, ...chatCommandPlugins ]; diff --git a/packages/jupyter-ai/src/web-components/index.ts b/packages/jupyter-ai/src/web-components/index.ts new file mode 100644 index 000000000..5f5e58107 --- /dev/null +++ b/packages/jupyter-ai/src/web-components/index.ts @@ -0,0 +1,2 @@ +export * from './web-components-plugin'; +export * from './jai-tool-call'; diff --git a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx new file mode 100644 index 000000000..d9530c620 --- /dev/null +++ b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx @@ -0,0 +1,119 @@ +import React, { useState } from 'react'; +import { + Box, + Typography, + Collapse, + IconButton, + CircularProgress +} from '@mui/material'; +import ExpandMore from '@mui/icons-material/ExpandMore'; +import CheckCircle from '@mui/icons-material/CheckCircle'; + +type JaiToolCallProps = { + id?: string; + type?: string; + function_name?: string; + function_args?: string; + index?: number; + output?: { + tool_call_id: string; + role: string; + name: string; + content: string | null; + }; +}; + +export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { + console.log({ + props + }); + const [expanded, setExpanded] = useState(false); + const toolComplete = !!(props.output && Object.keys(props.output).length > 0); + const hasOutput = !!(toolComplete && props.output?.content?.length); + + const handleExpandClick = () => { + setExpanded(!expanded); + }; + + const statusIcon: JSX.Element = toolComplete ? ( + + ) : ( + + ); + + const statusText: JSX.Element = ( + + {toolComplete ? 'Ran' : 'Running'}{' '} + + {props.function_name} + {' '} + tool + {toolComplete ? '.' : '...'} + + ); + + // const toolArgsJson = useMemo( + // () => JSON.stringify(props?.function_args ?? {}, null, 2), + // [props.function_args] + // ); + + const toolArgsSection: JSX.Element | null = props.function_args ? ( + + + Tool arguments + +
+        {props.function_args}
+      
+
+ ) : null; + + const toolOutputSection: JSX.Element | null = hasOutput ? ( + + + Tool output + +
{props.output?.content}
+
+ ) : null; + + if (!props.id || !props.type || !props.function_name) { + return null; + } + + return ( + + + {statusIcon} + {statusText} + + + + + + + + + {toolArgsSection} + {toolOutputSection} + + + + ); +} diff --git a/packages/jupyter-ai/src/web-components/web-components-plugin.ts b/packages/jupyter-ai/src/web-components/web-components-plugin.ts new file mode 100644 index 000000000..9d368b057 --- /dev/null +++ b/packages/jupyter-ai/src/web-components/web-components-plugin.ts @@ -0,0 +1,75 @@ +import { + JupyterFrontEnd, + JupyterFrontEndPlugin +} from '@jupyterlab/application'; +import r2wc from '@r2wc/react-to-web-component'; + +import { JaiToolCall } from './jai-tool-call'; +import { ISanitizer, Sanitizer } from '@jupyterlab/apputils'; +import { IRenderMime } from '@jupyterlab/rendermime'; + +/** + * Plugin that registers custom web components for usage in AI responses. + */ +export const webComponentsPlugin: JupyterFrontEndPlugin = + { + id: '@jupyter-ai/core:web-components', + autoStart: true, + provides: ISanitizer, + activate: (app: JupyterFrontEnd) => { + // Define the JaiToolCall web component + // ['id', 'type', 'function', 'index', 'output'] + const JaiToolCallWebComponent = r2wc(JaiToolCall, { + props: { + id: 'string', + type: 'string', + function_name: 'string', + // this is deliberately not 'json' since `function_args` may be a + // partial JSON string. + function_args: 'string', + index: 'number', + output: 'json' + } + }); + + // Register the web component + customElements.define('jai-tool-call', JaiToolCallWebComponent); + console.log("Registered custom 'jai-tool-call' web component."); + + // Finally, override the default Rendermime sanitizer to allow custom web + // components in the output. + class CustomSanitizer + extends Sanitizer + implements IRenderMime.ISanitizer + { + sanitize( + dirty: string, + customOptions: IRenderMime.ISanitizerOptions + ): string { + const options: IRenderMime.ISanitizerOptions = { + // default sanitizer options + ...(this as any)._options, + // custom sanitizer options (variable per call) + ...customOptions + }; + + return super.sanitize(dirty, { + ...options, + allowedTags: [...(options?.allowedTags ?? []), 'jai-tool-call'], + allowedAttributes: { + ...options?.allowedAttributes, + 'jai-tool-call': [ + 'id', + 'type', + 'function_name', + 'function_args', + 'index', + 'output' + ] + } + }); + } + } + return new CustomSanitizer(); + } + }; diff --git a/yarn.lock b/yarn.lock index 173eef6b3..d2ff65f5a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2256,6 +2256,7 @@ __metadata: "@lumino/widgets": ^2.3.2 "@mui/icons-material": ^5.11.0 "@mui/material": ^5.11.0 + "@r2wc/react-to-web-component": ^2.0.4 "@stylistic/eslint-plugin": ^3.0.1 "@types/jest": ^29 "@types/react-dom": ^18.2.0 @@ -4826,6 +4827,25 @@ __metadata: languageName: node linkType: hard +"@r2wc/core@npm:^1.0.0": + version: 1.2.0 + resolution: "@r2wc/core@npm:1.2.0" + checksum: e0dc23e8fd1f0d96193b67f5eb04b74b25b9f4609778e6ea2427c565eb590f458553cad307a2fdb3fc4614f6a576d7701b9bacf11775958bc560cc3b3b5aaae7 + languageName: node + linkType: hard + +"@r2wc/react-to-web-component@npm:^2.0.4": + version: 2.0.4 + resolution: "@r2wc/react-to-web-component@npm:2.0.4" + dependencies: + "@r2wc/core": ^1.0.0 + peerDependencies: + react: ^18.0.0 || ^19.0.0 + react-dom: ^18.0.0 || ^19.0.0 + checksum: 7b140ffd612173a30d74717d18efcf554774ef0ed0fe72f207ec21df707685ef5f4c34521e6840041665550c6461171dc32f12835f35beb1788ccac0c66c0e5c + languageName: node + linkType: hard + "@rjsf/core@npm:^5.13.4": version: 5.17.0 resolution: "@rjsf/core@npm:5.17.0"