Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of logit threshold sampler and confidence breaker #657

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Simplified LTS implementation
anchortense committed Oct 23, 2024
commit a6df6635503e726ca9146ccafbdf4c8b94b5cd53
33 changes: 17 additions & 16 deletions exllamav2/exllamav2_ext/cpp/sampling.cpp
Original file line number Diff line number Diff line change
@@ -440,10 +440,10 @@ int sort_descending
}

AVX2_TARGET_OPTIONAL
int logit_threshold_restore
int logit_threshold_temperature
(
float logit_min_threshold,
float logit_temp_threshold,
float logit_high_temp,
const int maxlogit,
const int vocab_size,
const float* logits,
@@ -452,51 +452,52 @@ int logit_threshold_restore
int* temp_indices
)
{
profile_start("logit_threshold_restore");
profile_start("logit_threshold_temperature");

float esum = 0.0f;
float static_pmass = 0.0f;
float itemp = 1.0f / std::max(logit_high_temp, 0.01f);
int n = 0;
float maxl = logits[maxlogit];
float effective_min = std::min(maxl, logit_min_threshold);

for (int i = 0; i < vocab_size; i++)
{
float target_logit = logits[i];
if (target_logit < effective_min) continue;

float l = target_logit - maxl;
if (exponent == 2.0f)
l *= -l;
else if (exponent != 1.0f)
l = -powf(fabs(l), exponent);
float e = expf(l);

float e = expf(l * itemp);
esum += e;
if (target_logit < logit_temp_threshold)

if (target_logit >= logit_temp_threshold)
temp_probs[i] = e;
else static_pmass += temp_probs[i];

n++;
}

float isum = 1.0f / esum;
float diffsum = 0.0f;
float temp_pmass = 0.0f;

for (int i = 0; i < vocab_size; i++)
{
if (logits[i] < effective_min) continue;
if (logits[i] < logit_temp_threshold)
if (logits[i] >= logit_temp_threshold)
{
temp_probs[i] *= isum;
diffsum += temp_probs[i];
n++;
temp_pmass += temp_probs[i];
}
}

float adjfactor = 1.0f - diffsum;
float adjfactor = (1.0f - static_pmass) / temp_pmass;

for (int i = 0; i < vocab_size; i++)
{
if (logits[i] >= logit_temp_threshold)
{
temp_probs[i] *= adjfactor;
n++;
}
}

sort_descending(vocab_size, temp_probs, temp_indices, n);
4 changes: 2 additions & 2 deletions exllamav2/exllamav2_ext/cpp/sampling.h
Original file line number Diff line number Diff line change
@@ -55,10 +55,10 @@ int sort_descending
int max_index
);

int logit_threshold_restore
int logit_threshold_temperature
(
float logit_min_threshold,
float logit_temp_threshold,
float logit_high_temp,
const int maxlogit,
const int vocab_size,
const float* logits,
26 changes: 10 additions & 16 deletions exllamav2/exllamav2_ext/ext_sampling.cpp
Original file line number Diff line number Diff line change
@@ -85,8 +85,6 @@ std::vector<float> sample_basic
(
torch::Tensor logits, // shape [bsz, 1, vocab_size]
float temperature,
float logit_temp_threshold,
float logit_min_threshold,
int top_k,
float top_p,
float top_a,
@@ -111,6 +109,8 @@ std::vector<float> sample_basic
float max_temp = 0.0f,
float temp_exponent = 1.0f,
float smoothing_factor = 0.0f,
float logit_temp_threshold = 0.0f,
float logit_high_temp = 0.0f,
float skew = 0.0f
)
{
@@ -140,8 +140,11 @@ std::vector<float> sample_basic

if (temperature < 0.01)
{
temperature = 1.0f;
top_k = 1;
if (logit_temp_threshold == 0.0f)
{
temperature = 1.0f;
top_k = 1;
}
}

for (int i = 0; i < bsz; i++)
@@ -166,19 +169,10 @@ std::vector<float> sample_basic
for (int j = 0; j < vocab_size; j++) temp_indices[j] = j;
int num_candidates = vocab_size;

if ((logit_temp_threshold > logit_min_threshold) && logit_min_threshold > 0.0f)
if (logit_temp_threshold > 0.0f)
{
num_candidates = logit_threshold_restore
(
logit_min_threshold,
logit_temp_threshold,
maxlogit,
vocab_size,
logits_ptr + i * vocab_size,
exponent,
temp_probs,
temp_indices
);
num_candidates = logit_threshold_temperature(logit_temp_threshold, logit_high_temp, maxlogit, vocab_size, logits_ptr + i * vocab_size, exponent, temp_probs, temp_indices);
normalize_cpu(num_candidates, temp_probs);
}

if (num_candidates > top_k && top_k > 0 && top_k < vocab_size)
4 changes: 2 additions & 2 deletions exllamav2/exllamav2_ext/ext_sampling.h
Original file line number Diff line number Diff line change
@@ -16,8 +16,6 @@ std::vector<float> sample_basic
(
torch::Tensor logits, // shape [bsz, vocab_size]
float temperature,
float logit_temp_threshold,
float logit_min_threshold,
int top_k,
float top_p,
float top_a,
@@ -42,6 +40,8 @@ std::vector<float> sample_basic
float max_temp,
float temp_exponent,
float smoothing_factor,
float logit_temp_threshold,
float logit_high_temp,
float skew
);

47 changes: 3 additions & 44 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
@@ -73,9 +73,8 @@ class Settings:

temperature_last: bool = False

logit_threshold_stats: bool = False
logit_temp_threshold: float = 0.0
logit_min_threshold: float = 0.0
logit_high_temp: float = 0.0

confidence_breaker: int = 0
confidence_breaker_debug: bool = False
@@ -535,18 +534,9 @@ def prep_logit_filter(lf):
output_ktokens = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.long)
output_kprobs = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.float)

