From f33dde30a1597b0e9d62bc7f35cb42a2e9910593 Mon Sep 17 00:00:00 2001 From: kossum <127719370+kossum@users.noreply.github.com> Date: Mon, 31 Mar 2025 04:15:39 +0900 Subject: [PATCH 1/4] feat: Add Gemma3 chat handler (#1976) --- llama_cpp/llama_chat_format.py | 89 ++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 17575c700..0d6d39cb8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3373,6 +3373,95 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +class Gemma3ChatHandler(Llava15ChatHandler): + # Chat Format: + # 'user\n{system_prompt}\n\n{prompt}\nmodel\n' + + DEFAULT_SYSTEM_MESSAGE = None + + CHAT_FORMAT = ( + "{{ '' }}" + "{%- if messages[0]['role'] == 'system' -%}" + "{%- if messages[0]['content'] is string -%}" + "{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}" + "{%- else -%}" + "{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}" + "{%- endif -%}" + "{%- set loop_messages = messages[1:] -%}" + "{%- else -%}" + "{%- set first_user_prefix = \"\" -%}" + "{%- set loop_messages = messages -%}" + "{%- endif -%}" + "{%- for message in loop_messages -%}" + "{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}" + "{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}" + "{%- endif -%}" + "{%- if (message['role'] == 'assistant') -%}" + "{%- set role = \"model\" -%}" + "{%- else -%}" + "{%- set role = message['role'] -%}" + "{%- endif -%}" + "{{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}" + "{%- if message['content'] is string -%}" + "{{ message['content'] | trim }}" + "{%- elif message['content'] is iterable -%}" + "{%- for item in message['content'] -%}" + "{%- if item['type'] == 'image' -%}" + "{{ '' }}" + "{%- elif item['type'] == 'text' -%}" + "{{ item['text'] | trim }}" + "{%- endif -%}" + "{%- endfor -%}" + "{%- else -%}" + "{{ raise_exception(\"Invalid content type\") }}" + "{%- endif -%}" + "{{ '\n' }}" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}" + "{{ 'model\n' }}" + "{%- endif -%}" + ) + + @staticmethod + def split_text_on_image_urls(text: str, image_urls: List[str]): + split_text: List[Tuple[Literal["text", "image_url"], str]] = [] + copied_urls = image_urls[:] + remaining = text + image_placeholder = "" + + while remaining: + # Find placeholder + pos = remaining.find(image_placeholder) + if pos != -1: + assert len(copied_urls) > 0 + if pos > 0: + split_text += [("text", remaining[:pos])] + split_text += [("text", "\n\n")] + split_text += [("image_url", copied_urls.pop(0))] + split_text += [("text", "\n\n")] + remaining = remaining[pos + len(image_placeholder):] + else: + assert len(copied_urls) == 0 + split_text.append(("text", remaining)) + remaining = "" + return split_text + + @staticmethod + def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): + image_urls: List[str] = [] + for message in messages: + if message["role"] == "user": + if message.get("content") is None: + continue + for content in message["content"]: + if isinstance(content, dict) and content.get("type") == "image": + if isinstance(content.get("image"), dict) and isinstance(content["image"].get("url"), str): + image_urls.append(content["image"]["url"]) + elif isinstance(content.get("url"), str): + image_urls.append(content["url"]) + return image_urls + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, From 25b2f8fe0d92cb27e364d3c9601dde77e50446bf Mon Sep 17 00:00:00 2001 From: kossum <127719370+kossum@users.noreply.github.com> Date: Thu, 3 Apr 2025 06:25:21 +0900 Subject: [PATCH 2/4] resolve the image embedding issue in gemma3 --- llama_cpp/llama_chat_format.py | 101 ++++++++++++++++++++++------- llama_cpp/llava_cpp.py | 112 +++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+), 22 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0d6d39cb8..7ac0f4016 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2835,24 +2835,7 @@ def __call__( ) llama.eval(tokens) else: - image_bytes = self.load_image(value) - embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) - if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): - raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" - ) - n_past = ctypes.c_int(llama.n_tokens) - n_past_p = ctypes.pointer(n_past) - with suppress_stdout_stderr(disable=self.verbose): - self._llava_cpp.llava_eval_image_embed( - llama.ctx, - embed, - llama.n_batch, - n_past_p, - ) - # Required to avoid issues with hf tokenizer - llama.input_ids[llama.n_tokens : n_past.value] = -1 - llama.n_tokens = n_past.value + self.eval_image(llama, value) # Get prompt tokens to avoid a cache miss prompt = llama.input_ids[: llama.n_tokens].tolist() @@ -2938,6 +2921,26 @@ def __call__( ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) + def eval_image(self, llama: llama.Llama, image_url: str): + image_bytes = self.load_image(image_url) + embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) + if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): + raise ValueError( + f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" + ) + n_past = ctypes.c_int(llama.n_tokens) + n_past_p = ctypes.pointer(n_past) + with suppress_stdout_stderr(disable=self.verbose): + self._llava_cpp.llava_eval_image_embed( + llama.ctx, + embed, + llama.n_batch, + n_past_p, + ) + # Required to avoid issues with hf tokenizer + llama.input_ids[llama.n_tokens : n_past.value] = -1 + llama.n_tokens = n_past.value + @staticmethod def _load_image(image_url: str) -> bytes: # TODO: Add Pillow support for other image formats beyond (jpg, png) @@ -3435,10 +3438,10 @@ def split_text_on_image_urls(text: str, image_urls: List[str]): if pos != -1: assert len(copied_urls) > 0 if pos > 0: - split_text += [("text", remaining[:pos])] - split_text += [("text", "\n\n")] - split_text += [("image_url", copied_urls.pop(0))] - split_text += [("text", "\n\n")] + split_text.append(("text", remaining[:pos])) + split_text.append(("text", "\n\n")) + split_text.append(("image_url", copied_urls.pop(0))) + split_text.append(("text", "\n\n")) remaining = remaining[pos + len(image_placeholder):] else: assert len(copied_urls) == 0 @@ -3461,6 +3464,60 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): image_urls.append(content["url"]) return image_urls + def eval_image(self, llama: llama.Llama, image_url: str): + import llama_cpp + + img_bytes = self.load_image(image_url) + img_u8_p = self._llava_cpp.clip_image_u8_init() + if not self._llava_cpp.clip_image_load_from_bytes( + ctypes.create_string_buffer(img_bytes, len(img_bytes)), + ctypes.c_size_t(len(img_bytes)), + img_u8_p, + ): + self._llava_cpp.clip_image_u8_free(img_u8_p) + raise ValueError("Failed to load image.") + + img_f32 = self._llava_cpp.clip_image_f32_batch() + img_f32_p = ctypes.byref(img_f32) + if not self._llava_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p): + self._llava_cpp.clip_image_f32_batch_free(img_f32_p) + self._llava_cpp.clip_image_u8_free(img_u8_p) + raise ValueError("Failed to preprocess image.") + + n_embd = llama_cpp.llama_model_n_embd(llama._model.model) + n_tokens = 256 + embed = (ctypes.c_float * (n_tokens * n_embd))() + if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed): + self._llava_cpp.clip_image_f32_batch_free(img_f32_p) + self._llava_cpp.clip_image_u8_free(img_u8_p) + raise ValueError("Failed to encode image.") + + self._llava_cpp.clip_image_f32_batch_free(img_f32_p) + self._llava_cpp.clip_image_u8_free(img_u8_p) + llama_cpp.llama_set_causal_attn(llama.ctx, False) + + seq_id_0 = (ctypes.c_int32 * 1)() + seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))() + for i in range(n_tokens): + seq_ids[i] = seq_id_0 + + batch = llama_cpp.llama_batch() + batch.n_tokens = n_tokens + batch.token = None + batch.embd = embed + batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)]) + batch.seq_id = seq_ids + batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens)) + batch.logits = (ctypes.c_int8 * n_tokens)() + + if llama_cpp.llama_decode(llama.ctx, batch): + raise ValueError("Failed to decode image.") + + llama_cpp.llama_set_causal_attn(llama.ctx, True) + # Required to avoid issues with hf tokenizer + llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1 + llama.n_tokens += n_tokens + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py index d9dfaf5fd..46ac5087f 100644 --- a/llama_cpp/llava_cpp.py +++ b/llama_cpp/llava_cpp.py @@ -7,6 +7,7 @@ c_int, c_uint8, c_float, + c_size_t, c_void_p, POINTER, _Pointer, # type: ignore @@ -141,6 +142,28 @@ def llava_eval_image_embed( ################################################ +# struct clip_image_u8_batch { +# struct clip_image_u8 * data; +# size_t size; +# }; +class clip_image_u8_batch(Structure): + _fields_ = [ + ("data", c_void_p), + ("size", c_size_t), + ] + + +# struct clip_image_f32_batch { +# struct clip_image_f32 * data; +# size_t size; +# }; +class clip_image_f32_batch(Structure): + _fields_ = [ + ("data", c_void_p), + ("size", c_size_t), + ] + + # /** load mmproj model */ # CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); @ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes) @@ -156,3 +179,92 @@ def clip_model_load( def clip_free(ctx: clip_ctx_p, /): ... + +# CLIP_API struct clip_image_u8 * clip_image_u8_init (); +@ctypes_function("clip_image_u8_init", [], c_void_p) +def clip_image_u8_init() -> Optional[c_void_p]: + ... + + +# CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); +@ctypes_function("clip_image_u8_free", [c_void_p], None) +def clip_image_u8_free(img: c_void_p, /): + ... + + +# CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); +@ctypes_function("clip_image_f32_free", [c_void_p], None) +def clip_image_f32_free(img: c_void_p, /): + ... + + +# CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); +@ctypes_function("clip_image_u8_batch_free", [POINTER(clip_image_u8_batch)], None) +def clip_image_u8_batch_free(batch: "_Pointer[clip_image_u8_batch]", /): + ... + + +# CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +@ctypes_function("clip_image_f32_batch_free", [POINTER(clip_image_f32_batch)], None) +def clip_image_f32_batch_free(batch: "_Pointer[clip_image_f32_batch]", /): + ... + + +# /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ +# CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); +@ctypes_function( + "clip_image_preprocess", + [ + clip_ctx_p_ctypes, + c_void_p, + POINTER(clip_image_f32_batch), + ], + c_bool, +) +def clip_image_preprocess( + ctx: clip_ctx_p, + img: c_void_p, + res_imgs: "_Pointer[clip_image_f32_batch]", + /, +) -> bool: + ... + + +# CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); +@ctypes_function( + "clip_image_batch_encode", + [ + clip_ctx_p_ctypes, + c_int, + POINTER(clip_image_f32_batch), + POINTER(c_float), + ], + c_bool, +) +def clip_image_batch_encode( + ctx: clip_ctx_p, + n_threads: c_int, + imgs: "_Pointer[clip_image_f32_batch]", + vec: c_void_p +) -> bool: + ... + + +# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ +# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); +@ctypes_function( + "clip_image_load_from_bytes", + [ + c_void_p, + c_size_t, + c_void_p, + ], + c_bool, +) +def clip_image_load_from_bytes( + bytes: c_void_p, + bytes_length: c_size_t, + img: c_void_p, + /, +) -> bool: + ... From 1b455888d40aa2f64ace593ddeb7c54a3087d631 Mon Sep 17 00:00:00 2001 From: kossum <127719370+kossum@users.noreply.github.com> Date: Thu, 3 Apr 2025 19:43:58 +0900 Subject: [PATCH 3/4] fix: added n_ctx check for prompt requirements when embedding images in Gemma3ChatHandler --- llama_cpp/llama_chat_format.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 7ac0f4016..cbac975bd 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3467,6 +3467,12 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): def eval_image(self, llama: llama.Llama, image_url: str): import llama_cpp + n_tokens = 256 + if llama.n_tokens + n_tokens > llama.n_ctx(): + raise ValueError( + f"Prompt exceeds n_ctx: {llama.n_tokens + n_tokens} > {llama.n_ctx()}" + ) + img_bytes = self.load_image(image_url) img_u8_p = self._llava_cpp.clip_image_u8_init() if not self._llava_cpp.clip_image_load_from_bytes( @@ -3485,7 +3491,6 @@ def eval_image(self, llama: llama.Llama, image_url: str): raise ValueError("Failed to preprocess image.") n_embd = llama_cpp.llama_model_n_embd(llama._model.model) - n_tokens = 256 embed = (ctypes.c_float * (n_tokens * n_embd))() if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed): self._llava_cpp.clip_image_f32_batch_free(img_f32_p) From 025e7fa44bfd071eb36b5641448c4e80a0b29917 Mon Sep 17 00:00:00 2001 From: kossum <127719370+kossum@users.noreply.github.com> Date: Fri, 4 Apr 2025 20:17:26 +0900 Subject: [PATCH 4/4] fix: modify the gemma3 chat template to be compatible with openai api --- llama_cpp/llama_chat_format.py | 17 +---------------- llama_cpp/llava_cpp.py | 3 ++- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index cbac975bd..4e1aad381 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3409,7 +3409,7 @@ class Gemma3ChatHandler(Llava15ChatHandler): "{{ message['content'] | trim }}" "{%- elif message['content'] is iterable -%}" "{%- for item in message['content'] -%}" - "{%- if item['type'] == 'image' -%}" + "{%- if item['type'] == 'image_url' -%}" "{{ '' }}" "{%- elif item['type'] == 'text' -%}" "{{ item['text'] | trim }}" @@ -3449,21 +3449,6 @@ def split_text_on_image_urls(text: str, image_urls: List[str]): remaining = "" return split_text - @staticmethod - def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): - image_urls: List[str] = [] - for message in messages: - if message["role"] == "user": - if message.get("content") is None: - continue - for content in message["content"]: - if isinstance(content, dict) and content.get("type") == "image": - if isinstance(content.get("image"), dict) and isinstance(content["image"].get("url"), str): - image_urls.append(content["image"]["url"]) - elif isinstance(content.get("url"), str): - image_urls.append(content["url"]) - return image_urls - def eval_image(self, llama: llama.Llama, image_url: str): import llama_cpp diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py index 46ac5087f..8a382b4d9 100644 --- a/llama_cpp/llava_cpp.py +++ b/llama_cpp/llava_cpp.py @@ -245,7 +245,8 @@ def clip_image_batch_encode( ctx: clip_ctx_p, n_threads: c_int, imgs: "_Pointer[clip_image_f32_batch]", - vec: c_void_p + vec: c_void_p, + /, ) -> bool: ...