Skip to content

feat: Add Gemma3 chat handler (#1976) #1989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 154 additions & 18 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2835,24 +2835,7 @@ def __call__(
)
llama.eval(tokens)
else:
image_bytes = self.load_image(value)
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
raise ValueError(
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
)
n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
# Required to avoid issues with hf tokenizer
llama.input_ids[llama.n_tokens : n_past.value] = -1
llama.n_tokens = n_past.value
self.eval_image(llama, value)

# Get prompt tokens to avoid a cache miss
prompt = llama.input_ids[: llama.n_tokens].tolist()
Expand Down Expand Up @@ -2938,6 +2921,26 @@ def __call__(
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)

def eval_image(self, llama: llama.Llama, image_url: str):
image_bytes = self.load_image(image_url)
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
raise ValueError(
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
)
n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
# Required to avoid issues with hf tokenizer
llama.input_ids[llama.n_tokens : n_past.value] = -1
llama.n_tokens = n_past.value

@staticmethod
def _load_image(image_url: str) -> bytes:
# TODO: Add Pillow support for other image formats beyond (jpg, png)
Expand Down Expand Up @@ -3373,6 +3376,139 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
)


class Gemma3ChatHandler(Llava15ChatHandler):
# Chat Format:
# '<bos><start_of_turn>user\n{system_prompt}\n\n{prompt}<end_of_turn>\n<start_of_turn>model\n'

DEFAULT_SYSTEM_MESSAGE = None

CHAT_FORMAT = (
"{{ '<bos>' }}"
"{%- if messages[0]['role'] == 'system' -%}"
"{%- if messages[0]['content'] is string -%}"
"{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}"
"{%- else -%}"
"{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}"
"{%- endif -%}"
"{%- set loop_messages = messages[1:] -%}"
"{%- else -%}"
"{%- set first_user_prefix = \"\" -%}"
"{%- set loop_messages = messages -%}"
"{%- endif -%}"
"{%- for message in loop_messages -%}"
"{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}"
"{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}"
"{%- endif -%}"
"{%- if (message['role'] == 'assistant') -%}"
"{%- set role = \"model\" -%}"
"{%- else -%}"
"{%- set role = message['role'] -%}"
"{%- endif -%}"
"{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}"
"{%- if message['content'] is string -%}"
"{{ message['content'] | trim }}"
"{%- elif message['content'] is iterable -%}"
"{%- for item in message['content'] -%}"
"{%- if item['type'] == 'image_url' -%}"
"{{ '<start_of_image>' }}"
"{%- elif item['type'] == 'text' -%}"
"{{ item['text'] | trim }}"
"{%- endif -%}"
"{%- endfor -%}"
"{%- else -%}"
"{{ raise_exception(\"Invalid content type\") }}"
"{%- endif -%}"
"{{ '<end_of_turn>\n' }}"
"{%- endfor -%}"
"{%- if add_generation_prompt -%}"
"{{ '<start_of_turn>model\n' }}"
"{%- endif -%}"
)

@staticmethod
def split_text_on_image_urls(text: str, image_urls: List[str]):
split_text: List[Tuple[Literal["text", "image_url"], str]] = []
copied_urls = image_urls[:]
remaining = text
image_placeholder = "<start_of_image>"

while remaining:
# Find placeholder
pos = remaining.find(image_placeholder)
if pos != -1:
assert len(copied_urls) > 0
if pos > 0:
split_text.append(("text", remaining[:pos]))
split_text.append(("text", "\n\n<start_of_image>"))
split_text.append(("image_url", copied_urls.pop(0)))
split_text.append(("text", "<end_of_image>\n\n"))
remaining = remaining[pos + len(image_placeholder):]
else:
assert len(copied_urls) == 0
split_text.append(("text", remaining))
remaining = ""
return split_text

def eval_image(self, llama: llama.Llama, image_url: str):
import llama_cpp

n_tokens = 256
if llama.n_tokens + n_tokens > llama.n_ctx():
raise ValueError(
f"Prompt exceeds n_ctx: {llama.n_tokens + n_tokens} > {llama.n_ctx()}"
)

img_bytes = self.load_image(image_url)
img_u8_p = self._llava_cpp.clip_image_u8_init()
if not self._llava_cpp.clip_image_load_from_bytes(
ctypes.create_string_buffer(img_bytes, len(img_bytes)),
ctypes.c_size_t(len(img_bytes)),
img_u8_p,
):
self._llava_cpp.clip_image_u8_free(img_u8_p)
raise ValueError("Failed to load image.")

