@@ -545,13 +545,12 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
545
545
class _LlamaTokenDataArray :
546
546
def __init__ (self , * , n_vocab : int ):
547
547
self .n_vocab = n_vocab
548
- self .candidates_data = np .array (
549
- [] ,
548
+ self .candidates_data = np .recarray (
549
+ ( self . n_vocab ,) ,
550
550
dtype = np .dtype (
551
551
[("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
552
552
),
553
553
)
554
- self .candidates_data .resize (3 , self .n_vocab , refcheck = False )
555
554
self .candidates = llama_cpp .llama_token_data_array (
556
555
data = self .candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p ),
557
556
size = self .n_vocab ,
@@ -561,14 +560,11 @@ def __init__(self, *, n_vocab: int):
561
560
self .default_candidates_data_p = np .zeros (self .n_vocab , dtype = np .single )
562
561
563
562
def copy_logits (self , logits : npt .NDArray [np .single ]):
564
- self .candidates_data ["id" ][:] = self .default_candidates_data_id
565
- self .candidates_data ["logit" ][:] = logits
566
- self .candidates_data ["p" ][:] = self .default_candidates_data_p
567
- self .candidates .data = self .candidates_data .ctypes .data_as (
568
- llama_cpp .llama_token_data_p
569
- )
570
- self .candidates .sorted = ctypes .c_bool (False )
571
- self .candidates .size = ctypes .c_size_t (self .n_vocab )
563
+ self .candidates_data .id [:] = self .default_candidates_data_id
564
+ self .candidates_data .logit [:] = logits
565
+ self .candidates_data .p [:] = self .default_candidates_data_p
566
+ self .candidates .sorted = False
567
+ self .candidates .size = self .n_vocab
572
568
573
569
574
570
# Python wrappers over common/common
@@ -759,14 +755,14 @@ def sample(
759
755
self .params .penalty_present ,
760
756
)
761
757
if not self .params .penalize_nl :
762
- token_data_array .candidates_data [ " logit" ] [nl_token ] = nl_logit
758
+ token_data_array .candidates_data . logit [nl_token ] = nl_logit
763
759
764
760
if self .grammar is not None :
765
761
ctx_main .sample_grammar (token_data_array , self .grammar )
766
762
767
763
if self .params .temp < 0 :
768
764
ctx_main .sample_softmax (token_data_array )
769
- id = token_data_array .candidates_data [ "id" ] [0 ]
765
+ id = token_data_array .candidates_data . id [0 ]
770
766
elif self .params .temp == 0 :
771
767
id = ctx_main .sample_token_greedy (token_data_array )
772
768
else :
0 commit comments