3939)
4040from ._logging import log_tool_error
4141from ._provider import Provider
42- from ._tools import Tool
42+ from ._tools import Stringable , Tool , ToolResult
4343from ._turn import Turn , user_turn
4444from ._typing_extensions import TypedDict
4545from ._utils import html_escape , wrap_async
@@ -96,6 +96,9 @@ def __init__(
9696 "rich_console" : {},
9797 "css_styles" : {},
9898 }
99+ self ._on_tool_request_default : Optional [
100+ Callable [[ContentToolRequest ], Stringable ]
101+ ] = None
99102
100103 def get_turns (
101104 self ,
@@ -658,7 +661,7 @@ def stream(
658661 kwargs = kwargs ,
659662 )
660663
661- def wrapper () -> Generator [str , None , None ]:
664+ def wrapper () -> Generator [Stringable , None , None ]:
662665 with display :
663666 for chunk in generator :
664667 yield chunk
@@ -695,7 +698,7 @@ async def stream_async(
695698
696699 display = self ._markdown_display (echo = echo )
697700
698- async def wrapper () -> AsyncGenerator [str , None ]:
701+ async def wrapper () -> AsyncGenerator [Stringable , None ]:
699702 with display :
700703 async for chunk in self ._chat_impl_async (
701704 turn ,
@@ -831,6 +834,7 @@ def register_tool(
831834 self ,
832835 func : Callable [..., Any ] | Callable [..., Awaitable [Any ]],
833836 * ,
837+ on_request : Optional [Callable [[ContentToolRequest ], Stringable ]] = None ,
834838 model : Optional [type [BaseModel ]] = None ,
835839 ):
836840 """
@@ -900,16 +904,49 @@ def add(a: int, b: int) -> int:
900904 ----------
901905 func
902906 The function to be invoked when the tool is called.
907+ on_request
908+ A callable that will be passed a :class:`~chatlas.ContentToolRequest`
909+ when the tool is requested. If defined, and the callable returns a
910+ stringable object, that value will be yielded to the chat as a part
911+ of the response.
903912 model
904913 A Pydantic model that describes the input parameters for the function.
905914 If not provided, the model will be inferred from the function's type hints.
906915 The primary reason why you might want to provide a model in
907916 Note that the name and docstring of the model takes precedence over the
908917 name and docstring of the function.
909918 """
910- tool = Tool (func , model = model )
919+ tool = Tool (func , on_request = on_request , model = model )
911920 self ._tools [tool .name ] = tool
912921
922+ def on_tool_request (
923+ self ,
924+ func : Callable [[ContentToolRequest ], Stringable ],
925+ ):
926+ """
927+ Register a default function to be invoked when a tool is requested.
928+
929+ This function will be invoked if a tool is requested that does not have
930+ a specific `on_request` function defined.
931+
932+ Parameters
933+ ----------
934+ func
935+ A callable that will be passed a :class:`~chatlas.ContentToolRequest`
936+ when the tool is requested. If defined, and the callable returns a
937+ stringable object, that value will be yielded to the chat as a part
938+ of the response.
939+ """
940+ self ._on_tool_request_default = func
941+
942+ def _on_tool_request (self , req : ContentToolRequest ) -> Stringable | None :
943+ tool_def = self ._tools .get (req .name , None )
944+ if tool_def and tool_def .on_request :
945+ return tool_def .on_request (req )
946+ if self ._on_tool_request_default :
947+ return self ._on_tool_request_default (req )
948+ return None
949+
913950 def export (
914951 self ,
915952 filename : str | Path ,
@@ -1040,7 +1077,7 @@ def _chat_impl(
10401077 display : MarkdownDisplay ,
10411078 stream : bool ,
10421079 kwargs : Optional [SubmitInputArgsT ] = None ,
1043- ) -> Generator [str , None , None ]:
1080+ ) -> Generator [Stringable , None , None ]:
10441081 user_turn_result : Turn | None = user_turn
10451082 while user_turn_result is not None :
10461083 for chunk in self ._submit_turns (
@@ -1051,7 +1088,24 @@ def _chat_impl(
10511088 kwargs = kwargs ,
10521089 ):
10531090 yield chunk
1054- user_turn_result = self ._invoke_tools ()
1091+
1092+ turn = self .get_last_turn (role = "assistant" )
1093+ assert turn is not None
1094+ user_turn_result = None
1095+
1096+ results : list [ContentToolResult ] = []
1097+ for x in turn .contents :
1098+ if isinstance (x , ContentToolRequest ):
1099+ req = self ._on_tool_request (x )
1100+ if req is not None :
1101+ yield req
1102+ result , output = self ._invoke_tool_request (x )
1103+ if output is not None :
1104+ yield output
1105+ results .append (result )
1106+
1107+ if results :
1108+ user_turn_result = Turn ("user" , results )
10551109
10561110 async def _chat_impl_async (
10571111 self ,
@@ -1060,7 +1114,7 @@ async def _chat_impl_async(
10601114 display : MarkdownDisplay ,
10611115 stream : bool ,
10621116 kwargs : Optional [SubmitInputArgsT ] = None ,
1063- ) -> AsyncGenerator [str , None ]:
1117+ ) -> AsyncGenerator [Stringable , None ]:
10641118 user_turn_result : Turn | None = user_turn
10651119 while user_turn_result is not None :
10661120 async for chunk in self ._submit_turns_async (
@@ -1071,7 +1125,24 @@ async def _chat_impl_async(
10711125 kwargs = kwargs ,
10721126 ):
10731127 yield chunk
1074- user_turn_result = await self ._invoke_tools_async ()
1128+
1129+ turn = self .get_last_turn (role = "assistant" )
1130+ assert turn is not None
1131+ user_turn_result = None
1132+
1133+ results : list [ContentToolResult ] = []
1134+ for x in turn .contents :
1135+ if isinstance (x , ContentToolRequest ):
1136+ req = self ._on_tool_request (x )
1137+ if req is not None :
1138+ yield req
1139+ result , output = await self ._invoke_tool_request_async (x )
1140+ if output is not None :
1141+ yield output
1142+ results .append (result )
1143+
1144+ if results :
1145+ user_turn_result = Turn ("user" , results )
10751146
10761147 def _submit_turns (
10771148 self ,
@@ -1085,7 +1156,7 @@ def _submit_turns(
10851156 if any (x ._is_async for x in self ._tools .values ()):
10861157 raise ValueError ("Cannot use async tools in a synchronous chat" )
10871158
1088- def emit (text : str | Content ):
1159+ def emit (text : Stringable ):
10891160 display .update (str (text ))
10901161
10911162 emit ("<br>\n \n " )
@@ -1148,7 +1219,7 @@ async def _submit_turns_async(
11481219 data_model : type [BaseModel ] | None = None ,
11491220 kwargs : Optional [SubmitInputArgsT ] = None ,
11501221 ) -> AsyncGenerator [str , None ]:
1151- def emit (text : str | Content ):
1222+ def emit (text : Stringable ):
11521223 display .update (str (text ))
11531224
11541225 emit ("<br>\n \n " )
@@ -1202,88 +1273,62 @@ def emit(text: str | Content):
12021273
12031274 self ._turns .extend ([user_turn , turn ])
12041275
1205- def _invoke_tools (self ) -> Turn | None :
1206- turn = self .get_last_turn ()
1207- if turn is None :
1208- return None
1209-
1210- results : list [ContentToolResult ] = []
1211- for x in turn .contents :
1212- if isinstance (x , ContentToolRequest ):
1213- tool_def = self ._tools .get (x .name , None )
1214- func = tool_def .func if tool_def is not None else None
1215- results .append (self ._invoke_tool (func , x .arguments , x .id ))
1216-
1217- if not results :
1218- return None
1276+ def _invoke_tool_request (
1277+ self , x : ContentToolRequest
1278+ ) -> tuple [ContentToolResult , Stringable ]:
1279+ tool_def = self ._tools .get (x .name , None )
1280+ func = tool_def .func if tool_def is not None else None
12191281
1220- return Turn ("user" , results )
1221-
1222- async def _invoke_tools_async (self ) -> Turn | None :
1223- turn = self .get_last_turn ()
1224- if turn is None :
1225- return None
1226-
1227- results : list [ContentToolResult ] = []
1228- for x in turn .contents :
1229- if isinstance (x , ContentToolRequest ):
1230- tool_def = self ._tools .get (x .name , None )
1231- func = None
1232- if tool_def :
1233- if tool_def ._is_async :
1234- func = tool_def .func
1235- else :
1236- func = wrap_async (tool_def .func )
1237- results .append (await self ._invoke_tool_async (func , x .arguments , x .id ))
1238-
1239- if not results :
1240- return None
1241-
1242- return Turn ("user" , results )
1243-
1244- @staticmethod
1245- def _invoke_tool (
1246- func : Callable [..., Any ] | None ,
1247- arguments : object ,
1248- id_ : str ,
1249- ) -> ContentToolResult :
12501282 if func is None :
1251- return ContentToolResult (id_ , value = None , error = "Unknown tool" )
1283+ return ContentToolResult (x . id , value = None , error = "Unknown tool" ), None
12521284
12531285 name = func .__name__
12541286
12551287 try :
1256- if isinstance (arguments , dict ):
1257- result = func (** arguments )
1288+ if isinstance (x . arguments , dict ):
1289+ result = func (** x . arguments )
12581290 else :
1259- result = func (arguments )
1291+ result = func (x . arguments )
12601292
1261- return ContentToolResult (id_ , value = result , error = None , name = name )
1293+ value , output = (result , None )
1294+ if isinstance (result , ToolResult ):
1295+ value , output = (result .assistant , result .output )
1296+
1297+ return ContentToolResult (x .id , value = value , error = None , name = name ), output
12621298 except Exception as e :
1263- log_tool_error (name , str (arguments ), e )
1264- return ContentToolResult (id_ , value = None , error = str (e ), name = name )
1299+ log_tool_error (name , str (x .arguments ), e )
1300+ return ContentToolResult (x .id , value = None , error = str (e ), name = name ), None
1301+
1302+ async def _invoke_tool_request_async (
1303+ self , x : ContentToolRequest
1304+ ) -> tuple [ContentToolResult , Stringable ]:
1305+ tool_def = self ._tools .get (x .name , None )
1306+ func = None
1307+ if tool_def :
1308+ if tool_def ._is_async :
1309+ func = tool_def .func
1310+ else :
1311+ func = wrap_async (tool_def .func )
12651312
1266- @staticmethod
1267- async def _invoke_tool_async (
1268- func : Callable [..., Awaitable [Any ]] | None ,
1269- arguments : object ,
1270- id_ : str ,
1271- ) -> ContentToolResult :
12721313 if func is None :
1273- return ContentToolResult (id_ , value = None , error = "Unknown tool" )
1314+ return ContentToolResult (x . id , value = None , error = "Unknown tool" ), None
12741315
12751316 name = func .__name__
12761317
12771318 try :
1278- if isinstance (arguments , dict ):
1279- result = await func (** arguments )
1319+ if isinstance (x . arguments , dict ):
1320+ result = await func (** x . arguments )
12801321 else :
1281- result = await func (arguments )
1322+ result = await func (x .arguments )
1323+
1324+ value , output = (result , None )
1325+ if isinstance (result , ToolResult ):
1326+ value , output = (result .assistant , result .output )
12821327
1283- return ContentToolResult (id_ , value = result , error = None , name = name )
1328+ return ContentToolResult (x . id , value = value , error = None , name = name ), output
12841329 except Exception as e :
1285- log_tool_error (func .__name__ , str (arguments ), e )
1286- return ContentToolResult (id_ , value = None , error = str (e ), name = name )
1330+ log_tool_error (func .__name__ , str (x . arguments ), e )
1331+ return ContentToolResult (x . id , value = None , error = str (e ), name = name ), None
12871332
12881333 def _markdown_display (
12891334 self , echo : Literal ["text" , "all" , "none" ]
@@ -1378,15 +1423,15 @@ class ChatResponse:
13781423 still be retrieved (via the `content` attribute).
13791424 """
13801425
1381- def __init__ (self , generator : Generator [str , None ]):
1426+ def __init__ (self , generator : Generator [Stringable , None ]):
13821427 self ._generator = generator
13831428 self .content : str = ""
13841429
13851430 def __iter__ (self ) -> Iterator [str ]:
13861431 return self
13871432
13881433 def __next__ (self ) -> str :
1389- chunk = next (self ._generator )
1434+ chunk = str ( next (self ._generator ) )
13901435 self .content += chunk # Keep track of accumulated content
13911436 return chunk
13921437
@@ -1430,15 +1475,15 @@ class ChatResponseAsync:
14301475 still be retrieved (via the `content` attribute).
14311476 """
14321477
1433- def __init__ (self , generator : AsyncGenerator [str , None ]):
1478+ def __init__ (self , generator : AsyncGenerator [Stringable , None ]):
14341479 self ._generator = generator
14351480 self .content : str = ""
14361481
14371482 def __aiter__ (self ) -> AsyncIterator [str ]:
14381483 return self
14391484
14401485 async def __anext__ (self ) -> str :
1441- chunk = await self ._generator .__anext__ ()
1486+ chunk = str ( await self ._generator .__anext__ () )
14421487 self .content += chunk # Keep track of accumulated content
14431488 return chunk
14441489
0 commit comments