Speculative decoding potential for running big LLMs on consumer grade GPUs efficiently #10466
Replies: 21 comments 48 replies
-
| Would there be any benefit in pruning down a 0.5B model to be even smaller? From your examples above it looks like the speculative models' size reduction has the biggest effect? You could prune the later layers like this: https://arxiv.org/abs/2403.17887 but with a calibration dataset you could probably prune down the width of the MLP hidden state quite significantly too... The  I think you could even apply L1-regularisation during fine-tuning to spasify the weights and then remove all those close to zero, but the effectiveness of this would depend on whether the induced sparseness was evenly distributed for the corresponding tensors in each layer (which from the paper above; I doubt is the case). It would be interesting to see where the balance point is between "tiny and fast/dumb" vs "small but slower/less-dumb" actually is. If using greedy speculation then it won't make any difference, but if you have to actually apply the softmax (instead of just finding the maximum logit), then for stuff like coding using only English; it would be perfectly valid to remove a lot (most) of the tokens and prune down the  | 
Beta Was this translation helpful? Give feedback.
-
| Just thinking about this some more and wondered how feasible it would be: 
 I'm thinking along the lines of using the draft model to create a tree (with probabilies on the edges and tokens in the nodes), and then use it to decide on a set of batches for the larger model to generate in parallel. If we constrain the branching factor to a fixed k, then we can again use Hinge Loss to try to pick the top-k using k-vs-all. I don't have a good idea of how the cost of batch processing grows though and it all depends on this. | 
Beta Was this translation helpful? Give feedback.
-
| Testing my server rebase for regressions after all the recent changes along with a few new "LRM" (Marco-o1 and QwQ) models and RPC mode also. The spec algo I implemented is greedy match with fixed size draft block and no probs computes. Hardware: RTX4070 GOLDCOIN: 
 HUMANEVAL 1ST PROBLEM 
 | 
Beta Was this translation helpful? Give feedback.
-
| @steampunque any update on this? | 
Beta Was this translation helpful? Give feedback.
-
| What's the common wisdom on quantising the speculative model? I can see one argument for not quantising it as the errors will accumulate over the sequence, but there is also the argument that a quantised model will generate tokens faster (due to being memory bound) and add less latency? | 
Beta Was this translation helpful? Give feedback.
-
| This problem seems to be a perfect target for Bayesian Filtering. I also wonder if we should have two probability thresholds: 
 At least for  Finally, I think the estimation errors and batch costs may not be static throughout the generation: 
 Again, filtering could use some second order terms to predict something akin to acceleration here. This seems a useful collection: | 
