Skip to content

fix CPU offloading related fail cases on XPU #11288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 15, 2025
Merged

Conversation

yao-matrix
Copy link
Contributor

@yao-matrix yao-matrix commented Apr 11, 2025

Symptom
As of now, below 5 CPU related offloading cases fail on XPU

tests/pipelines/kandinsky/test_kandinsky_img2img.py::KandinskyImg2ImgPipelineFastTests::test_offloads
tests/lora/test_lora_layers_sd.py::LoraIntegrationTests::test_a1111_with_model_cpu_offload
tests/lora/test_lora_layers_sd.py::LoraIntegrationTests::test_a1111_with_sequential_cpu_offload
tests/lora/test_lora_layers_sd.py::LoraIntegrationTests::test_sd_load_civitai_empty_network_alpha
tests/pipelines/stable_diffusion/test_stable_diffusion.py::StableDiffusionPipelineSlowTests::test_stable_diffusion_textual_inversion_with_model_cpu_offload

Proposal
change device arg's default value from "cuda" to None in enable_model_cpu_offload and enable_sequential_cpu_offload, and add an automatically device detection logic in these 2 functions to detect accelerator if not specified.

Possible Questions
Q1: why not just change the test case code to specify the device?
[A1]: this 2 utility functions is heavily used in docs and jupyter notebooks, and also be used in internal lora loader(as here ) and ti loader(as here) with the assumption that if not specified, the arg should be accelerator(before this PR, this narrowly means "cuda"). We need inherit this good assumption and generalize the accelerator from cuda to all types of accelerators, so diffusers can support multiple devices with zero code change.

Q2: why not add "npu" and other devices while get_device?
[A2]: I don't have these device access, so I prefer those guys who have access do that, so they can validate before merging PRs. And it's simple to add a if-branch in get_device, most of the efforts are validation efforts.

Backward Compatibility
Full backward compatibility.

  • If CUDA device and non arg passed: before this PR, will use cuda device; after this PR, will detect the system and if cuda detected, will use cuda device.
  • If CUDA device and cuda arg passed: before this PR, will use cuda device; after this PR, will use cuda device.
  • If XPU device and cuda arg passed: before this PR, will crash and report confusing log "Torch not compiled with CUDA enabled"; after this PR, will automatically detect xpu device and use xpu device.
  • If XPU device and xpu arg passed: before this PR, will use xpu device; after this PR, will use xpu device.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
Signed-off-by: YAO Matrix <[email protected]>
@yao-matrix yao-matrix marked this pull request as draft April 11, 2025 08:04
fix style
Signed-off-by: YAO Matrix <[email protected]>
@yao-matrix
Copy link
Contributor Author

@hlky , pls help review and comments, thx very much.

@yao-matrix yao-matrix marked this pull request as ready for review April 14, 2025 03:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky
Copy link
Contributor

hlky commented Apr 14, 2025

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yao-matrix. Changes look good to me and save updating a lot of other code/docs/tests

docs/   enable_model_cpu_offload() 169 results in 40 files
docs/   enable_sequential_cpu_offload()  15 results in 8 files
src/    enable_model_cpu_offload()  `_load_lora_into_text_encoder`, `load_lora_adapter`, `load_attn_procs`, and 43 results in 35 files for examples
src/    enable_sequential_cpu_offload()  `_load_lora_into_text_encoder`, `load_lora_adapter`, `load_attn_procs`
tests/  enable_model_cpu_offload()  47 results in 21 files
tests/  enable_sequential_cpu_offload()  14 results in 9 files

There is one related fast test failing (others are transient Hub issues):

def test_pipe_same_device_id_offload(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd.enable_model_cpu_offload(gpu_id=5)
assert sd._offload_gpu_id == 5
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5

This test could either be moved to PipelineSlowTests, or if it's already passing when ran on a non-CUDA system we could simply add device="cuda" as the test is not actually offloading anything and only checking _offload_gpu_id is set correctly.

sd.enable_model_cpu_offload(gpu_id=5)

Can you confirm whether this test passes on XPU system, @yao-matrix?

Let's also get a second opinion, cc @sayakpaul

@sayakpaul
Copy link
Member

This test could either be moved to PipelineSlowTests, or if it's already passing when ran on a non-CUDA system we could simply add device="cuda" as the test is not actually offloading anything and only checking _offload_gpu_id is set correctly.

I would prefer the second option here as it's only targeting SD pipeline and we already have a bunch of offloading related tests in test_pipelines_common.py. When using the second option, I would also decorate the test with require_torch_gpu as it looks like we're going to only test on CUDA devices?

@hlky
Copy link
Contributor

hlky commented Apr 14, 2025

@sayakpaul This is a fast test that runs on CPU, if we decorate with require_torch_gpu it will be skipped. The test itself is not offloading anything and therefore does not actually require a torch gpu and is only checking _offload_gpu_id is set correctly.

@sayakpaul
Copy link
Member

My bad. Then all good.

@yao-matrix
Copy link
Contributor Author

Thanks @yao-matrix. Changes look good to me and save updating a lot of other code/docs/tests

docs/   enable_model_cpu_offload() 169 results in 40 files
docs/   enable_sequential_cpu_offload()  15 results in 8 files
src/    enable_model_cpu_offload()  `_load_lora_into_text_encoder`, `load_lora_adapter`, `load_attn_procs`, and 43 results in 35 files for examples
src/    enable_sequential_cpu_offload()  `_load_lora_into_text_encoder`, `load_lora_adapter`, `load_attn_procs`
tests/  enable_model_cpu_offload()  47 results in 21 files
tests/  enable_sequential_cpu_offload()  14 results in 9 files

There is one related fast test failing (others are transient Hub issues):

def test_pipe_same_device_id_offload(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd.enable_model_cpu_offload(gpu_id=5)
assert sd._offload_gpu_id == 5
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5

This test could either be moved to PipelineSlowTests, or if it's already passing when ran on a non-CUDA system we could simply add device="cuda" as the test is not actually offloading anything and only checking _offload_gpu_id is set correctly.

sd.enable_model_cpu_offload(gpu_id=5)

Can you confirm whether this test passes on XPU system, @yao-matrix?

Let's also get a second opinion, cc @sayakpaul

@hlky, this case passes on XPU, I pasted the log as below:
"PASSED tests/pipelines/test_pipelines.py::PipelineFastTests::test_pipe_same_device_id_offload"

The reason is in the enable_model_cpu_offload, it will convert gpu_id to device w/ below code:
(device = torch.device(f"{device_type}:{self._offload_gpu_id}"))[https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L1134C9-L1134C71]
and before that, we've already detected the device correctly and assign the correct device_type.

Thx.

@hlky hlky merged commit 7edace9 into huggingface:main Apr 15, 2025
27 of 29 checks passed
@yao-matrix yao-matrix deleted the issue221 branch April 15, 2025 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants