-
Notifications
You must be signed in to change notification settings - Fork 575
/
Copy pathapi.py
691 lines (588 loc) · 23.3 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
import datetime
import re
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from outlines.generate.generator import sequence_generator
if TYPE_CHECKING:
import torch
FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]
class SequenceGenerator:
def __init__(
self,
fsm,
model,
sampler,
device,
):
self.fsm = fsm
self.model = model
self.sampler = sampler
self.tokenizer = model.tokenizer
self.device = device
self.num_samples = sampler.samples
def get_generated_token_ids(
self,
prompt_token_ids: "torch.Tensor",
token_ids: "torch.Tensor",
) -> List["torch.Tensor"]:
"""Get the tokens generated so far.
Parameters
----------
prompt_token_ids
Tensor that contains the token ids of the sequences' prompts.
token_ids
The generated token ids.
Returns
-------
A tensor that contains the token ids that have been generated so far.
"""
prompt_lengths = [len(prompt) for prompt in prompt_token_ids]
token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(token_ids, prompt_lengths)
]
return token_ids
def is_stop_sequence_found(
self, generated_sequences: List[str], stop_sequences: List[str]
) -> bool:
"""Determine whether one of the stop sequences has been generated.
Parameters
----------
generated_sequences
The list of sequences generated so far.
stop_sequences
The list that contains the sequence which stop the generation when
found.
Returns
-------
True if at least one of the stop sequences has been found in each generated
sequence.
"""
return all(
[
any([seq in generated for seq in stop_sequences])
for generated in generated_sequences
]
)
@staticmethod
def strip_max_words_sequences(sequence: str, max_words: Optional[int]) -> str:
if max_words is not None:
splits = sequence.split()
if len(splits) > max_words:
last_word = splits[-1]
sequence = sequence.rstrip(last_word).rstrip()
return sequence
@staticmethod
def strip_stop_sequences(sequence: str, stop_sequences: Optional[List[str]]) -> str:
"""Remove the stop sequences from the generated sequences.
Parameters
----------
sequence
One of the generated sequences.
stop_sequences
The list that contains the sequence which stop the generation when
found.
"""
if stop_sequences:
match_indexes = [sequence.find(seq) for seq in stop_sequences]
if any([index != -1 for index in match_indexes]):
# select the stop_sequence that is found first in the sequence
min_match_index_value = min([i for i in match_indexes if i != -1])
min_match_index_pos = match_indexes.index(min_match_index_value)
sequence = sequence[
: match_indexes[min_match_index_pos]
+ len(stop_sequences[min_match_index_pos])
]
return sequence
def format_sequence(self, sequence: str) -> FormattedOutput:
"""Translate the generated sequence to another type.
This method is for instance overridden when generating JSON to either
return a dictionnary or a Pydantic model.
Parameters
----------
sequence
A generated sequences.
Returns
-------
The formatted sequence.
"""
return sequence
def __call__(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
max_words: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.
Since `SequenceGenerator.stream` calls the tokenizer at every step this
method loops over the generator returned by `sequence_generator` itself
so the tokenizer is called only once after all token ids have been
generated.
Parameters
----------
prompts
A string or list of strings that are passed to the model before
generating the first token.
max_tokens
An integer representing maximum number of tokens that will be generated
(per prompt). If both `max_tokens` and `max_words` are passed, it will
stop when the first one is reached
max_words
An integer representing maximum number of words that will be generated
(per prompt). If both `max_tokens` and `max_words` are passed, it will
stop when the first one is reached
stop_at
A string or list of strings at which the text generated will stop
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
Returns
-------
The generation(s), potentially cast to another type.
"""
import torch
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(stop_at, str):
stop_at = [stop_at]
stop_sequences = stop_at
num_samples = self.num_samples
if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()
prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(self.device)
# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
# FSM states.
num_samples = self.num_samples
batch_size = len(prompts)
prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [0 for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
weights = torch.zeros(
(batch_size * num_samples), dtype=torch.float, device=self.device
)
states = sequence_generator(
self.model,
self.sampler,
fsms,
prompt_token_ids,
weights,
attention_masks,
fsm_states,
rng=rng,
)
# If we have max_words but no max_tokens, let's put a limit on the number of tokens
# so that we reduce the generation time and do not exceed context length if
# no stop token is met.
# A high estimation of average number of tokens per word in a multilanguage
# context is 2, let's take some precaution and increase it a bit to 3
if max_words and max_tokens is None:
max_tokens = 3 * max_words
while True:
try:
last_state = next(states)
if max_tokens or max_words or stop_sequences:
token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, token_ids
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
if max_words and all(
len(sentence.split()) > max_words
for sentence in self.tokenizer.decode(generated_token_ids)
):
break
if stop_sequences and self.is_stop_sequence_found(
self.tokenizer.decode(generated_token_ids), stop_sequences
):
break
except StopIteration:
break
token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)
generated = self.tokenizer.decode(generated_token_ids)
max_words_stripped = [
self.strip_max_words_sequences(sequence, max_words)
for sequence in generated
]
stripped = [
self.strip_stop_sequences(sequence, stop_sequences)
for sequence in max_words_stripped
]
formatted = [self.format_sequence(sequence) for sequence in stripped]
# We reshape the output to (batch_size, sample_size)
output: List[List[FormattedOutput]] = list()
for i in range(0, batch_size * num_samples, num_samples):
output.append(formatted[i : i + num_samples])
# We remove leading dimensions for the output
if batch_size == 1 and num_samples == 1:
return output[0][0]
elif batch_size == 1:
return output[0]
elif num_samples == 1:
return [samples[0] for samples in output]
else:
return output
def stream(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
max_words: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
) -> Iterator[Union[List[str], str, List[List[str]]]]:
"""Generate the text sequence one token at a time.
Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
choice but to decode the generated token ids at each step and compare the
current decoded strings to the previously decoded strings.
Parameters
----------
prompts
A string or list of strings that are passed to the model before
generating the first token.
max_tokens
An integer representing maximum number of tokens that will be generated
(per prompt)
stop_at
A string or list of strings at which the text generated will stop
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
Returns
-------
A string or list of strings that contain the generated text.
"""
import torch
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(stop_at, str):
stop_at = [stop_at]
if max_words and max_tokens is None:
max_tokens = 3 * max_words
stop_sequences = stop_at
num_samples = self.num_samples
prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(prompt_token_ids.device)
# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
# FSM states.
num_samples = self.num_samples
batch_size = len(prompts)
prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [0 for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
weights = torch.zeros(
(batch_size * num_samples),
dtype=torch.float,
device=prompt_token_ids.device,
)
if rng is None:
rng = torch.Generator(device=prompt_token_ids.device)
rng.seed()
states = sequence_generator(
self.model,
self.sampler,
fsms,
prompt_token_ids,
weights,
attention_masks,
fsm_states,
rng=rng,
)
def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
previously_generated_sequences = [
"" for _ in range(batch_size)
] * num_samples
num_generated = 0
is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
is_max_words_at_reached = [False for _ in range(batch_size)] * num_samples
while True:
if (
(max_tokens and num_generated >= max_tokens)
or all(is_stop_at_reached)
or all(is_max_words_at_reached)
):
return
try:
sequence = next(states)
num_generated += 1
except StopIteration:
return
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_sequences = self.tokenizer.decode(generated_token_ids)
if max_words is not None:
is_max_words_at_reached = [
stop or len(generated_sequence.split()) > max_words
for generated_sequence, stop in zip(
generated_sequences, is_max_words_at_reached
)
]
generated_sequences = [
self.strip_max_words_sequences(sequence, max_words)
if stop
else sequence
for sequence, stop in zip(
generated_sequences, is_max_words_at_reached
)
]
if stop_sequences:
is_stop_at_reached = [
stop
or self.is_stop_sequence_found(
[generated_sequence], stop_sequences
)
for generated_sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
]
generated_sequences = [
(
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
)
if stop
else sequence
)
for sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
]
next_tokens = [
token[len(sequence) :]
for token, sequence, stop in zip(
generated_sequences,
previously_generated_sequences,
is_stop_at_reached,
)
]
previously_generated_sequences = generated_sequences
# We reshape the output to (batch_size, sample_size)
output: List[List[str]] = list()
for i in range(0, batch_size * num_samples, num_samples):
output.append(next_tokens[i : i + num_samples])
# We remove leading dimensions for the output
if batch_size == 1 and num_samples == 1:
yield output[0][0]
elif batch_size == 1:
yield output[0]
elif num_samples == 1:
yield [samples[0] for samples in output]
else:
yield output
return token_generator()
@dataclass(frozen=True)
class GenerationParameters:
"""Generation parameters used in Outlines' public API."""
max_tokens: Optional[int]
stop_at: Optional[Union[str, List[str]]]
seed: Optional[int]
@dataclass(frozen=True)
class SamplingParameters:
"""Sampling parameters available in Outlines."""
sampler: str
num_samples: int = 1
top_p: Optional[float] = None
top_k: Optional[int] = None
temperature: Optional[float] = None
class SequenceGeneratorAdapter:
"""Class used to unify the interface to the model providers'
generation functions.
Attributes
----------
model
The wrapped model.
logits_processor
The logits processor to use to generate text.
sampler
The sampler to use to generate text.
"""
def __init__(self, model, logits_processor, sampler):
self.model = model
self.logits_processor = logits_processor
self.sampling_params = sampler.sampling_params
def prepare_generation_parameters(
self,
max_tokens: Optional[int],
stop_at: Optional[Union[str, List[str]]],
seed: Optional[int],
):
if isinstance(stop_at, str):
stop_at = [stop_at]
generation_params = GenerationParameters(
max_tokens,
stop_at,
seed,
)
return generation_params
def format_sequence(self, sequence: str) -> FormattedOutput:
"""Translate the generated sequence to another type.
This method is for instance overridden when generating JSON to either
return a dictionnary or a Pydantic model.
Parameters
----------
sequence
A generated sequences.
Returns
-------
The formatted sequence.
"""
return sequence
def _format(self, sequences):
"""Apply formatting to every string in a completion."""
if isinstance(sequences, list):
return [self._format(sequence) for sequence in sequences]
else:
return self.format_sequence(sequences)
@staticmethod
def reconstruct_till_max_words(sequence: str, max_words: Optional[int]) -> str:
if max_words is not None:
if len(sequence.split()) > max_words:
matches = re.findall(r"(\s*\S+)(\s*)", sequence)
return "".join(
word + whitespace for word, whitespace in matches[:max_words]
).rstrip()
return sequence
def __call__(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
max_words: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Generate text from a prompt of list of prompts."""
# If we have max_words but no max_tokens, let's put a limit on the number of tokens
# so that we reduce the generation time and do not exceed context length if
# no stop token is met.
# A high estimation of average number of tokens per word in a multilanguage
# context is 2, let's take some precaution and increase it a bit to 3
if max_words and max_tokens is None:
max_tokens = 3 * max_words
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
completions = self.model.generate(
prompts,
generation_params,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
if isinstance(completions, str):
completions = self.reconstruct_till_max_words(completions, max_words)
else:
completions = [
self.reconstruct_till_max_words(seq, max_words) for seq in completions
]
return self._format(completions)
def stream(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(
prompts,
generation_params,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter):
def __call__( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[str, Any],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""
Generate text from a prompt of list of prompts.
Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
"""
prompts, media = self._validate_prompt_media_types(prompts, media)
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
completions = self.model.generate(
prompts,
media,
generation_params,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
return self._format(completions)
def stream( # type: ignore
self,
prompts: Union[str, List[str]],
media: List[Union[str, Any, List[Union[str, Any]]]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""
prompts, media = self._validate_prompt_media_types(prompts, media)
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(
prompts,
media,
generation_params,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
@classmethod
def _validate_prompt_media_types(
cls,
prompts: Union[str, List[str]],
media: Union[str, Any, List[Union[str, Any]]],
) -> Union[Any, List[Any]]:
"""
Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image]
"""
def valid_types(prompts, media):
from PIL import Image # type: ignore
if isinstance(prompts, list):
if not isinstance(media, list) or len(prompts) != len(media):
return False
for subprompt, submedia in zip(prompts, media):
if not isinstance(subprompt, str) or not all(
isinstance(m, Image.Image) for m in submedia
):
return False
elif isinstance(prompts, str):
if not all(isinstance(m, Image.Image) for m in media):
return False
return True
if not valid_types(prompts, media):
raise TypeError(
"Expected (prompts, media) to be of type "
"(str, List[Image])), or (List[str], List[List[Image]]) "
f"instead got prompts={prompts}, media={media}"
)
return prompts, media