Skip to content

Commit 4b28884

Browse files
committed
Review: move everything to diffusion-cli for now
1 parent 4a13243 commit 4b28884

File tree

4 files changed

+317
-338
lines changed

4 files changed

+317
-338
lines changed

examples/diffusion/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set(TARGET llama-diffusion-cli)
2-
add_executable(${TARGET} diffusion-cli.cpp diffusion.cpp)
2+
add_executable(${TARGET} diffusion-cli.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/diffusion/diffusion-cli.cpp

Lines changed: 316 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,310 @@
88
#include <limits.h>
99
#include <string>
1010
#include <vector>
11+
#include <algorithm>
12+
#include <cmath>
13+
#include <limits>
14+
#include <random>
15+
16+
typedef bool (*diffusion_step_callback_t)(int32_t step,
17+
int32_t total_steps,
18+
const llama_token * tokens,
19+
int32_t n_tokens,
20+
void * user_data);
21+
22+
enum diffusion_alg {
23+
DIFFUSION_ALG_ORIGIN = 0,
24+
DIFFUSION_ALG_MASKGIT_PLUS = 1,
25+
DIFFUSION_ALG_TOPK_MARGIN = 2,
26+
DIFFUSION_ALG_ENTROPY = 3,
27+
};
28+
29+
struct diffusion_params {
30+
int32_t steps;
31+
float eps;
32+
float temperature;
33+
float top_p;
34+
int32_t top_k;
35+
llama_token mask_token_id;
36+
enum diffusion_alg algorithm;
37+
float alg_temp;
38+
diffusion_step_callback_t step_callback;
39+
void * step_callback_user_data;
40+
int32_t seed;
41+
};
42+
43+
44+
static diffusion_params diffusion_default_params() {
45+
diffusion_params params = {};
46+
params.steps = 64;
47+
params.eps = 1e-3f;
48+
params.temperature = 0.2f;
49+
params.top_p = 0.95f;
50+
params.top_k = 0;
51+
params.mask_token_id = LLAMA_TOKEN_NULL;
52+
params.algorithm = DIFFUSION_ALG_ORIGIN;
53+
params.alg_temp = 0.0f;
54+
params.step_callback = nullptr;
55+
params.step_callback_user_data = nullptr;
56+
params.seed = 0;
57+
return params;
58+
}
59+
60+
static void diffusion_generate(llama_context * ctx,
61+
const llama_token * input_tokens,
62+
llama_token * output_tokens,
63+
int32_t n_input,
64+
int32_t max_length,
65+
struct diffusion_params params,
66+
int32_t & n_generated) {
67+
68+
n_generated = 0;
69+
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
70+
return;
71+
}
72+
73+
const llama_model * model = llama_get_model(ctx);
74+
75+
// Initialize with input and pad with mask tokens
76+
std::copy(input_tokens, input_tokens + n_input, output_tokens);
77+
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
78+
79+
std::mt19937 rng(params.seed);
80+
81+
std::vector<float> timesteps(params.steps + 1);
82+
for (int32_t i = 0; i <= params.steps; i++) {
83+
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
84+
}
85+
86+
llama_set_causal_attn(ctx, false);
87+
88+
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
89+
90+
std::vector<llama_token_data> candidates(n_vocab);
91+
92+
std::vector<llama_token_data> conf_candidates;
93+
conf_candidates.reserve(max_length);
94+
95+
std::vector<int32_t> mask_positions;
96+
mask_positions.reserve(max_length);
97+
98+
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
99+
if (params.top_k > 0) {
100+
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
101+
}
102+
if (params.top_p < 1.0f) {
103+
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
104+
}
105+
if (params.temperature > 0.0f) {
106+
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
107+
}
108+
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
109+
110+
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
111+
112+
llama_batch batch = llama_batch_init(max_length, 0, 1);
113+
batch.n_tokens = max_length;
114+
115+
int64_t total_sampling_time = 0;
116+
int64_t total_time = 0;
117+
118+
int64_t time_start = ggml_time_us();
119+
for (int32_t step = 0; step < params.steps; step++) {
120+
if (params.step_callback) {
121+
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
122+
break;
123+
}
124+
}
125+
126+
for (int32_t i = 0; i < max_length; i++) {
127+
batch.token[i] = output_tokens[i];
128+
batch.pos[i] = i;
129+
batch.n_seq_id[i] = 1;
130+
batch.seq_id[i][0] = 0;
131+
batch.logits[i] = 1;
132+
}
133+
134+
int ret = llama_decode(ctx, batch);
135+
if (ret != 0) {
136+
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
137+
break;
138+
}
139+
140+
float * raw_logits = llama_get_logits(ctx);
141+
if (!raw_logits) {
142+
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
143+
break;
144+
}
145+
146+
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
147+
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
148+
};
149+
150+
int64_t time_start_sampling = ggml_time_us();
151+
152+
mask_positions.clear();
153+
for (int32_t i = 0; i < max_length; i++) {
154+
if (output_tokens[i] == params.mask_token_id) {
155+
mask_positions.push_back(i);
156+
}
157+
}
158+
159+
if (mask_positions.empty()) {
160+
break;
161+
}
162+
163+
float t = timesteps[step];
164+
float s = timesteps[step + 1];
165+
166+
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
167+
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
168+
169+
for (int32_t pos : mask_positions) {
170+
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
171+
const float * pos_logits = get_logits_for_pos(pos);
172+
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
173+
candidates[token_id].id = token_id;
174+
candidates[token_id].logit = pos_logits[token_id];
175+
candidates[token_id].p = 0.0f;
176+
}
177+
178+
llama_token_data_array cur_p = {
179+
/* .data = */ candidates.data(),
180+
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
181+
/* .selected = */ -1,
182+
/* .sorted = */ false,
183+
};
184+
185+
llama_sampler_apply(sampler, &cur_p);
186+
output_tokens[pos] = cur_p.data[cur_p.selected].id;
187+
}
188+
}
189+
} else {
190+
std::vector<std::pair<float, int32_t>> confidences;
191+
std::vector<llama_token> sampled_tokens(mask_positions.size());
192+
193+
for (size_t i = 0; i < mask_positions.size(); i++) {
194+
int32_t pos = mask_positions[i];
195+
const float * pos_logits = get_logits_for_pos(pos);
196+
197+
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
198+
candidates[token_id].logit = pos_logits[token_id];
199+
candidates[token_id].p = 0.0f;
200+
candidates[token_id].id = token_id;
201+
}
202+
203+
llama_token_data_array cur_p = {
204+
/* .data = */ candidates.data(),
205+
/* .size = */ candidates.size(),
206+
/* .selected = */ -1,
207+
/* .sorted = */ false,
208+
};
209+
210+
llama_sampler_apply(sampler, &cur_p);
211+
212+
llama_token sampled_token = cur_p.data[cur_p.selected].id;
213+
214+
float confidence = 0.0f;
215+
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
216+
const float epsilon = 1e-10f;
217+
for (size_t j = 0; j < cur_p.size; j++) {
218+
float prob = cur_p.data[j].p;
219+
confidence += prob * logf(prob + epsilon);
220+
}
221+
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
222+
confidence = cur_p.data[0].p - cur_p.data[1].p;
223+
} else {
224+
confidence = cur_p.data[cur_p.selected].p;
225+
}
226+
227+
sampled_tokens[i] = sampled_token;
228+
confidences.emplace_back(confidence, i);
229+
}
230+
231+
int32_t num_transfer =
232+
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
233+
234+
if (num_transfer > 0) {
235+
if (params.alg_temp == 0.0f) {
236+
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
237+
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
238+
if (a.first != b.first) {
239+
return a.first > b.first;
240+
}
241+
return a.second < b.second;
242+
});
243+
} else {
244+
conf_candidates.clear();
245+
246+
for (int32_t pos = 0; pos < max_length; pos++) {
247+
float conf_logit = -std::numeric_limits<float>::infinity();
248+
249+
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
250+
if (it != mask_positions.end()) {
251+
size_t mask_idx = std::distance(mask_positions.begin(), it);
252+
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
253+
}
254+
255+
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
256+
}
257+
258+
llama_token_data_array conf_array = {
259+
/* .data = */ conf_candidates.data(),
260+
/* .size = */ conf_candidates.size(),
261+
/* .selected = */ -1,
262+
/* .sorted = */ false,
263+
};
264+
265+
for (int32_t i = 0; i < num_transfer; i++) {
266+
// Apply distribution sampler to get selected index
267+
llama_sampler_apply(dist_sampler, &conf_array);
268+
int selected_idx = conf_array.selected;
269+
confidences[i].second = conf_candidates[selected_idx].id;
270+
271+
conf_candidates[selected_idx].p = 0.0f;
272+
conf_array.selected = -1;
273+
}
274+
}
275+
276+
if (params.alg_temp == 0.0f) {
277+
// Deterministic - use confidence order
278+
for (int32_t i = 0; i < num_transfer; i++) {
279+
int32_t mask_idx = confidences[i].second;
280+
int32_t pos = mask_positions[mask_idx];
281+
llama_token token = sampled_tokens[mask_idx];
282+
output_tokens[pos] = token;
283+
}
284+
} else {
285+
for (int32_t i = 0; i < num_transfer; i++) {
286+
int32_t pos = confidences[i].second;
287+
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
288+
if (it != mask_positions.end()) {
289+
int32_t mask_idx = std::distance(mask_positions.begin(), it);
290+
output_tokens[pos] = sampled_tokens[mask_idx];
291+
}
292+
}
293+
}
294+
}
295+
}
296+
int64_t time_end_sampling = ggml_time_us();
297+
total_sampling_time += time_end_sampling - time_start_sampling;
298+
}
299+
int64_t time_end = ggml_time_us();
300+
total_time += time_end - time_start;
301+
302+
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
303+
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
304+
305+
306+
llama_batch_free(batch);
307+
llama_sampler_free(sampler);
308+
llama_sampler_free(dist_sampler);
309+
310+
n_generated = max_length;
311+
}
312+
313+
314+
11315

