Skip to content

Commit 3b3c1d6

Browse files
committed
Add beam search. Invoke by adding "beam_search": 2 (for example) to /v1/completions POST.
1 parent d644199 commit 3b3c1d6

File tree

4 files changed

+106
-3
lines changed

4 files changed

+106
-3
lines changed

llama_cpp/llama.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Callable,
1818
)
1919
from collections import deque, OrderedDict
20+
from dataclasses import dataclass
2021

2122
import diskcache
2223
import ctypes
@@ -199,6 +200,42 @@ class StoppingCriteriaList(List[StoppingCriteria]):
199200
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
200201
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
201202

203+
# Custom data that is accessible to the beam_search_callback() function.
204+
@dataclass
205+
class beam_search_callback_data:
206+
ctx: llama_cpp.llama_context_p
207+
response_tokens: List[int]
208+
209+
def beam_view_to_string(ctx, beam_view):
210+
string = f"p({beam_view.p}): "
211+
for i in range(beam_view.n_tokens):
212+
string += str(llama_cpp.llama_token_to_str(ctx, beam_view.tokens[i]))
213+
return string
214+
215+
def is_at_eos(tokens, n_tokens) :
216+
return 0 < n_tokens and tokens[n_tokens-1] == llama_cpp.llama_token_eos();
217+
218+
# beam_search_callback requires a global dictionary to pass data via their object id.
219+
beam_search_dictionary = {}
220+
221+
# beam_search_callback() must flag beams when they reach end-of-sentence.
222+
# TODO: Use stop_sequences.
223+
def beam_search_callback(callback_data_id, beams_state):
224+
for i in range(beams_state.n_beams):
225+
beam_view = beams_state.beam_views[i]
226+
if not beam_view.eos and is_at_eos(beam_view.tokens, beam_view.n_tokens):
227+
beam_view.eos = True; # Flag beams as EOS as required.
228+
callback_data = beam_search_dictionary[callback_data_id]
229+
# Collect tokens into callback_data.response_tokens
230+
if 0 < beams_state.common_prefix_length:
231+
assert(0 < beams_state.n_beams);
232+
tokens = ctypes.cast(beams_state.beam_views[0].tokens, ctypes.POINTER(ctypes.c_int * beams_state.common_prefix_length)).contents
233+
callback_data.response_tokens.extend(tokens)
234+
235+
# DEBUG print beams and their relative probabilities
236+
# print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
237+
# for i in range(beams_state.n_beams):
238+
# print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))
202239

