@@ -47,8 +47,8 @@ const char* model_version_to_str[] = {
4747};
4848
4949const char * sampling_methods_str[] = {
50- " default" ,
5150 " Euler" ,
51+ " Euler A" ,
5252 " Heun" ,
5353 " DPM2" ,
5454 " DPM++ (2s)" ,
@@ -59,7 +59,6 @@ const char* sampling_methods_str[] = {
5959 " LCM" ,
6060 " DDIM \" trailing\" " ,
6161 " TCD" ,
62- " Euler A" ,
6362};
6463
6564/* ================================================== Helper Functions ================================================*/
@@ -2228,8 +2227,8 @@ enum rng_type_t str_to_rng_type(const char* str) {
22282227}
22292228
22302229const char * sample_method_to_str[] = {
2231- " default" ,
22322230 " euler" ,
2231+ " euler_a" ,
22332232 " heun" ,
22342233 " dpm2" ,
22352234 " dpm++2s_a" ,
@@ -2240,7 +2239,6 @@ const char* sample_method_to_str[] = {
22402239 " lcm" ,
22412240 " ddim_trailing" ,
22422241 " tcd" ,
2243- " euler_a" ,
22442242};
22452243
22462244const char * sd_sample_method_name (enum sample_method_t sample_method) {
@@ -2469,7 +2467,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
24692467 sample_params->guidance .slg .layer_end = 0 .2f ;
24702468 sample_params->guidance .slg .scale = 0 .f ;
24712469 sample_params->scheduler = SCHEDULER_COUNT;
2472- sample_params->sample_method = SAMPLE_METHOD_DEFAULT ;
2470+ sample_params->sample_method = SAMPLE_METHOD_COUNT ;
24732471 sample_params->sample_steps = 20 ;
24742472}
24752473
@@ -2627,19 +2625,19 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
26272625
26282626enum sample_method_t sd_get_default_sample_method (const sd_ctx_t * sd_ctx) {
26292627 if (sd_ctx != nullptr && sd_ctx->sd != nullptr ) {
2630- SDVersion version = sd_ctx->sd ->version ;
2631- if (sd_version_is_dit (version))
2632- return EULER;
2633- else
2634- return EULER_A;
2628+ if (sd_version_is_dit (sd_ctx->sd ->version )) {
2629+ return EULER_SAMPLE_METHOD;
2630+ }
26352631 }
2636- return SAMPLE_METHOD_COUNT ;
2632+ return EULER_A_SAMPLE_METHOD ;
26372633}
26382634
26392635enum scheduler_t sd_get_default_scheduler (const sd_ctx_t * sd_ctx) {
2640- auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd ->denoiser );
2641- if (edm_v_denoiser) {
2642- return EXPONENTIAL_SCHEDULER;
2636+ if (sd_ctx != nullptr && sd_ctx->sd != nullptr ) {
2637+ auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd ->denoiser );
2638+ if (edm_v_denoiser) {
2639+ return EXPONENTIAL_SCHEDULER;
2640+ }
26432641 }
26442642 return DISCRETE_SCHEDULER;
26452643}
@@ -2827,7 +2825,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
28272825 int C = sd_ctx->sd ->get_latent_channel ();
28282826 int W = width / sd_ctx->sd ->get_vae_scale_factor ();
28292827 int H = height / sd_ctx->sd ->get_vae_scale_factor ();
2830- LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
28312828
28322829 struct ggml_tensor * control_latent = nullptr ;
28332830 if (sd_version_is_control (sd_ctx->sd ->version ) && image_hint != nullptr ) {
@@ -3056,10 +3053,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
30563053 sd_ctx->sd ->rng ->manual_seed (seed);
30573054 sd_ctx->sd ->sampler_rng ->manual_seed (seed);
30583055
3059- int sample_steps = sd_img_gen_params->sample_params .sample_steps ;
3060-
30613056 size_t t0 = ggml_time_ms ();
30623057
3058+ enum sample_method_t sample_method = sd_img_gen_params->sample_params .sample_method ;
3059+ if (sample_method == SAMPLE_METHOD_COUNT) {
3060+ sample_method = sd_get_default_sample_method (sd_ctx);
3061+ }
3062+ LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
3063+
3064+ int sample_steps = sd_img_gen_params->sample_params .sample_steps ;
30633065 std::vector<float > sigmas = sd_ctx->sd ->denoiser ->get_sigmas (sample_steps, sd_img_gen_params->sample_params .scheduler , sd_ctx->sd ->version );
30643066
30653067 ggml_tensor* init_latent = nullptr ;
@@ -3248,11 +3250,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
32483250 LOG_INFO (" encode_first_stage completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
32493251 }
32503252
3251- enum sample_method_t sample_method = sd_img_gen_params->sample_params .sample_method ;
3252- if (sample_method == SAMPLE_METHOD_DEFAULT) {
3253- sample_method = sd_get_default_sample_method (sd_ctx);
3254- }
3255-
32563253 sd_image_t * result_images = generate_image_internal (sd_ctx,
32573254 work_ctx,
32583255 init_latent,
@@ -3302,6 +3299,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
33023299
33033300 int vae_scale_factor = sd_ctx->sd ->get_vae_scale_factor ();
33043301
3302+ enum sample_method_t sample_method = sd_vid_gen_params->sample_params .sample_method ;
3303+ if (sample_method == SAMPLE_METHOD_COUNT) {
3304+ sample_method = sd_get_default_sample_method (sd_ctx);
3305+ }
3306+ LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
3307+
33053308 int high_noise_sample_steps = 0 ;
33063309 if (sd_ctx->sd ->high_noise_diffusion_model ) {
33073310 high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params .sample_steps ;
@@ -3570,6 +3573,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35703573 // High Noise Sample
35713574 if (high_noise_sample_steps > 0 ) {
35723575 LOG_DEBUG (" sample(high noise) %dx%dx%d" , W, H, T);
3576+ enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params .sample_method ;
3577+ if (high_noise_sample_method == SAMPLE_METHOD_COUNT) {
3578+ high_noise_sample_method = sd_get_default_sample_method (sd_ctx);
3579+ }
3580+ LOG_INFO (" sampling(high noise) using %s method" , sampling_methods_str[high_noise_sample_method]);
3581+
35733582 int64_t sampling_start = ggml_time_ms ();
35743583
35753584 std::vector<float > high_noise_sigmas = std::vector<float >(sigmas.begin (), sigmas.begin () + high_noise_sample_steps + 1 );
@@ -3588,7 +3597,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35883597 sd_vid_gen_params->high_noise_sample_params .guidance ,
35893598 sd_vid_gen_params->high_noise_sample_params .eta ,
35903599 sd_vid_gen_params->high_noise_sample_params .shifted_timestep ,
3591- sd_vid_gen_params-> high_noise_sample_params . sample_method ,
3600+ high_noise_sample_method ,
35923601 high_noise_sigmas,
35933602 -1 ,
35943603 {},
@@ -3625,7 +3634,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36253634 sd_vid_gen_params->sample_params .guidance ,
36263635 sd_vid_gen_params->sample_params .eta ,
36273636 sd_vid_gen_params->sample_params .shifted_timestep ,
3628- sd_vid_gen_params-> sample_params . sample_method ,
3637+ sample_method,
36293638 sigmas,
36303639 -1 ,
36313640 {},
0 commit comments