Skip to content

enable cpu offloading of new pipelines on XPU & use device agnostic empty to make pipelines work on XPU #11671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
replace_example_docstring,
)
from ...utils.import_utils import is_transformers_version
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel

Expand Down Expand Up @@ -267,9 +267,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
device_mod = getattr(torch, device.type, None)
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
empty_device_cache(device.type)

model_sequence = [
self.text_encoder.text_model,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/consisid/consisid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):

Parameters:
- model_path: Path to the directory containing model files.
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
- dtype: Data type (e.g., torch.float32) for model inference.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -1339,7 +1339,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -1311,7 +1311,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -1500,7 +1500,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -1858,7 +1858,7 @@ def denoising_value_valid(dnv):
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,11 @@ def __call__(

# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -921,7 +921,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

image = image.to(device=device, dtype=dtype)

Expand Down Expand Up @@ -1632,7 +1632,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -1766,7 +1766,7 @@ def denoising_value_valid(dnv):
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -876,7 +876,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

image = image.to(device=device, dtype=dtype)

Expand Down Expand Up @@ -1574,7 +1574,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -853,7 +853,7 @@ def __call__(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if is_controlnet_compiled and is_torch_higher_equal_2_1:
if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down Expand Up @@ -902,7 +902,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per discussion in this PR #11288, we change the default to None, so cpu_offloading can work on other accelerators like XPU w/ application code change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with such changes, cases like tests/pipelines/wuerstchen/test_wuerstchen_combined.py::WuerstchenCombinedPipelineFastTests::test_cpu_offload_forward_pass_twice, tests/pipelines/kandinsky2_2/test_kandinsky_combined.py::KandinskyV22PipelineImg2ImgCombinedFastTests::test_cpu_offload_forward_pass_twice can pass on XPU

r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
Expand Down Expand Up @@ -411,7 +411,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
Expand Down Expand Up @@ -652,7 +652,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
Expand Down Expand Up @@ -407,7 +407,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
Expand All @@ -417,7 +417,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)

def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
Expand Down Expand Up @@ -656,7 +656,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import KolorsPipelineOutput
from .text_encoder import ChatGLMModel
Expand Down Expand Up @@ -618,7 +618,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

image = image.to(device=device, dtype=dtype)

Expand Down
14 changes: 8 additions & 6 deletions src/diffusers/pipelines/musicldm/pipeline_musicldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin


Expand Down Expand Up @@ -397,20 +397,22 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
`forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

device = torch.device(f"cuda:{gpu_id}")
device_type = get_device()
device = torch.device(f"{device_type}:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
empty_device_cache() # otherwise we don't see the memory savings (but they probably exist)

model_sequence = [
self.text_encoder.text_model,
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -1228,7 +1228,11 @@ def __call__(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if (
torch.cuda.is_available()
and (is_unet_compiled and is_controlnet_compiled)
and is_torch_higher_equal_2_1
):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
Expand Down Expand Up @@ -1309,7 +1313,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
empty_device_cache()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
Expand Down
Loading