Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExLlamaV2Sampler.Settings.logits_processor #634

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add ExLlamaV2Sampler.Settings.logits_processor
lapp0 committed Sep 23, 2024
commit 4aa4ebdda53809a0f159627010164fbb13fb02c6
108 changes: 108 additions & 0 deletions examples/json_schema_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Install Outlines:
# pip install outlines

# Download Model:
# huggingface-cli download bartowski/Phi-3.1-mini-4k-instruct-exl2 --revision 6_5 --local-dir Phi-3.1-mini-4k-instruct-exl2-6_5

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))



from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler

from outlines.processors import JSONLogitsProcessor
from outlines.models.exllamav2 import patch_tokenizer as patch_exl2_tokenizer_for_outlines

from pydantic import BaseModel, Field, RootModel
from typing import Optional, Union, Literal
from datetime import time


################################################
# Create Structured JSON Generator With Outlines
################################################

# Additional Examples: https://outlines-dev.github.io/outlines/cookbook/
# JSON Generation Docs: https://outlines-dev.github.io/outlines/reference/json/
# `outlines.processors` also supports guaranteed regex patterns and lark grammars

# Example: Home Assistant extension for natural language commands -> actions
class LightAction(BaseModel):
entity: Literal["light"] = "light"
action: Literal["turn_on", "turn_off", "set_brightness"]
brightness: Optional[int] = Field(None, ge=0, le=100)
execute_at: Optional[time] = None


class OvenAction(BaseModel):
entity: Literal["oven"] = "oven"
action: Literal["turn_on", "turn_off", "set_temperature"]
temperature: Optional[float] = Field(None, ge=50, le=300)
execute_at: Optional[time] = None


class HomeAssistantAction(BaseModel):
instruction: Union[LightAction, OvenAction]


def create_generator(model_dir="/mnt/str/models/mistral-7b-exl2/4.0bpw"):
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len=32768, lazy=True)
model.load_autosplit(cache, progress=True)

print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
tokenizer.vocabulary = tokenizer.extended_piece_to_id

# Initialize the generator with all default parameters
return ExLlamaV2DynamicGenerator(
model=model,
cache=cache,
tokenizer=tokenizer,
)


generator = create_generator("./Phi-3.1-mini-4k-instruct-exl2-6_5")

gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.logits_processor = JSONLogitsProcessor(
HomeAssistantAction,
patch_exl2_tokenizer_for_outlines(generator.tokenizer)
)


rules = "JSON for an instruction with an entity (light or oven) and action (turn_on, turn_off, set_brightness, set temperature). *Optionally* you may set an execute_at time-of-day if the user specifies, otherwise set to null"
prompts = [
f"<|user|> {rules} Turn the lights lower please!<|end|><|assistant|>",
f"<|user|> {rules} I need the oven set for homemade pizza when I get home from work at 6PM.<|end|><|assistant|>",
f"<|user|> {rules} Oh no the lights are off and I can't find the switch!<|end|><|assistant|>",
]

outputs = generator.generate(
prompt=prompts,
gen_settings=gen_settings,
max_new_tokens=2048,
completion_only=True,
encode_special_tokens=False,
stop_conditions=[generator.tokenizer.eos_token_id],
)

# raw json format
for idx, output in enumerate(outputs):
print(output)
# Output:
# {"instruction": {"entity": "light", "action": "set_brightness", "execute_at": null}}
# {"instruction": {"entity": "oven", "action": "set_temperature", "execute_at": "18:00:00"} }
# {"instruction": {"entity": "light", "action": "turn_on"}}

# pydantic model format
for idx, output in enumerate(outputs):
print(repr(HomeAssistantAction.parse_raw(output)))
# Output:
# HomeAssistantAction(instruction=LightAction(entity='light', action='set_brightness', brightness=None, execute_at=None))
# HomeAssistantAction(instruction=OvenAction(entity='oven', action='set_temperature', temperature=None, execute_at=datetime.time(18, 0)))
# HomeAssistantAction(instruction=LightAction(entity='light', action='turn_on', brightness=None, execute_at=None))
3 changes: 3 additions & 0 deletions exllamav2/generator/base.py
Original file line number Diff line number Diff line change
@@ -193,6 +193,8 @@ def generate_simple(
return_offsets = True,
add_bos = add_bos)

