diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 17575c700..4e1aad381 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2835,24 +2835,7 @@ def __call__( ) llama.eval(tokens) else: - image_bytes = self.load_image(value) - embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) - if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): - raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" - ) - n_past = ctypes.c_int(llama.n_tokens) - n_past_p = ctypes.pointer(n_past) - with suppress_stdout_stderr(disable=self.verbose): - self._llava_cpp.llava_eval_image_embed( - llama.ctx, - embed, - llama.n_batch, - n_past_p, - ) - # Required to avoid issues with hf tokenizer - llama.input_ids[llama.n_tokens : n_past.value] = -1 - llama.n_tokens = n_past.value + self.eval_image(llama, value) # Get prompt tokens to avoid a cache miss prompt = llama.input_ids[: llama.n_tokens].tolist() @@ -2938,6 +2921,26 @@ def __call__( ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) + def eval_image(self, llama: llama.Llama, image_url: str): + image_bytes = self.load_image(image_url) + embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) + if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): + raise ValueError( + f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" + ) + n_past = ctypes.c_int(llama.n_tokens) + n_past_p = ctypes.pointer(n_past) + with suppress_stdout_stderr(disable=self.verbose): + self._llava_cpp.llava_eval_image_embed( + llama.ctx, + embed, + llama.n_batch, + n_past_p, + ) + # Required to avoid issues with hf tokenizer + llama.input_ids[llama.n_tokens : n_past.value] = -1 + llama.n_tokens = n_past.value + @staticmethod def _load_image(image_url: str) -> bytes: # TODO: Add Pillow support for other image formats beyond (jpg, png) @@ -3373,6 +3376,139 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +class Gemma3ChatHandler(Llava15ChatHandler): + # Chat Format: + # 'user\n{system_prompt}\n\n{prompt}\nmodel\n' + + DEFAULT_SYSTEM_MESSAGE = None + + CHAT_FORMAT = ( + "{{ '' }}" + "{%- if messages[0]['role'] == 'system' -%}" + "{%- if messages[0]['content'] is string -%}" + "{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}" + "{%- else -%}" + "{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}" + "{%- endif -%}" + "{%- set loop_messages = messages[1:] -%}" + "{%- else -%}" + "{%- set first_user_prefix = \"\" -%}" + "{%- set loop_messages = messages -%}" + "{%- endif -%}" + "{%- for message in loop_messages -%}" + "{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}" + "{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}" + "{%- endif -%}" + "{%- if (message['role'] == 'assistant') -%}" + "{%- set role = \"model\" -%}" + "{%- else -%}" + "{%- set role = message['role'] -%}" + "{%- endif -%}" + "{{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}" + "{%- if message['content'] is string -%}" + "{{ message['content'] | trim }}" + "{%- elif message['content'] is iterable -%}" + "{%- for item in message['content'] -%}" + "{%- if item['type'] == 'image_url' -%}" + "{{ '' }}" + "{%- elif item['type'] == 'text' -%}" + "{{ item['text'] | trim }}" + "{%- endif -%}" + "{%- endfor -%}" + "{%- else -%}" + "{{ raise_exception(\"Invalid content type\") }}" + "{%- endif -%}" + "{{ '\n' }}" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}" + "{{ 'model\n' }}" + "{%- endif -%}" + ) + + @staticmethod + def split_text_on_image_urls(text: str, image_urls: List[str]): + split_text: List[Tuple[Literal["text", "image_url"], str]] = [] + copied_urls = image_urls[:] + remaining = text + image_placeholder = "" + + while remaining: + # Find placeholder + pos = remaining.find(image_placeholder) + if pos != -1: + assert len(copied_urls) > 0 + if pos > 0: + split_text.append(("text", remaining[:pos])) + split_text.append(("text", "\n\n")) + split_text.append(("image_url", copied_urls.pop(0))) + split_text.append(("text", "\n\n")) + remaining = remaining[pos + len(image_placeholder):] + else: + assert len(copied_urls) == 0 + split_text.append(("text", remaining)) + remaining = "" + return split_text + + def eval_image(self, llama: llama.Llama, image_url: str): + import llama_cpp + + n_tokens = 256 + if llama.n_tokens + n_tokens > llama.n_ctx(): + raise ValueError( + f"Prompt exceeds n_ctx: {llama.n_tokens + n_tokens} > {llama.n_ctx()}" + ) + + img_bytes = self.load_image(image_url) + img_u8_p = self._llava_cpp.clip_image_u8_init() + if not self._llava_cpp.clip_image_load_from_bytes( + ctypes.create_string_buffer(img_bytes, len(img_bytes)), + ctypes.c_size_t(len(img_bytes)), + img_u8_p, + ): + self._llava_cpp.clip_image_u8_free(img_u8_p) + raise ValueError("Failed to load image.") + + img_f32 = self._llava_cpp.clip_image_f32_batch() + img_f32_p = ctypes.byref(img_f32) + if not self._llava_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p): + self._llava_cpp.clip_image_f32_batch_free(img_f32_p) + self._llava_cpp.clip_image_u8_free(img_u8_p) + raise ValueError("Failed to preprocess image.") + + n_embd = llama_cpp.llama_model_n_embd(llama._model.model) + embed = (ctypes.c_float * (n_tokens * n_embd))() + if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed): + self._llava_cpp.clip_image_f32_batch_free(img_f32_p) + self._llava_cpp.clip_image_u8_free(img_u8_p) + raise ValueError("Failed to encode image.") + + self._llava_cpp.clip_image_f32_batch_free(img_f32_p) + self._llava_cpp.clip_image_u8_free(img_u8_p) + llama_cpp.llama_set_causal_attn(llama.ctx, False) + + seq_id_0 = (ctypes.c_int32 * 1)() + seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))() + for i in range(n_tokens): + seq_ids[i] = seq_id_0 + + batch = llama_cpp.llama_batch() + batch.n_tokens = n_tokens + batch.token = None + batch.embd = embed + batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)]) + batch.seq_id = seq_ids + batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens)) + batch.logits = (ctypes.c_int8 * n_tokens)() + + if llama_cpp.llama_decode(llama.ctx, batch): + raise ValueError("Failed to decode image.") + + llama_cpp.llama_set_causal_attn(llama.ctx, True) + # Required to avoid issues with hf tokenizer + llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1 + llama.n_tokens += n_tokens + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py index d9dfaf5fd..8a382b4d9 100644 --- a/llama_cpp/llava_cpp.py +++ b/llama_cpp/llava_cpp.py @@ -7,6 +7,7 @@ c_int, c_uint8, c_float, + c_size_t, c_void_p, POINTER, _Pointer, # type: ignore @@ -141,6 +142,28 @@ def llava_eval_image_embed( ################################################ +# struct clip_image_u8_batch { +# struct clip_image_u8 * data; +# size_t size; +# }; +class clip_image_u8_batch(Structure): + _fields_ = [ + ("data", c_void_p), + ("size", c_size_t), + ] + + +# struct clip_image_f32_batch { +# struct clip_image_f32 * data; +# size_t size; +# }; +class clip_image_f32_batch(Structure): + _fields_ = [ + ("data", c_void_p), + ("size", c_size_t), + ] + + # /** load mmproj model */ # CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); @ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes) @@ -156,3 +179,93 @@ def clip_model_load( def clip_free(ctx: clip_ctx_p, /): ... + +# CLIP_API struct clip_image_u8 * clip_image_u8_init (); +@ctypes_function("clip_image_u8_init", [], c_void_p) +def clip_image_u8_init() -> Optional[c_void_p]: + ... + + +# CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); +@ctypes_function("clip_image_u8_free", [c_void_p], None) +def clip_image_u8_free(img: c_void_p, /): + ... + + +# CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); +@ctypes_function("clip_image_f32_free", [c_void_p], None) +def clip_image_f32_free(img: c_void_p, /): + ... + + +# CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); +@ctypes_function("clip_image_u8_batch_free", [POINTER(clip_image_u8_batch)], None) +def clip_image_u8_batch_free(batch: "_Pointer[clip_image_u8_batch]", /): + ... + + +# CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +@ctypes_function("clip_image_f32_batch_free", [POINTER(clip_image_f32_batch)], None) +def clip_image_f32_batch_free(batch: "_Pointer[clip_image_f32_batch]", /): + ... + + +# /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ +# CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); +@ctypes_function( + "clip_image_preprocess", + [ + clip_ctx_p_ctypes, + c_void_p, + POINTER(clip_image_f32_batch), + ], + c_bool, +) +def clip_image_preprocess( + ctx: clip_ctx_p, + img: c_void_p, + res_imgs: "_Pointer[clip_image_f32_batch]", + /, +) -> bool: + ... + + +# CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); +@ctypes_function( + "clip_image_batch_encode", + [ + clip_ctx_p_ctypes, + c_int, + POINTER(clip_image_f32_batch), + POINTER(c_float), + ], + c_bool, +) +def clip_image_batch_encode( + ctx: clip_ctx_p, + n_threads: c_int, + imgs: "_Pointer[clip_image_f32_batch]", + vec: c_void_p, + /, +) -> bool: + ... + + +# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ +# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); +@ctypes_function( + "clip_image_load_from_bytes", + [ + c_void_p, + c_size_t, + c_void_p, + ], + c_bool, +) +def clip_image_load_from_bytes( + bytes: c_void_p, + bytes_length: c_size_t, + img: c_void_p, + /, +) -> bool: + ...