Skip to content
51 changes: 51 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3519,6 +3519,57 @@ def __call__(self, **kwargs):
return super().__call__(**kwargs)


class Gemma3ChatHandler(Llava15ChatHandler):
# Chat Format:
# '<bos><start_of_turn>user\n{system_prompt}\n\n{prompt}<end_of_turn>\n<start_of_turn>model\n'

DEFAULT_SYSTEM_MESSAGE = None

CHAT_FORMAT = (
"{% if messages[0]['role'] == 'system' %}"
"{% if messages[0]['content'] is string %}"
"{% set first_user_prefix = messages[0]['content'] + '\n\n' %}"
"{% else %}"
"{% set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' %}"
"{% endif %}"
"{% set loop_messages = messages[1:] %}"
"{% else %}"
"{% set first_user_prefix = \"\" %}"
"{% set loop_messages = messages %}"
"{% endif %}"
"{% for message in loop_messages %}"
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}"
"{% endif %}"
"{% if (message['role'] == 'assistant') %}"
"{% set role = \"model\" %}"
"{% else %}"
"{% set role = message['role'] %}"
"{% endif %}"
"{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}"
"{% if message['content'] is string %}"
"{{ message['content'] | trim }}"
"{% elif message['content'] is iterable %}"
"{% for item in message['content'] %}"
"{% if item['type'] == 'image_url' and item['image_url'] is string %}"
"{{ '\n\n' + item['image_url'] + '\n\n' }}"
"{% elif item['type'] == 'image_url' and item['image_url'] is mapping %}"
"{{ '\n\n' + item['image_url']['url'] + '\n\n' }}"
"{% elif item['type'] == 'text' %}"
"{{ item['text'] | trim }}"
"{% endif %}"
"{% endfor %}"
"{% else %}"
"{{ raise_exception(\"Invalid content type\") }}"
"{% endif %}"
"{{ '<end_of_turn>\n' }}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<start_of_turn>model\n' }}"
"{% endif %}"
)


@register_chat_completion_handler("chatml-function-calling")
def chatml_function_calling(
llama: llama.Llama,
Expand Down