Skip to content

Commit 5ab40e6

Browse files
CISCabetlen
andauthored
feat: Support multiple chat templates - step 1 (abetlen#1396)
* Support multiple chat templates - step 1 As a first step, allow user to to select template from metadata with chat_format parameter in the form of `chat_template.name`. * register chat templates to self.chat_formats instead of globally * Don't expose internal chat handlers yet --------- Co-authored-by: Andrei <[email protected]>
1 parent bf66a28 commit 5ab40e6

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

llama_cpp/llama.py

+28-22
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def __init__(
378378

379379
self.chat_format = chat_format
380380
self.chat_handler = chat_handler
381+
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = {}
381382

382383
self.draft_model = draft_model
383384

@@ -409,10 +410,33 @@ def __init__(
409410
if self.verbose:
410411
print(f"Model metadata: {self.metadata}", file=sys.stderr)
411412

413+
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
414+
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
415+
416+
eos_token = self._model.token_get_text(eos_token_id)
417+
bos_token = self._model.token_get_text(bos_token_id)
418+
419+
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
420+
template_choices = dict((name[10:], template) for name, template in self.metadata.items() if name.startswith("tokenizer.chat_template."))
421+
422+
if "tokenizer.chat_template" in self.metadata:
423+
template_choices["chat_template.default"] = self.metadata["tokenizer.chat_template"]
424+
425+
if self.verbose and template_choices:
426+
print(f"Available chat formats from metadata: {', '.join(template_choices.keys())}", file=sys.stderr)
427+
428+
for name, template in template_choices.items():
429+
self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
430+
template=template,
431+
eos_token=eos_token,
432+
bos_token=bos_token,
433+
stop_token_ids=[eos_token_id],
434+
).to_chat_handler()
435+
412436
if (
413437
self.chat_format is None
414438
and self.chat_handler is None
415-
and "tokenizer.chat_template" in self.metadata
439+
and "chat_template.default" in template_choices
416440
):
417441
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
418442
self.metadata
@@ -423,30 +447,12 @@ def __init__(
423447
if self.verbose:
424448
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
425449
else:
426-
template = self.metadata["tokenizer.chat_template"]
427-
try:
428-
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
429-
except:
430-
eos_token_id = self.token_eos()
431-
try:
432-
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
433-
except:
434-
bos_token_id = self.token_bos()
435-
436-
eos_token = self._model.token_get_text(eos_token_id)
437-
bos_token = self._model.token_get_text(bos_token_id)
438-
439450
if self.verbose:
440-
print(f"Using gguf chat template: {template}", file=sys.stderr)
451+
print(f"Using gguf chat template: {template_choices['chat_template.default']}", file=sys.stderr)
441452
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
442453
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
443454

444-
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
445-
template=template,
446-
eos_token=eos_token,
447-
bos_token=bos_token,
448-
stop_token_ids=[eos_token_id],
449-
).to_chat_handler()
455+
self.chat_format = "chat_template.default"
450456

451457
if self.chat_format is None and self.chat_handler is None:
452458
self.chat_format = "llama-2"
@@ -1719,7 +1725,7 @@ def create_chat_completion(
17191725
Returns:
17201726
Generated chat completion or a stream of chat completion chunks.
17211727
"""
1722-
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler(
1728+
handler = self.chat_handler or self._chat_handlers.get(self.chat_format) or llama_chat_format.get_chat_completion_handler(
17231729
self.chat_format
17241730
)
17251731
return handler(

0 commit comments

Comments
 (0)