8
8
#include < limits.h>
9
9
#include < string>
10
10
#include < vector>
11
+ #include < algorithm>
12
+ #include < cmath>
13
+ #include < limits>
14
+ #include < random>
15
+
16
+ typedef bool (*diffusion_step_callback_t )(int32_t step,
17
+ int32_t total_steps,
18
+ const llama_token * tokens,
19
+ int32_t n_tokens,
20
+ void * user_data);
21
+
22
+ enum diffusion_alg {
23
+ DIFFUSION_ALG_ORIGIN = 0 ,
24
+ DIFFUSION_ALG_MASKGIT_PLUS = 1 ,
25
+ DIFFUSION_ALG_TOPK_MARGIN = 2 ,
26
+ DIFFUSION_ALG_ENTROPY = 3 ,
27
+ };
28
+
29
+ struct diffusion_params {
30
+ int32_t steps;
31
+ float eps;
32
+ float temperature;
33
+ float top_p;
34
+ int32_t top_k;
35
+ llama_token mask_token_id;
36
+ enum diffusion_alg algorithm;
37
+ float alg_temp;
38
+ diffusion_step_callback_t step_callback;
39
+ void * step_callback_user_data;
40
+ int32_t seed;
41
+ };
42
+
43
+
44
+ static diffusion_params diffusion_default_params () {
45
+ diffusion_params params = {};
46
+ params.steps = 64 ;
47
+ params.eps = 1e-3f ;
48
+ params.temperature = 0 .2f ;
49
+ params.top_p = 0 .95f ;
50
+ params.top_k = 0 ;
51
+ params.mask_token_id = LLAMA_TOKEN_NULL;
52
+ params.algorithm = DIFFUSION_ALG_ORIGIN;
53
+ params.alg_temp = 0 .0f ;
54
+ params.step_callback = nullptr ;
55
+ params.step_callback_user_data = nullptr ;
56
+ params.seed = 0 ;
57
+ return params;
58
+ }
59
+
60
+ static void diffusion_generate (llama_context * ctx,
61
+ const llama_token * input_tokens,
62
+ llama_token * output_tokens,
63
+ int32_t n_input,
64
+ int32_t max_length,
65
+ struct diffusion_params params,
66
+ int32_t & n_generated) {
67
+
68
+ n_generated = 0 ;
69
+ if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
70
+ return ;
71
+ }
72
+
73
+ const llama_model * model = llama_get_model (ctx);
74
+
75
+ // Initialize with input and pad with mask tokens
76
+ std::copy (input_tokens, input_tokens + n_input, output_tokens);
77
+ std::fill (output_tokens + n_input, output_tokens + max_length, params.mask_token_id );
78
+
79
+ std::mt19937 rng (params.seed );
80
+
81
+ std::vector<float > timesteps (params.steps + 1 );
82
+ for (int32_t i = 0 ; i <= params.steps ; i++) {
83
+ timesteps[i] = 1 .0f - (float ) i / params.steps * (1 .0f - params.eps );
84
+ }
85
+
86
+ llama_set_causal_attn (ctx, false );
87
+
88
+ int32_t n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model));
89
+
90
+ std::vector<llama_token_data> candidates (n_vocab);
91
+
92
+ std::vector<llama_token_data> conf_candidates;
93
+ conf_candidates.reserve (max_length);
94
+
95
+ std::vector<int32_t > mask_positions;
96
+ mask_positions.reserve (max_length);
97
+
98
+ struct llama_sampler * sampler = llama_sampler_chain_init (llama_sampler_chain_default_params ());
99
+ if (params.top_k > 0 ) {
100
+ llama_sampler_chain_add (sampler, llama_sampler_init_top_k (params.top_k ));
101
+ }
102
+ if (params.top_p < 1 .0f ) {
103
+ llama_sampler_chain_add (sampler, llama_sampler_init_top_p (params.top_p , 1 ));
104
+ }
105
+ if (params.temperature > 0 .0f ) {
106
+ llama_sampler_chain_add (sampler, llama_sampler_init_temp (params.temperature ));
107
+ }
108
+ llama_sampler_chain_add (sampler, llama_sampler_init_dist (params.seed ));
109
+
110
+ struct llama_sampler * dist_sampler = llama_sampler_init_dist (params.seed );
111
+
112
+ llama_batch batch = llama_batch_init (max_length, 0 , 1 );
113
+ batch.n_tokens = max_length;
114
+
115
+ int64_t total_sampling_time = 0 ;
116
+ int64_t total_time = 0 ;
117
+
118
+ int64_t time_start = ggml_time_us ();
119
+ for (int32_t step = 0 ; step < params.steps ; step++) {
120
+ if (params.step_callback ) {
121
+ if (!params.step_callback (step, params.steps , output_tokens, max_length, params.step_callback_user_data )) {
122
+ break ;
123
+ }
124
+ }
125
+
126
+ for (int32_t i = 0 ; i < max_length; i++) {
127
+ batch.token [i] = output_tokens[i];
128
+ batch.pos [i] = i;
129
+ batch.n_seq_id [i] = 1 ;
130
+ batch.seq_id [i][0 ] = 0 ;
131
+ batch.logits [i] = 1 ;
132
+ }
133
+
134
+ int ret = llama_decode (ctx, batch);
135
+ if (ret != 0 ) {
136
+ LOG_ERR (" %s: failed to decode at step %d, ret = %d\n " , __func__, step, ret);
137
+ break ;
138
+ }
139
+
140
+ float * raw_logits = llama_get_logits (ctx);
141
+ if (!raw_logits) {
142
+ LOG_ERR (" %s: failed to get logits at step %d\n " , __func__, step);
143
+ break ;
144
+ }
145
+
146
+ auto get_logits_for_pos = [&](int32_t pos) -> const float * {
147
+ return pos == 0 ? raw_logits : raw_logits + (pos - 1 ) * n_vocab;
148
+ };
149
+
150
+ int64_t time_start_sampling = ggml_time_us ();
151
+
152
+ mask_positions.clear ();
153
+ for (int32_t i = 0 ; i < max_length; i++) {
154
+ if (output_tokens[i] == params.mask_token_id ) {
155
+ mask_positions.push_back (i);
156
+ }
157
+ }
158
+
159
+ if (mask_positions.empty ()) {
160
+ break ;
161
+ }
162
+
163
+ float t = timesteps[step];
164
+ float s = timesteps[step + 1 ];
165
+
166
+ if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
167
+ float p_transfer = (step < params.steps - 1 ) ? (1 .0f - s / t) : 1 .0f ;
168
+
169
+ for (int32_t pos : mask_positions) {
170
+ if (std::uniform_real_distribution<float >(0 .0f , 1 .0f )(rng) < p_transfer) {
171
+ const float * pos_logits = get_logits_for_pos (pos);
172
+ for (int32_t token_id = 0 ; token_id < n_vocab; token_id++) {
173
+ candidates[token_id].id = token_id;
174
+ candidates[token_id].logit = pos_logits[token_id];
175
+ candidates[token_id].p = 0 .0f ;
176
+ }
177
+
178
+ llama_token_data_array cur_p = {
179
+ /* .data = */ candidates.data (),
180
+ /* .size = */ (size_t ) n_vocab, // Reset size to full vocab
181
+ /* .selected = */ -1 ,
182
+ /* .sorted = */ false ,
183
+ };
184
+
185
+ llama_sampler_apply (sampler, &cur_p);
186
+ output_tokens[pos] = cur_p.data [cur_p.selected ].id ;
187
+ }
188
+ }
189
+ } else {
190
+ std::vector<std::pair<float , int32_t >> confidences;
191
+ std::vector<llama_token> sampled_tokens (mask_positions.size ());
192
+
193
+ for (size_t i = 0 ; i < mask_positions.size (); i++) {
194
+ int32_t pos = mask_positions[i];
195
+ const float * pos_logits = get_logits_for_pos (pos);
196
+
197
+ for (int32_t token_id = 0 ; token_id < n_vocab; token_id++) {
198
+ candidates[token_id].logit = pos_logits[token_id];
199
+ candidates[token_id].p = 0 .0f ;
200
+ candidates[token_id].id = token_id;
201
+ }
202
+
203
+ llama_token_data_array cur_p = {
204
+ /* .data = */ candidates.data (),
205
+ /* .size = */ candidates.size (),
206
+ /* .selected = */ -1 ,
207
+ /* .sorted = */ false ,
208
+ };
209
+
210
+ llama_sampler_apply (sampler, &cur_p);
211
+
212
+ llama_token sampled_token = cur_p.data [cur_p.selected ].id ;
213
+
214
+ float confidence = 0 .0f ;
215
+ if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
216
+ const float epsilon = 1e-10f ;
217
+ for (size_t j = 0 ; j < cur_p.size ; j++) {
218
+ float prob = cur_p.data [j].p ;
219
+ confidence += prob * logf (prob + epsilon);
220
+ }
221
+ } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
222
+ confidence = cur_p.data [0 ].p - cur_p.data [1 ].p ;
223
+ } else {
224
+ confidence = cur_p.data [cur_p.selected ].p ;
225
+ }
226
+
227
+ sampled_tokens[i] = sampled_token;
228
+ confidences.emplace_back (confidence, i);
229
+ }
230
+
231
+ int32_t num_transfer =
232
+ (step < params.steps - 1 ) ? (int32_t ) (mask_positions.size () * (1 .0f - s / t)) : mask_positions.size ();
233
+
234
+ if (num_transfer > 0 ) {
235
+ if (params.alg_temp == 0 .0f ) {
236
+ std::partial_sort (confidences.begin (), confidences.begin () + num_transfer, confidences.end (),
237
+ [](const std::pair<float , int32_t > & a, const std::pair<float , int32_t > & b) {
238
+ if (a.first != b.first ) {
239
+ return a.first > b.first ;
240
+ }
241
+ return a.second < b.second ;
242
+ });
243
+ } else {
244
+ conf_candidates.clear ();
245
+
246
+ for (int32_t pos = 0 ; pos < max_length; pos++) {
247
+ float conf_logit = -std::numeric_limits<float >::infinity ();
248
+
249
+ auto it = std::find (mask_positions.begin (), mask_positions.end (), pos);
250
+ if (it != mask_positions.end ()) {
251
+ size_t mask_idx = std::distance (mask_positions.begin (), it);
252
+ conf_logit = confidences[mask_idx].first / params.alg_temp ; // Apply temperature scaling
253
+ }
254
+
255
+ conf_candidates.emplace_back (llama_token_data{ pos, conf_logit, 0 .0f });
256
+ }
257
+
258
+ llama_token_data_array conf_array = {
259
+ /* .data = */ conf_candidates.data (),
260
+ /* .size = */ conf_candidates.size (),
261
+ /* .selected = */ -1 ,
262
+ /* .sorted = */ false ,
263
+ };
264
+
265
+ for (int32_t i = 0 ; i < num_transfer; i++) {
266
+ // Apply distribution sampler to get selected index
267
+ llama_sampler_apply (dist_sampler, &conf_array);
268
+ int selected_idx = conf_array.selected ;
269
+ confidences[i].second = conf_candidates[selected_idx].id ;
270
+
271
+ conf_candidates[selected_idx].p = 0 .0f ;
272
+ conf_array.selected = -1 ;
273
+ }
274
+ }
275
+
276
+ if (params.alg_temp == 0 .0f ) {
277
+ // Deterministic - use confidence order
278
+ for (int32_t i = 0 ; i < num_transfer; i++) {
279
+ int32_t mask_idx = confidences[i].second ;
280
+ int32_t pos = mask_positions[mask_idx];
281
+ llama_token token = sampled_tokens[mask_idx];
282
+ output_tokens[pos] = token;
283
+ }
284
+ } else {
285
+ for (int32_t i = 0 ; i < num_transfer; i++) {
286
+ int32_t pos = confidences[i].second ;
287
+ auto it = std::find (mask_positions.begin (), mask_positions.end (), pos);
288
+ if (it != mask_positions.end ()) {
289
+ int32_t mask_idx = std::distance (mask_positions.begin (), it);
290
+ output_tokens[pos] = sampled_tokens[mask_idx];
291
+ }
292
+ }
293
+ }
294
+ }
295
+ }
296
+ int64_t time_end_sampling = ggml_time_us ();
297
+ total_sampling_time += time_end_sampling - time_start_sampling;
298
+ }
299
+ int64_t time_end = ggml_time_us ();
300
+ total_time += time_end - time_start;
301
+
302
+ LOG_INF (" \n total time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n " ,
303
+ total_time / 1000.0 , total_time / 1000.0 / params.steps , total_sampling_time / 1000.0 / params.steps );
304
+
305
+
306
+ llama_batch_free (batch);
307
+ llama_sampler_free (sampler);
308
+ llama_sampler_free (dist_sampler);
309
+
310
+ n_generated = max_length;
311
+ }
312
+
313
+
314
+
11
315
12
316
static std::string format_input_text (const std::string & prompt, bool use_chat_template, llama_model * model) {
13
317
if (!use_chat_template) {
@@ -34,24 +338,24 @@ struct callback_data {
34
338
int32_t n_input;
35
339
};
36
340
37
- static bool diffusion_step_callback (int32_t step
38
- , int32_t total_steps
39
- , const llama_token * tokens
40
- , int32_t n_tokens
41
- , void * user_data) {
341
+ static bool diffusion_step_callback (int32_t step,
342
+ int32_t total_steps,
343
+ const llama_token * tokens,
344
+ int32_t n_tokens,
345
+ void * user_data) {
42
346
(void )user_data;
43
347
44
348
callback_data * data = static_cast <callback_data *>(user_data);
45
349
46
350
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
47
351
int progress_percent = (step * 100 ) / total_steps;
48
352
int progress_bars = (step * 50 ) / total_steps;
49
- LOG_INF (" \r diffusion step: %d/%d [%s%s] %d%%"
50
- , step
51
- , total_steps
52
- , std::string (progress_bars, ' =' ).c_str ()
53
- , std::string (50 - progress_bars, ' ' ).c_str ()
54
- , progress_percent);
353
+ LOG_INF (" \r diffusion step: %d/%d [%s%s] %d%%" ,
354
+ step,
355
+ total_steps,
356
+ std::string (progress_bars, ' =' ).c_str (),
357
+ std::string (50 - progress_bars, ' ' ).c_str (),
358
+ progress_percent);
55
359
};
56
360
57
361
if (data->diff_params ->visual_mode ) {
@@ -157,7 +461,7 @@ int main(int argc, char ** argv) {
157
461
ldiff_params.temperature = params.sampling .temp ;
158
462
ldiff_params.top_p = params.sampling .top_p ;
159
463
ldiff_params.top_k = params.sampling .top_k ;
160
- ldiff_params.algorithm = static_cast <enum diffusion_algorithm >(params.diffusion .algorithm );
464
+ ldiff_params.algorithm = static_cast <enum diffusion_alg >(params.diffusion .algorithm );
161
465
ldiff_params.alg_temp = params.diffusion .alg_temp ;
162
466
ldiff_params.seed = params.sampling .seed ;
163
467
0 commit comments