17
17
Callable ,
18
18
)
19
19
from collections import deque , OrderedDict
20
+ from dataclasses import dataclass
20
21
21
22
import diskcache
22
23
import ctypes
@@ -199,6 +200,42 @@ class StoppingCriteriaList(List[StoppingCriteria]):
199
200
def __call__ (self , input_ids : List [int ], logits : List [float ]) -> bool :
200
201
return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
201
202
203
+ # Custom data that is accessible to the beam_search_callback() function.
204
+ @dataclass
205
+ class beam_search_callback_data :
206
+ ctx : llama_cpp .llama_context_p
207
+ response_tokens : List [int ]
208
+
209
+ def beam_view_to_string (ctx , beam_view ):
210
+ string = f"p({ beam_view .p } ): "
211
+ for i in range (beam_view .n_tokens ):
212
+ string += str (llama_cpp .llama_token_to_str (ctx , beam_view .tokens [i ]))
213
+ return string
214
+
215
+ def is_at_eos (tokens , n_tokens ) :
216
+ return 0 < n_tokens and tokens [n_tokens - 1 ] == llama_cpp .llama_token_eos ();
217
+
218
+ # beam_search_callback requires a global dictionary to pass data via their object id.
219
+ beam_search_dictionary = {}
220
+
221
+ # beam_search_callback() must flag beams when they reach end-of-sentence.
222
+ # TODO: Use stop_sequences.
223
+ def beam_search_callback (callback_data_id , beams_state ):
224
+ for i in range (beams_state .n_beams ):
225
+ beam_view = beams_state .beam_views [i ]
226
+ if not beam_view .eos and is_at_eos (beam_view .tokens , beam_view .n_tokens ):
227
+ beam_view .eos = True ; # Flag beams as EOS as required.
228
+ callback_data = beam_search_dictionary [callback_data_id ]
229
+ # Collect tokens into callback_data.response_tokens
230
+ if 0 < beams_state .common_prefix_length :
231
+ assert (0 < beams_state .n_beams );
232
+ tokens = ctypes .cast (beams_state .beam_views [0 ].tokens , ctypes .POINTER (ctypes .c_int * beams_state .common_prefix_length )).contents
233
+ callback_data .response_tokens .extend (tokens )
234
+
235
+ # DEBUG print beams and their relative probabilities
236
+ # print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
237
+ # for i in range(beams_state.n_beams):
238
+ # print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))
202
239
203
240
class Llama :
204
241
"""High-level Python wrapper for a llama.cpp model."""
@@ -475,6 +512,7 @@ def eval(self, tokens: Sequence[int]):
475
512
tokens: The list of tokens to evaluate.
476
513
"""
477
514
assert self .ctx is not None
515
+
478
516
n_ctx = self ._n_ctx
479
517
for i in range (0 , len (tokens ), self .n_batch ):
480
518
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
@@ -719,6 +757,7 @@ def generate(
719
757
logits_processor : Optional [LogitsProcessorList ] = None ,
720
758
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
721
759
grammar : Optional [LlamaGrammar ] = None ,
760
+ beam_width : int = 0 ,
722
761
) -> Generator [int , Optional [Sequence [int ]], None ]:
723
762
"""Create a generator of tokens from a prompt.
724
763
@@ -760,6 +799,28 @@ def generate(
760
799
if grammar is not None :
761
800
grammar .reset ()
762
801
802
+ if 0 < beam_width :
803
+ print ("beam_width=" , beam_width )
804
+ self .eval (tokens )
805
+ callback_data = beam_search_callback_data (self .ctx , [])
806
+ beam_search_dictionary [id (callback_data )] = callback_data
807
+ callback = llama_cpp .llama_beam_search_callback (beam_search_callback )
808
+ n_remain = llama_cpp .llama_n_ctx (self .ctx ) - self .n_tokens
809
+ llama_cpp .llama_beam_search (self .ctx , callback , id (callback_data ),
810
+ beam_width ,
811
+ self .n_tokens ,
812
+ n_remain ,
813
+ self .n_threads )
814
+ beam_search_dictionary .pop (id (callback_data ))
815
+ # Ideally we would yield from within the callback, but that is impossible.
816
+ for token in callback_data .response_tokens :
817
+ string = llama_cpp .llama_token_to_str (self .ctx , token )
818
+ np .append (self .input_ids , [token ])
819
+ np .append (self .scores , [0.0 ])
820
+ self .n_tokens += 1
821
+ yield token
822
+ return
823
+
763
824
while True :
764
825
self .eval (tokens )
765
826
token = self .sample (
@@ -776,6 +837,7 @@ def generate(
776
837
logits_processor = logits_processor ,
777
838
grammar = grammar ,
778
839
)
840
+
779
841
if stopping_criteria is not None and stopping_criteria (
780
842
self ._input_ids .tolist (), self ._scores [- 1 , :].tolist ()
781
843
):
@@ -878,6 +940,7 @@ def _create_completion(
878
940
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
879
941
logits_processor : Optional [LogitsProcessorList ] = None ,
880
942
grammar : Optional [LlamaGrammar ] = None ,
943
+ beam_width : int = 0 ,
881
944
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
882
945
assert self .ctx is not None
883
946
@@ -956,6 +1019,7 @@ def _create_completion(
956
1019
stopping_criteria = stopping_criteria ,
957
1020
logits_processor = logits_processor ,
958
1021
grammar = grammar ,
1022
+ beam_width = beam_width ,
959
1023
):
960
1024
if token == self ._token_eos :
961
1025
text = self .detokenize (completion_tokens )
@@ -1301,6 +1365,7 @@ def create_completion(
1301
1365
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1302
1366
logits_processor : Optional [LogitsProcessorList ] = None ,
1303
1367
grammar : Optional [LlamaGrammar ] = None ,
1368
+ beam_width : int = 0 ,
1304
1369
) -> Union [Completion , Iterator [CompletionChunk ]]:
1305
1370
"""Generate text from a prompt.
1306
1371
@@ -1316,6 +1381,7 @@ def create_completion(
1316
1381
repeat_penalty: The penalty to apply to repeated tokens.
1317
1382
top_k: The top-k value to use for sampling.
1318
1383
stream: Whether to stream the results.
1384
+ beam_width: Number of beams to use in beam search. 0 disables.
1319
1385
1320
1386
Raises:
1321
1387
ValueError: If the requested tokens exceed the context window.
@@ -1345,7 +1411,8 @@ def create_completion(
1345
1411
model = model ,
1346
1412
stopping_criteria = stopping_criteria ,
1347
1413
logits_processor = logits_processor ,
1348
- grammar = grammar
1414
+ grammar = grammar ,
1415
+ beam_width = beam_width ,
1349
1416
)
1350
1417
if stream :
1351
1418
chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -1376,6 +1443,7 @@ def __call__(
1376
1443
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1377
1444
logits_processor : Optional [LogitsProcessorList ] = None ,
1378
1445
grammar : Optional [LlamaGrammar ] = None ,
1446
+ beam_width : int = 0 ,
1379
1447
) -> Union [Completion , Iterator [CompletionChunk ]]:
1380
1448
"""Generate text from a prompt.
1381
1449
@@ -1391,6 +1459,7 @@ def __call__(
1391
1459
repeat_penalty: The penalty to apply to repeated tokens.
1392
1460
top_k: The top-k value to use for sampling.
1393
1461
stream: Whether to stream the results.
1462
+ beam_width: Number of beams to use in beam search. 0 disables.
1394
1463
1395
1464
Raises:
1396
1465
ValueError: If the requested tokens exceed the context window.
@@ -1421,6 +1490,7 @@ def __call__(
1421
1490
stopping_criteria = stopping_criteria ,
1422
1491
logits_processor = logits_processor ,
1423
1492
grammar = grammar ,
1493
+ beam_width = beam_width ,
1424
1494
)
1425
1495
1426
1496
def _convert_text_completion_to_chat (
0 commit comments