diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index ddd5372b4dd8..cd9508c00d99 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -800,17 +800,20 @@ def __call__( ) height, width = control_image.shape[-2:] - control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) if control_mode is not None: control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) @@ -819,7 +822,9 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - for control_image_ in control_image: + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -831,17 +836,18 @@ def __call__( ) height, width = control_image_.shape[-2:] - control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + if self.controlnet.nets[0].input_hint_block is None: + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) control_images.append(control_image_) @@ -955,6 +961,7 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] latents_dtype = latents.dtype