@@ -378,6 +378,7 @@ def __init__(
378
378
379
379
self .chat_format = chat_format
380
380
self .chat_handler = chat_handler
381
+ self ._chat_handlers : Dict [str , llama_chat_format .LlamaChatCompletionHandler ] = {}
381
382
382
383
self .draft_model = draft_model
383
384
@@ -409,10 +410,33 @@ def __init__(
409
410
if self .verbose :
410
411
print (f"Model metadata: { self .metadata } " , file = sys .stderr )
411
412
413
+ eos_token_id = int (self .metadata .get ("tokenizer.ggml.eos_token_id" , self .token_eos ()))
414
+ bos_token_id = int (self .metadata .get ("tokenizer.ggml.bos_token_id" , self .token_bos ()))
415
+
416
+ eos_token = self ._model .token_get_text (eos_token_id )
417
+ bos_token = self ._model .token_get_text (bos_token_id )
418
+
419
+ # Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
420
+ template_choices = dict ((name [10 :], template ) for name , template in self .metadata .items () if name .startswith ("tokenizer.chat_template." ))
421
+
422
+ if "tokenizer.chat_template" in self .metadata :
423
+ template_choices ["chat_template.default" ] = self .metadata ["tokenizer.chat_template" ]
424
+
425
+ if self .verbose and template_choices :
426
+ print (f"Available chat formats from metadata: { ', ' .join (template_choices .keys ())} " , file = sys .stderr )
427
+
428
+ for name , template in template_choices .items ():
429
+ self ._chat_handlers [name ] = llama_chat_format .Jinja2ChatFormatter (
430
+ template = template ,
431
+ eos_token = eos_token ,
432
+ bos_token = bos_token ,
433
+ stop_token_ids = [eos_token_id ],
434
+ ).to_chat_handler ()
435
+
412
436
if (
413
437
self .chat_format is None
414
438
and self .chat_handler is None
415
- and "tokenizer. chat_template" in self . metadata
439
+ and "chat_template.default " in template_choices
416
440
):
417
441
chat_format = llama_chat_format .guess_chat_format_from_gguf_metadata (
418
442
self .metadata
@@ -423,30 +447,12 @@ def __init__(
423
447
if self .verbose :
424
448
print (f"Guessed chat format: { chat_format } " , file = sys .stderr )
425
449
else :
426
- template = self .metadata ["tokenizer.chat_template" ]
427
- try :
428
- eos_token_id = int (self .metadata ["tokenizer.ggml.eos_token_id" ])
429
- except :
430
- eos_token_id = self .token_eos ()
431
- try :
432
- bos_token_id = int (self .metadata ["tokenizer.ggml.bos_token_id" ])
433
- except :
434
- bos_token_id = self .token_bos ()
435
-
436
- eos_token = self ._model .token_get_text (eos_token_id )
437
- bos_token = self ._model .token_get_text (bos_token_id )
438
-
439
450
if self .verbose :
440
- print (f"Using gguf chat template: { template } " , file = sys .stderr )
451
+ print (f"Using gguf chat template: { template_choices [ 'chat_template.default' ] } " , file = sys .stderr )
441
452
print (f"Using chat eos_token: { eos_token } " , file = sys .stderr )
442
453
print (f"Using chat bos_token: { bos_token } " , file = sys .stderr )
443
454
444
- self .chat_handler = llama_chat_format .Jinja2ChatFormatter (
445
- template = template ,
446
- eos_token = eos_token ,
447
- bos_token = bos_token ,
448
- stop_token_ids = [eos_token_id ],
449
- ).to_chat_handler ()
455
+ self .chat_format = "chat_template.default"
450
456
451
457
if self .chat_format is None and self .chat_handler is None :
452
458
self .chat_format = "llama-2"
@@ -1719,7 +1725,7 @@ def create_chat_completion(
1719
1725
Returns:
1720
1726
Generated chat completion or a stream of chat completion chunks.
1721
1727
"""
1722
- handler = self .chat_handler or llama_chat_format .get_chat_completion_handler (
1728
+ handler = self .chat_handler or self . _chat_handlers . get ( self . chat_format ) or llama_chat_format .get_chat_completion_handler (
1723
1729
self .chat_format
1724
1730
)
1725
1731
return handler (
0 commit comments