Skip to content

Commit

Permalink
feat(generator): add maximum number of words limit in generation
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz committed Jan 15, 2025
1 parent 088f439 commit cb65326
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 8 deletions.
95 changes: 87 additions & 8 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import re
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
Expand Down Expand Up @@ -81,9 +82,18 @@ def is_stop_sequence_found(
]
)

def strip_stop_sequences(
self, sequence: str, stop_sequences: Optional[List[str]]
) -> str:
@staticmethod
def strip_max_words_sequences(sequence: str, max_words: Optional[int]) -> str:
if max_words is not None:
splits = sequence.split()
if len(splits) > max_words:
last_word = splits[-1]
sequence = sequence.rstrip(last_word).rstrip()

return sequence

@staticmethod
def strip_stop_sequences(sequence: str, stop_sequences: Optional[List[str]]) -> str:
"""Remove the stop sequences from the generated sequences.
Parameters
Expand Down Expand Up @@ -130,6 +140,7 @@ def __call__(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
max_words: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
Expand All @@ -147,7 +158,12 @@ def __call__(
generating the first token.
max_tokens
An integer representing maximum number of tokens that will be generated
(per prompt)
(per prompt). If both `max_tokens` and `max_words` are passed, it will
stop when the first one is reached
max_words
An integer representing maximum number of words that will be generated
(per prompt). If both `max_tokens` and `max_words` are passed, it will
stop when the first one is reached
stop_at
A string or list of strings at which the text generated will stop
rng
Expand Down Expand Up @@ -202,16 +218,29 @@ def __call__(
rng=rng,
)

# If we have max_words but no max_tokens, let's put a limit on the number of tokens
# so that we reduce the generation time and do not exceed context length if
# no stop token is met.
# A high estimation of average number of tokens per word in a multilanguage
# context is 2, let's take some precaution and increase it a bit to 3
if max_words and max_tokens is None:
max_tokens = 3 * max_words

while True:
try:
last_state = next(states)
if max_tokens or stop_sequences:
if max_tokens or max_words or stop_sequences:
token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, token_ids
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
if max_words and all(
len(sentence.split()) > max_words
for sentence in self.tokenizer.decode(generated_token_ids)
):
break
if stop_sequences and self.is_stop_sequence_found(
self.tokenizer.decode(generated_token_ids), stop_sequences
):
Expand All @@ -223,9 +252,13 @@ def __call__(
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)

generated = self.tokenizer.decode(generated_token_ids)
max_words_stripped = [
self.strip_max_words_sequences(sequence, max_words)
for sequence in generated
]
stripped = [
self.strip_stop_sequences(sequence, stop_sequences)
for sequence in generated
for sequence in max_words_stripped
]
formatted = [self.format_sequence(sequence) for sequence in stripped]

Expand All @@ -248,6 +281,7 @@ def stream(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
max_words: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
) -> Iterator[Union[List[str], str, List[List[str]]]]:
Expand Down Expand Up @@ -328,9 +362,12 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
] * num_samples
num_generated = 0
is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
is_max_words_at_reached = [False for _ in range(batch_size)] * num_samples
while True:
if (max_tokens and num_generated >= max_tokens) or all(
is_stop_at_reached
if (
(max_tokens and num_generated >= max_tokens)
or all(is_stop_at_reached)
or all(is_max_words_at_reached)
):
return
try:
Expand All @@ -340,6 +377,21 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
return
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_sequences = self.tokenizer.decode(generated_token_ids)
if max_words is not None:
is_max_words_at_reached = [
stop or len(generated_sequence.split()) > max_words
for generated_sequence, stop in zip(
generated_sequences, is_max_words_at_reached
)
]
generated_sequences = [
self.strip_max_words_sequences(sequence, max_words)
if stop
else sequence
for sequence, stop in zip(
generated_sequences, is_max_words_at_reached
)
]
if stop_sequences:
is_stop_at_reached = [
stop
Expand Down Expand Up @@ -473,16 +525,36 @@ def _format(self, sequences):
else:
return self.format_sequence(sequences)

@staticmethod
def reconstruct_till_max_words(sequence: str, max_words: Optional[int]) -> str:
if max_words is not None:
if len(sequence.split()) > max_words:
matches = re.findall(r"(\s*\S+)(\s*)", sequence)
return "".join(
word + whitespace for word, whitespace in matches[:max_words]
).rstrip()

return sequence

def __call__(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
max_words: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Generate text from a prompt of list of prompts."""

# If we have max_words but no max_tokens, let's put a limit on the number of tokens
# so that we reduce the generation time and do not exceed context length if
# no stop token is met.
# A high estimation of average number of tokens per word in a multilanguage
# context is 2, let's take some precaution and increase it a bit to 3
if max_words and max_tokens is None:
max_tokens = 3 * max_words

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
Expand All @@ -495,6 +567,13 @@ def __call__(
**model_specific_params,
)

if isinstance(completions, str):
completions = self.reconstruct_till_max_words(completions, max_words)
else:
completions = [
self.reconstruct_till_max_words(seq, max_words) for seq in completions
]

return self._format(completions)

def stream(
Expand Down
11 changes: 11 additions & 0 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,17 @@ def test_generate_text(request, model_fixture, sampler_name):
assert isinstance(res, str)


@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search"))
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_text_max_words(request, model_fixture, sampler_name):
max_words = 5
model = request.getfixturevalue(model_fixture)
generator = generate.text(model, getattr(samplers, sampler_name)())
with enforce_not_implemented(model_fixture, sampler_name):
res = generator("Write a long sentence", max_words=max_words)
assert len(res.split()) <= max_words


@pytest.mark.parametrize("pattern", REGEX_PATTERNS)
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_regex(request, model_fixture, pattern):
Expand Down
88 changes: 88 additions & 0 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,91 @@ def test_expand_attention_masks(attention_masks, ancestors, expected_result):
def test_bias_logits(logits, indices_to_mask, expected):
masked_logits = bias_logits(logits, indices_to_mask)
assert torch.equal(masked_logits, expected)


def test_generator_max_words():
class MockFSM:
first_state = 0

def get_next_state(self, state, next_token_ids):
return 4

def get_next_instruction(self, *_):
return Generate([4])

def is_final_state(self, _):
return False # let's generate tokens for ever

def copy(self):
return self

class MockTokenizer:
def encode(self, _):
# Input: "test"
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])

def decode(self, tokens):
return [" ".join(["test" for _ in tokens[0]])]

class MockModel:
def __init__(self):
self.tokenizer = MockTokenizer()

def __call__(*_):
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None

class sampler:
def __init__(self):
self.samples = 1

def __call__(self, biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None

generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
result = generator("test", max_words=5)
assert result == "test test test test test"


def test_generator_max_tokens_from_max_words():
class MockFSM:
first_state = 0

def get_next_state(self, state, next_token_ids):
return 4

def get_next_instruction(self, *_):
return Generate([4])

def is_final_state(self, _):
return False # let's generate tokens for ever

def copy(self):
return self

class MockTokenizer:
def encode(self, _):
# Input: "test"
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])

def decode(self, tokens):
return [
"123456789"[: len(tokens[0])]
] # not generating any word seperated by white space

class MockModel:
def __init__(self):
self.tokenizer = MockTokenizer()

def __call__(*_):
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None

class sampler:
def __init__(self):
self.samples = 1

def __call__(self, biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None

generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
result = generator("test", max_words=2) # should generate max_words * 3 tokens
assert result == "123456"

0 comments on commit cb65326

Please sign in to comment.