Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update #149

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,14 @@ def register(
self,
function: None = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Callable:
...
) -> Callable: ...

@overload
def register(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionDefinition:
...
) -> FunctionDefinition: ...

def register(
self,
Expand Down
8 changes: 6 additions & 2 deletions chatlab/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@

"""


from typing import Callable, Optional

from pydantic import BaseModel


class ChatlabMetadata(BaseModel):
"""ChatLab metadata for a function."""

expose_exception_to_llm: bool = True
render: Optional[Callable] = None
bubble_exceptions: bool = False


def bubble_exceptions(func):
if not hasattr(func, "chatlab_metadata"):
func.chatlab_metadata = ChatlabMetadata()
Expand All @@ -51,6 +53,7 @@ def bubble_exceptions(func):
func.chatlab_metadata.bubble_exceptions = True
return func


def expose_exception_to_llm(func):
"""Expose exceptions from calling the function to the LLM.

Expand Down Expand Up @@ -107,6 +110,7 @@ def store_knowledge_graph(kg: KnowledgeGraph, comment: str = "Knowledge Graph"):
chat.register(store_knowledge_graph)
'''


def incremental_display(render_func: Callable):
def decorator(func):
if not hasattr(func, "chatlab_metadata"):
Expand All @@ -118,5 +122,5 @@ def decorator(func):

func.chatlab_metadata.render = render_func
return func
return decorator

return decorator
3 changes: 1 addition & 2 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam:


class HasGetToolArgumentsParameter(Protocol):
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam:
...
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam: ...


def assistant_tool_calls(tool_calls: Iterable[HasGetToolArgumentsParameter]) -> ChatCompletionMessageParam:
Expand Down
1 change: 1 addition & 0 deletions chatlab/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class ChatModel(Enum):
"""Models available for use with chatlab."""

GPT_4_TURBO_PREVIEW = "gpt-4-turbo-preview"
GPT_4_0125_PREVIEW = "gpt-4-0125-preview"
GPT_4_1106_PREVIEW = "gpt-4-1106-preview"
Expand Down
14 changes: 9 additions & 5 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,14 @@ def register(
self,
function: None = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Callable:
...
) -> Callable: ...

@overload
def register(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionDefinition:
...
) -> FunctionDefinition: ...

def register(
self,
Expand Down Expand Up @@ -438,7 +436,13 @@ def api_manifest(self, function_call_option: FunctionCall = "auto") -> APIManife

@property
def tools(self) -> Iterable[ChatCompletionToolParam]:
return [{"type": "function", "function": adapt_function_definition(f)} for f in self.__schemas.values()]
return [
ChatCompletionToolParam(
type="function",
function=adapt_function_definition(f), # type: ignore
)
for f in self.__schemas.values()
]

async def call(self, name: str, arguments: Optional[str] = None) -> Any:
"""Call a function by name with the given parameters."""
Expand Down
1 change: 0 additions & 1 deletion chatlab/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
"run_python",
"shell_functions",
]

1 change: 1 addition & 0 deletions chatlab/tools/_mediatypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Media types for rich output for LLMs and in-notebook."""

import json
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/colors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Let models pick and show color palettes to you."""

import hashlib
from typing import List, Optional
from pydantic import BaseModel, validator, Field
Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

You've been warned. Have fun and be safe!
"""

import asyncio
import os

Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The in-IPython python code runner for ChatLab."""

from traceback import TracebackException
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/shell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Shell commands for ChatLab."""

import asyncio
import subprocess

Expand Down
7 changes: 2 additions & 5 deletions chatlab/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Views for ChatLab."""

from .assistant import AssistantMessageView
from .tools import ToolArguments, ToolCalled

__all__ = [
"AssistantMessageView",
"ToolArguments",
"ToolCalled"
]
__all__ = ["AssistantMessageView", "ToolArguments", "ToolCalled"]
27 changes: 16 additions & 11 deletions chatlab/views/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from IPython.display import display
from IPython.core.getipython import get_ipython

from instructor.dsl.partialjson import JSONParser

from jiter import from_json


class ToolArguments(AutoUpdate):
Expand Down Expand Up @@ -75,11 +74,11 @@ def update(self) -> None:

def render(self):
if self.custom_render is not None:
# We use the same definition as was in the original function
try:
parser = JSONParser()
possible_args = parser.parse(self.arguments)

possible_args = from_json(self.arguments.encode("utf-8"), partial_mode="trailing-strings")
except Exception:
return None
try:
Model = extract_model_from_function(self.name, self.custom_render)
# model = Model.model_validate(possible_args)
model = Model(**possible_args)
Expand Down Expand Up @@ -110,13 +109,17 @@ def append_arguments(self, arguments: str):
def apply_result(self, result: str):
"""Replaces the existing display with a new one that shows the result of the tool being called."""
tc = ToolCalled(
id=self.id, name=self.name, arguments=self.arguments, result=result, display_id=self.display_id,
custom_render=self.custom_render
id=self.id,
name=self.name,
arguments=self.arguments,
result=result,
display_id=self.display_id,
custom_render=self.custom_render,
)
tc.update()
return tc

async def call(self, function_registry: FunctionRegistry) -> 'ToolCalled':
async def call(self, function_registry: FunctionRegistry) -> "ToolCalled":
"""Call the function and return a stack of messages for LLM and human consumption."""
function_name = self.name
function_args = self.arguments
Expand Down Expand Up @@ -185,9 +188,11 @@ def render(self):
if self.custom_render is not None:
# We use the same definition as was in the original function
try:
parser = JSONParser()
possible_args = parser.parse(self.arguments)
possible_args = from_json(self.arguments.encode("utf-8"), partial_mode="trailing-strings")
except Exception:
return None

try:
Model = extract_model_from_function(self.name, self.custom_render)
# model = Model.model_validate(possible_args)
model = Model(**possible_args)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"outputs": [],
"source": [
"from datetime import datetime\n",
"from pytz import timezone, all_timezones, utc\n",
"from pytz import timezone, all_timezones\n",
"from typing import Optional\n",
"from pydantic import BaseModel\n",
"\n",
Expand Down
Loading
Loading