-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Open
Labels
Description
Describe the bug
UlyssesSPDataLoaderAdapter
returns duplicate data. Each GPU will receive N copies on the same SP chunk, where N is the SP world size.
To Reproduce
Run this code (adapted from the official tutorial of integrating DeepSpeed Ulysses into HF Transformers) with a machine with 8 NVIDIA GPUs. Launch it with deepspeed filename.py
.
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter
from deepspeed.runtime.utils import move_to_device
from deepspeed.utils import groups
from torch import tensor
from transformers import AutoModelForCausalLM
import deepspeed
import deepspeed.comm as dist
import torch
model_name_or_path = 'Qwen/Qwen3-0.6B'
max_length = 24
sequence_parallel_size = 8
micro_batch_size = 1
config_dict = {
"train_micro_batch_size_per_gpu": micro_batch_size,
"train_batch_size": micro_batch_size,
"world_size": 8,
"zero_optimization": {
"stage": 3,
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"sequence_parallel_size": sequence_parallel_size,
}
dtype = torch.bfloat16
# a simple Dataset
# replace with a real dataset but make sure `position_ids` are returned
input_ids = tensor(
[
list(range(0, 17)),
list(range(17, 34)),
],
dtype=torch.long,
)
position_ids = tensor(
[
list(range(0, 17)),
list(range(17, 34)),
],
dtype=torch.long,
)
n_gpus = 8
pad_token_id = 100
ds = torch.utils.data.TensorDataset(input_ids, position_ids)
def collate_fn(batch):
input_ids = torch.stack([d[0] for d in batch]) # (batch_size, seq_len)
position_ids = torch.stack([d[1] for d in batch]) # (batch_size, seq_len)
# Pad to multiple of n_gpus
seq_len = input_ids.shape[1]
pad_length = n_gpus - seq_len % n_gpus
input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=pad_token_id)
position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=pad_token_id)
print(f"[{dist.get_rank()}] input_ids: {input_ids.shape}")
print(f"[{dist.get_rank()}] position_ids: {position_ids.shape}")
return dict(input_ids=input_ids,
position_ids=position_ids,
labels=input_ids)
dist.init_distributed(dist_backend='nccl', dist_init_required=True)
# Ulysses injection into HF Transformers
mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=model_name_or_path,
core_attn_implementation="sdpa",
sequence_parallel_size=sequence_parallel_size,
max_length=max_length,
micro_batch_size=micro_batch_size,
seq_length_is_variable=True,
)
# Deepspeed setup
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters(),
mpu=mpu)
# UlyssesSPDataLoaderAdapter injection
sp_group = groups._get_sequence_parallel_group()
sp_world_size = groups._get_sequence_parallel_world_size()
sp_rank = groups._get_sequence_parallel_rank()
print('=' * 100)
print(f"[{dist.get_rank()}] sp_rank: {sp_rank}")
print(f"[{dist.get_rank()}] sp_world_size: {sp_world_size}")
print(f"[{dist.get_rank()}] sp_group: {sp_group}")
print(f"[{dist.get_rank()}] model.device: {model.device}")
print('=' * 100)
dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
print(f"[{dist.get_rank()}] len(dl): {len(dl)}")
dl = UlyssesSPDataLoaderAdapter(
dl,
sp_rank=sp_rank,
sp_group=sp_group,
sp_world_size=sp_world_size,
device=model.device,
)
print(f"[{dist.get_rank()}] len(dl): {len(dl)}")
# Normal training loop
for iter, batch in enumerate(dl):
batch = move_to_device(batch, model.device)
outputs = model(**batch)
# as of this writing HF doesn't calculate loss with shift_labels yet and requires us to do it manually (liger does that automatically)
shift_labels = batch["shift_labels"]
print('=' * 100)
print(f"[{dist.get_rank()}, {iter}] Input IDs: {batch['input_ids']}", flush=True)
# print(f"[{dist.get_rank()}] Labels: {batch['labels']}")
# print(f"[{dist.get_rank()}, {iter}] Shift Labels: {shift_labels}", flush=True)
print('=' * 100)
loss = model.module.loss_function(
logits=outputs.logits,
labels=None,
shift_labels=shift_labels,
vocab_size=model.module.config.vocab_size,
)
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = sum((shift_labels != -100).view(-1))
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / total_good_tokens
model.backward(loss)
model.step()
and we can see that each rank receives N duplicates of the
Expected behavior
The expected behavior is that each rank (or GPU) will only print the same input IDs once. But instead, on a node with 8 GPUs, we see that each rank prints the same input IDs 8 times.
ds_report output
[2025-08-12 15:33:07,474] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-devel package with yum
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
[WARNING] gds requires the dev libaio .so object and headers but these were not found.
[WARNING] gds: please install the libaio-devel package with yum
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.6
[WARNING] using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/test/test07/miniconda3/envs/cyf/lib/python3.12/site-packages/torch']
torch version .................... 2.6.0+cu124
deepspeed install path ........... ['/home/test/test07/miniconda3/envs/cyf/lib/python3.12/site-packages/deepspeed']
deepspeed info ................... 0.17.4, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.5, cuda 12.4
shared memory (/dev/shm) size .... 503.66 GB
System info (please complete the following information):
- OS: Linux CentOS7, x86_64
- One machine with 8 A800-SXM4-80GB GPUs.
- 3.12.3
- CUDA version: 12.2, Driver version: 535.183.06
Launcher context
Launched with deepspeed ...