diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 575423ee80e7..9b0cfe0bbde9 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -171,7 +171,7 @@ def set_shift(self, shift: float): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], + timestep: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ @@ -180,8 +180,10 @@ def scale_noise( Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`, *optional*): + The noise tensor. Returns: `torch.FloatTensor`: @@ -212,6 +214,9 @@ def scale_noise( while len(sigma.shape) < len(sample.shape): sigma = sigma.unsqueeze(-1) + if noise is None: + noise = torch.randn_like(sample) + sample = sigma * noise + (1.0 - sigma) * sample return sample diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py index 2addc5f3eeec..ceaa1c3746fe 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py @@ -110,7 +110,7 @@ def set_begin_index(self, begin_index: int = 0): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], + timestep: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ @@ -119,8 +119,10 @@ def scale_noise( Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`, *optional*): + The noise tensor. Returns: `torch.FloatTensor`: @@ -130,6 +132,10 @@ def scale_noise( self._init_step_index(timestep) sigma = self.sigmas[self.step_index] + + if noise is None: + noise = torch.randn_like(sample) + sample = sigma * noise + (1.0 - sigma) * sample return sample diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py index d79556ae8077..617687d254ec 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_lcm.py +++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py @@ -192,7 +192,7 @@ def set_scale_factors(self, scale_factors: list, upscale_mode): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], + timestep: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ @@ -201,8 +201,10 @@ def scale_noise( Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`, *optional*): + The noise tensor. Returns: `torch.FloatTensor`: @@ -233,6 +235,9 @@ def scale_noise( while len(sigma.shape) < len(sample.shape): sigma = sigma.unsqueeze(-1) + if noise is None: + noise = torch.randn_like(sample) + sample = sigma * noise + (1.0 - sigma) * sample return sample