-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Labels
good first issueGood for newcomersGood for newcomershelp wantedExtra attention is neededExtra attention is needed
Description
Hi, I was wondering if there were any efforts on great.py natively supporting Distributed Data Parallels? Currently I am doing a workaround by editing my own trainer file and saving it via torch save.
Below is how I invoke it.
torchrun --nproc_per_node=8 ddptest.py
import os
import pandas as pd
from be_great import GReaT
import torch.distributed as dist
import torch
from collections import OrderedDict
def main():
# Set CUDA devices for each process
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dataFile = "/edit/for/your/own/repo.csv"
data = pd.read_csv(dataFile)
great = GReaT("gpt2-xl",
batch_size=8,
epochs=50,
fp16=True
)
# Move the model to the appropriate GPU
great.model.to(local_rank)
# Wrap the model for distributed training
great.model = torch.nn.parallel.DistributedDataParallel(
great.model, device_ids=[local_rank], output_device=local_rank
)
trainer = great.fit(data, data.columns.to_list())
# Save the model only from rank 0 process
if dist.get_rank() == 0:
# Create a new state dict with corrected key names
state_dict = great.model.state_dict()
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# Save the model with the modified state dictl
torch.save(new_state_dict, "/edit/for/your/own/model.pt")
if __name__ == "__main__":
# Initialize the distributed process group
dist.init_process_group(backend="nccl")
main()
Again thank you so much for this awesome framework.
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomershelp wantedExtra attention is neededExtra attention is needed