Skip to content

Commit af3ed50

Browse files
committed
fix: Use numpy recarray for candidates data, fixes bug with temp < 0
1 parent 2d89964 commit af3ed50

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

llama_cpp/_internals.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,12 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
545545
class _LlamaTokenDataArray:
546546
def __init__(self, *, n_vocab: int):
547547
self.n_vocab = n_vocab
548-
self.candidates_data = np.array(
549-
[],
548+
self.candidates_data = np.recarray(
549+
(self.n_vocab,),
550550
dtype=np.dtype(
551551
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
552552
),
553553
)
554-
self.candidates_data.resize(3, self.n_vocab, refcheck=False)
555554
self.candidates = llama_cpp.llama_token_data_array(
556555
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
557556
size=self.n_vocab,
@@ -561,14 +560,11 @@ def __init__(self, *, n_vocab: int):
561560
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
562561

563562
def copy_logits(self, logits: npt.NDArray[np.single]):
564-
self.candidates_data["id"][:] = self.default_candidates_data_id
565-
self.candidates_data["logit"][:] = logits
566-
self.candidates_data["p"][:] = self.default_candidates_data_p
567-
self.candidates.data = self.candidates_data.ctypes.data_as(
568-
llama_cpp.llama_token_data_p
569-
)
570-
self.candidates.sorted = ctypes.c_bool(False)
571-
self.candidates.size = ctypes.c_size_t(self.n_vocab)
563+
self.candidates_data.id[:] = self.default_candidates_data_id
564+
self.candidates_data.logit[:] = logits
565+
self.candidates_data.p[:] = self.default_candidates_data_p
566+
self.candidates.sorted = False
567+
self.candidates.size = self.n_vocab
572568

573569

574570
# Python wrappers over common/common
@@ -759,14 +755,14 @@ def sample(
759755
self.params.penalty_present,
760756
)
761757
if not self.params.penalize_nl:
762-
token_data_array.candidates_data["logit"][nl_token] = nl_logit
758+
token_data_array.candidates_data.logit[nl_token] = nl_logit
763759

764760
if self.grammar is not None:
765761
ctx_main.sample_grammar(token_data_array, self.grammar)
766762

767763
if self.params.temp < 0:
768764
ctx_main.sample_softmax(token_data_array)
769-
id = token_data_array.candidates_data["id"][0]
765+
id = token_data_array.candidates_data.id[0]
770766
elif self.params.temp == 0:
771767
id = ctx_main.sample_token_greedy(token_data_array)
772768
else:

0 commit comments

Comments
 (0)