pre_ids = torch.empty(*ids.shape[:-1], 0)

if prompts_identical:
position_offsets = None

@@ -268,6 +270,7 @@ def generate_simple(
ExLlamaV2Sampler.sample(
logits,
gen_settings,
pre_ids,
self.sequence_ids,
random.random(),
self.tokenizer,
5 changes: 3 additions & 2 deletions exllamav2/generator/dynamic.py
Original file line number Diff line number Diff line change
@@ -74,7 +74,7 @@ class CachePage:
kv_position: int
kv_position_revert: int
# Specific tokens for which KV is valid assuming prev_hash
sequence: torch.Tensor
sequence: torch.Tensors
can_revert: bool
# Used by defragmenter
new_page_index: int
@@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None):
self.current_loras = loras
else:
self.current_loras = [loras]


def generate(
self,
@@ -1766,6 +1766,7 @@ def receive_logits(
ExLlamaV2Sampler.sample(
logits,
self.gen_settings,
self.sequences[0].input_ids.torch(),
self.sequences[0].sequence_ids.torch(),
self.rng.random(),
self.generator.tokenizer,
13 changes: 13 additions & 0 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Callable
import torch
import torch.nn.functional as F
from exllamav2 import ExLlamaV2Tokenizer
@@ -71,6 +72,8 @@ class Settings:
typical: float = 0
skew: float = 0

logits_processor: Optional[Callable] = None

temperature_last: bool = False

mirostat: bool = False
@@ -269,6 +272,7 @@ def apply_dry(
def sample(
logits: torch.tensor,
settings: Settings,
input_ids: torch.tensor,
sequence_ids: torch.tensor,
random: float,
tokenizer: ExLlamaV2Tokenizer,
@@ -289,6 +293,9 @@ def sample(
:param settings:
ExLlamaV2Sampler.Settings

:param input_ids:
The prompt portion of sequence_ids, shape (batch_size, seq_len)

:param sequence_ids:
Past token IDs to consider for repetition penalty etc., shape (batch_size, seq_len)

@@ -354,6 +361,12 @@ def sample(
logits = logits.unsqueeze(0)
batch_size = 1

# Apply logits processor

if settings.logits_processor:
generated_ids = sequence_ids[:, input_ids.shape[1]:]
logits = settings.logits_processor(generated_ids, logits)

# Prepare filter

logit_filter = None
22 changes: 16 additions & 6 deletions exllamav2/generator/streaming.py
Original file line number Diff line number Diff line change
@@ -327,6 +327,7 @@ def begin_stream_ex(
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
if input_ids.shape[0] == 2:
assert gen_settings.cfg_scale is not None, "No CFG scale set"
self.input_ids = input_ids

self.position_offsets = position_offsets
self.input_mask = input_mask
@@ -500,7 +501,7 @@ def stream(self, **kwargs) -> Union[Tuple[str, bool, torch.Tensor],

if self.return_logits:
ret.append(logits)

return tuple(ret)


@@ -819,6 +820,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
logits,
gen_settings,
self.input_ids,
self.sequence_ids[:1, :],
random.random(),
self.tokenizer,
@@ -854,12 +856,12 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
for f in self.filters: f.feed(token)

# Accept token

if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
else:
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)

return token, ptokens, pprobs, prob, eos, logits.flatten(1), dev_logits


@@ -881,7 +883,15 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
self.draft_cache,
input_mask = self.input_mask,
position_offsets = self.position_offsets).float().cpu()
token, _, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
token, _, _, prob, _ = ExLlamaV2Sampler.sample(
logits,
draft_gen_settings,
self.input_ids,
draft_sequence_ids,
random.random(),
self.tokenizer,
prefix_token if k == 0 else None
)

if prob < self.speculative_prob_threshold:
self.draft_cache.current_seq_len -= 1
@@ -918,6 +928,7 @@ def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
logits,
gen_settings,
self.input_ids,
self.sequence_ids[:1, :], random.random(),
self.tokenizer,
prefix_token,
@@ -980,6 +991,7 @@ def _gen_single_token_ngram(self, gen_settings, prefix_token = None):
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
logits,
gen_settings,
self.input_ids,
self.sequence_ids[:1, :],
random.random(),
self.tokenizer,
@@ -1038,5 +1050,3 @@ def ngram_preload(self,

self.ngram_preloaded = NgramCache(self.speculative_ngram_min, self.speculative_ngram_max, None)
self.ngram_preloaded.update(input_ids)


326 changes: 326 additions & 0 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@

import sys, os, gc, time, random
import torch

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2CacheBase,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
)

from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)

import time

model: ExLlamaV2
config: ExLlamaV2Config
tokenizer: ExLlamaV2Tokenizer
cache: ExLlamaV2CacheBase


class SamplerLogitsProcessor:
def __init__(
self,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
top_a: float = 0.0,
disallow_tokens: list = [],
):
if temperature <= 0:
raise ValueError("Temperature must be > 0.")
if top_k < 0:
raise ValueError("top_k must be >= 0.")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be between 0 and 1.")
if top_a < 0:
raise ValueError("top_a must be >= 0.")
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.top_a = top_a
self.disallow_tokens = torch.tensor(disallow_tokens)

@torch.no_grad()
def __call__(self, input_ids, logits):
# Apply temperature scaling
if self.temperature != 1.0:
logits /= self.temperature

# Initialize mask
mask = torch.zeros_like(logits, dtype=torch.bool)

if self.top_p > 0.0 or self.top_a > 0.0:
probs = torch.nn.functional.softmax(logits, dim=-1)
sorted_probs, sorted_indices = probs.sort(descending=True, dim=-1)

# Apply top-p filtering
if self.top_p > 0.0:
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_mask = cumulative_probs > self.top_p
sorted_mask = sorted_mask.roll(shifts=1, dims=-1)
sorted_mask[..., 0] = False
mask.scatter_(-1, sorted_indices, sorted_mask)

# Apply top-a filtering
if self.top_a > 0.0:
max_probs = sorted_probs[:, 0].unsqueeze(-1)
mask |= probs < (max_probs / self.top_a)

# top-k: logits > kth largest value's logits
if self.top_k > 0:
threshold = logits.topk(self.top_k, dim=-1, largest=True).values[..., -1, None]
mask |= logits < threshold # Compare logits directly with the threshold

# Filter disallowed tokens
if self.disallow_tokens is not None and self.disallow_tokens.numel() > 0:
self.disallow_tokens = self.disallow_tokens.to(device=logits.device)
mask.index_fill_(-1, self.disallow_tokens, True)

# Apply the mask
logits.masked_fill_(mask, -float("inf"))
return logits



def unload():
global model, config, tokenizer, cache

model.unload()
model = None
config = None
cache = None
tokenizer = None

gc.collect()
torch.cuda.empty_cache()


def load_model(model_dir, split = None, cache_8bit = False):
global model, config, tokenizer, cache

config = ExLlamaV2Config()
config.model_dir = model_dir
config.prepare()

model = ExLlamaV2(config)
print(" -- Loading model: " + model_dir)

model.load(split)

tokenizer = ExLlamaV2Tokenizer(config)

if cache_8bit:
print(" -- Creating 8-bit cache")
cache = ExLlamaV2Cache_8bit(model, batch_size = 4)
else:
print(" -- Creating 16-bit cache")
cache = ExLlamaV2Cache(model, batch_size = 4)


def test_gen_normal(prompt, max_new_tokens):
global model, config, tokenizer, cache

print("--------------------------------")
print("Generating, normal")
print()

generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

settings = ExLlamaV2Sampler.Settings()
settings.logits_processor = SamplerLogitsProcessor(
temperature=0.85,
top_k=50,
top_p=0.8,
top_a=0.0,
disallow_tokens=[tokenizer.eos_token_id],
)

generator.warmup()
time_begin = time.time()

output = generator.generate_simple(prompt, settings, max_new_tokens, seed=1234)

time_end = time.time()
time_total = time_end - time_begin

print(output)
print()
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, {max_new_tokens / time_total:.2f} tokens/second")


def test_gen_streaming(prompt, max_new_tokens):
global model, config, tokenizer, cache

print("--------------------------------")
print("Generating, streaming")
print()

generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)

settings = ExLlamaV2Sampler.Settings()
settings.logits_processor = SamplerLogitsProcessor(
temperature=0.85,
top_k=50,
top_p=0.8,
top_a=0.0,
disallow_tokens=[tokenizer.eos_token_id],
)

input_ids = tokenizer.encode(prompt)
prompt_tokens = input_ids.shape[-1]

print(prompt, end = "")
sys.stdout.flush()

time_begin_prompt = time.time()

generator.set_stop_conditions([])
generator.begin_stream(input_ids, settings)

time_begin_stream = time.time()
generated_tokens = 0

while True:
chunk, eos, _ = generator.stream()
generated_tokens += 1
print(chunk, end = "")
sys.stdout.flush()
if eos or generated_tokens == max_new_tokens: break

time_end = time.time()

time_prompt = time_begin_stream - time_begin_prompt
time_tokens = time_end - time_begin_stream

print()
print()
print(f"Prompt processed in {time_prompt:.2f} seconds, {prompt_tokens} tokens, {prompt_tokens / time_prompt:.2f} tokens/second")
print(f"Response generated in {time_tokens:.2f} seconds, {generated_tokens} tokens, {generated_tokens / time_tokens:.2f} tokens/second")


def test_gen_batch(max_new_tokens):

print("--------------------------------")
print("Generating, batched")
print()

generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

settings = ExLlamaV2Sampler.Settings()
settings.logits_processor = SamplerLogitsProcessor(
temperature=0.85,
top_k=50,
top_p=0.8,
top_a=0.0,
disallow_tokens=[tokenizer.eos_token_id],
)

generator.warmup()
time_begin = time.time()

prompts = ["Here's how to create a powerful love potio",
"For once,",
"The events of the American Civil W",
"A bird in the hand is worth"]

output = generator.generate_simple(prompts, settings, max_new_tokens, seed = 1234, token_healing = True)

time_end = time.time()
time_total = time_end - time_begin

for o in output:
print(o)
print("---")
print()
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, throughput {4 * max_new_tokens / time_total:.2f} tokens/second")


def test_multicache(max_new_tokens):

print("--------------------------------")
print("Generating, batched multi cache")
print()

settings = ExLlamaV2Sampler.Settings()
settings.logits_processor = SamplerLogitsProcessor(
temperature=0.85,
top_k=50,
top_p=0.8,
top_a=0.0,
disallow_tokens=[tokenizer.eos_token_id],
)

prompts = ["Here's how to create a powerful love potion",
"For once,",
"The events of the American Civil War",
"A bird in the hand is worth"]

caches = [ExLlamaV2Cache(model, max_seq_len = 256) for _ in range(len(prompts))]
input_ids = []

for i in range(len(prompts)):

input_ids.append(tokenizer.encode(prompts[i]))
model.forward(input_ids[i][:, :-1], caches[i], input_mask = None, preprocess_only = True)

time_begin = time.time()

for i in range(max_new_tokens):

inputs = torch.cat([x[:, -1:] for x in input_ids], dim = 0)
logits = model.forward(inputs, caches, input_mask = None).float().cpu()

r = random.random()
for j in range(len(input_ids)):
token, _, _ = ExLlamaV2Sampler.sample(logits[j:j + 1, :, :], settings, input_ids[j], r, tokenizer)
input_ids[j] = torch.cat([input_ids[j], token], dim = 1)

output = [tokenizer.decode(ids)[0] for ids in input_ids]

time_end = time.time()
time_total = time_end - time_begin

for o in output:
print(o)
print("---")
print()
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, throughput {4 * max_new_tokens / time_total:.2f} tokens/second")


def tests(model_dir, cache_8bit, use_split):

if use_split: split = [1, 24]
else: split = None
print("--------------------------------")
print(f" -- Split: {split}")
load_model(model_dir, split = split, cache_8bit = cache_8bit)

test_gen_normal("Our story begins in the Scottish town of Auchtermuchty, where once", 150)
test_gen_streaming("Our story begins in the Scottish town of Auchtermuchty, where once", 150)
test_gen_batch(40)
if model.is_quant(): test_multicache(40)

unload()


q_model_directory = "/mnt/str/models/mistral-7b-instruct-exl2/4.0bpw/"
f_model_directory = "/mnt/str/models/tinyllama-1b-ckpt503/"

tests(q_model_directory, False, False)
tests(q_model_directory, False, True)
tests(q_model_directory, True, False)
tests(q_model_directory, True, True)
tests(f_model_directory, False, False)
tests(f_model_directory, False, True)
tests(f_model_directory, True, False)
tests(f_model_directory, True, True)