if settings.logit_temp_threshold > 0.0 or settings.logit_min_threshold > 0.0:
logit_filter = prep_logit_filter(logit_filter)
effective_filter = max(settings.logit_temp_threshold, settings.logit_min_threshold)
logit_filter[logits.squeeze(1) < effective_filter] = False
if not torch.any(logit_filter):
logit_filter.view(-1)[torch.argmax(logits.squeeze(1))] = True

m = ext_c.sample_basic(
logits,
1.0 if settings.temperature_last else settings.temperature,
settings.logit_temp_threshold,
settings.logit_min_threshold,
settings.top_k,
settings.top_p,
settings.top_a,
@@ -571,6 +561,8 @@ def prep_logit_filter(lf):
settings.max_temp,
settings.temp_exponent,
settings.smoothing_factor,
settings.logit_temp_threshold,
settings.logit_high_temp,
settings.skew
)

@@ -596,39 +588,6 @@ def prep_logit_filter(lf):
else:
confidence_flag = None

if settings.logit_threshold_stats:
selected_token = output_tokens
batch_logits_squeezed = logits[0, 0, :]
token_logit = batch_logits_squeezed[selected_token]
min_logit_threshold = min(torch.max(batch_logits_squeezed).item(),
settings.logit_min_threshold if settings.logit_min_threshold > 0
else settings.logit_temp_threshold)
filtered_indices_mask = batch_logits_squeezed >= min_logit_threshold
filtered_logits = batch_logits_squeezed[filtered_indices_mask]
probs = F.softmax(batch_logits_squeezed, dim=-1)
filtered_probs = probs[filtered_indices_mask]

# Calculate the statistics for filtered_logits
min_filtered = filtered_logits.min().item() if len(filtered_logits) > 0 else float('nan')
mean_filtered = filtered_logits.mean().item() if len(filtered_logits) > 0 else float('nan')
max_filtered = filtered_logits.max().item() if len(filtered_logits) > 0 else float('nan')
std_filtered = filtered_logits.std().item() if len(filtered_logits) > 0 else float('nan')
min_p_equivalent = filtered_probs[filtered_logits.argmin()].item()

debug_string = (
f"total logits: {batch_logits_squeezed.size(0):<7} "
f"filtered to: {filtered_logits.size(0):<4} "
f"min: {min_filtered:>5.2f} "
f"mean: {mean_filtered:>5.2f} "
f"max: {max_filtered:>5.2f} "
f"std: {std_filtered:>5.2f} "
f"selected logit: {token_logit.item():>5.2f} "
f"selected token: {selected_token.item():<7} "
f"min_p: {min_p_equivalent:>6.5f}"
)
print(debug_string, flush=True)


if settings.mirostat: settings.mirostat_mu = m

# Stop condition from filters