Skip to content

Adding Native Distributed Data Parallels Support #50

@hiberfil

Description

@hiberfil

Hi, I was wondering if there were any efforts on great.py natively supporting Distributed Data Parallels? Currently I am doing a workaround by editing my own trainer file and saving it via torch save.

Below is how I invoke it.

torchrun --nproc_per_node=8 ddptest.py

import os
import pandas as pd
from be_great import GReaT
import torch.distributed as dist
import torch
from collections import OrderedDict

def main():
    # Set CUDA devices for each process
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    dataFile = "/edit/for/your/own/repo.csv"
    data = pd.read_csv(dataFile)

    great = GReaT("gpt2-xl",         
                      batch_size=8,
                      epochs=50,                           
                      fp16=True
                     )

   # Move the model to the appropriate GPU
    great.model.to(local_rank)  

    # Wrap the model for distributed training
    great.model = torch.nn.parallel.DistributedDataParallel(
        great.model, device_ids=[local_rank], output_device=local_rank
    )

    trainer = great.fit(data, data.columns.to_list())

    
        # Save the model only from rank 0 process
    if dist.get_rank() == 0:
        # Create a new state dict with corrected key names
        state_dict = great.model.state_dict()
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v

        # Save the model with the modified state dictl
        torch.save(new_state_dict, "/edit/for/your/own/model.pt")


if __name__ == "__main__":
    # Initialize the distributed process group
    dist.init_process_group(backend="nccl") 
    main()

Again thank you so much for this awesome framework.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions