Skip to content

Commit fd28225

Browse files
committed
feedback
1 parent d7f2e88 commit fd28225

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ By selectively loading and unloading the models you need at a given stage and sh
244244

245245
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
246246

247-
Call [`parallelize`] on the model and pass a [`ContextParallelConfig`]. The config supports the `ring_degree` argument that determines how many devices to use for Ring Attention.
247+
Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transformer model. The config supports the `ring_degree` argument that determines how many devices to use for Ring Attention.
248248

249249
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized [attention backend](../optimization/attention_backends). The example below uses the FlashAttention backend.
250250

@@ -258,32 +258,26 @@ Refer to the table below for the supported attention backends enabled by [`~Mode
258258

259259
```py
260260
import torch
261-
from diffusers import QwenImagePipeline, ContextParallelConfig, enable_parallelism
261+
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
262262

263263
try:
264264
torch.distributed.init_process_group("nccl")
265265
rank = torch.distributed.get_rank()
266266
device = torch.device("cuda", rank % torch.cuda.device_count())
267267
torch.cuda.set_device(device)
268268

269-
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda")
270-
271-
pipeline.transformer.parallelize(config=ContextParallelConfig(ring_degree=2))
269+
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
270+
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
272271
pipeline.transformer.set_attention_backend("flash")
273-
```
274272

275-
Pass your pipeline to [`~ModelMixin.enable_parallelism`] as a context manager to activate and coordinate context parallelism.
276-
277-
```py
278273
prompt = """
279274
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
280275
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
281276
"""
282277

283278
# Must specify generator so all ranks start with same latents (or pass your own)
284279
generator = torch.Generator().manual_seed(42)
285-
with enable_parallelism(pipeline):
286-
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
280+
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
287281

288282
if rank == 0:
289283
image.save("output.png")

0 commit comments

Comments
 (0)