Skip to content

Commit d7a1a03

Browse files
authored
[docs] CP (#12331)
* init * feedback * feedback * feedback * feedback * feedback * feedback
1 parent b596545 commit d7a1a03

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@
7070
title: Reduce memory usage
7171
- local: optimization/speed-memory-optims
7272
title: Compiling and offloading quantized models
73-
- local: api/parallel
74-
title: Parallel inference
7573
- title: Community optimizations
7674
sections:
7775
- local: optimization/pruna
@@ -282,6 +280,8 @@
282280
title: Outputs
283281
- local: api/quantization
284282
title: Quantization
283+
- local: api/parallel
284+
title: Parallel inference
285285
- title: Modular
286286
sections:
287287
- local: api/modular_diffusers/pipeline

docs/source/en/api/parallel.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. -->
1111

1212
# Parallelism
1313

14-
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
14+
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. Refer to the [Distributed inferece](../training/distributed_inference) guide to learn more.
1515

1616
## ParallelConfig
1717

docs/source/en/training/distributed_inference.md

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,64 @@ with torch.no_grad():
226226
image[0].save("split_transformer.png")
227227
```
228228

229-
## Resources
229+
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
230230

231-
- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.
232-
- For more details, check out Accelerate's [Distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
233-
- The `device_map` argument assign models or an entire pipeline to devices. Refer to the [device placement](../using-diffusers/loading#device-placement) docs for more information.
231+
## Context parallelism
232+
233+
[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.
234+
235+
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
236+
237+
### Ring Attention
238+
239+
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
240+
241+
Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transformer model. The config supports the `ring_degree` argument that determines how many devices to use for Ring Attention.
242+
243+
```py
244+
import torch
245+
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
246+
247+
try:
248+
torch.distributed.init_process_group("nccl")
249+
rank = torch.distributed.get_rank()
250+
device = torch.device("cuda", rank % torch.cuda.device_count())
251+
torch.cuda.set_device(device)
252+
253+
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
254+
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
255+
pipeline.transformer.set_attention_backend("flash")
256+
257+
prompt = """
258+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
259+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
260+
"""
261+
262+
# Must specify generator so all ranks start with same latents (or pass your own)
263+
generator = torch.Generator().manual_seed(42)
264+
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
265+
266+
if rank == 0:
267+
image.save("output.png")
268+
269+
except Exception as e:
270+
print(f"An error occurred: {e}")
271+
torch.distributed.breakpoint()
272+
raise
273+
274+
finally:
275+
if torch.distributed.is_initialized():
276+
torch.distributed.destroy_process_group()
277+
```
278+
279+
### Ulysses Attention
280+
281+
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
282+
283+
[`ContextParallelConfig`] supports Ulysses Attention through the `ulysses_degree` argument. This determines how many devices to use for Ulysses Attention.
284+
285+
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
286+
287+
```py
288+
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
289+
```

0 commit comments

Comments
 (0)