Skip to content

Commit 54fa074

Browse files
authored
Improve docstrings and type hints in scheduling_dpmsolver_singlestep.py (#12798)
feat: add flow sigmas, dynamic shifting, and refine type hints in DPMSolverSinglestepScheduler
1 parent 3d02cd5 commit 54fa074

File tree

1 file changed

+74
-50
lines changed

1 file changed

+74
-50
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 74 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -86,42 +86,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
8686
methods the library implements for all schedulers such as loading and saving.
8787
8888
Args:
89-
num_train_timesteps (`int`, defaults to 1000):
89+
num_train_timesteps (`int`, defaults to `1000`):
9090
The number of diffusion steps to train the model.
91-
beta_start (`float`, defaults to 0.0001):
91+
beta_start (`float`, defaults to `0.0001`):
9292
The starting `beta` value of inference.
93-
beta_end (`float`, defaults to 0.02):
93+
beta_end (`float`, defaults to `0.02`):
9494
The final `beta` value.
95-
beta_schedule (`str`, defaults to `"linear"`):
95+
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
9696
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
9797
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
98-
trained_betas (`np.ndarray`, *optional*):
98+
trained_betas (`np.ndarray` or `List[float]`, *optional*):
9999
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
100-
solver_order (`int`, defaults to 2):
100+
solver_order (`int`, defaults to `2`):
101101
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
102102
sampling, and `solver_order=3` for unconditional sampling.
103-
prediction_type (`str`, defaults to `epsilon`, *optional*):
103+
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
104104
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
105-
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
106-
Video](https://huggingface.co/papers/2210.02303) paper).
105+
`sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
106+
Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
107107
thresholding (`bool`, defaults to `False`):
108108
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
109109
as Stable Diffusion.
110-
dynamic_thresholding_ratio (`float`, defaults to 0.995):
110+
dynamic_thresholding_ratio (`float`, defaults to `0.995`):
111111
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
112-
sample_max_value (`float`, defaults to 1.0):
112+
sample_max_value (`float`, defaults to `1.0`):
113113
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
114114
`algorithm_type="dpmsolver++"`.
115-
algorithm_type (`str`, defaults to `dpmsolver++`):
116-
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
115+
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
116+
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, or `sde-dpmsolver++`. The `dpmsolver`
117117
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
118118
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
119119
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
120120
sampling like in Stable Diffusion.
121-
solver_type (`str`, defaults to `midpoint`):
121+
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
122122
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
123123
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
124-
lower_order_final (`bool`, defaults to `True`):
124+
lower_order_final (`bool`, defaults to `False`):
125125
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
126126
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
127127
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -132,15 +132,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
132132
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
133133
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
134134
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
135-
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
135+
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
136+
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
137+
flow_shift (`float`, *optional*, defaults to `1.0`):
138+
The flow shift parameter for flow-based models.
139+
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
136140
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
137-
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
141+
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
138142
lambda_min_clipped (`float`, defaults to `-inf`):
139143
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
140144
cosine (`squaredcos_cap_v2`) noise schedule.
141-
variance_type (`str`, *optional*):
142-
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
143-
contains the predicted Gaussian variance.
145+
variance_type (`"learned"` or `"learned_range"`, *optional*):
146+
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
147+
output contains the predicted Gaussian variance.
148+
use_dynamic_shifting (`bool`, defaults to `False`):
149+
Whether to use dynamic shifting for the noise schedule.
150+
time_shift_type (`"exponential"`, defaults to `"exponential"`):
151+
The type of time shifting to apply.
144152
"""
145153

146154
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -152,27 +160,27 @@ def __init__(
152160
num_train_timesteps: int = 1000,
153161
beta_start: float = 0.0001,
154162
beta_end: float = 0.02,
155-
beta_schedule: str = "linear",
156-
trained_betas: Optional[np.ndarray] = None,
163+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
164+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
157165
solver_order: int = 2,
158-
prediction_type: str = "epsilon",
166+
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
159167
thresholding: bool = False,
160168
dynamic_thresholding_ratio: float = 0.995,
161169
sample_max_value: float = 1.0,
162-
algorithm_type: str = "dpmsolver++",
163-
solver_type: str = "midpoint",
170+
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
171+
solver_type: Literal["midpoint", "heun"] = "midpoint",
164172
lower_order_final: bool = False,
165173
use_karras_sigmas: Optional[bool] = False,
166174
use_exponential_sigmas: Optional[bool] = False,
167175
use_beta_sigmas: Optional[bool] = False,
168176
use_flow_sigmas: Optional[bool] = False,
169177
flow_shift: Optional[float] = 1.0,
170-
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
178+
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
171179
lambda_min_clipped: float = -float("inf"),
172-
variance_type: Optional[str] = None,
180+
variance_type: Optional[Literal["learned", "learned_range"]] = None,
173181
use_dynamic_shifting: bool = False,
174-
time_shift_type: str = "exponential",
175-
):
182+
time_shift_type: Literal["exponential"] = "exponential",
183+
) -> None:
176184
if self.config.use_beta_sigmas and not is_scipy_available():
177185
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
178186
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -242,6 +250,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
242250
Args:
243251
num_inference_steps (`int`):
244252
The number of diffusion steps used when generating samples with a pre-trained model.
253+
254+
Returns:
255+
`List[int]`:
256+
The list of solver orders for each timestep.
245257
"""
246258
steps = num_inference_steps
247259
order = self.config.solver_order
@@ -276,21 +288,29 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
276288
return orders
277289

278290
@property
279-
def step_index(self):
291+
def step_index(self) -> Optional[int]:
280292
"""
281293
The index counter for current timestep. It will increase 1 after each scheduler step.
294+
295+
Returns:
296+
`int` or `None`:
297+
The current step index.
282298
"""
283299
return self._step_index
284300

285301
@property
286-
def begin_index(self):
302+
def begin_index(self) -> Optional[int]:
287303
"""
288304
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
305+
306+
Returns:
307+
`int` or `None`:
308+
The begin index.
289309
"""
290310
return self._begin_index
291311

292312
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
293-
def set_begin_index(self, begin_index: int = 0):
313+
def set_begin_index(self, begin_index: int = 0) -> None:
294314
"""
295315
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
296316
@@ -302,19 +322,21 @@ def set_begin_index(self, begin_index: int = 0):
302322

303323
def set_timesteps(
304324
self,
305-
num_inference_steps: int = None,
306-
device: Union[str, torch.device] = None,
325+
num_inference_steps: Optional[int] = None,
326+
device: Optional[Union[str, torch.device]] = None,
307327
mu: Optional[float] = None,
308328
timesteps: Optional[List[int]] = None,
309-
):
329+
) -> None:
310330
"""
311331
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
312332
313333
Args:
314-
num_inference_steps (`int`):
334+
num_inference_steps (`int`, *optional*):
315335
The number of diffusion steps used when generating samples with a pre-trained model.
316336
device (`str` or `torch.device`, *optional*):
317337
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
338+
mu (`float`, *optional*):
339+
The mu parameter for dynamic shifting.
318340
timesteps (`List[int]`, *optional*):
319341
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
320342
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
@@ -453,7 +475,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
453475
return sample
454476

455477
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
456-
def _sigma_to_t(self, sigma, log_sigmas):
478+
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
457479
"""
458480
Convert sigma values to corresponding timestep values through interpolation.
459481
@@ -490,7 +512,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
490512
return t
491513

492514
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
493-
def _sigma_to_alpha_sigma_t(self, sigma):
515+
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
494516
"""
495517
Convert sigma values to alpha_t and sigma_t values.
496518
@@ -512,7 +534,7 @@ def _sigma_to_alpha_sigma_t(self, sigma):
512534
return alpha_t, sigma_t
513535

514536
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
515-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
537+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
516538
"""
517539
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
518540
Models](https://huggingface.co/papers/2206.00364).
@@ -637,7 +659,7 @@ def convert_model_output(
637659
self,
638660
model_output: torch.Tensor,
639661
*args,
640-
sample: torch.Tensor = None,
662+
sample: Optional[torch.Tensor] = None,
641663
**kwargs,
642664
) -> torch.Tensor:
643665
"""
@@ -733,7 +755,7 @@ def dpm_solver_first_order_update(
733755
self,
734756
model_output: torch.Tensor,
735757
*args,
736-
sample: torch.Tensor = None,
758+
sample: Optional[torch.Tensor] = None,
737759
noise: Optional[torch.Tensor] = None,
738760
**kwargs,
739761
) -> torch.Tensor:
@@ -797,7 +819,7 @@ def singlestep_dpm_solver_second_order_update(
797819
self,
798820
model_output_list: List[torch.Tensor],
799821
*args,
800-
sample: torch.Tensor = None,
822+
sample: Optional[torch.Tensor] = None,
801823
noise: Optional[torch.Tensor] = None,
802824
**kwargs,
803825
) -> torch.Tensor:
@@ -908,7 +930,7 @@ def singlestep_dpm_solver_third_order_update(
908930
self,
909931
model_output_list: List[torch.Tensor],
910932
*args,
911-
sample: torch.Tensor = None,
933+
sample: Optional[torch.Tensor] = None,
912934
noise: Optional[torch.Tensor] = None,
913935
**kwargs,
914936
) -> torch.Tensor:
@@ -1030,8 +1052,8 @@ def singlestep_dpm_solver_update(
10301052
self,
10311053
model_output_list: List[torch.Tensor],
10321054
*args,
1033-
sample: torch.Tensor = None,
1034-
order: int = None,
1055+
sample: Optional[torch.Tensor] = None,
1056+
order: Optional[int] = None,
10351057
noise: Optional[torch.Tensor] = None,
10361058
**kwargs,
10371059
) -> torch.Tensor:
@@ -1125,7 +1147,7 @@ def index_for_timestep(
11251147
return step_index
11261148

11271149
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
1128-
def _init_step_index(self, timestep):
1150+
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
11291151
"""
11301152
Initialize the step_index counter for the scheduler.
11311153
@@ -1146,7 +1168,7 @@ def step(
11461168
model_output: torch.Tensor,
11471169
timestep: Union[int, torch.Tensor],
11481170
sample: torch.Tensor,
1149-
generator=None,
1171+
generator: Optional[torch.Generator] = None,
11501172
return_dict: bool = True,
11511173
) -> Union[SchedulerOutput, Tuple]:
11521174
"""
@@ -1156,11 +1178,13 @@ def step(
11561178
Args:
11571179
model_output (`torch.Tensor`):
11581180
The direct output from learned diffusion model.
1159-
timestep (`int`):
1181+
timestep (`int` or `torch.Tensor`):
11601182
The current discrete timestep in the diffusion chain.
11611183
sample (`torch.Tensor`):
11621184
A current instance of a sample created by the diffusion process.
1163-
return_dict (`bool`):
1185+
generator (`torch.Generator`, *optional*):
1186+
A random number generator for stochastic sampling.
1187+
return_dict (`bool`, defaults to `True`):
11641188
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
11651189
11661190
Returns:
@@ -1277,5 +1301,5 @@ def add_noise(
12771301
noisy_samples = alpha_t * original_samples + sigma_t * noise
12781302
return noisy_samples
12791303

1280-
def __len__(self):
1304+
def __len__(self) -> int:
12811305
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)