Beta Was this translation helpful? Give feedback.
-
| Just remembered this post: and it uses a quant of the same model for speculative decoding: and interestingly, this shows little difference between the levels of quants (ignoring the problems with Q3 he's trying to highlight). | 
Beta Was this translation helpful? Give feedback.
-
| Do you think it is possible to compute the 'KL Divergence'' of DRAFT model again TARGET one? | 
Beta Was this translation helpful? Give feedback.
-
| I've been looking into this the last couple of days and have identified  
 So as a proof of concept to see if we can improve on this, I've written this hacky code to test the potential for improvements: First you have to run this script to calculate the sequence probability thresholds where a draft is +EV: #!/bin/bash
max_pp=64
num_repeats=10
# Generate comma-separated PP and NPL lists
pp_list=$(seq -s ',' 1 $max_pp)
npl_list=$(printf '1%.0s,' $(seq 1 $num_repeats) | sed 's/,$//')
# Turn off NUMA balancing
echo 0 | sudo tee /proc/sys/kernel/numa_balancing > /dev/null
# Ask for permission to drop caches
read -p "Do you want to drop caches? (y/n) " -n 1 -r
echo    # Move to a new line
if [[ $REPLY =~ ^[Yy]$ ]]
then
    echo "Dropping caches..."
    echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null
fi
# Temporary file for JSONL output
temp_file=$(mktemp)
jsonl_file=$(mktemp)
# Run the benchmark and save full output
echo "Running benchmark..."
CUDA_VISIBLE_DEVICES=1 ~/llama.cpp/build/bin/llama-batched-bench \
    --model ~/models/gguf/deepseek-v3-0324-Q4_K_XL.gguf \
    --n-gpu-layers 99 \
    --numa distribute \
    --threads 80 \
    --override-tensor exps=CPU \
    --flash-attn \
    -c 2048 -b 2048 -ub 512 \
    -npp "$pp_list" \
    -ntg 0 \
    -npl "$npl_list" \
    --output-format jsonl | tee "$temp_file"
# NOTE: The first result always seems to be bogus, so skip over it.
echo -n "Extracting results..."
count=$(grep '^{' "$temp_file" | tail -n +2 | tee "$jsonl_file" | wc -l)
echo " Done ($count results extracted)"
# Process the extracted JSONL
jq -s --raw-output '
    # Calculate max_pp from actual data
    (map(.pp) | max) as $max_pp |
    # Create dictionary with {sum, count} for each PP
    reduce .[] as $item (
        {};
        ($item.pp | tostring) as $pp |
        .[$pp].sum = (.[$pp].sum + $item.speed) |
        .[$pp].count = (.[$pp].count + 1)
    ) |
    # Calculate averages in natural PP order
    [range(1; $max_pp + 1) as $pp |
        (.[($pp|tostring)]).sum / .[($pp|tostring)].count
    ] as $averages |
    # Normalize relative to PP=1
    $averages[0] as $base |
    [1] + [$averages[1:][] | $base / . ] |
    # Format with 3 decimal places
    map(. * 1000 | round | . / 1000) |
    "const std::vector<double> p_mins = { " + join(", ") + " };"
' "$jsonl_file"
# Clean up
rm "$temp_file" "$jsonl_file"(obviously you will need to change the parameters you expect to use your own model with...) which will output something that looks like this: const std::vector<double> p_mins = { 1, 1.554, 1.046, 0.837, 0.702, 0.609, 0.548, 0.502, 0.471, 0.444, 0.413, 0.393, 0.378, 0.365, 0.352, 0.34, 0.333, 0.325, 0.316, 0.309, 0.303, 0.298, 0.291, 0.285, 0.283, 0.279, 0.274, 0.27, 0.269, 0.265, 0.262, 0.26, 0.258, 0.255, 0.252, 0.25, 0.245, 0.243, 0.241, 0.24, 0.238, 0.237, 0.235, 0.234, 0.232, 0.231, 0.231, 0.23, 0.229, 0.228, 0.227, 0.226, 0.225, 0.225, 0.224, 0.223, 0.222, 0.221, 0.221, 0.22, 0.219, 0.219, 0.219, 0.218 };
 So after the code above has been run, you need to replace the code in  llama.cpp/common/speculative.cpp Line 242 in f470bc3     // ??? CAN THIS EVER BE ANYTHING BUT 0 HERE ???
    printf("%d ", (int) result.size());
    // calculated empirically using llama-batch-bench
    const std::vector<double> p_mins = { 1, 1.554, 1.046, 0.837, 0.702, 0.609, 0.548, 0.502, 0.471, 0.444, 0.413, 0.393, 0.378, 0.365, 0.352, 0.34, 0.333, 0.325, 0.316, 0.309, 0.303, 0.298, 0.291, 0.285, 0.283, 0.279, 0.274, 0.27, 0.269, 0.265, 0.262, 0.26, 0.258, 0.255, 0.252, 0.25, 0.245, 0.243, 0.241, 0.24, 0.238, 0.237, 0.235, 0.234, 0.232, 0.231, 0.231, 0.23, 0.229, 0.228, 0.227, 0.226, 0.225, 0.225, 0.224, 0.223, 0.222, 0.221, 0.221, 0.22, 0.219, 0.219, 0.219, 0.218 };
    // used to re-calibrate the probabilities if needed (ie: 1: none, <1: sharpen, >1: flatten)
    // note: also acts as a minimum edge threshold, so likely wants to be slightly >1 even for a well-calibrated draft model...
    const float recalibration_power = 1.02f;
    // this allows the heuristic to break earlier than testing all against p_mins.back()
    // note: assumes that strings of close to p=1.0 tokens occur rarely... can be optimised by looking for the largest gap printout with max_lookahead=MAX_INT
    const int max_lookahead = 5;
    int best_draft_size = 0;
    float sequence_p = 1.0;
    for (int i = 0; i < params.n_draft; ++i) {
        common_batch_clear(batch);
        common_sampler_sample(smpl, ctx, 0, true);
        const auto * cur_p = common_sampler_get_candidates(smpl);
        for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
            LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
        }
        // add drafted token for each sequence
        const llama_token id = cur_p->data[0].id;
        common_sampler_accept(smpl, id, true);
        result.push_back(id);
        if (params.n_draft <= (int) result.size()) {
            best_draft_size = result.size();
            break;
        }
        // re-calibrate if necessary
        sequence_p *= pow(cur_p->data[0].p, recalibration_power);
        // only collect draft tokens with positive expected values
        if (sequence_p >= p_mins[(int) result.size()]) {
            best_draft_size = result.size();
        }
        // break as soon as we are fairly confident we can't improve on the best found so far
        if (sequence_p < p_mins[std::min(best_draft_size + max_lookahead, (int) p_mins.size() - 1)]) {
            break;
        }
        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
        // evaluate the drafted tokens on the draft model
        llama_decode(ctx, batch);
        prompt.push_back(id);
    }
    printf("%d %d [%d] %.3f\n", (int) result.size(), best_draft_size, (best_draft_size > 0 ? (int) result.size() - best_draft_size : 0), sequence_p);
    // note: this truncates to the token *after* the last that was seen to be +EV!
    result.resize(best_draft_size);
    return result;Then test on different extremes of prompts: 
 and so on... It seems the above method now works on all these different extremes, and so long as the draft model is cheap to run (eg:  
 So the question now is: 
 If we don't want to calculate on the fly, then perhaps we could make it so that  Again, sorry the code is such a hacky mess, but I just wanted to see if there was any potential in this method before proceeding... It does appear to be quite a significant improvement and a lot less complex than the old PR that was removed... I am still not sure how the "reuse" code works above or if it can ever get to the  @ggerganov @steampunque Is this worth trying to tidy up and make a proper PR out of? Does it universally improve on the existing algorithm for other models and models run on other back-ends like the Mac? | 
Beta Was this translation helpful? Give feedback.
-
| It's no better or worse for  
 but interestingly it gets the same performance (when drafted by  For the  but nothing will fit the  It's certainly interesting and I think probably worth looking into more. | 
Beta Was this translation helpful? Give feedback.
-
| Here is the version that uses the rational approximation for V3/R1 and can (in theory) use any value of      // ??? CAN THIS EVER BE ANYTHING BUT 0 HERE ???
    printf("%d ", (int) result.size());
    static constexpr auto rationalFit = [](int x, double a = 2.6288, double b = 3.996, double c = 0.1761) {
        return (x < 3) ? 1.0 : (a / (static_cast<double>(x - 3) + b) + c);
    };
        
    // used to re-calibrate the probabilities if needed (ie: 1: none, <1: sharpen, >1: flatten)
    // note: also acts as a minimum edge threshold, so likely wants to be slightly >1 even for a well-calibrated draft model...
    const float recalibration_power = 1.02f;
            
    // this allows the heuristic to break earlier than testing all against p_mins.back()
    // note: assumes that strings of close to p=1.0 tokens occur rarely... can be optimised by looking for the largest gap printout with max_lookahead=MAX_INT
    const int max_lookahead = 5;
        
    int best_draft_size = 0;
        
    float sequence_p = 1.0;
        
    for (int i = 0; i < params.n_draft; ++i) {
        common_batch_clear(batch);
        
        common_sampler_sample(smpl, ctx, 0, true);
        
        const auto * cur_p = common_sampler_get_candidates(smpl);
        for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
            LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
        }
        // add drafted token for each sequence
        const llama_token id = cur_p->data[0].id;
        common_sampler_accept(smpl, id, true);
        result.push_back(id);
        if (params.n_draft <= (int) result.size()) {
            best_draft_size = result.size();
            break;
        }
        // re-calibrate if necessary
        sequence_p *= pow(cur_p->data[0].p, recalibration_power);
        // only collect draft tokens with positive expected values
        if (sequence_p >= rationalFit((int) result.size())) {
            best_draft_size = result.size();
        }
        // break as soon as we are fairly confident we can't improve on the best found so far
        if (sequence_p < rationalFit(best_draft_size + max_lookahead)) {
            break;
        }
        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
        // evaluate the drafted tokens on the draft model
        llama_decode(ctx, batch);
        prompt.push_back(id);and some hacky python code to try to fit the 3 different classes of approximations: import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
# Your data points
y_data = np.array([0.837, 0.702, 0.609, 0.548, 0.502, 0.471, 0.444,
                   0.413, 0.393, 0.378, 0.365, 0.352, 0.34, 0.333, 0.325, 0.316, 0.309,
                   0.303, 0.298, 0.291, 0.285, 0.283, 0.279, 0.274, 0.27, 0.269, 0.265,
                   0.262, 0.26, 0.258, 0.255, 0.252, 0.25, 0.245, 0.243, 0.241, 0.24,
                   0.238, 0.237, 0.235, 0.234, 0.232, 0.231, 0.231, 0.23, 0.229, 0.228,
                   0.227, 0.226, 0.225, 0.225, 0.224, 0.223, 0.222, 0.221, 0.221, 0.22,
                   0.219, 0.219, 0.219, 0.218])
x_data = np.arange(len(y_data))
# Define some candidate functions
def exp_decay(x, a, b, c):
    return a * np.exp(-b * x) + c
def power_decay(x, a, b, c):
    return a * (x + 1)**(-b) + c
def rational(x, a, b, c):
    return a / (x + b) + c
# Fit each function
try:
    popt_exp, _ = curve_fit(exp_decay, x_data, y_data, p0=[0.5, 0.1, 0.2])
    popt_power, _ = curve_fit(power_decay, x_data, y_data, p0=[0.5, 0.5, 0.2])
    popt_rat, _ = curve_fit(rational, x_data, y_data, p0=[0.5, 1, 0.2])
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.scatter(x_data, y_data, label='Data')
    plt.plot(x_data, exp_decay(x_data, *popt_exp), label=f'Exponential: {popt_exp.round(4)}')
    plt.plot(x_data, power_decay(x_data, *popt_power), label=f'Power: {popt_power.round(4)}')
    plt.plot(x_data, rational(x_data, *popt_rat), label=f'Rational: {popt_rat.round(4)}')
    plt.legend()
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.title('Function Fitting Comparison')
    plt.show()
    
    # Calculate and print RMSE for each fit
    def rmse(y_true, y_pred):
        return np.sqrt(np.mean((y_true - y_pred)**2))
    
    print("RMSE for exponential fit:", rmse(y_data, exp_decay(x_data, *popt_exp)))
    print("RMSE for power fit:", rmse(y_data, power_decay(x_data, *popt_power)))
    print("RMSE for rational fit:", rmse(y_data, rational(x_data, *popt_rat)))
    
except Exception as e:
    print("Error during fitting:", e)(you will possibly need to fudge the initial values that are >1 to get a good fit like I did here...) | 
Beta Was this translation helpful? Give feedback.
-
| After sleeping on this, then I think it's basically gonna be far too much hassle to try to implement the ideas from these tests, but the key point they shows is that the marginal cost of adding 1 more token to a batch should somehow be taken into account. | 
Beta Was this translation helpful? Give feedback.
-
| 
 
 | 
Beta Was this translation helpful? Give feedback.
-
| I'm getting some really good results now: 
 To use this, first you have to run this script to generate the timing data: #!/bin/bash
# Environment variables
export CUDA_VISIBLE_DEVICES=0
# Configuration variables
BENCH_EXE="/home/juk/llama.cpp/build/bin/llama-batched-bench"
MODEL_PATH="/home/juk/models/gguf/qwen-2.5-coder-Q6_K.gguf"
#MODEL_PATH="/Users/juk/models/gguf/qwen-2.5-coder-Q6_K.gguf"
#MODEL_PATH="/home/juk/models/gguf/draft_models/Qwen2.5-Coder-DRAFT-0.6B-Q4_0.gguf"
#MODEL_PATH="/Users/juk/models/gguf/draft_models/Qwen2.5-Coder-DRAFT-0.6B-Q4_0.gguf"
# Benchmark parameters
PROMPT_SIZE=1024
MAX_BATCH_SIZE=32
NUM_SAMPLES=5
# Model-specific parameters
MODEL_PARAMS="--n-gpu-layers 99 \
              --flash-attn"
# Generate comma-separated PROMPT_SIZE and BATCH_SIZE lists (NOTE: Process 2x before 1x to help with warmup)
PROMPT_SIZE_LIST="$((PROMPT_SIZE * 2)),${PROMPT_SIZE}"
BATCH_SIZE_LIST=$(printf "%s," $(for i in $(seq 1 $NUM_SAMPLES); do seq 1 $MAX_BATCH_SIZE; done) | sed 's/,$//')
# Output files
LOG_FILE="benchmark_results.log"
JSONL_PP1X="results_pp1x.jsonl"
JSONL_PP2X="results_pp2x.jsonl"
# Clean previous files
rm -f "$LOG_FILE" "$JSONL_PP1X" "$JSONL_PP2X"
# Run the benchmark
echo "- Running benchmark..."
$BENCH_EXE \
    --model "$MODEL_PATH" \
    $MODEL_PARAMS \
    --ctx_size "$((PROMPT_SIZE * 2 + MAX_BATCH_SIZE))" \
    -pps \
    -npp "$PROMPT_SIZE_LIST" \
    -npl "$BATCH_SIZE_LIST" \
    -ntg 1 \
    --output-format jsonl | tee "$LOG_FILE"
# Extract JSONL lines from the log (NOTE: Skip first set of samples as seems to need a warmup to get accurate stats)
echo -n "- Extracting results..."
grep '^{' "$LOG_FILE" | tail -n "+$((NUM_SAMPLES + 1))" | grep "\"pp\": ${PROMPT_SIZE}" > "$JSONL_PP1X"
grep '^{' "$LOG_FILE" | tail -n "+$((NUM_SAMPLES + 1))"| grep "\"pp\": $((PROMPT_SIZE * 2))" > "$JSONL_PP2X"
COUNT1=$(wc -l < "$JSONL_PP1X")
COUNT2=$(wc -l < "$JSONL_PP2X")
echo " Done ($COUNT1 1xPP results, $COUNT2 2xPP results)"
# Function to extract values as bash array
extract_values() {
    local jsonl_file=$1
    
    jq -s --raw-output '
        (map(.pl) | max) as $max_pl |
        reduce .[] as $item (
            {};
            ($item.pl | tostring) as $pl |
            .[$pl].sum = (.[$pl].sum + $item.speed_tg) |
            .[$pl].count = (.[$pl].count + 1)
        ) |
        [range(1; $max_pl + 1) as $pl |
            (.[($pl|tostring)]).sum / .[($pl|tostring)].count
        ] |
        map(. * 1000 | round | . / 1000) |
        join(" ")
    ' "$jsonl_file"
}
# Function to process JSONL and output python/C++ vectors
process_results() {
    local jsonl_file_1x=$1
    local jsonl_file_2x=$2
    
    # Extract values as arrays
    local values_1x=($(extract_values "$jsonl_file_1x"))
    local values_2x=($(extract_values "$jsonl_file_2x"))
       
    # Solve equations: speed_tg = 2*v1x - v2x, speed_pp = v2x - v1x
    local speed_tg=()
    local speed_pp=()
    
    for i in "${!values_1x[@]}"; do
        if [[ -n "${values_2x[i]}" ]]; then
            local v1x="${values_1x[i]}"
            local v2x="${values_2x[i]}"
            
            # Calculate using awk for floating point arithmetic
            local tg=$(awk "BEGIN {printf \"%.3f\", 2 * $v1x - $v2x}")
            local pp=$(awk "BEGIN {printf \"%.3f\", $v1x - $v2x}")
            
            speed_tg+=("$tg")
            speed_pp+=("$pp")
        fi
    done
    # Output raw vectors
    local raw_1x_str=$(IFS=', '; echo "${values_1x[*]}")
    local raw_2x_str=$(IFS=', '; echo "${values_2x[*]}")
    echo "----------------------------------------"
    echo "tg_at_pp_1x = np.array([$raw_1x_str])"
    echo "tg_at_pp_2x = np.array([$raw_2x_str])"
    # Output solution vectors
    local tg_str=$(IFS=', '; echo "${speed_tg[*]}")
    local pp_str=$(IFS=', '; echo "${speed_pp[*]}")
    echo "pp_overhead = np.array([$pp_str])"
    echo "tg_at_pp_0 = np.array([$tg_str])"
    echo "----------------------------------------"
    echo "const std::vector<float> model_batch_speeds = { $tg_str };"
    echo "----------------------------------------"
}
# Process the extracted JSONL
echo "- Generating data vectors:"
process_results "$JSONL_PP1X" "$JSONL_PP2X"
# Clean up log, but leave JSONL files
rm "$LOG_FILE"Then paste the generated vector of      // *****************************************************
    // *** The main model's tokens/s for each batch size ***
    // *****************************************************
    // - RTX 5000 Ada
    /*
    const std::vector<float> model_batch_speeds = {
            17.956,36.097,53.640,69.224,78.671,81.642,81.894,84.374,
            134.441,148.602,162.406,176.664,190.502,204.359,218.614,232.686,
            238.893,251.720,265.238,278.409,292.450,305.247,318.849,331.799,
            347.519,361.104,373.258,386.453,399.033,411.275,422.836,438.136
    };
    */
    // - M1 Ultra 64GB
    const std::vector<float> model_batch_speeds = {
            14.178,15.084,15.667,26.851,28.725,29.529,28.106,31.060,
            23.468,26.027,28.514,31.074,33.542,36.004,38.530,41.095,
            43.549,45.945,48.379,52.055,54.575,57.107,59.433,61.969,
            64.455,67.023,69.414,71.868,74.377,76.834,79.221,81.995
    };
    // ******************************************************
    // *** The draft model's tokens/s for batch size of 1 ***
    // ******************************************************
    // - RTX 5000 Ada
    /*
    const float draft_tg_speed = 353.540;
    */
    
    // - M1 Ultra 64GB
    const float draft_tg_speed = 231.948;
    
    // ==========================================================================================================
        
    // The estimated lookahead cost per token, in terms of the main model's token generation speed
    const float lookahead_cost_estimate = model_batch_speeds[0] / draft_tg_speed;
    
    // The maximum lookahead relative to the best we have seen so far
    const int max_lookahead = 5;
    
    // ==========================================================================================================
    
    // The best draft size and its associated expected value so far (ie: init to the the main model's TG speed)
    int best_draft_size = 0;
    float best_draft_ev = model_batch_speeds[0];
    
    // The current sequence probability, as predicted by the draft model
    float current_sequence_p = 1.0;
   
    GGML_ASSERT((int) model_batch_speeds.size() == params.n_draft);
    for (int i = 0; i < params.n_draft; ++i) {
        
        // Sample a draft token
        common_batch_clear(batch);
        common_sampler_sample(smpl, ctx, 0, true);
        const auto * cur_p = common_sampler_get_candidates(smpl);
        const llama_token id = cur_p->data[0].id;
        common_sampler_accept(smpl, id, true);
        // Save the sampled token id
        result.push_back(id);
        
        // Get the current draft size we are looking at
        const int current_draft_size = result.size();
        // If we have enough tokens already, then stop
        if (current_draft_size >= params.n_draft) {
            best_draft_size = result.size();
            break;
        }
        // Update the sequence probability using the sampled token's predicted probability
        current_sequence_p *= cur_p->data[0].p;
        // Calculate the expected value (in terms of the main model's tokens/s) for this sequence
        const float current_sequence_ev = current_sequence_p * model_batch_speeds[current_draft_size];
        
        // Is this token clearly +EV compared to what we have so far?
        if (current_sequence_ev > best_draft_ev) {
            best_draft_size = current_draft_size;
            best_draft_ev = current_sequence_ev;
        }
        // Otherwise we have to decide if we might see a +EV draft in the future, or stop now
        else {
            
            bool stop_now = true;
            
            for (int j = current_draft_size + 1; j < (int) model_batch_speeds.size(); j++) {
                // Don't bother looking too far relative to the best we have seen so far
                if (j - best_draft_size > max_lookahead) {
                    break;
                }
                
                // This approximates the cost of the lookahead in terms of the model's tokens/s.
                const float lookahead_cost = lookahead_cost_estimate * (float) (j - current_draft_size);
                // Calculate the discounted potential EV of this lookahead depth
                // NOTE: This assumes the worst case of the draft predicting p=1.0 for all future tokens...
                const float potential_ev = current_sequence_p * (model_batch_speeds[j] - lookahead_cost);
                // Is this lookahead depth potentially +EV compared to the best we have so far?
                if (potential_ev > best_draft_ev) {
                    stop_now = false;
                    break;
                }
            }
            
            // If no chance to improve the EV, then stop now
            if (stop_now) {
                break;
            }
        }
        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
        // evaluate the drafted tokens on the draft model
        llama_decode(ctx, batch);
        prompt.push_back(id);
    }   
    
    // NOTE: This truncates to the token *after* the last that was seen to be +EV!
    //       - The reason for this is because the main model will generate all tokens
    //         at this final position, and if we get there successfully; we can use
    //         whichever it finds regardless (ie: this last token isn't part of the draft).
    result.resize(best_draft_size);
    return result;and also find the draft model's token-generation speed (using the above script or just running it on its own) and set  This code then needs to replace all the code below the  llama.cpp/common/speculative.cpp Line 242 in f470bc3 Then make sure to run  
 On the  On the  For "low draftability" prompts like "Tell me the rules of chess" I'm still getting ~1.3x for the  To put this in perceptive: The absolute very best hand-tuned settings using the existing algorithm on the  It will take me all night to generate the data for  I really think this sort of "profile guided drafting" could be huge! I've no idea if I can make it into a proper PR yet though... | 
Beta Was this translation helpful? Give feedback.
-
| I reran some spec benches on Qwen2.5 32B coder. It looks like something in the backend got a lot faster than it used to be (most likely combined RPC + CUDA optimizations). Also I think there is potentially a simple heuristic which can be used to adapt spec block length without computing any probs as I will discuss after the results. HW: 4070 + 1 RPC 4070 HUMANEVAL 1ST PROBLEM: DEFS : DN = drafted totkens DA = accepted tokens TG= Token gen t/s 
 The interesting result from this table is the DA/DN ratio. When this ratio is >>0.5, it suggests the block size is too short and higher speed can be obtained by increasing it. As the block size is increased, DA/DN monotonically decreases. At the critical point of below DA/DN = 0.5, diminishing returns is found. This suggests a very simple heuristic of monitoring DA/DA during gen and boosting spec block length until its just above 0.5 to get optimal spec speed. No probs compute required, just monitoring the draft accept ratio. Now test a harder spec on the chess prompt with the same model: 
 Here diminishing returns occurs at DA/DN < 0.33. So a threshold of DA/DN 0.5 would not work here since it would never increase block length above 2, and some kind of scheme which modifies the threshold as a function of the block length is needed (simple: go from 0.3 up to 0.5 as block length varies from 1 to 8). However I would not trust such a heuristic in practice and still feel more comfortable running with either fixed 4 or 8. I am only 5 t/s below max at block length 8 on code and I think block length 4 is a good general purpose value for my particular spec algorithm. Since I would not be running Qwen coder on the chess prompt but on a general model I don't have to worry about specifying the draft length since it defaults to 4 for me on all general models. | 
Beta Was this translation helpful? Give feedback.
-
| 
 I think sadly this is doomed to fail - check out these graphs: Adaptive block length would work well here as the cost for each block size scales almost linearly. This single down-tick at the start for the CUDA MLA kernel can still be avoided by using the  But then you get stuff like this from the Metal flash-attention kernels: We should never take the next batch size above for all the cases where there is a downward step, as even if the draft model were 100% correct; it would still be -EV to do so! The huge uptick between sizes 3 and 4 mean we should also accept a way lower probability compared to the sizes before and after! | 
Beta Was this translation helpful? Give feedback.
-
| Looking at (the gradient of) all these together gives the best overall picture I think: 
 The " 
 The " I've tried to tidy up and explain the code a little better (and fixed a bug to do with the      // ??? CAN THIS EVER BE ANYTHING BUT 0 HERE ???
    //printf("[%d", (int) result.size());
    //GGML_ASSERT((int) result.size() == 0);
    // RTX 5000 Ada: Qwen2.5-Coder-32B-Instruct-Q6_K.gguf + Qwen2.5-Coder-0.5B-Instruct-Q4_0.gguf
    const std::vector<float> main_batch_speeds = { 18.39, 37.40, 55.68, 73.54, 90.26, 106.11, 114.92, 117.37, 140.73, 156.29, 171.05, 186.18, 201.16, 215.75, 230.52, 245.79, 254.85, 268.25, 282.39, 296.24, 311.07, 323.96, 338.24, 351.62, 374.63, 387.71, 401.52, 415.38, 429.15, 443.62, 455.55, 470.03 };
    const std::vector<float> draft_batch_speeds = { 324.59, 756.20, 1108.57, 1435.62, 1663.25, 1934.04, 2153.58, 2235.32, 1978.88, 2126.60, 2309.37, 2492.45, 2678.03, 2839.02, 3026.49, 3183.50, 3398.03, 3395.50, 3567.67, 3692.83, 3857.95, 4010.42, 4153.14, 4285.82, 4452.40, 4428.92, 4555.89, 4716.58, 4800.15, 4920.78, 5048.93, 5162.57 };
    // ==========================================================================================================
        
    // The estimated lookahead cost per token, in terms of the main model's token generation speed
    const float lookahead_cost_estimate = main_batch_speeds[0] / draft_batch_speeds[0];
    
    // The maximum lookahead relative to the best we have seen so far
    const int max_lookahead = 5;
    
    // ==========================================================================================================
    
    // The best draft size and its associated expected value so far (ie: init to the the main model's TG speed)
    int best_draft_size = 0;
    float best_draft_ev = main_batch_speeds[0];
    
    // The current sequence probability, as predicted by the draft model
    float current_sequence_p = 1.0;
   
    GGML_ASSERT((int) main_batch_speeds.size() == params.n_draft);
    GGML_ASSERT((int) draft_batch_speeds.size() == params.n_draft);
    for (int i = 0; i < params.n_draft; ++i) {
        
        // Sample a draft token
        common_batch_clear(batch);
        common_sampler_sample(smpl, ctx, 0, true);
        const auto * cur_p = common_sampler_get_candidates(smpl);
        const llama_token id = cur_p->data[0].id;
        common_sampler_accept(smpl, id, true);
        // Save the sampled token id
        result.push_back(id);
        
        // Get the current draft size we are looking at
        const int current_draft_size = result.size();
        
        // If we have enough tokens already, then stop
        if (current_draft_size >= params.n_draft) {
            break;
        }
        // Update the sequence probability using the sampled token's predicted probability
        current_sequence_p *= cur_p->data[0].p;
        // Calculate the expected value (in terms of the main model's tokens/s) for this sequence
        const float current_sequence_ev = current_sequence_p * main_batch_speeds[current_draft_size];
        
        // Is this token clearly +EV compared to what we have so far?
        if (current_sequence_ev > best_draft_ev) {
            best_draft_size = current_draft_size;
            best_draft_ev = current_sequence_ev;
        }
        // Otherwise we have to decide if we might see a +EV draft in the future, or stop now
        else {
            
            bool stop_now = true;
            
            for (int j = current_draft_size + 1; j < (int) main_batch_speeds.size(); j++) {
                // Don't bother looking too far relative to the best we have seen so far
                if (j - best_draft_size > max_lookahead) {
                    break;
                }
                
                // This approximates the cost of the lookahead in terms of the main model's tokens/s.
                const float lookahead_cost = lookahead_cost_estimate * (float) (j - current_draft_size);
                // Calculate the discounted potential EV of this lookahead depth
                // NOTE: This assumes the worst case of the draft predicting p=1.0 for all future tokens...
                const float potential_ev = current_sequence_p * (main_batch_speeds[j] - lookahead_cost);
                // Is this lookahead depth potentially +EV compared to the best we have so far?
                if (potential_ev > best_draft_ev) {
                    stop_now = false;
                    break;
                }
            }
            
            // If no chance to improve the EV, then stop now
            if (stop_now) {
                break;
            }
        }
        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
        // evaluate the drafted tokens on the draft model
        llama_decode(ctx, batch);
        prompt.push_back(id);
    }   
    //printf(", %d] {%d} %d %.2f (+%.2f)\n",(int) result.size(), (best_draft_size > 0 ? (int) result.size() - best_draft_size : 0), best_draft_size + 1, best_draft_ev, best_draft_ev - main_batch_speeds[0]);
    // NOTE: The main model should also generate the next token after the most +EV size we found.
    //       This is because if we successfully get to this token, the main model will see the
    //       full distribution and cannot be wrong (ie: essentially it's free if we get to it).
    result.resize(best_draft_size + 1);
    return result;but I don't really see much potential for large improvements and don't want to complicate it any more. The code I used to generate the  #!/bin/bash
# Environment variables
export CUDA_VISIBLE_DEVICES=0
# Configuration variables
BATCHED_BENCH_EXE="~/llama.cpp/build/bin/llama-batched-bench"
MAIN_MODEL_PATH="~/models/gguf/Qwen2.5-Coder-32B-Instruct-Q6_K.gguf"
#MAIN_MODEL_PATH="~/models/gguf/Deepseek-V3-0324-Q4_K_XL.gguf"
DRAFT_MODEL_PATH="~/models/gguf/draft_models/Qwen2.5-Coder-0.5B-Instruct-Q4_0.gguf"
#DRAFT_MODEL_PATH="~/models/gguf/draft_models/DeepSeek-V3-0324-CODER-DRAFT-0.6B-Q4_0.gguf"
# Benchmark parameters
PROMPT_SIZE=512
MAX_DRAFT_SIZE=32
NUM_SAMPLES=5
# Model-specific parameters
MODEL_PARAMS="--n-gpu-layers 99 \
              --flash-attn"
#MODEL_PARAMS="--n-gpu-layers 99 \
#              --flash-attn \
#              --numa distribute \
#              --threads 80 \
#              --override-tensor exps=CPU"
# Generate PROMPT_SIZE_LIST and BATCH_SIZE_LIST (NOTE: Process 2x before 1x to help with warmup)
PROMPT_SIZE_LIST="$((PROMPT_SIZE * 2)),${PROMPT_SIZE}"
BATCH_SIZE_LIST=$(printf "%s," $(for i in $(seq 1 $NUM_SAMPLES); do seq 1 $MAX_DRAFT_SIZE; done) | sed 's/,$//')
# Function to run benchmark for a model
run_benchmark() {
    local model_path=$1
    local log_file=$2
    
    echo "- Running benchmark for $(basename "$model_path")..."
    $BATCHED_BENCH_EXE \
        --model "$model_path" \
        $MODEL_PARAMS \
        --ctx_size "$((PROMPT_SIZE * 2 + MAX_DRAFT_SIZE))" \
        -pps \
        -npp "$PROMPT_SIZE_LIST" \
        -npl "$BATCH_SIZE_LIST" \
        -ntg 1 \
        --output-format jsonl | tee "$log_file"
}
# Function to extract and process results
extract_and_process() {
    local log_file=$1
    local model_name=$2
    local jsonl_pp1x="${model_name}_pp1x.jsonl"
    local jsonl_pp2x="${model_name}_pp2x.jsonl"
    
    # Extract JSONL lines from the log (NOTE: Skip first set of samples as seems to need a warmup to get accurate stats)
    echo -n "- Extracting results for $model_name..." >&2
    grep '^{' "$log_file" | tail -n "+$((NUM_SAMPLES + 1))" | grep "\"pp\": ${PROMPT_SIZE}" > "$jsonl_pp1x"
    grep '^{' "$log_file" | tail -n "+$((NUM_SAMPLES + 1))"| grep "\"pp\": $((PROMPT_SIZE * 2))" > "$jsonl_pp2x"
    COUNT1=$(wc -l < "$jsonl_pp1x")
    COUNT2=$(wc -l < "$jsonl_pp2x")
    echo " Done ($COUNT1 1xPP results, $COUNT2 2xPP results)" >&2
    
    # Process results and capture output
    process_results "$jsonl_pp1x" "$jsonl_pp2x" "$model_name"
    
    # Clean up
    rm "$log_file" "$jsonl_pp1x" "$jsonl_pp2x"
}
# Function to extract values as bash array
extract_values() {
    local jsonl_file=$1
    
    jq -s --raw-output '
        (map(.pl) | max) as $max_pl |
        reduce .[] as $item (
            {};
            ($item.pl | tostring) as $pl |
            .[$pl].sum = (.[$pl].sum + $item.speed_tg) |
            .[$pl].count = (.[$pl].count + 1)
        ) |
        [range(1; $max_pl + 1) as $pl |
            (.[($pl|tostring)]).sum / .[($pl|tostring)].count
        ] |
        map(. * 1000 | round | . / 1000) |
        join(" ")
    ' "$jsonl_file"
}
# Function to process JSONL and output C++ vector we will need
process_results() {
    local jsonl_file_1x=$1
    local jsonl_file_2x=$2
    local model_name=$3
    
    # Extract values as arrays
    local values_1x=($(extract_values "$jsonl_file_1x"))
    local values_2x=($(extract_values "$jsonl_file_2x"))
       
    # Solve the pair of simultaneous equations:
    # - net_batch_pl = 2*v1x - v2x  (ie: Batch speed without PP overhead)
    # - overhead_pp = v2x - v1x     (ie: Extra PP overhead)
    local net_batch_pl=()
    local overhead_pp=()
    
    for i in "${!values_1x[@]}"; do
        if [[ -n "${values_2x[i]}" ]]; then
            local v1x="${values_1x[i]}"
            local v2x="${values_2x[i]}"
            
            # Calculate using awk for floating point arithmetic
            local tg=$(awk "BEGIN {printf \"%.2f\", 2 * $v1x - $v2x}")
            local pp=$(awk "BEGIN {printf \"%.2f\", $v1x - $v2x}")
            
            net_batch_pl+=("$tg")
            overhead_pp+=("$pp")
        fi
    done
    # Return the C++ formatted vector
    local batch_speeds_str=$(printf '%s, ' "${net_batch_pl[@]}")
    batch_speeds_str=${batch_speeds_str%, }  # Remove trailing ", "
    echo "const std::vector<float> ${model_name}_batch_speeds = { $batch_speeds_str };"
}
# Clean previous files
rm -f main_*.log main_*.jsonl draft_*.log draft_*.jsonl
# Run benchmarks for both models and capture C++ output
run_benchmark "$MAIN_MODEL_PATH" "main_benchmark.log"
main_output=$(extract_and_process "main_benchmark.log" "main")
run_benchmark "$DRAFT_MODEL_PATH" "draft_benchmark.log"
draft_output=$(extract_and_process "draft_benchmark.log" "draft")
# Output both C++ vectors at the end
echo "----------------------------------------"
echo "$main_output"
echo "$draft_output"
echo "----------------------------------------"but I'm not convinced the use of  
 I've left it in though, as it could just be my setup isn't making it useful, or perhaps with more patience it could be run with  You could even run more than 2 and then fit a linear regression line to try to predict  But I really just want to keep this minimal so hopefully it's understandable, and clearly shows the problem it solves related to the existing  I doubt I'll be able to do much more for this, but hopefully somebody found it interesting! :) I'm not sure how easy it would be to add as a proper PR either... It would probably need to export the vectors as JSON and then import them somehow (the "jumpiness" phenomenon killed any idea of parametrising the lines sadly). | 
Beta Was this translation helpful? Give feedback.
-
| I've tracked down what is causing this: Line 1093 in 5fce5f9         if (v_mla) {
#if 0
            // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
            // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
            cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
            cur = ggml_mul_mat(ctx0, v_mla, cur);
#else
            // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
            // The permutations are noops and only change how the tensor data is interpreted.
            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
            cur = ggml_mul_mat(ctx0, v_mla, cur);
            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
            cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
#endif
        }using @fairydreaming's original method of only permuting if          if (v_mla) {
            if (n_tokens <= n_head) {
                // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
                // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
                cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
                cur = ggml_mul_mat(ctx0, v_mla, cur);
            } else {
                // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
                // The permutations are noops and only change how the tensor data is interpreted.
                cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
                cur = ggml_mul_mat(ctx0, v_mla, cur);
                cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
                cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
            } 
        }and the nasty jump is gone: (rerunning with the Linux caches dropped for NUMA and with the full  @JohannesGaessler It looks here like the crossover point may be  | 
Beta Was this translation helpful? Give feedback.
-
| I've simplified this now (basically back to what I started with!): 1. Run this to generate the  | 
Beta Was this translation helpful? Give feedback.
-
| I made some progress on this investigation, testing out the already available dynamic offload op capability of ggml to see if it can give any benefits. To test this on cuda backend the following source file change is required ggml/src/ggml-cuda/ggml-cuda.c This change enables controlling min batch size for offload in the cuda backend with env var GGML_OFFLOAD_OP_MIN_BATCH_SIZE.   Then set this env var to <= min spec block size prior to loading GGML_OFFLOAD_OP_MIN_BATCH_SIZE=3 llama-server .... When using speculation with a fully offloaded draft this will result in all computations being done on the GPU, with the layers that are on CPU being dynamically shipped to the GPU over PCIE, and hidden states for layers shipped back over PCIE to the CPU layers when done. Hardware: 4070 + 9900k, with PCI4 x16 and DDR4 motherboard (shuttle XPC) Tests : Qwen 2.5 coder 32B target Qwen 2.5 coder 0.5B draft (IQ4_XS), 38 of 64 layers offloaded to 4070, rest on CPU humaneval 1st problem, fixed spec block of 15 fully offloaded speedup: 3.3x nvtop shows 90% GPU use but low power of only 60/200W; i.e. 1/3 use. This most likely means that PCIE transfers are beiing counted as GPU activity by nvtop and there is no pipelining in the tensor scheduling; i.e. ggml scheduler is not scheduling simultaneous transfer of next tensor to be processed with compute of current tensor, instead its most likely a serial ship tensor->wait for gpu results->ship next tensor. This will be slow and underutilize hardware by more than a factor of 2 since batch compute time and tensor transfer time will be similar. However even with this inefficiency a speedup of 3.3x is still found with only PCIE gen 4 bus on "easy" speculation (code). ON the chess prompt with a fixed spec block size of 4 results get much worse as expected: fully offloaded speedup: 1.13 So its still worth it to offload to GPU but the token gen is still unusably slow, while with efficient speculation a usable token gen on order of 10t/s is found with the single 4070. In conclusion I think this idea does have good potential. It can be used with the codebase as is, albeit suboptimally, with the simple change of adding GGML_OFFLOAD_OP_MIN_BATCH_SIZE env var to give an x3 tg boost on code tasks. To move ahead and push the GPU to improved utilization it would be necessary to update the backend tensor scheduling to pipeline PCIE transfer of tensors and GPU compute. | 
Beta Was this translation helpful? Give feedback.
-
| I've been working some more on the "profile guided speculation" and pretty sure the master has a bad bug in it: llama.cpp/common/speculative.cpp Line 332 in f8f071f I'm not sure if it has always been here or got added recently, but if you follow this loop through:     // sample n_draft tokens from the draft model
    for (int i = 0; i < params.n_draft; ++i) {
        common_batch_clear(batch);
        common_sampler_sample(smpl, ctx_dft, 0, true);
        const auto * cur_p = common_sampler_get_candidates(smpl, true);
        for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
            LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
        }
        // add drafted token for each sequence
        const llama_token id = cur_p->data[0].id;
        common_sampler_accept(smpl, id, true);
        result.push_back(id);
        if (params.n_draft <= (int) result.size()) {
            break;
        }
        // only collect very high-confidence draft tokens
        if (cur_p->data[0].p < params.p_min) {
            break;
        }
        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
        // evaluate the drafted tokens on the draft model
        llama_decode(ctx_dft, batch);
        prompt_dft.push_back(id);
    }
    if (!spec->vocab_dft_compatible) {
        std::string detokenized = common_detokenize(ctx_dft, result, true);
        detokenized = replace_to_tgt(spec, detokenized);
        LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
        result = common_tokenize(ctx_tgt, detokenized, false, true);
        if (result.size() > (size_t)params.n_draft) {
            result.resize(params.n_draft);
        }
    }
    return result;
}and assume the first token is a really low probability like 0.1%, then it ends up in  This causes the function to always return 1 tokens regardless, and then in turn the calling code to runs a batch of 2 (which nearly always fails to speculative successfully 2 tokens due to the 1st token's low probability!). I'm fairly sure the  If you follow the code through the  This same logic can be followed for the second token being a low probability (causing a batch of 3 where the 3rd token is almost never successfully speculated), and so on... | 
Beta Was this translation helpful? Give feedback.















Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I recently added an efficient greedy-only spec decode to my downstream server patch (a completely different implementation than the current spec decode PR). I then evaluated tg performance for two cases : 1) Solve the first humaneval problem with coding model and 2) solve the goldcoin problem with general model. I used Qwen 14B for the target and 0.5B, 1.5B, and 3B for the drafts. I evaluated tg vs. draft token length on a 4070 fully offloaded with the target and draft weights where target is IQ4_XS quant and draft is Q6_K quant.
HUMANEVAL first problem:
TARGET Qwen2.5-Coder-14B-Instruct
DRAFTS Qwen2.5-Coder-0.5B-Instruct, Qwen2.5-Coder-1.5B-Instruct, Qwen2.5-Coder-3B-Instruct
TPS vs draft tokens:
GOLDCOIN
I have 10 apples. I find 3 gold coins in the bottom of a river. The river runs near a big city that has something to do with what I can spend the coins on. I then lose 4 apples but gain a gold coin. Three birds run into my path and drop 6 apples each. I play an online game and win 6 gold coins but I have to share them equally with my 2 teammates. I buy apples for all the coins I have. The price of an apple is 0.5 coins. How many apples do I have? And where is the river? Use step-by-step reasoning to solve this problem.
TARGET Qwen2.5-14B-Instruct
DRAFTS Qwen2.5-0.5B-Instruct, Qwen2.5-1.5B-Instruct, Qwen2.5-3B-Instruct
TPS vs draft tokens:
TARGET Llama 3.1 8B Instruct
DRAFT Llama 3.2 1B Instruct
TPS vs draft tokens:
TARGET Gemma 2 9B it IQ4_XS
DRAFT Gemma 2 2B it IQ4_XS
TPS vs draft tokens:
Results Summary:
Coding shows a max speedup of 2.5x tg at 10 draft tokens speculated using 0.5B model. At 1.5B draft the max speedup is 1.63x at 4 draft tokens. At 3B draft the max speedup is 1.33 at 4 draft tokens. The efficiency crossover (where draft+target is the same as no draft) is >32 draft tokens for 0.5B, >16 draft tokens for 1.5B, and 11 draft tokens for 3B.
Goldcoin shows a max speedup of 1.4x tg at 4 draft tokens speculated using 0.5B model. at 1.5 draft the max speedup is 1.17x at 4 draft tokens. At 3B draft the max speedup is 1.08 at 1 draft token. The efficiency crossover (where draft+target is the same as no draft) is 12 tokens for 0.5B, 6 tokens for 1.5B, and 3 tokens for 3B.
With Llama 3.18B instruct drafted by Llama 3.2 1B instruct a speedup in token gen of 1.83x is found at draft tokens of 5.
With Gemma2 9B it drafted by Gemma2 2B it there is never any speculative decoding speedup. Guess 2B not distilled from 9B at all but was trained on a completey different data set.
Conclusions and potential for running big LLMs on consumer grade GPUs:
Small draft model is needed (sine qua non). 0.5B size seems to work well. Any model in the range of 8G or above can benefit by distilling a 0.5B draft and speculating the model. Returns fall off rapidly as draft gets bigger, already questionable at 1.5B and not really useful at 3B draft. Coding is far more efficient than general text gen with speculation. Qwen 2.5 series is perfect for exploiting the potential of speculation.
For running big LLMs on consumer grade GPUs with limited memory it is desired to avoid the need to store all model weights and output layer in VRAM because there is not enough room. Most of the model weights are sitting there doing nothing most of the time, i.e. a 32 layer model has 31 dead weights sitting there occupying VRAM doing nothing 31/32 of the time. To get around this problem it is necessary to dynamically swap the layers into VRAM as they are needed from CPU RAM which is normally much higher capacity. If the draft size at the efficiency crossover is big enough, there may be (emphasis on may, it needs to be investigated for feasibility) enough time to compute the target batch (say 8 to 10 samples) and simultaneously transfer the next layer into the GPU. The GPU capacity needs 1 working layer allocation and one transfer allocation (two total model layers which are ping ponged between compute and transfer) + a fully offloaded speculator. KV for speculator and target should also both be in GPU mem. Even if it is needed to go above the efficiency crossover, it can still be more efficient to do dynamic layer loading to GPU because offloading to CPU is an immediate 10X or higher slowdown due to memory BW limits.
Beta Was this translation helpful? Give feedback.
All reactions