Skip to content

Commit 01a010b

Browse files
committed
Fix llama_cpp and Llama type signatures. Closes abetlen#221
1 parent fb57b94 commit 01a010b

File tree

3 files changed

+58
-64
lines changed

3 files changed

+58
-64
lines changed

llama_cpp/llama.py

+35-41
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ class LlamaCache:
1515
"""Cache for a llama.cpp model."""
1616

1717
def __init__(self, capacity_bytes: int = (2 << 30)):
18-
self.cache_state: OrderedDict[
19-
Tuple[llama_cpp.llama_token, ...], "LlamaState"
20-
] = OrderedDict()
18+
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
2119
self.capacity_bytes = capacity_bytes
2220

2321
@property
@@ -26,8 +24,8 @@ def cache_size(self):
2624

2725
def _find_longest_prefix_key(
2826
self,
29-
key: Tuple[llama_cpp.llama_token, ...],
30-
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
27+
key: Tuple[int, ...],
28+
) -> Optional[Tuple[int, ...]]:
3129
min_len = 0
3230
min_key = None
3331
keys = (
@@ -39,7 +37,7 @@ def _find_longest_prefix_key(
3937
min_key = k
4038
return min_key
4139

42-
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
40+
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
4341
key = tuple(key)
4442
_key = self._find_longest_prefix_key(key)
4543
if _key is None:
@@ -48,10 +46,10 @@ def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
4846
self.cache_state.move_to_end(_key)
4947
return value
5048

51-
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
49+
def __contains__(self, key: Sequence[int]) -> bool:
5250
return self._find_longest_prefix_key(tuple(key)) is not None
5351

54-
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
52+
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
5553
key = tuple(key)
5654
if key in self.cache_state:
5755
del self.cache_state[key]
@@ -63,7 +61,7 @@ def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState")
6361
class LlamaState:
6462
def __init__(
6563
self,
66-
eval_tokens: Deque[llama_cpp.llama_token],
64+
eval_tokens: Deque[int],
6765
eval_logits: Deque[List[float]],
6866
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
6967
llama_state_size: int,
@@ -141,7 +139,7 @@ def __init__(
141139

142140
self.last_n_tokens_size = last_n_tokens_size
143141
self.n_batch = min(n_ctx, n_batch)
144-
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
142+
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
145143
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
146144

147145
self.cache: Optional[LlamaCache] = None
@@ -176,9 +174,7 @@ def __init__(
176174
if self.verbose:
177175
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
178176

179-
def tokenize(
180-
self, text: bytes, add_bos: bool = True
181-
) -> List[llama_cpp.llama_token]:
177+
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
182178
"""Tokenize a string.
183179
184180
Args:
@@ -197,7 +193,7 @@ def tokenize(
197193
self.ctx,
198194
text,
199195
tokens,
200-
n_ctx,
196+
llama_cpp.c_int(n_ctx),
201197
llama_cpp.c_bool(add_bos),
202198
)
203199
if int(n_tokens) < 0:
@@ -216,7 +212,7 @@ def tokenize(
216212
)
217213
return list(tokens[:n_tokens])
218214

219-
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
215+
def detokenize(self, tokens: List[int]) -> bytes:
220216
"""Detokenize a list of tokens.
221217
222218
Args:
@@ -228,7 +224,9 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
228224
assert self.ctx is not None
229225
output = b""
230226
for token in tokens:
231-
output += llama_cpp.llama_token_to_str(self.ctx, token)
227+
output += llama_cpp.llama_token_to_str(
228+
self.ctx, llama_cpp.llama_token(token)
229+
)
232230
return output
233231

234232
def set_cache(self, cache: Optional[LlamaCache]):
@@ -244,7 +242,7 @@ def reset(self):
244242
self.eval_tokens.clear()
245243
self.eval_logits.clear()
246244

247-
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
245+
def eval(self, tokens: Sequence[int]):
248246
"""Evaluate a list of tokens.
249247
250248
Args:
@@ -458,7 +456,7 @@ def sample(
458456

459457
def generate(
460458
self,
461-
tokens: Sequence[llama_cpp.llama_token],
459+
tokens: Sequence[int],
462460
top_k: int = 40,
463461
top_p: float = 0.95,
464462
temp: float = 0.80,
@@ -470,9 +468,7 @@ def generate(
470468
mirostat_mode: int = 0,
471469
mirostat_tau: float = 5.0,
472470
mirostat_eta: float = 0.1,
473-
) -> Generator[
474-
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
475-
]:
471+
) -> Generator[int, Optional[Sequence[int]], None]:
476472
"""Create a generator of tokens from a prompt.
477473
478474
Examples:
@@ -617,14 +613,14 @@ def _create_completion(
617613
assert self.ctx is not None
618614
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
619615
created: int = int(time.time())
620-
completion_tokens: List[llama_cpp.llama_token] = []
616+
completion_tokens: List[int] = []
621617
# Add blank space to start of prompt to match OG llama tokenizer
622-
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
623-
b" " + prompt.encode("utf-8")
624-
)
618+
prompt_tokens: List[int] = self.tokenize(b" " + prompt.encode("utf-8"))
625619
text: bytes = b""
626620
returned_tokens: int = 0
627-
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
621+
stop = (
622+
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
623+
)
628624
model_name: str = model if model is not None else self.model_path
629625

630626
if self.verbose:
@@ -724,7 +720,9 @@ def _create_completion(
724720
for token in remaining_tokens:
725721
token_end_position += len(self.detokenize([token]))
726722
# Check if stop sequence is in the token
727-
if token_end_position >= (remaining_length - first_stop_position - 1):
723+
if token_end_position >= (
724+
remaining_length - first_stop_position - 1
725+
):
728726
break
729727
logprobs_or_none: Optional[CompletionLogprobs] = None
730728
if logprobs is not None:
@@ -744,7 +742,7 @@ def _create_completion(
744742
)
745743
)
746744
top_logprob = {
747-
self.detokenize([llama_cpp.llama_token(i)]).decode(
745+
self.detokenize([i]).decode(
748746
"utf-8", errors="ignore"
749747
): logprob
750748
for logprob, i in sorted_logprobs[:logprobs]
@@ -822,9 +820,7 @@ def _create_completion(
822820
)
823821
)
824822
top_logprob = {
825-
self.detokenize([llama_cpp.llama_token(i)]).decode(
826-
"utf-8", errors="ignore"
827-
): logprob
823+
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
828824
for logprob, i in sorted_logprobs[:logprobs]
829825
}
830826
top_logprob.update({token_str: current_logprobs[int(token)]})
@@ -924,9 +920,7 @@ def _create_completion(
924920
)
925921
token_logprobs.append(sorted_logprobs[int(token)][0])
926922
top_logprob: Optional[Dict[str, float]] = {
927-
self.detokenize([llama_cpp.llama_token(i)]).decode(
928-
"utf-8", errors="ignore"
929-
): logprob
923+
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
930924
for logprob, i in sorted_logprobs[:logprobs]
931925
}
932926
top_logprob.update({token_str: logprobs_token[int(token)]})
@@ -1188,7 +1182,9 @@ def create_chat_completion(
11881182
Returns:
11891183
Generated chat completion or a stream of chat completion chunks.
11901184
"""
1191-
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1185+
stop = (
1186+
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1187+
)
11921188
chat_history = "".join(
11931189
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
11941190
for message in messages
@@ -1296,17 +1292,17 @@ def load_state(self, state: LlamaState) -> None:
12961292
raise RuntimeError("Failed to set llama state data")
12971293

12981294
@staticmethod
1299-
def token_eos() -> llama_cpp.llama_token:
1295+
def token_eos() -> int:
13001296
"""Return the end-of-sequence token."""
13011297
return llama_cpp.llama_token_eos()
13021298

13031299
@staticmethod
1304-
def token_bos() -> llama_cpp.llama_token:
1300+
def token_bos() -> int:
13051301
"""Return the beginning-of-sequence token."""
13061302
return llama_cpp.llama_token_bos()
13071303

13081304
@staticmethod
1309-
def token_nl() -> llama_cpp.llama_token:
1305+
def token_nl() -> int:
13101306
"""Return the newline token."""
13111307
return llama_cpp.llama_token_nl()
13121308

@@ -1317,9 +1313,7 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
13171313
return [math.log(x / sum_exps) for x in exps]
13181314

13191315
@staticmethod
1320-
def longest_token_prefix(
1321-
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
1322-
):
1316+
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
13231317
longest_prefix = 0
13241318
for _a, _b in zip(a, b):
13251319
if _a == _b:

0 commit comments

Comments
 (0)