Skip to content

Commit 4aa4ebd

Browse files
committed
Add ExLlamaV2Sampler.Settings.logits_processor
1 parent 10a8842 commit 4aa4ebd

File tree

6 files changed

+469
-8
lines changed

6 files changed

+469
-8
lines changed

examples/json_schema_outlines.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Install Outlines:
2+
# pip install outlines
3+
4+
# Download Model:
5+
# huggingface-cli download bartowski/Phi-3.1-mini-4k-instruct-exl2 --revision 6_5 --local-dir Phi-3.1-mini-4k-instruct-exl2-6_5
6+
7+
import sys, os
8+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9+
10+
11+
12+
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
13+
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
14+
15+
from outlines.processors import JSONLogitsProcessor
16+
from outlines.models.exllamav2 import patch_tokenizer as patch_exl2_tokenizer_for_outlines
17+
18+
from pydantic import BaseModel, Field, RootModel
19+
from typing import Optional, Union, Literal
20+
from datetime import time
21+
22+
23+
################################################
24+
# Create Structured JSON Generator With Outlines
25+
################################################
26+
27+
# Additional Examples: https://outlines-dev.github.io/outlines/cookbook/
28+
# JSON Generation Docs: https://outlines-dev.github.io/outlines/reference/json/
29+
# `outlines.processors` also supports guaranteed regex patterns and lark grammars
30+
31+
# Example: Home Assistant extension for natural language commands -> actions
32+
class LightAction(BaseModel):
33+
entity: Literal["light"] = "light"
34+
action: Literal["turn_on", "turn_off", "set_brightness"]
35+
brightness: Optional[int] = Field(None, ge=0, le=100)
36+
execute_at: Optional[time] = None
37+
38+
39+
class OvenAction(BaseModel):
40+
entity: Literal["oven"] = "oven"
41+
action: Literal["turn_on", "turn_off", "set_temperature"]
42+
temperature: Optional[float] = Field(None, ge=50, le=300)
43+
execute_at: Optional[time] = None
44+
45+
46+
class HomeAssistantAction(BaseModel):
47+
instruction: Union[LightAction, OvenAction]
48+
49+
50+
def create_generator(model_dir="/mnt/str/models/mistral-7b-exl2/4.0bpw"):
51+
config = ExLlamaV2Config(model_dir)
52+
config.arch_compat_overrides()
53+
model = ExLlamaV2(config)
54+
cache = ExLlamaV2Cache(model, max_seq_len=32768, lazy=True)
55+
model.load_autosplit(cache, progress=True)
56+
57+
print("Loading tokenizer...")
58+
tokenizer = ExLlamaV2Tokenizer(config)
59+
tokenizer.vocabulary = tokenizer.extended_piece_to_id
60+
61+
# Initialize the generator with all default parameters
62+
return ExLlamaV2DynamicGenerator(
63+
model=model,
64+
cache=cache,
65+
tokenizer=tokenizer,
66+
)
67+
68+
69+
generator = create_generator("./Phi-3.1-mini-4k-instruct-exl2-6_5")
70+
71+
gen_settings = ExLlamaV2Sampler.Settings()
72+
gen_settings.logits_processor = JSONLogitsProcessor(
73+
HomeAssistantAction,
74+
patch_exl2_tokenizer_for_outlines(generator.tokenizer)
75+
)
76+
77+
78+
rules = "JSON for an instruction with an entity (light or oven) and action (turn_on, turn_off, set_brightness, set temperature). *Optionally* you may set an execute_at time-of-day if the user specifies, otherwise set to null"
79+
prompts = [
80+
f"<|user|> {rules} Turn the lights lower please!<|end|><|assistant|>",
81+
f"<|user|> {rules} I need the oven set for homemade pizza when I get home from work at 6PM.<|end|><|assistant|>",
82+
f"<|user|> {rules} Oh no the lights are off and I can't find the switch!<|end|><|assistant|>",
83+
]
84+
85+
outputs = generator.generate(
86+
prompt=prompts,
87+
gen_settings=gen_settings,
88+
max_new_tokens=2048,
89+
completion_only=True,
90+
encode_special_tokens=False,
91+
stop_conditions=[generator.tokenizer.eos_token_id],
92+
)
93+
94+
# raw json format
95+
for idx, output in enumerate(outputs):
96+
print(output)
97+
# Output:
98+
# {"instruction": {"entity": "light", "action": "set_brightness", "execute_at": null}}
99+
# {"instruction": {"entity": "oven", "action": "set_temperature", "execute_at": "18:00:00"} }
100+
# {"instruction": {"entity": "light", "action": "turn_on"}}
101+
102+
# pydantic model format
103+
for idx, output in enumerate(outputs):
104+
print(repr(HomeAssistantAction.parse_raw(output)))
105+
# Output:
106+
# HomeAssistantAction(instruction=LightAction(entity='light', action='set_brightness', brightness=None, execute_at=None))
107+
# HomeAssistantAction(instruction=OvenAction(entity='oven', action='set_temperature', temperature=None, execute_at=datetime.time(18, 0)))
108+
# HomeAssistantAction(instruction=LightAction(entity='light', action='turn_on', brightness=None, execute_at=None))

exllamav2/generator/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def generate_simple(
193193
return_offsets = True,
194194
add_bos = add_bos)
195195

