forked from Trojaner/text-generation-webui-stable_diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformers_logits.py
67 lines (53 loc) · 2.45 KB
/
transformers_logits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Implementation taken from "outlines":
# https://github.com/outlines-dev/outlines
#
# License: Apache License 2.0:
# https://github.com/outlines-dev/outlines/blob/68b71ae810e0d6815a83df525da6d707cd4e971a/LICENSE
from typing import Optional, Type, Union
import torch
from outlines.fsm.guide import Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str
from pydantic import BaseModel
from transformers import LogitsProcessor, PreTrainedTokenizerBase
from typing_extensions import override
class FsmLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizerBase, fsm: Guide):
self.fsm = fsm
self._tokenizer = tokenizer
self._fsm_state = 0
self._is_first_token = True
@override
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
is_first_token = self._is_first_token
if self._is_first_token:
self._is_first_token = False
mask = torch.full_like(scores, -float("inf"))
for i in range(len(input_ids)):
if not is_first_token:
last_token = int(input_ids[i][-1].item())
self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token)
allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens
mask[i][allowed_tokens] = 0
biased_scores = scores + mask
return biased_scores # type: ignore
def copy(self) -> "FsmLogitsProcessor":
return FsmLogitsProcessor(tokenizer=self._tokenizer, fsm=self.fsm.copy())
class RegexLogitsProcessor(FsmLogitsProcessor):
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
assert isinstance(tokenizer, PreTrainedTokenizerBase)
fsm = RegexGuide(regex_string, tokenizer)
super().__init__(tokenizer=tokenizer, fsm=fsm)
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(
self,
schema: Union[dict, Type[BaseModel], str],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None,
):
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
tokenizer = adapt_tokenizer(tokenizer=tokenizer)
super().__init__(regex_string=regex_string, tokenizer=tokenizer)