Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Parallel generation implemenation #1209

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,17 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True

def set_batch_parallel(self, batch: Sequence[int], position: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
self.batch.n_tokens = n_tokens
for i in range(n_tokens):
self.batch.token[i] = batch[i]
self.batch.pos[i] = position
self.batch.seq_id[i][0] = i
self.batch.n_seq_id[i] = 1
self.batch.logits[i] = logits_all

def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
Expand Down
133 changes: 133 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,27 @@ def eval(self, tokens: Sequence[int]):
# Update n_tokens
self.n_tokens += n_tokens

def eval_parallel(self, tokens: List[int]):
"""Evaluate a list of tokens in different sequences but at the same position.

Args:
position: The position to evaluate the tokens at.
tokens: The list of tokens to evaluate.
"""
assert self._ctx.ctx is not None
assert self._batch.batch is not None
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)

n_past = self.n_tokens
n_tokens = len(tokens)
self._batch.set_batch_parallel(batch=tokens, position=n_past, logits_all=True)
self._ctx.decode(self._batch)
# Save logits
size = n_tokens * self._n_vocab
self._scores.reshape(-1)[:size] = self._ctx.get_logits()[:size]
# Update n_tokens
self.n_tokens += 1

def sample(
self,
top_k: int = 40,
Expand Down Expand Up @@ -714,6 +735,94 @@ def generate(
]
)

@staticmethod
def longest_common_prefix(vecs):
if (max_len := min([len(s) for s in vecs], default=0)) == 0:
return []
for i in range(max_len):
if len(set([s[i] for s in vecs])) > 1:
return vecs[0][:i]
else:
return vecs[0][:max_len]

def generate_parallel(
self,
tokens: List[List[int]],
max_tokens: Optional[int] = None,
**kwargs
) -> Iterator[List[int]]:
# get prompt and token counts
n_parallel = len(tokens)
n_tokens = [len(toks) for toks in tokens]

# set default max_tokens
max_prompt = max(n_tokens)
if max_tokens is None:
max_tokens = self._n_ctx // n_parallel - max_prompt
max_length = max_prompt + max_tokens

# check for overflows
if max_tokens <= 0:
raise ValueError(f"Maximum number of tokens exceeded")

# Run longest prefix in serial to populate kv cache. In the simplest case we look for the
# longest common prefix, but in general we could look for prefixes that are shared by certain
# subsets of the prompts.

# find the longest common prefix
prefix_tokens = self.longest_common_prefix(tokens)
prefix_len = len(prefix_tokens)

# reset batch and run prefix eval
self.reset()
self.eval(prefix_tokens)

# copy the kv_cache to other streams
for i in range(n_parallel):
llama_cpp.llama_kv_cache_seq_cp(self.ctx, 0, i, 0, prefix_len - 1)

# remember the batch index of the last token for each parallel sequence
i_batch = [prefix_len - 1 for _ in range(n_parallel)]

# since the prompts may be of different lengths, just yield the common prefix
for i in range(prefix_len):
result = [tokens[j][i] for j in range(n_parallel)]
yield result

# run the decoding loop
for k in range(prefix_len, max_length):
# sample the next token for each parallel sequence / stream
new_ids = []
for i in range(n_parallel):
# if the stream has already finished
if i_batch[i] < 0:
continue

# see if we're still in the prompt
if k < n_tokens[i]:
new_id = tokens[i][k]
else:
# get last logits and sample a new token
new_id = self.sample(idx=i_batch[i], **kwargs)

# is it an end of stream? -> mark the stream as finished
if new_id == self._token_eos:
i_batch[i] = -1
continue

# increment counters
i_batch[i] = len(new_ids)
new_ids.append(new_id)

# check for done or run next eval
if len(new_ids) == 0:
break
else:
self.eval_parallel(new_ids)

# yield new tokens
yield [new_ids[j] if j >= 0 else None for j in i_batch]

def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> CreateEmbeddingResponse:
Expand Down Expand Up @@ -1460,6 +1569,30 @@ def create_completion(
completion: Completion = next(completion_or_chunks) # type: ignore
return completion

def _create_completion_parallel(
self,
prompts: List[str],
**kwargs
) -> Iterator[List[str]]:
tokens: List[List[int]] = [self.tokenize(p.encode("utf-8")) for p in prompts]
for toks in self.generate_parallel(tokens, **kwargs):
yield [
self.detokenize([tok]).decode("utf-8") if tok is not None else ""
for tok in toks
]

def create_completion_parallel(
self,
prompts: List[str],
stream: bool = False,
**kwargs
) -> List[str]:
genpar: Iterator[List[str]] = self._create_completion_parallel(prompts, **kwargs)
if stream:
return genpar
else:
return ["".join(toks) for toks in zip(*genpar)]

def __call__(
self,
prompt: str,
Expand Down
Loading