203240
class Llama:
204241
"""High-level Python wrapper for a llama.cpp model."""
@@ -475,6 +512,7 @@ def eval(self, tokens: Sequence[int]):
475512
tokens: The list of tokens to evaluate.
476513
"""
477514
assert self.ctx is not None
515+
478516
n_ctx = self._n_ctx
479517
for i in range(0, len(tokens), self.n_batch):
480518
batch = tokens[i : min(len(tokens), i + self.n_batch)]
@@ -719,6 +757,7 @@ def generate(
719757
logits_processor: Optional[LogitsProcessorList] = None,
720758
stopping_criteria: Optional[StoppingCriteriaList] = None,
721759
grammar: Optional[LlamaGrammar] = None,
760+
beam_width: int = 0,
722761
) -> Generator[int, Optional[Sequence[int]], None]:
723762
"""Create a generator of tokens from a prompt.
724763
@@ -760,6 +799,28 @@ def generate(
760799
if grammar is not None:
761800
grammar.reset()
762801

802+
if 0 < beam_width:
803+
print("beam_width=", beam_width)
804+
self.eval(tokens)
805+
callback_data = beam_search_callback_data(self.ctx, [])
806+
beam_search_dictionary[id(callback_data)] = callback_data
807+
callback = llama_cpp.llama_beam_search_callback(beam_search_callback)
808+
n_remain = llama_cpp.llama_n_ctx(self.ctx) - self.n_tokens
809+
llama_cpp.llama_beam_search(self.ctx, callback, id(callback_data),
810+
beam_width,
811+
self.n_tokens,
812+
n_remain,
813+
self.n_threads)
814+
beam_search_dictionary.pop(id(callback_data))
815+
# Ideally we would yield from within the callback, but that is impossible.
816+
for token in callback_data.response_tokens:
817+
string = llama_cpp.llama_token_to_str(self.ctx, token)
818+
np.append(self.input_ids, [token])
819+
np.append(self.scores, [0.0])
820+
self.n_tokens += 1
821+
yield token
822+
return
823+
763824
while True:
764825
self.eval(tokens)
765826
token = self.sample(
@@ -776,6 +837,7 @@ def generate(
776837
logits_processor=logits_processor,
777838
grammar=grammar,
778839
)
840+
779841
if stopping_criteria is not None and stopping_criteria(
780842
self._input_ids.tolist(), self._scores[-1, :].tolist()
781843
):
@@ -878,6 +940,7 @@ def _create_completion(
878940
stopping_criteria: Optional[StoppingCriteriaList] = None,
879941
logits_processor: Optional[LogitsProcessorList] = None,
880942
grammar: Optional[LlamaGrammar] = None,
943+
beam_width: int = 0,
881944
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
882945
assert self.ctx is not None
883946

@@ -956,6 +1019,7 @@ def _create_completion(
9561019
stopping_criteria=stopping_criteria,
9571020
logits_processor=logits_processor,
9581021
grammar=grammar,
1022+
beam_width=beam_width,
9591023
):
9601024
if token == self._token_eos:
9611025
text = self.detokenize(completion_tokens)
@@ -1301,6 +1365,7 @@ def create_completion(
13011365
stopping_criteria: Optional[StoppingCriteriaList] = None,
13021366
logits_processor: Optional[LogitsProcessorList] = None,
13031367
grammar: Optional[LlamaGrammar] = None,
1368+
beam_width: int = 0,
13041369
) -> Union[Completion, Iterator[CompletionChunk]]:
13051370
"""Generate text from a prompt.
13061371
@@ -1316,6 +1381,7 @@ def create_completion(
13161381
repeat_penalty: The penalty to apply to repeated tokens.
13171382
top_k: The top-k value to use for sampling.
13181383
stream: Whether to stream the results.
1384+
beam_width: Number of beams to use in beam search. 0 disables.
13191385
13201386
Raises:
13211387
ValueError: If the requested tokens exceed the context window.
@@ -1345,7 +1411,8 @@ def create_completion(
13451411
model=model,
13461412
stopping_criteria=stopping_criteria,
13471413
logits_processor=logits_processor,
1348-
grammar=grammar
1414+
grammar=grammar,
1415+
beam_width=beam_width,
13491416
)
13501417
if stream:
13511418
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -1376,6 +1443,7 @@ def __call__(
13761443
stopping_criteria: Optional[StoppingCriteriaList] = None,
13771444
logits_processor: Optional[LogitsProcessorList] = None,
13781445
grammar: Optional[LlamaGrammar] = None,
1446+
beam_width: int = 0,
13791447
) -> Union[Completion, Iterator[CompletionChunk]]:
13801448
"""Generate text from a prompt.
13811449
@@ -1391,6 +1459,7 @@ def __call__(
13911459
repeat_penalty: The penalty to apply to repeated tokens.
13921460
top_k: The top-k value to use for sampling.
13931461
stream: Whether to stream the results.
1462+
beam_width: Number of beams to use in beam search. 0 disables.
13941463
13951464
Raises:
13961465
ValueError: If the requested tokens exceed the context window.
@@ -1421,6 +1490,7 @@ def __call__(
14211490
stopping_criteria=stopping_criteria,
14221491
logits_processor=logits_processor,
14231492
grammar=grammar,
1493+
beam_width=beam_width,
14241494
)
14251495

14261496
def _convert_text_completion_to_chat(

llama_cpp/llama_cpp.py

+32
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,38 @@ def llama_grammar_accept_token(
13121312
]
13131313
_lib.llama_grammar_accept_token.restype = None
13141314

1315+
# Beam search types and function
1316+
class llama_beam_view(Structure):
1317+
_fields_ = [
1318+
("tokens", POINTER(c_int)),
1319+
("n_tokens", c_size_t),
1320+
("p", c_float),
1321+
("eos", c_bool)
1322+
]
1323+
1324+
class llama_beams_state(Structure):
1325+
_fields_ = [
1326+
("beam_views", POINTER(llama_beam_view)),
1327+
("n_beams", c_size_t),
1328+
("common_prefix_length", c_size_t),
1329+
("last_call", c_bool)
1330+
]
1331+
1332+
# typedef void (*llama_beam_search_callback_fn_t)(void* callback_data, llama_beams_state);
1333+
llama_beam_search_callback = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state)
1334+
1335+
def llama_beam_search(ctx: llama_context_p,
1336+
callback: llama_beam_search_callback,
1337+
callback_data: c_void_p,
1338+
n_beams: c_size_t,
1339+
n_past: c_int,
1340+
n_predict: c_int,
1341+
n_threads: c_int):
1342+
return _lib.llama_beam_search(ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads)
1343+
1344+
_lib.llama_beam_search.argtypes = [llama_context_p, llama_beam_search_callback, c_void_p, c_size_t, c_int, c_int, c_int]
1345+
_lib.llama_beam_search.restype = None
1346+
13151347
# Performance information
13161348

13171349

llama_cpp/server/app.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Settings(BaseSettings):
7070
ge=0,
7171
description="Last n tokens to keep for repeat penalty calculation.",
7272
)
73-
logits_all: bool = Field(default=True, description="Whether to return logits.")
73+
logits_all: bool = Field(default=False, description="Whether to return logits.")
7474
cache: bool = Field(
7575
default=False,
7676
description="Use a cache to reduce processing times for evaluated prompts.",
@@ -525,6 +525,7 @@ class CreateCompletionRequest(BaseModel):
525525
top_k: int = top_k_field
526526
repeat_penalty: float = repeat_penalty_field
527527
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
528+
beam_width: int = 0
528529

529530
model_config = {
530531
"json_schema_extra": {

vendor/llama.cpp

0 commit comments

Comments
 (0)