196+
pre_ids = torch.empty(*ids.shape[:-1], 0)
197+
196198
if prompts_identical:
197199
position_offsets = None
198200

@@ -268,6 +270,7 @@ def generate_simple(
268270
ExLlamaV2Sampler.sample(
269271
logits,
270272
gen_settings,
273+
pre_ids,
271274
self.sequence_ids,
272275
random.random(),
273276
self.tokenizer,

exllamav2/generator/dynamic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class CachePage:
7474
kv_position: int
7575
kv_position_revert: int
7676
# Specific tokens for which KV is valid assuming prev_hash
77-
sequence: torch.Tensor
77+
sequence: torch.Tensors
7878
can_revert: bool
7979
# Used by defragmenter
8080
new_page_index: int
@@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None):
525525
self.current_loras = loras
526526
else:
527527
self.current_loras = [loras]
528-
528+
529529

530530
def generate(
531531
self,
@@ -1766,6 +1766,7 @@ def receive_logits(
17661766
ExLlamaV2Sampler.sample(
17671767
logits,
17681768
self.gen_settings,
1769+
self.sequences[0].input_ids.torch(),
17691770
self.sequences[0].sequence_ids.torch(),
17701771
self.rng.random(),
17711772
self.generator.tokenizer,

exllamav2/generator/sampler.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from dataclasses import dataclass, field
3+
from typing import Optional, Callable
34
import torch
45
import torch.nn.functional as F
56
from exllamav2 import ExLlamaV2Tokenizer
@@ -71,6 +72,8 @@ class Settings:
7172
typical: float = 0
7273
skew: float = 0
7374

75+
logits_processor: Optional[Callable] = None
76+
7477
temperature_last: bool = False
7578

7679
mirostat: bool = False
@@ -269,6 +272,7 @@ def apply_dry(
269272
def sample(
270273
logits: torch.tensor,
271274
settings: Settings,
275+
input_ids: torch.tensor,
272276
sequence_ids: torch.tensor,
273277
random: float,
274278
tokenizer: ExLlamaV2Tokenizer,
@@ -289,6 +293,9 @@ def sample(
289293
:param settings:
290294
ExLlamaV2Sampler.Settings
291295
296+
:param input_ids:
297+
The prompt portion of sequence_ids, shape (batch_size, seq_len)
298+
292299
:param sequence_ids:
293300
Past token IDs to consider for repetition penalty etc., shape (batch_size, seq_len)
294301
@@ -354,6 +361,12 @@ def sample(
354361
logits = logits.unsqueeze(0)
355362
batch_size = 1
356363

364+
# Apply logits processor
365+
366+
if settings.logits_processor:
367+
generated_ids = sequence_ids[:, input_ids.shape[1]:]
368+
logits = settings.logits_processor(generated_ids, logits)
369+
357370
# Prepare filter
358371

359372
logit_filter = None

exllamav2/generator/streaming.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def begin_stream_ex(
327327
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
328328
if input_ids.shape[0] == 2:
329329
assert gen_settings.cfg_scale is not None, "No CFG scale set"
330+
self.input_ids = input_ids
330331

331332
self.position_offsets = position_offsets
332333
self.input_mask = input_mask
@@ -500,7 +501,7 @@ def stream(self, **kwargs) -> Union[Tuple[str, bool, torch.Tensor],
500501

501502
if self.return_logits:
502503
ret.append(logits)
503-
504+
504505
return tuple(ret)
505506

506507

@@ -819,6 +820,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
819820
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
820821
logits,
821822
gen_settings,
823+
self.input_ids,
822824
self.sequence_ids[:1, :],
823825
random.random(),
824826
self.tokenizer,
@@ -854,12 +856,12 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
854856
for f in self.filters: f.feed(token)
855857

856858
# Accept token
857-
859+
858860
if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
859861
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
860862
else:
861863
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
862-
864+
863865
return token, ptokens, pprobs, prob, eos, logits.flatten(1), dev_logits
864866

865867

@@ -881,7 +883,15 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
881883
self.draft_cache,
882884
input_mask = self.input_mask,
883885
position_offsets = self.position_offsets).float().cpu()
884-
token, _, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
886+
token, _, _, prob, _ = ExLlamaV2Sampler.sample(
887+
logits,
888+
draft_gen_settings,
889+
self.input_ids,
890+
draft_sequence_ids,
891+
random.random(),
892+
self.tokenizer,
893+
prefix_token if k == 0 else None
894+
)
885895

886896
if prob < self.speculative_prob_threshold:
887897
self.draft_cache.current_seq_len -= 1
@@ -918,6 +928,7 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
918928
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
919929
logits,
920930
gen_settings,
931+
self.input_ids,
921932
self.sequence_ids[:1, :], random.random(),
922933
self.tokenizer,
923934
prefix_token,
@@ -980,6 +991,7 @@ def _gen_single_token_ngram(self, gen_settings, prefix_token = None):
980991
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
981992
logits,
982993
gen_settings,
994+
self.input_ids,
983995
self.sequence_ids[:1, :],
984996
random.random(),
985997
self.tokenizer,
@@ -1038,5 +1050,3 @@ def ngram_preload(self,
10381050

10391051
self.ngram_preloaded = NgramCache(self.speculative_ngram_min, self.speculative_ngram_max, None)
10401052
self.ngram_preloaded.update(input_ids)
1041-
1042-

0 commit comments

Comments
 (0)