Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of logit threshold sampler and confidence breaker #657

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,66 @@ int sort_descending
return pre;
}

AVX2_TARGET_OPTIONAL
void logit_threshold_temperature
(
const int num_candidates,
float logit_temp_threshold,
float logit_high_temp,
const int maxlogit,
const float* logits,
const float exponent,
float* temp_probs
)
{
profile_start("logit_threshold_temperature");

float esum = 0.0f;
float static_pmass = 0.0f;
float itemp = 1.0f / std::max(logit_high_temp, 0.01f);
float maxl = logits[maxlogit];
std::vector<int> above_threshold_indices;

for (int i = 0; i < num_candidates; i++)
{
float target_logit = logits[i];

float l = target_logit - maxl;
if (exponent == 2.0f)
l *= -l;
else if (exponent != 1.0f)
l = -powf(fabs(l), exponent);

float e = expf(l * itemp);
esum += e;

if (target_logit >= logit_temp_threshold)
{
temp_probs[i] = e;
above_threshold_indices.push_back(i);
}
else static_pmass += temp_probs[i];
}

float isum = (esum >= 0.0f) ? (1.0f / esum) : 1024.0f;
float temp_pmass = 0.0f;

for (int i : above_threshold_indices)
{
temp_probs[i] *= isum;
temp_pmass += temp_probs[i];
}

float adjfactor = (temp_pmass >= 0.0f) ? ((1.0f - static_pmass) / temp_pmass) : 1024.0f;

for (int i : above_threshold_indices)
{
temp_probs[i] *= adjfactor;
}

profile_stop();
}

AVX2_TARGET_OPTIONAL
int top_k_cpu
(
Expand Down
11 changes: 11 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ int sort_descending
int max_index
);

void logit_threshold_temperature
(
const int num_candidates,
float logit_temp_threshold,
float logit_high_temp,
const int maxlogit,
const float* logits,
const float exponent,
float* temp_probs
);

int top_k_cpu
(
const int num_candidates,
Expand Down
17 changes: 14 additions & 3 deletions exllamav2/exllamav2_ext/ext_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ std::vector<float> sample_basic
float max_temp = 0.0f,
float temp_exponent = 1.0f,
float smoothing_factor = 0.0f,
float logit_temp_threshold = 0.0f,
float logit_high_temp = 0.0f,
float skew = 0.0f
)
{
Expand Down Expand Up @@ -138,8 +140,11 @@ std::vector<float> sample_basic

if (temperature < 0.01)
{
temperature = 1.0f;
top_k = 1;
if (logit_temp_threshold == 0.0f)
{
temperature = 1.0f;
top_k = 1;
}
}

for (int i = 0; i < bsz; i++)
Expand All @@ -164,7 +169,13 @@ std::vector<float> sample_basic
for (int j = 0; j < vocab_size; j++) temp_indices[j] = j;
int num_candidates = vocab_size;

if (top_k > 0 && top_k < vocab_size)
if (logit_temp_threshold > 0.0f)
{
logit_threshold_temperature(num_candidates, logit_temp_threshold, logit_high_temp, maxlogit, logits_ptr + i * vocab_size, exponent, temp_probs);
normalize_cpu(num_candidates, temp_probs);
}

if (num_candidates > top_k && top_k > 0 && top_k < vocab_size)
{
num_candidates = top_k_cpu(num_candidates, temp_probs, temp_indices, top_k, maxlogit);
normalize_cpu(num_candidates, temp_probs);
Expand Down
2 changes: 2 additions & 0 deletions exllamav2/exllamav2_ext/ext_sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ std::vector<float> sample_basic
float max_temp,
float temp_exponent,
float smoothing_factor,
float logit_temp_threshold,
float logit_high_temp,
float skew
);

Expand Down
85 changes: 67 additions & 18 deletions exllamav2/generator/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None):
self.current_loras = loras
else:
self.current_loras = [loras]