img_f32 = self._llava_cpp.clip_image_f32_batch()
img_f32_p = ctypes.byref(img_f32)
if not self._llava_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p):
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
self._llava_cpp.clip_image_u8_free(img_u8_p)
raise ValueError("Failed to preprocess image.")

n_embd = llama_cpp.llama_model_n_embd(llama._model.model)
embed = (ctypes.c_float * (n_tokens * n_embd))()
if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed):
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
self._llava_cpp.clip_image_u8_free(img_u8_p)
raise ValueError("Failed to encode image.")

self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
self._llava_cpp.clip_image_u8_free(img_u8_p)
llama_cpp.llama_set_causal_attn(llama.ctx, False)

seq_id_0 = (ctypes.c_int32 * 1)()
seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))()
for i in range(n_tokens):
seq_ids[i] = seq_id_0

batch = llama_cpp.llama_batch()
batch.n_tokens = n_tokens
batch.token = None
batch.embd = embed
batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)])
batch.seq_id = seq_ids
batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens))
batch.logits = (ctypes.c_int8 * n_tokens)()

if llama_cpp.llama_decode(llama.ctx, batch):
raise ValueError("Failed to decode image.")

llama_cpp.llama_set_causal_attn(llama.ctx, True)
# Required to avoid issues with hf tokenizer
llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1
llama.n_tokens += n_tokens


@register_chat_completion_handler("chatml-function-calling")
def chatml_function_calling(
llama: llama.Llama,
Expand Down
113 changes: 113 additions & 0 deletions llama_cpp/llava_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
c_int,
c_uint8,
c_float,
c_size_t,
c_void_p,
POINTER,
_Pointer, # type: ignore
Expand Down Expand Up @@ -141,6 +142,28 @@ def llava_eval_image_embed(
################################################


# struct clip_image_u8_batch {
# struct clip_image_u8 * data;
# size_t size;
# };
class clip_image_u8_batch(Structure):
_fields_ = [
("data", c_void_p),
("size", c_size_t),
]


# struct clip_image_f32_batch {
# struct clip_image_f32 * data;
# size_t size;
# };
class clip_image_f32_batch(Structure):
_fields_ = [
("data", c_void_p),
("size", c_size_t),
]


# /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
Expand All @@ -156,3 +179,93 @@ def clip_model_load(
def clip_free(ctx: clip_ctx_p, /):
...


# CLIP_API struct clip_image_u8 * clip_image_u8_init ();
@ctypes_function("clip_image_u8_init", [], c_void_p)
def clip_image_u8_init() -> Optional[c_void_p]:
...


# CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
@ctypes_function("clip_image_u8_free", [c_void_p], None)
def clip_image_u8_free(img: c_void_p, /):
...


# CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
@ctypes_function("clip_image_f32_free", [c_void_p], None)
def clip_image_f32_free(img: c_void_p, /):
...


# CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
@ctypes_function("clip_image_u8_batch_free", [POINTER(clip_image_u8_batch)], None)
def clip_image_u8_batch_free(batch: "_Pointer[clip_image_u8_batch]", /):
...


# CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
@ctypes_function("clip_image_f32_batch_free", [POINTER(clip_image_f32_batch)], None)
def clip_image_f32_batch_free(batch: "_Pointer[clip_image_f32_batch]", /):
...


# /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
# CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
@ctypes_function(
"clip_image_preprocess",
[
clip_ctx_p_ctypes,
c_void_p,
POINTER(clip_image_f32_batch),
],
c_bool,
)
def clip_image_preprocess(
ctx: clip_ctx_p,
img: c_void_p,
res_imgs: "_Pointer[clip_image_f32_batch]",
/,
) -> bool:
...


# CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
@ctypes_function(
"clip_image_batch_encode",
[
clip_ctx_p_ctypes,
c_int,
POINTER(clip_image_f32_batch),
POINTER(c_float),
],
c_bool,
)
def clip_image_batch_encode(
ctx: clip_ctx_p,
n_threads: c_int,
imgs: "_Pointer[clip_image_f32_batch]",
vec: c_void_p,
/,
) -> bool:
...


# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
@ctypes_function(
"clip_image_load_from_bytes",
[
c_void_p,
c_size_t,
c_void_p,
],
c_bool,
)
def clip_image_load_from_bytes(
bytes: c_void_p,
bytes_length: c_size_t,
img: c_void_p,
/,
) -> bool:
...