diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index ea8d07feb..be4e0c835 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -351,10 +351,49 @@ def _convert_completion_to_chat_function( ], stream: bool, ): + def _completion_text_to_tool_calls( + tool_name: str, + completion_text: str, + completion_id: str, + stream: bool, + ) -> Union[ + llama_types.ChatCompletionMessageToolCalls, List[llama_types.ChatCompletionMessageToolCallChunk] + ]: + try: + function_calls = json.loads(completion_text) + assert isinstance(function_calls, list) + except Exception as e: + function_calls = [] + + i = 0 + tool_calls = [] + for function_call in function_calls: + function_name = function_call.get("name") + function_arguments = function_call.get("arguments") + if function_name == tool_name and function_arguments: + tool_id = f'call__{i}_{tool_name}_{completion_id}' + tool_call = { + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(function_arguments, ensure_ascii=False), + }, + } + if stream: + tool_call["index"] = i + typed_call: llama_types.ChatCompletionMessageToolCallChunk = tool_call + else: + typed_call: llama_types.ChatCompletionMessageToolCall = tool_call + tool_calls.append(typed_call) + i += 1 + + return tool_calls + if not stream: completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore assert "usage" in completion - tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"] + tool_calls: llama_types.ChatCompletionMessageToolCalls = _completion_text_to_tool_calls(tool_name, completion["choices"][0]["text"], completion["id"], stream) # type: ignore # TODO: Fix for legacy function calls chat_completion: llama_types.CreateChatCompletionResponse = { "id": "chat" + completion["id"], @@ -366,24 +405,12 @@ def _convert_completion_to_chat_function( "index": 0, "message": { "role": "assistant", - "content": None, - "function_call": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - "tool_calls": [ - { - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - } - ], + "content": None if tool_calls else completion["choices"][0]["text"], + "function_call": tool_calls[0]["function"] if tool_calls else None, + "tool_calls": tool_calls or None, }, "logprobs": completion["choices"][0]["logprobs"], - "finish_reason": "tool_calls", + "finish_reason": "tool_calls" if tool_calls else completion["choices"][0]["finish_reason"], } ], "usage": completion["usage"], @@ -400,13 +427,15 @@ def _stream_response_to_function_stream( id_ = None created = None model = None - tool_id = None + finish = None + tools_called = "" for chunk in chunks: + tools_called += chunk["choices"][0]["text"] + finish = chunk["choices"][0]["finish_reason"] if first: id_ = "chat" + chunk["id"] created = chunk["created"] model = chunk["model"] - tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] yield { "id": id_, "object": "chat.completion.chunk", @@ -438,31 +467,15 @@ def _stream_response_to_function_stream( "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": chunk["choices"][0][ - "text" - ], - }, - } - ], + "content": chunk["choices"][0]["text"], + "function_call": None, + "tool_calls": None, }, } ], } first = False continue - assert tool_id is not None yield { "id": "chat" + chunk["id"], "object": "chat.completion.chunk", @@ -475,28 +488,16 @@ def _stream_response_to_function_stream( "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - } - ], + "content": chunk["choices"][0]["text"], + "function_call": None, + "tool_calls": None, }, } ], } if id_ is not None and created is not None and model is not None: + tool_calls: List[llama_types.ChatCompletionMessageToolCallChunk] = _completion_text_to_tool_calls(tool_name, tools_called, id_, stream) # type: ignore yield { "id": id_, "object": "chat.completion.chunk", @@ -505,13 +506,13 @@ def _stream_response_to_function_stream( "choices": [ { "index": 0, - "finish_reason": "tool_calls", + "finish_reason": "tool_calls" if tool_calls else finish, "logprobs": None, "delta": { "role": None, "content": None, - "function_call": None, - "tool_calls": None, + "function_call": tool_calls[0]["function"] if tool_calls else None, + "tool_calls": tool_calls or None, }, } ], @@ -621,7 +622,22 @@ def chat_completion_handler( tool = next((t for t in tools if t["function"]["name"] == name), None) if tool is None: raise ValueError(f"Tool choice '{name}' not found in tools.") - schema = tool["function"]["parameters"] + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "arguments": tool["function"]["parameters"] + }, + "required": [ + "name", + "arguments" + ] + } + } try: # create grammar from json schema grammar = llama_grammar.LlamaGrammar.from_json_schema( @@ -3486,9 +3502,25 @@ def chatml_function_calling( add_generation_prompt=True, ) prompt += f"functions.{tool_name}:\n" + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "arguments": tool["function"]["parameters"] + }, + "required": [ + "name", + "arguments" + ] + } + } try: grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + json.dumps(schema), verbose=llama.verbose ) except Exception as e: grammar = llama_grammar.LlamaGrammar.from_string(