12316
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
13317
if (!use_chat_template) {
@@ -34,24 +338,24 @@ struct callback_data {
34338
int32_t n_input;
35339
};
36340

37-
static bool diffusion_step_callback(int32_t step
38-
, int32_t total_steps
39-
, const llama_token * tokens
40-
, int32_t n_tokens
41-
, void * user_data) {
341+
static bool diffusion_step_callback(int32_t step,
342+
int32_t total_steps,
343+
const llama_token * tokens,
344+
int32_t n_tokens,
345+
void * user_data) {
42346
(void)user_data;
43347

44348
callback_data * data = static_cast<callback_data *>(user_data);
45349

46350
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
47351
int progress_percent = (step * 100) / total_steps;
48352
int progress_bars = (step * 50) / total_steps;
49-
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%"
50-
, step
51-
, total_steps
52-
, std::string(progress_bars, '=').c_str()
53-
, std::string(50 - progress_bars, ' ').c_str()
54-
, progress_percent);
353+
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
354+
step,
355+
total_steps,
356+
std::string(progress_bars, '=').c_str(),
357+
std::string(50 - progress_bars, ' ').c_str(),
358+
progress_percent);
55359
};
56360

57361
if (data->diff_params->visual_mode) {
@@ -157,7 +461,7 @@ int main(int argc, char ** argv) {
157461
ldiff_params.temperature = params.sampling.temp;
158462
ldiff_params.top_p = params.sampling.top_p;
159463
ldiff_params.top_k = params.sampling.top_k;
160-
ldiff_params.algorithm = static_cast<enum diffusion_algorithm>(params.diffusion.algorithm);
464+
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
161465
ldiff_params.alg_temp = params.diffusion.alg_temp;
162466
ldiff_params.seed = params.sampling.seed;
163467

0 commit comments

Comments
 (0)