Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion examples/pytorch/llama/llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,44 @@
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(dir_path + "/../../..")
from examples.pytorch.llama.utils.llama import Llama

def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
# taken from https://github.com/microsoft/DeepSpeed/blob/4559aa9b02bf9f113951e0092ebecc21debf8e20/deepspeed/comm/comm.py#L630
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()

master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)

# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])

os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(distributed_port)

if verbose:
print(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}".
format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output_len', type=int, default=32,
Expand Down Expand Up @@ -109,8 +142,10 @@ def main():
print("{}: {}".format(arg, getattr(args, arg)))
print("=========================================\n")

os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(TORCH_DISTRIBUTED_DEFAULT_PORT))
mpi_discovery(distributed_port=int(os.environ["MASTER_PORT"]), verbose=True)
if tensor_para_size * pipeline_para_size > 1:
dist.init_process_group(backend=dist.Backend.MPI)
dist.init_process_group(rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE']))
rank = dist.get_rank() if dist.is_initialized() else 0
device_count = dist.get_world_size() if dist.is_initialized() else 1
device = rank % device_count
Expand Down