diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 5d3c52a8a..edf4d0fbb 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,4 +1,5 @@ import datetime +import re from copy import copy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union @@ -81,9 +82,18 @@ def is_stop_sequence_found( ] ) - def strip_stop_sequences( - self, sequence: str, stop_sequences: Optional[List[str]] - ) -> str: + @staticmethod + def strip_max_words_sequences(sequence: str, max_words: Optional[int]) -> str: + if max_words is not None: + splits = sequence.split() + if len(splits) > max_words: + last_word = splits[-1] + sequence = sequence.rstrip(last_word).rstrip() + + return sequence + + @staticmethod + def strip_stop_sequences(sequence: str, stop_sequences: Optional[List[str]]) -> str: """Remove the stop sequences from the generated sequences. Parameters @@ -130,6 +140,7 @@ def __call__( self, prompts: Union[str, List[str]], max_tokens: Optional[int] = None, + max_words: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, rng: Optional["torch.Generator"] = None, ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: @@ -147,7 +158,12 @@ def __call__( generating the first token. max_tokens An integer representing maximum number of tokens that will be generated - (per prompt) + (per prompt). If both `max_tokens` and `max_words` are passed, it will + stop when the first one is reached + max_words + An integer representing maximum number of words that will be generated + (per prompt). If both `max_tokens` and `max_words` are passed, it will + stop when the first one is reached stop_at A string or list of strings at which the text generated will stop rng @@ -202,16 +218,29 @@ def __call__( rng=rng, ) + # If we have max_words but no max_tokens, let's put a limit on the number of tokens + # so that we reduce the generation time and do not exceed context length if + # no stop token is met. + # A high estimation of average number of tokens per word in a multilanguage + # context is 2, let's take some precaution and increase it a bit to 3 + if max_words and max_tokens is None: + max_tokens = 3 * max_words + while True: try: last_state = next(states) - if max_tokens or stop_sequences: + if max_tokens or max_words or stop_sequences: token_ids = last_state.token_ids generated_token_ids = self.get_generated_token_ids( prompt_token_ids, token_ids ) if max_tokens and len(generated_token_ids[0]) >= max_tokens: break + if max_words and all( + len(sentence.split()) > max_words + for sentence in self.tokenizer.decode(generated_token_ids) + ): + break if stop_sequences and self.is_stop_sequence_found( self.tokenizer.decode(generated_token_ids), stop_sequences ): @@ -223,9 +252,13 @@ def __call__( generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) generated = self.tokenizer.decode(generated_token_ids) + max_words_stripped = [ + self.strip_max_words_sequences(sequence, max_words) + for sequence in generated + ] stripped = [ self.strip_stop_sequences(sequence, stop_sequences) - for sequence in generated + for sequence in max_words_stripped ] formatted = [self.format_sequence(sequence) for sequence in stripped] @@ -248,6 +281,7 @@ def stream( self, prompts: Union[str, List[str]], max_tokens: Optional[int] = None, + max_words: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, rng: Optional["torch.Generator"] = None, ) -> Iterator[Union[List[str], str, List[List[str]]]]: @@ -284,6 +318,9 @@ def stream( if isinstance(stop_at, str): stop_at = [stop_at] + if max_words and max_tokens is None: + max_tokens = 3 * max_words + stop_sequences = stop_at num_samples = self.num_samples @@ -328,9 +365,12 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ] * num_samples num_generated = 0 is_stop_at_reached = [False for _ in range(batch_size)] * num_samples + is_max_words_at_reached = [False for _ in range(batch_size)] * num_samples while True: - if (max_tokens and num_generated >= max_tokens) or all( - is_stop_at_reached + if ( + (max_tokens and num_generated >= max_tokens) + or all(is_stop_at_reached) + or all(is_max_words_at_reached) ): return try: @@ -340,6 +380,21 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: return generated_token_ids = sequence.token_ids[:, -num_generated:] generated_sequences = self.tokenizer.decode(generated_token_ids) + if max_words is not None: + is_max_words_at_reached = [ + stop or len(generated_sequence.split()) > max_words + for generated_sequence, stop in zip( + generated_sequences, is_max_words_at_reached + ) + ] + generated_sequences = [ + self.strip_max_words_sequences(sequence, max_words) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_max_words_at_reached + ) + ] if stop_sequences: is_stop_at_reached = [ stop @@ -473,16 +528,36 @@ def _format(self, sequences): else: return self.format_sequence(sequences) + @staticmethod + def reconstruct_till_max_words(sequence: str, max_words: Optional[int]) -> str: + if max_words is not None: + if len(sequence.split()) > max_words: + matches = re.findall(r"(\s*\S+)(\s*)", sequence) + return "".join( + word + whitespace for word, whitespace in matches[:max_words] + ).rstrip() + + return sequence + def __call__( self, prompts: Union[str, List[str]], max_tokens: Optional[int] = None, + max_words: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, seed: Optional[int] = None, **model_specific_params, ): """Generate text from a prompt of list of prompts.""" + # If we have max_words but no max_tokens, let's put a limit on the number of tokens + # so that we reduce the generation time and do not exceed context length if + # no stop token is met. + # A high estimation of average number of tokens per word in a multilanguage + # context is 2, let's take some precaution and increase it a bit to 3 + if max_words and max_tokens is None: + max_tokens = 3 * max_words + generation_params = self.prepare_generation_parameters( max_tokens, stop_at, seed ) @@ -495,6 +570,13 @@ def __call__( **model_specific_params, ) + if isinstance(completions, str): + completions = self.reconstruct_till_max_words(completions, max_words) + else: + completions = [ + self.reconstruct_till_max_words(seq, max_words) for seq in completions + ] + return self._format(completions) def stream( diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 3d36faa5d..326c4f47d 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -245,6 +245,17 @@ def test_generate_text(request, model_fixture, sampler_name): assert isinstance(res, str) +@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search")) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text_max_words(request, model_fixture, sampler_name): + max_words = 5 + model = request.getfixturevalue(model_fixture) + generator = generate.text(model, getattr(samplers, sampler_name)()) + with enforce_not_implemented(model_fixture, sampler_name): + res = generator("Write a long sentence", max_words=max_words) + assert len(res.split()) <= max_words + + @pytest.mark.parametrize("pattern", REGEX_PATTERNS) @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_regex(request, model_fixture, pattern): diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index 5a2edf8dc..f83ee62b7 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -495,3 +495,97 @@ def test_expand_attention_masks(attention_masks, ancestors, expected_result): def test_bias_logits(logits, indices_to_mask, expected): masked_logits = bias_logits(logits, indices_to_mask) assert torch.equal(masked_logits, expected) + + +def test_generator_max_words(): + class MockFSM: + first_state = 0 + + def get_next_state(self, state, next_token_ids): + return 4 + + def get_next_instruction(self, *_): + return Generate([4]) + + def is_final_state(self, _): + return False # let's generate tokens for ever + + def copy(self): + return self + + class MockTokenizer: + def encode(self, _): + # Input: "test" + return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) + + def decode(self, tokens): + return [" ".join(["test" for _ in tokens[0]])] + + class MockModel: + def __init__(self): + self.tokenizer = MockTokenizer() + + def __call__(*_): + return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None + + class sampler: + def __init__(self): + self.samples = 1 + + def __call__(self, biased_logits, *_): + return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None + + generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu") + result = generator("test", max_words=3) + assert result == "test test test" + + sequence = generator.stream("test", max_words=3) + assert "".join(sequence) == "test test test" + + +def test_generator_max_tokens_from_max_words(): + class MockFSM: + first_state = 0 + + def get_next_state(self, state, next_token_ids): + return 4 + + def get_next_instruction(self, *_): + return Generate([4]) + + def is_final_state(self, _): + return False # let's generate tokens for ever + + def copy(self): + return self + + class MockTokenizer: + def encode(self, _): + # Input: "test" + return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) + + def decode(self, tokens): + return [ + "123456789"[: len(tokens[0])] + ] # not generating any word seperated by white space + + class MockModel: + def __init__(self): + self.tokenizer = MockTokenizer() + + def __call__(*_): + return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None + + class sampler: + def __init__(self): + self.samples = 1 + + def __call__(self, biased_logits, *_): + return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None + + generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu") + result = generator("test", max_words=2) # should generate max_words * 3 tokens + assert result == "123456" + + sequence = generator.stream("test", max_words=2) + assert "".join(sequence) == "123456"