Skip to content

Commit 2034588

Browse files
authored
refactor: optimize the handling of sample method (#999)
1 parent 490c51d commit 2034588

File tree

4 files changed

+64
-52
lines changed

4 files changed

+64
-52
lines changed

denoiser.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ static void sample_k_diffusion(sample_method_t method,
640640
size_t steps = sigmas.size() - 1;
641641
// sample_euler_ancestral
642642
switch (method) {
643-
case EULER_A: {
643+
case EULER_A_SAMPLE_METHOD: {
644644
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
645645
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
646646

@@ -693,7 +693,7 @@ static void sample_k_diffusion(sample_method_t method,
693693
}
694694
}
695695
} break;
696-
case EULER: // Implemented without any sigma churn
696+
case EULER_SAMPLE_METHOD: // Implemented without any sigma churn
697697
{
698698
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
699699

@@ -726,7 +726,7 @@ static void sample_k_diffusion(sample_method_t method,
726726
}
727727
}
728728
} break;
729-
case HEUN: {
729+
case HEUN_SAMPLE_METHOD: {
730730
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
731731
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
732732

@@ -776,7 +776,7 @@ static void sample_k_diffusion(sample_method_t method,
776776
}
777777
}
778778
} break;
779-
case DPM2: {
779+
case DPM2_SAMPLE_METHOD: {
780780
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
781781
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
782782

@@ -828,7 +828,7 @@ static void sample_k_diffusion(sample_method_t method,
828828
}
829829

830830
} break;
831-
case DPMPP2S_A: {
831+
case DPMPP2S_A_SAMPLE_METHOD: {
832832
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
833833
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
834834

@@ -892,7 +892,7 @@ static void sample_k_diffusion(sample_method_t method,
892892
}
893893
}
894894
} break;
895-
case DPMPP2M: // DPM++ (2M) from Karras et al (2022)
895+
case DPMPP2M_SAMPLE_METHOD: // DPM++ (2M) from Karras et al (2022)
896896
{
897897
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
898898

@@ -931,7 +931,7 @@ static void sample_k_diffusion(sample_method_t method,
931931
}
932932
}
933933
} break;
934-
case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
934+
case DPMPP2Mv2_SAMPLE_METHOD: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
935935
{
936936
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
937937

@@ -974,7 +974,7 @@ static void sample_k_diffusion(sample_method_t method,
974974
}
975975
}
976976
} break;
977-
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
977+
case IPNDM_SAMPLE_METHOD: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
978978
{
979979
int max_order = 4;
980980
ggml_tensor* x_next = x;
@@ -1049,7 +1049,7 @@ static void sample_k_diffusion(sample_method_t method,
10491049
}
10501050
}
10511051
} break;
1052-
case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
1052+
case IPNDM_V_SAMPLE_METHOD: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
10531053
{
10541054
int max_order = 4;
10551055
std::vector<ggml_tensor*> buffer_model;
@@ -1123,7 +1123,7 @@ static void sample_k_diffusion(sample_method_t method,
11231123
d_cur = ggml_dup_tensor(work_ctx, x_next);
11241124
}
11251125
} break;
1126-
case LCM: // Latent Consistency Models
1126+
case LCM_SAMPLE_METHOD: // Latent Consistency Models
11271127
{
11281128
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
11291129
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@@ -1158,8 +1158,8 @@ static void sample_k_diffusion(sample_method_t method,
11581158
}
11591159
}
11601160
} break;
1161-
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
1162-
// with the "trailing" timestep spacing
1161+
case DDIM_TRAILING_SAMPLE_METHOD: // Denoising Diffusion Implicit Models
1162+
// with the "trailing" timestep spacing
11631163
{
11641164
// See J. Song et al., "Denoising Diffusion Implicit
11651165
// Models", arXiv:2010.02502 [cs.LG]
@@ -1352,8 +1352,8 @@ static void sample_k_diffusion(sample_method_t method,
13521352
// factor c_in.
13531353
}
13541354
} break;
1355-
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
1356-
// Trajectory Consistency Distillation
1355+
case TCD_SAMPLE_METHOD: // Strategic Stochastic Sampling (Algorithm 4) in
1356+
// Trajectory Consistency Distillation
13571357
{
13581358
// See J. Zheng et al., "Trajectory Consistency
13591359
// Distillation: Improved Latent Consistency Distillation

examples/cli/main.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1902,10 +1902,14 @@ int main(int argc, const char* argv[]) {
19021902
return 1;
19031903
}
19041904

1905-
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
1905+
if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
19061906
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
19071907
}
19081908

1909+
if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
1910+
params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
1911+
}
1912+
19091913
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
19101914
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
19111915
}

stable-diffusion.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ const char* model_version_to_str[] = {
4747
};
4848

