@@ -15,9 +15,7 @@ class LlamaCache:
15
15
"""Cache for a llama.cpp model."""
16
16
17
17
def __init__ (self , capacity_bytes : int = (2 << 30 )):
18
- self .cache_state : OrderedDict [
19
- Tuple [llama_cpp .llama_token , ...], "LlamaState"
20
- ] = OrderedDict ()
18
+ self .cache_state : OrderedDict [Tuple [int , ...], "LlamaState" ] = OrderedDict ()
21
19
self .capacity_bytes = capacity_bytes
22
20
23
21
@property
@@ -26,8 +24,8 @@ def cache_size(self):
26
24
27
25
def _find_longest_prefix_key (
28
26
self ,
29
- key : Tuple [llama_cpp . llama_token , ...],
30
- ) -> Optional [Tuple [llama_cpp . llama_token , ...]]:
27
+ key : Tuple [int , ...],
28
+ ) -> Optional [Tuple [int , ...]]:
31
29
min_len = 0
32
30
min_key = None
33
31
keys = (
@@ -39,7 +37,7 @@ def _find_longest_prefix_key(
39
37
min_key = k
40
38
return min_key
41
39
42
- def __getitem__ (self , key : Sequence [llama_cpp . llama_token ]) -> "LlamaState" :
40
+ def __getitem__ (self , key : Sequence [int ]) -> "LlamaState" :
43
41
key = tuple (key )
44
42
_key = self ._find_longest_prefix_key (key )
45
43
if _key is None :
@@ -48,10 +46,10 @@ def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
48
46
self .cache_state .move_to_end (_key )
49
47
return value
50
48
51
- def __contains__ (self , key : Sequence [llama_cpp . llama_token ]) -> bool :
49
+ def __contains__ (self , key : Sequence [int ]) -> bool :
52
50
return self ._find_longest_prefix_key (tuple (key )) is not None
53
51
54
- def __setitem__ (self , key : Sequence [llama_cpp . llama_token ], value : "LlamaState" ):
52
+ def __setitem__ (self , key : Sequence [int ], value : "LlamaState" ):
55
53
key = tuple (key )
56
54
if key in self .cache_state :
57
55
del self .cache_state [key ]
@@ -63,7 +61,7 @@ def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState")
63
61
class LlamaState :
64
62
def __init__ (
65
63
self ,
66
- eval_tokens : Deque [llama_cpp . llama_token ],
64
+ eval_tokens : Deque [int ],
67
65
eval_logits : Deque [List [float ]],
68
66
llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
69
67
llama_state_size : int ,
@@ -141,7 +139,7 @@ def __init__(
141
139
142
140
self .last_n_tokens_size = last_n_tokens_size
143
141
self .n_batch = min (n_ctx , n_batch )
144
- self .eval_tokens : Deque [llama_cpp . llama_token ] = deque (maxlen = n_ctx )
142
+ self .eval_tokens : Deque [int ] = deque (maxlen = n_ctx )
145
143
self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx if logits_all else 1 )
146
144
147
145
self .cache : Optional [LlamaCache ] = None
@@ -176,9 +174,7 @@ def __init__(
176
174
if self .verbose :
177
175
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
178
176
179
- def tokenize (
180
- self , text : bytes , add_bos : bool = True
181
- ) -> List [llama_cpp .llama_token ]:
177
+ def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
182
178
"""Tokenize a string.
183
179
184
180
Args:
@@ -197,7 +193,7 @@ def tokenize(
197
193
self .ctx ,
198
194
text ,
199
195
tokens ,
200
- n_ctx ,
196
+ llama_cpp . c_int ( n_ctx ) ,
201
197
llama_cpp .c_bool (add_bos ),
202
198
)
203
199
if int (n_tokens ) < 0 :
@@ -216,7 +212,7 @@ def tokenize(
216
212
)
217
213
return list (tokens [:n_tokens ])
218
214
219
- def detokenize (self , tokens : List [llama_cpp . llama_token ]) -> bytes :
215
+ def detokenize (self , tokens : List [int ]) -> bytes :
220
216
"""Detokenize a list of tokens.
221
217
222
218
Args:
@@ -228,7 +224,9 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
228
224
assert self .ctx is not None
229
225
output = b""
230
226
for token in tokens :
231
- output += llama_cpp .llama_token_to_str (self .ctx , token )
227
+ output += llama_cpp .llama_token_to_str (
228
+ self .ctx , llama_cpp .llama_token (token )
229
+ )
232
230
return output
233
231
234
232
def set_cache (self , cache : Optional [LlamaCache ]):
@@ -244,7 +242,7 @@ def reset(self):
244
242
self .eval_tokens .clear ()
245
243
self .eval_logits .clear ()
246
244
247
- def eval (self , tokens : Sequence [llama_cpp . llama_token ]):
245
+ def eval (self , tokens : Sequence [int ]):
248
246
"""Evaluate a list of tokens.
249
247
250
248
Args:
@@ -458,7 +456,7 @@ def sample(
458
456
459
457
def generate (
460
458
self ,
461
- tokens : Sequence [llama_cpp . llama_token ],
459
+ tokens : Sequence [int ],
462
460
top_k : int = 40 ,
463
461
top_p : float = 0.95 ,
464
462
temp : float = 0.80 ,
@@ -470,9 +468,7 @@ def generate(
470
468
mirostat_mode : int = 0 ,
471
469
mirostat_tau : float = 5.0 ,
472
470
mirostat_eta : float = 0.1 ,
473
- ) -> Generator [
474
- llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
475
- ]:
471
+ ) -> Generator [int , Optional [Sequence [int ]], None ]:
476
472
"""Create a generator of tokens from a prompt.
477
473
478
474
Examples:
@@ -617,14 +613,14 @@ def _create_completion(
617
613
assert self .ctx is not None
618
614
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
619
615
created : int = int (time .time ())
620
- completion_tokens : List [llama_cpp . llama_token ] = []
616
+ completion_tokens : List [int ] = []
621
617
# Add blank space to start of prompt to match OG llama tokenizer
622
- prompt_tokens : List [llama_cpp .llama_token ] = self .tokenize (
623
- b" " + prompt .encode ("utf-8" )
624
- )
618
+ prompt_tokens : List [int ] = self .tokenize (b" " + prompt .encode ("utf-8" ))
625
619
text : bytes = b""
626
620
returned_tokens : int = 0
627
- stop = stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
621
+ stop = (
622
+ stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
623
+ )
628
624
model_name : str = model if model is not None else self .model_path
629
625
630
626
if self .verbose :
@@ -724,7 +720,9 @@ def _create_completion(
724
720
for token in remaining_tokens :
725
721
token_end_position += len (self .detokenize ([token ]))
726
722
# Check if stop sequence is in the token
727
- if token_end_position >= (remaining_length - first_stop_position - 1 ):
723
+ if token_end_position >= (
724
+ remaining_length - first_stop_position - 1
725
+ ):
728
726
break
729
727
logprobs_or_none : Optional [CompletionLogprobs ] = None
730
728
if logprobs is not None :
@@ -744,7 +742,7 @@ def _create_completion(
744
742
)
745
743
)
746
744
top_logprob = {
747
- self .detokenize ([llama_cpp . llama_token ( i ) ]).decode (
745
+ self .detokenize ([i ]).decode (
748
746
"utf-8" , errors = "ignore"
749
747
): logprob
750
748
for logprob , i in sorted_logprobs [:logprobs ]
@@ -822,9 +820,7 @@ def _create_completion(
822
820
)
823
821
)
824
822
top_logprob = {
825
- self .detokenize ([llama_cpp .llama_token (i )]).decode (
826
- "utf-8" , errors = "ignore"
827
- ): logprob
823
+ self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
828
824
for logprob , i in sorted_logprobs [:logprobs ]
829
825
}
830
826
top_logprob .update ({token_str : current_logprobs [int (token )]})
@@ -924,9 +920,7 @@ def _create_completion(
924
920
)
925
921
token_logprobs .append (sorted_logprobs [int (token )][0 ])
926
922
top_logprob : Optional [Dict [str , float ]] = {
927
- self .detokenize ([llama_cpp .llama_token (i )]).decode (
928
- "utf-8" , errors = "ignore"
929
- ): logprob
923
+ self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
930
924
for logprob , i in sorted_logprobs [:logprobs ]
931
925
}
932
926
top_logprob .update ({token_str : logprobs_token [int (token )]})
@@ -1188,7 +1182,9 @@ def create_chat_completion(
1188
1182
Returns:
1189
1183
Generated chat completion or a stream of chat completion chunks.
1190
1184
"""
1191
- stop = stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
1185
+ stop = (
1186
+ stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
1187
+ )
1192
1188
chat_history = "" .join (
1193
1189
f'### { "Human" if message ["role" ] == "user" else "Assistant" } :{ message ["content" ]} '
1194
1190
for message in messages
@@ -1296,17 +1292,17 @@ def load_state(self, state: LlamaState) -> None:
1296
1292
raise RuntimeError ("Failed to set llama state data" )
1297
1293
1298
1294
@staticmethod
1299
- def token_eos () -> llama_cpp . llama_token :
1295
+ def token_eos () -> int :
1300
1296
"""Return the end-of-sequence token."""
1301
1297
return llama_cpp .llama_token_eos ()
1302
1298
1303
1299
@staticmethod
1304
- def token_bos () -> llama_cpp . llama_token :
1300
+ def token_bos () -> int :
1305
1301
"""Return the beginning-of-sequence token."""
1306
1302
return llama_cpp .llama_token_bos ()
1307
1303
1308
1304
@staticmethod
1309
- def token_nl () -> llama_cpp . llama_token :
1305
+ def token_nl () -> int :
1310
1306
"""Return the newline token."""
1311
1307
return llama_cpp .llama_token_nl ()
1312
1308
@@ -1317,9 +1313,7 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
1317
1313
return [math .log (x / sum_exps ) for x in exps ]
1318
1314
1319
1315
@staticmethod
1320
- def longest_token_prefix (
1321
- a : Sequence [llama_cpp .llama_token ], b : Sequence [llama_cpp .llama_token ]
1322
- ):
1316
+ def longest_token_prefix (a : Sequence [int ], b : Sequence [int ]):
1323
1317
longest_prefix = 0
1324
1318
for _a , _b in zip (a , b ):
1325
1319
if _a == _b :
0 commit comments