def generate(
self,
Expand Down Expand Up @@ -1170,10 +1170,10 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
for i in range(batch_logits.shape[1]):
job_logits = batch_logits[a:b, i:i+1, :]
if i == 0 and mt_sample:
next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \
futures.popleft().result()
else:
next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \
job.receive_logits(job_logits)

eos, sampled_token = job.receive_sample(
Expand All @@ -1183,6 +1183,7 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
next_k_probs,
next_prob,
filter_eos,
confidence_flag,
results
)

Expand Down Expand Up @@ -1661,6 +1662,12 @@ def __init__(

self.checkpoint = None

# Confidence breaker

self.confidence_breaker = gen_settings.confidence_breaker
self.confidence_breaker_debug = gen_settings.confidence_breaker_debug
self.confidence_flag_sequence = [False] * self.confidence_breaker

# Measurement

self.time_enqueue = None
Expand Down Expand Up @@ -1762,7 +1769,7 @@ def receive_logits(
else:
blocked_tokens = self.stop_tokens_list

next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \
ExLlamaV2Sampler.sample(
logits,
self.gen_settings,
Expand All @@ -1777,7 +1784,7 @@ def receive_logits(
# sync = True
)

return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos
return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag


def receive_sample(
Expand All @@ -1788,6 +1795,7 @@ def receive_sample(
next_k_probs: torch.Tensor | None,
next_prob: torch.Tensor | None,
filter_eos: bool | None,
confidence_flag: bool | None,
results: list
):
page_size = self.generator.page_size
Expand All @@ -1801,6 +1809,14 @@ def receive_sample(
if f.use_background_worker():
self.generator.filter_queue.append(f)

# Update confidence_flag_sequence

if confidence_flag is not None:
self.confidence_flag_sequence.append(confidence_flag)
# Limit the size of the sequence to prevent it from growing indefinitely
if len(self.confidence_flag_sequence) > self.confidence_breaker + 1:
self.confidence_flag_sequence.pop(0)

# Accept token

self.new_tokens += 1
Expand Down Expand Up @@ -1986,7 +2002,7 @@ def emit(
# End on stop tokens

if next_token.item() in self.stop_tokens:
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_token", stop_token = next_token.item())

# Stop if we reach max_new_tokens
# TODO: Auto-extend option
Expand All @@ -2011,7 +2027,8 @@ def emit(
else:
return emit(results)

# Hold text as long as it contains part of a banned string
# Hold text as long as it contains part of a banned string,
# or until we know a confidence breaker will not be triggered

def unset_checkpoint():
self.checkpoint = None
Expand All @@ -2026,6 +2043,7 @@ def set_checkpoint():
"held_k_tokens": self.held_k_tokens.clone(1),
"held_k_probs": self.held_k_probs.clone(1),
"held_logits": self.held_logits.clone(1),
"flag_sequence": self.confidence_flag_sequence[:-1].copy(),
"explored_tokens": [next_token.item()],
}
else:
Expand Down Expand Up @@ -2054,28 +2072,59 @@ def rewind_checkpoint():
off_tokens = self.held_tokens.slice(len(self.checkpoint["held_tokens"]), None)
off_text = self.held_text[len(self.checkpoint["held_text"]):]
self.held_text = self.checkpoint["held_text"]
self.held_token = self.checkpoint["held_tokens"]
self.held_tokens = self.checkpoint["held_tokens"]
self.held_probs = self.checkpoint["held_probs"]
self.held_k_tokens = self.checkpoint["held_k_tokens"]
self.held_k_probs = self.checkpoint["held_k_probs"]
self.held_logits = self.checkpoint["held_logits"]
self.confidence_flag_sequence = self.checkpoint["flag_sequence"]
self.checkpoint["offset"] = 0
return off_tokens, off_text

if self.banned_strings_utf32_offsets is not None and self.new_tokens > 0:
match = ext_c.partial_strings_match(
np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8),
self.banned_strings_utf32_offsets,
self.banned_strings_utf32_buffer
)
if match >= 0:
# Handle banned strings and confidence flags using checkpointing

if self.new_tokens > 0:
# Check for banned strings
banned_string_match = -1
if self.banned_strings_utf32_offsets is not None:
banned_string_match = ext_c.partial_strings_match(
np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8),
self.banned_strings_utf32_offsets,
self.banned_strings_utf32_buffer
)

confidence_breaker_match = -1
if self.confidence_breaker > 0:
# Check for confidence_flag sequence
if confidence_flag is not None:
last_n_flags = self.confidence_flag_sequence[-self.confidence_breaker:]
if not confidence_flag:
confidence_breaker_match = -1 # False flag
elif all(last_n_flags):
confidence_breaker_match = 1 # Match
else:
confidence_breaker_match = -2 # Partial match, wait and see
elif self.confidence_flag_sequence[-1]:
confidence_breaker_match = -2 # Pause current sequence without resetting partial match
else:
confidence_breaker_match = -1 # Treat None as False flag, following previous False flag

if confidence_breaker_match >= 0: # Match confidence breaker
set_checkpoint()
if self.confidence_breaker_debug:
print(f'[Confidence breaker activated on text: "{self.held_text}"]', flush=True)
offending_tokens, offending_text = rewind_checkpoint()
return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens)
elif banned_string_match >= 0:
set_checkpoint()
offending_tokens, offending_text = rewind_checkpoint()
return emit(results, emit_held = True, suppressed_text = offending_text, suppressed_tokens = offending_tokens)
elif match == -2:
return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens)
elif banned_string_match == -2 or confidence_breaker_match == -2: # Partial match
set_checkpoint()
return emit(results)
else:
else: # Reset and permit text passthrough
if len(self.full_completion) > 0:
set_checkpoint()
unset_checkpoint()

# End on stop strings
Expand Down
35 changes: 34 additions & 1 deletion exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class Settings:

temperature_last: bool = False

logit_temp_threshold: float = 0.0
logit_high_temp: float = 0.0

confidence_breaker: int = 0
confidence_breaker_debug: bool = False
cb_mid_threshold: float = 0.0
cb_high_threshold: float = 0.0

mirostat: bool = False
mirostat_tau: float = 1.5
mirostat_eta: float = 0.1
Expand Down Expand Up @@ -420,6 +428,7 @@ def prep_logit_filter(lf):
# Temporarily ban individual tokens

if blocked_tokens:
saved_logits = logits[:, :, blocked_tokens].clone()
logits[:, :, blocked_tokens] = -1e30

# Token bias
Expand Down Expand Up @@ -552,9 +561,33 @@ def prep_logit_filter(lf):
settings.max_temp,
settings.temp_exponent,
settings.smoothing_factor,
settings.logit_temp_threshold,
settings.logit_high_temp,
settings.skew
)

if settings.confidence_breaker > 0:
if blocked_tokens and 'saved_logits' in locals():
# Restore the saved logits values for the blocked tokens
logits[:, :, blocked_tokens] = saved_logits

squeezed_logits = logits.squeeze(0).squeeze(0)
probs = F.softmax(squeezed_logits, dim=-1)
token_prob = probs[output_tokens]
token_logit = squeezed_logits[output_tokens]
if settings.cb_mid_threshold <= 1.0:
confidence_flag = (token_prob >= settings.cb_mid_threshold).item()
else:
confidence_flag = (token_logit >= settings.cb_mid_threshold).item()
if settings.cb_high_threshold <= 1.0:
if (token_prob > settings.cb_high_threshold).item():
confidence_flag = None
else:
if (token_logit > settings.cb_high_threshold).item():
confidence_flag = None
else:
confidence_flag = None

if settings.mirostat: settings.mirostat_mu = m

# Stop condition from filters
Expand All @@ -563,4 +596,4 @@ def prep_logit_filter(lf):
if len(filters) > 0 and end_tokens is not None and output_tokens[0].item() in end_tokens:
end_filter = True

return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter
return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter, confidence_flag