4949
const char* sampling_methods_str[] = {
50-
"default",
5150
"Euler",
51+
"Euler A",
5252
"Heun",
5353
"DPM2",
5454
"DPM++ (2s)",
@@ -59,7 +59,6 @@ const char* sampling_methods_str[] = {
5959
"LCM",
6060
"DDIM \"trailing\"",
6161
"TCD",
62-
"Euler A",
6362
};
6463

6564
/*================================================== Helper Functions ================================================*/
@@ -2228,8 +2227,8 @@ enum rng_type_t str_to_rng_type(const char* str) {
22282227
}
22292228

22302229
const char* sample_method_to_str[] = {
2231-
"default",
22322230
"euler",
2231+
"euler_a",
22332232
"heun",
22342233
"dpm2",
22352234
"dpm++2s_a",
@@ -2240,7 +2239,6 @@ const char* sample_method_to_str[] = {
22402239
"lcm",
22412240
"ddim_trailing",
22422241
"tcd",
2243-
"euler_a",
22442242
};
22452243

22462244
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -2469,7 +2467,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
24692467
sample_params->guidance.slg.layer_end = 0.2f;
24702468
sample_params->guidance.slg.scale = 0.f;
24712469
sample_params->scheduler = SCHEDULER_COUNT;
2472-
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
2470+
sample_params->sample_method = SAMPLE_METHOD_COUNT;
24732471
sample_params->sample_steps = 20;
24742472
}
24752473

@@ -2627,19 +2625,19 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
26272625

26282626
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
26292627
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
2630-
SDVersion version = sd_ctx->sd->version;
2631-
if (sd_version_is_dit(version))
2632-
return EULER;
2633-
else
2634-
return EULER_A;
2628+
if (sd_version_is_dit(sd_ctx->sd->version)) {
2629+
return EULER_SAMPLE_METHOD;
2630+
}
26352631
}
2636-
return SAMPLE_METHOD_COUNT;
2632+
return EULER_A_SAMPLE_METHOD;
26372633
}
26382634

26392635
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
2640-
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
2641-
if (edm_v_denoiser) {
2642-
return EXPONENTIAL_SCHEDULER;
2636+
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
2637+
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
2638+
if (edm_v_denoiser) {
2639+
return EXPONENTIAL_SCHEDULER;
2640+
}
26432641
}
26442642
return DISCRETE_SCHEDULER;
26452643
}
@@ -2827,7 +2825,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
28272825
int C = sd_ctx->sd->get_latent_channel();
28282826
int W = width / sd_ctx->sd->get_vae_scale_factor();
28292827
int H = height / sd_ctx->sd->get_vae_scale_factor();
2830-
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
28312828

28322829
struct ggml_tensor* control_latent = nullptr;
28332830
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
@@ -3056,10 +3053,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
30563053
sd_ctx->sd->rng->manual_seed(seed);
30573054
sd_ctx->sd->sampler_rng->manual_seed(seed);
30583055

3059-
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
3060-
30613056
size_t t0 = ggml_time_ms();
30623057

3058+
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
3059+
if (sample_method == SAMPLE_METHOD_COUNT) {
3060+
sample_method = sd_get_default_sample_method(sd_ctx);
3061+
}
3062+
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
3063+
3064+
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
30633065
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);
30643066

30653067
ggml_tensor* init_latent = nullptr;
@@ -3248,11 +3250,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
32483250
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
32493251
}
32503252

3251-
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
3252-
if (sample_method == SAMPLE_METHOD_DEFAULT) {
3253-
sample_method = sd_get_default_sample_method(sd_ctx);
3254-
}
3255-
32563253
sd_image_t* result_images = generate_image_internal(sd_ctx,
32573254
work_ctx,
32583255
init_latent,
@@ -3302,6 +3299,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
33023299

33033300
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
33043301

3302+
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
3303+
if (sample_method == SAMPLE_METHOD_COUNT) {
3304+
sample_method = sd_get_default_sample_method(sd_ctx);
3305+
}
3306+
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
3307+
33053308
int high_noise_sample_steps = 0;
33063309
if (sd_ctx->sd->high_noise_diffusion_model) {
33073310
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
@@ -3570,6 +3573,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35703573
// High Noise Sample
35713574
if (high_noise_sample_steps > 0) {
35723575
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
3576+
enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
3577+
if (high_noise_sample_method == SAMPLE_METHOD_COUNT) {
3578+
high_noise_sample_method = sd_get_default_sample_method(sd_ctx);
3579+
}
3580+
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
3581+
35733582
int64_t sampling_start = ggml_time_ms();
35743583

35753584
std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
@@ -3588,7 +3597,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35883597
sd_vid_gen_params->high_noise_sample_params.guidance,
35893598
sd_vid_gen_params->high_noise_sample_params.eta,
35903599
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
3591-
sd_vid_gen_params->high_noise_sample_params.sample_method,
3600+
high_noise_sample_method,
35923601
high_noise_sigmas,
35933602
-1,
35943603
{},
@@ -3625,7 +3634,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36253634
sd_vid_gen_params->sample_params.guidance,
36263635
sd_vid_gen_params->sample_params.eta,
36273636
sd_vid_gen_params->sample_params.shifted_timestep,
3628-
sd_vid_gen_params->sample_params.sample_method,
3637+
sample_method,
36293638
sigmas,
36303639
-1,
36313640
{},

stable-diffusion.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ enum rng_type_t {
3636
};
3737

3838
enum sample_method_t {
39-
SAMPLE_METHOD_DEFAULT,
40-
EULER,
41-
HEUN,
42-
DPM2,
43-
DPMPP2S_A,
44-
DPMPP2M,
45-
DPMPP2Mv2,
46-
IPNDM,
47-
IPNDM_V,
48-
LCM,
49-
DDIM_TRAILING,
50-
TCD,
51-
EULER_A,
39+
EULER_SAMPLE_METHOD,
40+
EULER_A_SAMPLE_METHOD,
41+
HEUN_SAMPLE_METHOD,
42+
DPM2_SAMPLE_METHOD,
43+
DPMPP2S_A_SAMPLE_METHOD,
44+
DPMPP2M_SAMPLE_METHOD,
45+
DPMPP2Mv2_SAMPLE_METHOD,
46+
IPNDM_SAMPLE_METHOD,
47+
IPNDM_V_SAMPLE_METHOD,
48+
LCM_SAMPLE_METHOD,
49+
DDIM_TRAILING_SAMPLE_METHOD,
50+
TCD_SAMPLE_METHOD,
5251
SAMPLE_METHOD_COUNT
5352
};
5453

0 commit comments

Comments
 (0)