1111#define TIMESTEPS 1000
1212#define FLUX_TIMESTEPS 1000
1313
14- struct SigmaSchedule {
15- int version = 0 ;
14+ struct SigmaScheduler {
1615 typedef std::function<float (float )> t_to_sigma_t ;
1716
1817 virtual std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
1918};
2019
21- struct DiscreteSchedule : SigmaSchedule {
20+ struct DiscreteScheduler : SigmaScheduler {
2221 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
2322 std::vector<float > result;
2423
@@ -42,7 +41,7 @@ struct DiscreteSchedule : SigmaSchedule {
4241 }
4342};
4443
45- struct ExponentialSchedule : SigmaSchedule {
44+ struct ExponentialScheduler : SigmaScheduler {
4645 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
4746 std::vector<float > sigmas;
4847
@@ -149,7 +148,10 @@ std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
149148/*
150149https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
151150*/
152- struct AYSSchedule : SigmaSchedule {
151+ struct AYSScheduler : SigmaScheduler {
152+ SDVersion version;
153+ explicit AYSScheduler (SDVersion version)
154+ : version(version) {}
153155 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
154156 const std::vector<float > noise_levels[] = {
155157 /* SD1.5 */
@@ -169,19 +171,19 @@ struct AYSSchedule : SigmaSchedule {
169171 std::vector<float > results (n + 1 );
170172
171173 if (sd_version_is_sd2 ((SDVersion)version)) {
172- LOG_WARN (" AYS not designed for SD2.X models" );
174+ LOG_WARN (" AYS_SCHEDULER not designed for SD2.X models" );
173175 } /* fallthrough */
174176 else if (sd_version_is_sd1 ((SDVersion)version)) {
175- LOG_INFO (" AYS using SD1.5 noise levels" );
177+ LOG_INFO (" AYS_SCHEDULER using SD1.5 noise levels" );
176178 inputs = noise_levels[0 ];
177179 } else if (sd_version_is_sdxl ((SDVersion)version)) {
178- LOG_INFO (" AYS using SDXL noise levels" );
180+ LOG_INFO (" AYS_SCHEDULER using SDXL noise levels" );
179181 inputs = noise_levels[1 ];
180182 } else if (version == VERSION_SVD) {
181- LOG_INFO (" AYS using SVD noise levels" );
183+ LOG_INFO (" AYS_SCHEDULER using SVD noise levels" );
182184 inputs = noise_levels[2 ];
183185 } else {
184- LOG_ERROR (" Version not compatible with AYS scheduler" );
186+ LOG_ERROR (" Version not compatible with AYS_SCHEDULER scheduler" );
185187 return results;
186188 }
187189
@@ -203,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
203205/*
204206 * GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
205207 */
206- struct GITSSchedule : SigmaSchedule {
208+ struct GITSScheduler : SigmaScheduler {
207209 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
208210 if (sigma_max <= 0 .0f ) {
209211 return std::vector<float >{};
@@ -232,7 +234,7 @@ struct GITSSchedule : SigmaSchedule {
232234 }
233235};
234236
235- struct SGMUniformSchedule : SigmaSchedule {
237+ struct SGMUniformScheduler : SigmaScheduler {
236238 std::vector<float > get_sigmas (uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
237239 std::vector<float > result;
238240 if (n == 0 ) {
@@ -251,7 +253,7 @@ struct SGMUniformSchedule : SigmaSchedule {
251253 }
252254};
253255
254- struct KarrasSchedule : SigmaSchedule {
256+ struct KarrasScheduler : SigmaScheduler {
255257 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
256258 // These *COULD* be function arguments here,
257259 // but does anybody ever bother to touch them?
@@ -270,7 +272,7 @@ struct KarrasSchedule : SigmaSchedule {
270272 }
271273};
272274
273- struct SimpleSchedule : SigmaSchedule {
275+ struct SimpleScheduler : SigmaScheduler {
274276 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
275277 std::vector<float > result_sigmas;
276278
@@ -299,8 +301,8 @@ struct SimpleSchedule : SigmaSchedule {
299301 }
300302};
301303
302- // Close to Beta Schedule , but increadably simple in code.
303- struct SmoothStepSchedule : SigmaSchedule {
304+ // Close to Beta Scheduler , but increadably simple in code.
305+ struct SmoothStepScheduler : SigmaScheduler {
304306 static constexpr float smoothstep (float x) {
305307 return x * x * (3 .0f - 2 .0f * x);
306308 }
@@ -329,7 +331,6 @@ struct SmoothStepSchedule : SigmaSchedule {
329331};
330332
331333struct Denoiser {
332- std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
333334 virtual float sigma_min () = 0;
334335 virtual float sigma_max () = 0;
335336 virtual float sigma_to_t (float sigma) = 0;
@@ -338,8 +339,47 @@ struct Denoiser {
338339 virtual ggml_tensor* noise_scaling (float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
339340 virtual ggml_tensor* inverse_noise_scaling (float sigma, ggml_tensor* latent) = 0;
340341
341- virtual std::vector<float > get_sigmas (uint32_t n) {
342+ virtual std::vector<float > get_sigmas (uint32_t n, scheduler_t scheduler_type, SDVersion version ) {
342343 auto bound_t_to_sigma = std::bind (&Denoiser::t_to_sigma, this , std::placeholders::_1);
344+ std::shared_ptr<SigmaScheduler> scheduler;
345+ switch (scheduler_type) {
346+ case DISCRETE_SCHEDULER:
347+ LOG_INFO (" get_sigmas with discrete scheduler" );
348+ scheduler = std::make_shared<DiscreteScheduler>();
349+ break ;
350+ case KARRAS_SCHEDULER:
351+ LOG_INFO (" get_sigmas with Karras scheduler" );
352+ scheduler = std::make_shared<KarrasScheduler>();
353+ break ;
354+ case EXPONENTIAL_SCHEDULER:
355+ LOG_INFO (" get_sigmas exponential scheduler" );
356+ scheduler = std::make_shared<ExponentialScheduler>();
357+ break ;
358+ case AYS_SCHEDULER:
359+ LOG_INFO (" get_sigmas with Align-Your-Steps scheduler" );
360+ scheduler = std::make_shared<AYSScheduler>(version);
361+ break ;
362+ case GITS_SCHEDULER:
363+ LOG_INFO (" get_sigmas with GITS scheduler" );
364+ scheduler = std::make_shared<GITSScheduler>();
365+ break ;
366+ case SGM_UNIFORM_SCHEDULER:
367+ LOG_INFO (" get_sigmas with SGM Uniform scheduler" );
368+ scheduler = std::make_shared<SGMUniformScheduler>();
369+ break ;
370+ case SIMPLE_SCHEDULER:
371+ LOG_INFO (" get_sigmas with Simple scheduler" );
372+ scheduler = std::make_shared<SimpleScheduler>();
373+ break ;
374+ case SMOOTHSTEP_SCHEDULER:
375+ LOG_INFO (" get_sigmas with SmoothStep scheduler" );
376+ scheduler = std::make_shared<SmoothStepScheduler>();
377+ break ;
378+ default :
379+ LOG_INFO (" get_sigmas with discrete scheduler (default)" );
380+ scheduler = std::make_shared<DiscreteScheduler>();
381+ break ;
382+ }
343383 return scheduler->get_sigmas (n, sigma_min (), sigma_max (), bound_t_to_sigma);
344384 }
345385};
@@ -426,7 +466,6 @@ struct EDMVDenoiser : public CompVisVDenoiser {
426466
427467 EDMVDenoiser (float min_sigma = 0.002 , float max_sigma = 120.0 )
428468 : min_sigma(min_sigma), max_sigma(max_sigma) {
429- scheduler = std::make_shared<ExponentialSchedule>();
430469 }
431470
432471 float t_to_sigma (float t) override {
@@ -1109,7 +1148,7 @@ static void sample_k_diffusion(sample_method_t method,
11091148 // end beta) (which unfortunately k-diffusion's data
11101149 // structure hides from the denoiser), and the sigmas are
11111150 // also needed to invert the behavior of CompVisDenoiser
1112- // (k-diffusion's LMSDiscreteScheduler )
1151+ // (k-diffusion's LMSDiscreteSchedulerr )
11131152 float beta_start = 0 .00085f ;
11141153 float beta_end = 0 .0120f ;
11151154 std::vector<double > alphas_cumprod;
@@ -1137,7 +1176,7 @@ static void sample_k_diffusion(sample_method_t method,
11371176
11381177 for (int i = 0 ; i < steps; i++) {
11391178 // The "trailing" DDIM timestep, see S. Lin et al.,
1140- // "Common Diffusion Noise Schedules and Sample Steps
1179+ // "Common Diffusion Noise Schedulers and Sample Steps
11411180 // are Flawed", arXiv:2305.08891 [cs], p. 4, Table
11421181 // 2. Most variables below follow Diffusers naming
11431182 //
0 commit comments