Skip to content

Conversation

stevhliu
Copy link
Member

@stevhliu stevhliu commented Sep 15, 2025

companion docs for context parallelism with Ring/Ulysses attention (see #11941)

cc @a-r-r-o-w

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@stevhliu stevhliu requested review from sayakpaul and DN6 September 15, 2025 20:01
@sayakpaul sayakpaul changed the base branch from main to attn-dispatcher-cp-and-training September 22, 2025 08:47
@sayakpaul sayakpaul changed the base branch from attn-dispatcher-cp-and-training to main September 22, 2025 08:48
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think we should merge this PR in #11941.

I will try to gather some benchmarks to include in the docs.

Comment on lines 269 to 272
pipeline.transformer.parallelize(config=ContextParallelConfig(ring_degree=2))
pipeline.transformer.set_attention_backend("flash")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to call parallelize() after loading the model, or is it better to pass a parallel config when initializing the model? Or are both approaches same?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think it's just enable_prallelism() now.

[`ContextParallelConfig`] also supports Ulysses Attention through the `ulysses_degree` argument. This determines the number of devices to use for Ulysses Attention.

```py
pipeline.transformer.parallelize(config=ContextParallelConfig(ulysses_degree=2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is ParallelConfig used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't include ParallelConfig because it seems like you just pass ContextParallelConfig to it. So I opted to use ContextParallelConfig directly.

Is the ParallelConfig class meant to support other parallelism strategies not yet implemented?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul So, the intention with ParallelConfig is to support different kinds of parallelism easily. If you pass just ContextParallelConfig, it just create a ParallelConfig using that automatically.

I think current example is sufficient but we can ofcourse revision once there is more parallelisms supported natively

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks! Can we supplement a ParallelConfig as well? 👀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I don't see any parallelize() method in #11941

@a-r-r-o-w
Copy link
Contributor

Sorry for the delay! Please LMK if I can help with anything :) The CP PR is currently blocked because I can't make updates to it (the branch is in the diffusers repo and not a personal fork, so I can't push changes). Hopefully someone can address the tests there and we can proceed here too

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @stevhliu ! LGTM in general, but the examples are outdated a bit. The latest inference snippet removes enable_parallelism and handles that internally.

The final code looks like this: #11941 (comment)

Sorry for the inconvenience! I forgot to update the description of that PR

@stevhliu
Copy link
Member Author

Ah my bad, I missed that! Code snippet should be updated now. Let me know if there are any more changes :)

@stevhliu stevhliu marked this pull request as ready for review September 25, 2025 16:06
[`ContextParallelConfig`] supports Ulysses Attention through the `ulysses_degree` argument. This determines how many devices to use for Ulysses Attention.

```py
pipeline.transformer.parallelize(config=ContextParallelConfig(ulysses_degree=2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pipeline.transformer.parallelize(config=ContextParallelConfig(ulysses_degree=2))
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))

Just one last change and this should be good I think. Off to you @sayakpaul

@stevhliu stevhliu requested a review from sayakpaul September 26, 2025 15:52
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good. I would also link the distributed_inference doc from parallel.md.

@yh8899
Copy link

yh8899 commented Sep 30, 2025

I use this demo code run flux, but the result have little different between cp and original. Is there have numerical stability in cp? This is my code:

import torch
from diffusers import AutoModel, FluxPipeline, ContextParallelConfig

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)

    model_id = "black-forest-labs/FLUX.1-dev"
    transformer = AutoModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
    pipeline = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16).to(device)
    pipeline.transformer.set_attention_backend("_native_cudnn")
    
    pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2))

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator(device="cpu").manual_seed(42)
    image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]

    if rank == 0:
        image.save("output_cp.png")

except Exception as e:
    print(f"An error occurred: {e}")
    torch.distributed.breakpoint()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

No cp:

import torch
from diffusers import AutoModel, FluxPipeline, ContextParallelConfig


model_id = "black-forest-labs/FLUX.1-dev"
transformer = AutoModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
pipeline.transformer.set_attention_backend("_native_cudnn")

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator(device="cpu").manual_seed(42)
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]

image.save("output.png")
output_cp output

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@stevhliu stevhliu merged commit d7a1a03 into huggingface:main Sep 30, 2025
1 check passed
@stevhliu stevhliu deleted the cp branch September 30, 2025 16:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants