Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port parallel_state.py (mpu) from Megatron-Deepspeed #7176

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

stas00
Copy link
Collaborator

@stas00 stas00 commented Mar 26, 2025

Since we need mpu for Ulysses outside of Meg-DS we need the mpu code, so this PR ports the code.

It appears non-trivial to just trim this file to SP groups as DS calls into many other methods of this class if mpu is not None.

attn: @jeffra

Since we need `mpu` for Ulysses outside of Meg-DS we need the mpu code, so this PR ports the code.

It appears non-trivial to just trim this file to SP groups as DS calls into many other methods of this class if `mpu is not None`.
@stas00
Copy link
Collaborator Author

stas00 commented Mar 26, 2025

Here is the attempt to make a slim SP-only version which wasn't full successful because DS calls other non-sp methods when mpu is passed:

import torch.distributed as dist
import torch

# For DeepSpeed's sequence parallel
_SEQUENCE_PARALLEL_GROUP = None
_SEQUENCE_PARALLEL_WORLD_SIZE = None
_SEQUENCE_PARALLEL_RANK = None

# XXX: fixme
def initialize_model_parallel(sequence_parallel_size: int = 1):

    # Get world size and rank. Ensure some consistencies.
    assert dist.is_initialized()
    world_size: int = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()

    #sequence_parallel_size = 4
    num_sequence_parallel_groups: int = world_size // sequence_parallel_size
    if world_size % sequence_parallel_size != 0:
        raise RuntimeError(
            f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})"
        )

    # Build the sequence parallel groups.
    global _SEQUENCE_PARALLEL_GROUP
    # XXX: fixme
    if _SEQUENCE_PARALLEL_GROUP is not None: return
    assert _SEQUENCE_PARALLEL_GROUP is None, 'sequence parallel group is already initialized'
    for i in range(num_sequence_parallel_groups):
        ranks = range(i * sequence_parallel_size,
                      (i + 1) * sequence_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _SEQUENCE_PARALLEL_GROUP = group   

def get_sequence_parallel_group():
    """Get the sequence parallel group the caller rank belongs to."""
    assert _SEQUENCE_PARALLEL_GROUP is not None, \
        'sequence parallel group is not initialized'
    return _SEQUENCE_PARALLEL_GROUP

def set_sequence_parallel_world_size(world_size):
    """Set the sequence  parallel size"""
    global _SEQUENCE_PARALLEL_WORLD_SIZE
    _SEQUENCE_PARALLEL_WORLD_SIZE = world_size

def get_sequence_parallel_world_size():
    """Return world size for the sequence parallel group."""
    global _SEQUENCE_PARALLEL_WORLD_SIZE
    if _SEQUENCE_PARALLEL_WORLD_SIZE is not None:
        return _SEQUENCE_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_sequence_parallel_group())

def set_sequence_parallel_rank(rank):
    """Set sequence parallel rank."""
    global _SEQUENCE_PARALLEL_RANK
    _SEQUENCE_PARALLEL_RANK = rank


def get_sequence_parallel_rank():
    """Return my rank for the sequence parallel group."""
    global _SEQUENCE_PARALLEL_RANK
    if _SEQUENCE_PARALLEL_RANK is not None:
        return _SEQUENCE_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_sequence_parallel_group())



def get_sequence_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the sequence parallel group."""
    global_rank = dist.get_rank()
    local_world_size = get_sequence_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


# no ops
def get_data_parallel_group(): return None
# XXX: fix me
def get_data_parallel_world_size(): return 2
def get_model_parallel_world_size(): return 1

def destroy_model_parallel():
    """Set the groups to none."""
    global _SEQUENCE_PARALLEL_GROUP
    _SEQUENCE_PARALLEL_GROUP = None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant