Skip to content

Commit 869d023

Browse files
authored
refactor: optimize the handling of scheduler (#998)
1 parent e9bc3b6 commit 869d023

File tree

4 files changed

+104
-122
lines changed

4 files changed

+104
-122
lines changed

denoiser.hpp

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
#define TIMESTEPS 1000
1212
#define FLUX_TIMESTEPS 1000
1313

14-
struct SigmaSchedule {
15-
int version = 0;
14+
struct SigmaScheduler {
1615
typedef std::function<float(float)> t_to_sigma_t;
1716

1817
virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
1918
};
2019

21-
struct DiscreteSchedule : SigmaSchedule {
20+
struct DiscreteScheduler : SigmaScheduler {
2221
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
2322
std::vector<float> result;
2423

@@ -42,7 +41,7 @@ struct DiscreteSchedule : SigmaSchedule {
4241
}
4342
};
4443

45-
struct ExponentialSchedule : SigmaSchedule {
44+
struct ExponentialScheduler : SigmaScheduler {
4645
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
4746
std::vector<float> sigmas;
4847

@@ -149,7 +148,10 @@ std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
149148
/*
150149
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
151150
*/
152-
struct AYSSchedule : SigmaSchedule {
151+
struct AYSScheduler : SigmaScheduler {
152+
SDVersion version;
153+
explicit AYSScheduler(SDVersion version)
154+
: version(version) {}
153155
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
154156
const std::vector<float> noise_levels[] = {
155157
/* SD1.5 */
@@ -169,19 +171,19 @@ struct AYSSchedule : SigmaSchedule {
169171
std::vector<float> results(n + 1);
170172

171173
if (sd_version_is_sd2((SDVersion)version)) {
172-
LOG_WARN("AYS not designed for SD2.X models");
174+
LOG_WARN("AYS_SCHEDULER not designed for SD2.X models");
173175
} /* fallthrough */
174176
else if (sd_version_is_sd1((SDVersion)version)) {
175-
LOG_INFO("AYS using SD1.5 noise levels");
177+
LOG_INFO("AYS_SCHEDULER using SD1.5 noise levels");
176178
inputs = noise_levels[0];
177179
} else if (sd_version_is_sdxl((SDVersion)version)) {
178-
LOG_INFO("AYS using SDXL noise levels");
180+
LOG_INFO("AYS_SCHEDULER using SDXL noise levels");
179181
inputs = noise_levels[1];
180182
} else if (version == VERSION_SVD) {
181-
LOG_INFO("AYS using SVD noise levels");
183+
LOG_INFO("AYS_SCHEDULER using SVD noise levels");
182184
inputs = noise_levels[2];
183185
} else {
184-
LOG_ERROR("Version not compatible with AYS scheduler");
186+
LOG_ERROR("Version not compatible with AYS_SCHEDULER scheduler");
185187
return results;
186188
}
187189

@@ -203,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
203205
/*
204206
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
205207
*/
206-
struct GITSSchedule : SigmaSchedule {
208+
struct GITSScheduler : SigmaScheduler {
207209
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
208210
if (sigma_max <= 0.0f) {
209211
return std::vector<float>{};
@@ -232,7 +234,7 @@ struct GITSSchedule : SigmaSchedule {
232234
}
233235
};
234236

235-
struct SGMUniformSchedule : SigmaSchedule {
237+
struct SGMUniformScheduler : SigmaScheduler {
236238
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
237239
std::vector<float> result;
238240
if (n == 0) {
@@ -251,7 +253,7 @@ struct SGMUniformSchedule : SigmaSchedule {
251253
}
252254
};
253255

254-
struct KarrasSchedule : SigmaSchedule {
256+
struct KarrasScheduler : SigmaScheduler {
255257
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
256258
// These *COULD* be function arguments here,
257259
// but does anybody ever bother to touch them?
@@ -270,7 +272,7 @@ struct KarrasSchedule : SigmaSchedule {
270272
}
271273
};
272274

273-
struct SimpleSchedule : SigmaSchedule {
275+
struct SimpleScheduler : SigmaScheduler {
274276
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
275277
std::vector<float> result_sigmas;
276278

@@ -299,8 +301,8 @@ struct SimpleSchedule : SigmaSchedule {
299301
}
300302
};
301303

302-
// Close to Beta Schedule, but increadably simple in code.
303-
struct SmoothStepSchedule : SigmaSchedule {
304+
// Close to Beta Scheduler, but increadably simple in code.
305+
struct SmoothStepScheduler : SigmaScheduler {
304306
static constexpr float smoothstep(float x) {
305307
return x * x * (3.0f - 2.0f * x);
306308
}
@@ -329,7 +331,6 @@ struct SmoothStepSchedule : SigmaSchedule {
329331
};
330332

331333
struct Denoiser {
332-
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
333334
virtual float sigma_min() = 0;
334335
virtual float sigma_max() = 0;
335336
virtual float sigma_to_t(float sigma) = 0;
@@ -338,8 +339,47 @@ struct Denoiser {
338339
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
339340
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
340341

341-
virtual std::vector<float> get_sigmas(uint32_t n) {
342+
virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
342343
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
344+
std::shared_ptr<SigmaScheduler> scheduler;
345+
switch (scheduler_type) {
346+
case DISCRETE_SCHEDULER:
347+
LOG_INFO("get_sigmas with discrete scheduler");
348+
scheduler = std::make_shared<DiscreteScheduler>();
349+
break;
350+
case KARRAS_SCHEDULER:
351+
LOG_INFO("get_sigmas with Karras scheduler");
352+
scheduler = std::make_shared<KarrasScheduler>();
353+
break;
354+
case EXPONENTIAL_SCHEDULER:
355+
LOG_INFO("get_sigmas exponential scheduler");
356+
scheduler = std::make_shared<ExponentialScheduler>();
357+
break;
358+
case AYS_SCHEDULER:
359+
LOG_INFO("get_sigmas with Align-Your-Steps scheduler");
360+
scheduler = std::make_shared<AYSScheduler>(version);
361+
break;
362+
case GITS_SCHEDULER:
363+
LOG_INFO("get_sigmas with GITS scheduler");
364+
scheduler = std::make_shared<GITSScheduler>();
365+
break;
366+
case SGM_UNIFORM_SCHEDULER:
367+
LOG_INFO("get_sigmas with SGM Uniform scheduler");
368+
scheduler = std::make_shared<SGMUniformScheduler>();
369+
break;
370+
case SIMPLE_SCHEDULER:
371+
LOG_INFO("get_sigmas with Simple scheduler");
372+
scheduler = std::make_shared<SimpleScheduler>();
373+
break;
374+
case SMOOTHSTEP_SCHEDULER:
375+
LOG_INFO("get_sigmas with SmoothStep scheduler");
376+
scheduler = std::make_shared<SmoothStepScheduler>();
377+
break;
378+
default:
379+
LOG_INFO("get_sigmas with discrete scheduler (default)");
380+
scheduler = std::make_shared<DiscreteScheduler>();
381+
break;
382+
}
343383
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
344384
}
345385
};
@@ -426,7 +466,6 @@ struct EDMVDenoiser : public CompVisVDenoiser {
426466

427467
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
428468
: min_sigma(min_sigma), max_sigma(max_sigma) {
429-
scheduler = std::make_shared<ExponentialSchedule>();
430469
}
431470

432471
float t_to_sigma(float t) override {
@@ -1109,7 +1148,7 @@ static void sample_k_diffusion(sample_method_t method,
11091148
// end beta) (which unfortunately k-diffusion's data
11101149
// structure hides from the denoiser), and the sigmas are
11111150
// also needed to invert the behavior of CompVisDenoiser
1112-
// (k-diffusion's LMSDiscreteScheduler)
1151+
// (k-diffusion's LMSDiscreteSchedulerr)
11131152
float beta_start = 0.00085f;
11141153
float beta_end = 0.0120f;
11151154
std::vector<double> alphas_cumprod;
@@ -1137,7 +1176,7 @@ static void sample_k_diffusion(sample_method_t method,
11371176

11381177
for (int i = 0; i < steps; i++) {
11391178
// The "trailing" DDIM timestep, see S. Lin et al.,
1140-
// "Common Diffusion Noise Schedules and Sample Steps
1179+
// "Common Diffusion Noise Schedulers and Sample Steps
11411180
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
11421181
// 2. Most variables below follow Diffusers naming
11431182
//

examples/cli/main.cpp

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -912,34 +912,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
912912
return 1;
913913
};
914914

915-
auto on_schedule_arg = [&](int argc, const char** argv, int index) {
915+
auto on_scheduler_arg = [&](int argc, const char** argv, int index) {
916916
if (++index >= argc) {
917917
return -1;
918918
}
919919
const char* arg = argv[index];
920-
params.sample_params.scheduler = str_to_schedule(arg);
921-
if (params.sample_params.scheduler == SCHEDULE_COUNT) {
920+
params.sample_params.scheduler = str_to_scheduler(arg);
921+
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
922922
fprintf(stderr, "error: invalid scheduler %s\n",
923923
arg);
924924
return -1;
925925
}
926926
return 1;
927927
};
928928

929-
auto on_high_noise_schedule_arg = [&](int argc, const char** argv, int index) {
930-
if (++index >= argc) {
931-
return -1;
932-
}
933-
const char* arg = argv[index];
934-
params.high_noise_sample_params.scheduler = str_to_schedule(arg);
935-
if (params.high_noise_sample_params.scheduler == SCHEDULE_COUNT) {
936-
fprintf(stderr, "error: invalid high noise scheduler %s\n",
937-
arg);
938-
return -1;
939-
}
940-
return 1;
941-
};
942-
943929
auto on_prediction_arg = [&](int argc, const char** argv, int index) {
944930
if (++index >= argc) {
945931
return -1;
@@ -1212,7 +1198,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
12121198
{"",
12131199
"--scheduler",
12141200
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
1215-
on_schedule_arg},
1201+
on_scheduler_arg},
12161202
{"",
12171203
"--skip-layers",
12181204
"layers to skip for SLG steps (default: [7,8,9])",
@@ -1222,10 +1208,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
12221208
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
12231209
" default: euler for Flux/SD3/Wan, euler_a otherwise",
12241210
on_high_noise_sample_method_arg},
1225-
{"",
1226-
"--high-noise-scheduler",
1227-
"(high noise) denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
1228-
on_high_noise_schedule_arg},
12291211
{"",
12301212
"--high-noise-skip-layers",
12311213
"(high noise) layers to skip for SLG steps (default: [7,8,9])",
@@ -1442,8 +1424,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
14421424
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
14431425
}
14441426
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
1445-
if (params.sample_params.scheduler != DEFAULT) {
1446-
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
1427+
if (params.sample_params.scheduler != SCHEDULER_COUNT) {
1428+
parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler));
14471429
}
14481430
parameter_string += ", ";
14491431
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
@@ -1924,6 +1906,10 @@ int main(int argc, const char* argv[]) {
19241906
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
19251907
}
19261908

1909+
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
1910+
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
1911+
}
1912+
19271913
if (params.mode == IMG_GEN) {
19281914
sd_img_gen_params_t img_gen_params = {
19291915
params.prompt.c_str(),

0 commit comments

Comments
 (0)