Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 260 additions & 34 deletions mbridge/core/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import os
from abc import ABC
from typing import Callable, Generator
from glob import glob
from collections import defaultdict

import torch
from megatron.core import parallel_state
from megatron.core import parallel_state as mpu
from megatron.core.models.gpt.gpt_model import ModelType
from transformers import AutoConfig
from transformers.utils.hub import cached_file
from safetensors import safe_open

from .parallel_states import ParallelStates
from .safetensor_io import SafeTensorIO
Expand Down Expand Up @@ -259,26 +262,134 @@ def load_weights(

def _save_weights_fast(
self,
per_tensor_generator: Generator[tuple[str, torch.Tensor], None, None],
models: list,
weights_path: str,
) -> None:
dp_rank = parallel_state.get_data_parallel_rank()
save_rank = dp_rank * self.mpu.cp_size + self.mpu.cp_rank
dp_cp_size = self.mpu.cp_size * parallel_state.get_data_parallel_world_size()

# save dense weight will faster
param_cnt = 0
for hf_weight_name, tensor in per_tensor_generator:
# only tp_rank 0 should write
if param_cnt % dp_cp_size == save_rank and self.mpu.tp_rank == 0:
self.safetensor_io.save_tmp_hf_weight(
hf_weight_name, tensor, weights_path
)
param_cnt += 1
torch.distributed.barrier()
if len(glob(os.path.join(weights_path, "*.safetensors"))) > 0:
raise ValueError(f"The path:{weights_path} should not has safetensors files")

rank = torch.distributed.get_rank()
def encode_filename(mcore_weight_name, *values):
return mcore_weight_name + '--' + '--'.join(
str(int(v)) if v is not None else '' for v in values)

def decode_filename(filename):
parts = filename.split('--')
mcore_weight_name = parts[0]
parts = parts[1:]
return [mcore_weight_name] + [None if p == '' else int(p) for p in parts]

per_tensor_generator = self.export_weights_without_gather(models)
# step 1: save the split_tp_ep file
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
ep_dp_group = mpu.get_expert_data_parallel_group()
ep_save_size = torch.distributed.get_world_size(ep_dp_group)
if self.config.num_moe_experts:
assert ep_save_size == world_size // (self.mpu.ep_size * self.mpu.etp_size *
self.mpu.pp_size)
ep_save_rank = torch.distributed.get_rank(ep_dp_group)
ep_save_cnt = 0

dp_cp_group = mpu.get_data_parallel_group(True)
tp_save_size = torch.distributed.get_world_size(dp_cp_group)
tp_save_rank = torch.distributed.get_rank(dp_cp_group)
assert tp_save_size == world_size // (self.mpu.pp_size * self.mpu.tp_size)
tp_save_cnt = 0

pp_save_size = world_size // self.mpu.pp_size
dp_rank = mpu.get_data_parallel_rank()
pp_save_rank = dp_rank * self.mpu.cp_size * self.mpu.tp_size + self.mpu.cp_rank * self.mpu.tp_size + self.mpu.tp_rank
pp_save_cnt = 0

for (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel,
partition_dim, mcore_weight) in per_tensor_generator:
assert "-" not in mcore_weight_name
filename = encode_filename(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size,
tensor_model_parallel, partition_dim)
# save EP/ETP
if ep_size > 0:
if ep_save_cnt % ep_save_size == ep_save_rank:
assert tp_size > 0
self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path)
ep_save_cnt += 1
continue
# save tp
if tp_size > 0:
if tp_save_cnt % tp_save_size == tp_save_rank:
assert ep_size == 0
self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path)
tp_save_cnt += 1
continue
# save not tp and ep
if pp_save_cnt % pp_save_size == pp_save_rank:
assert ep_size == 0 and tp_size == 0
self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path)
pp_save_cnt += 1

torch.distributed.barrier()

# step 2: merge tp/ep and convert to hf weight
def load_file(file_tuple):
file, _, _, _, tensor_model_parallel, partition_dim = file_tuple
with safe_open(file, framework="pt", device="cpu") as f:
assert len(f.keys()) == 1
tensor = f.get_tensor(f.keys()[0])
setattr(tensor, 'tensor_model_parallel', tensor_model_parallel)
setattr(tensor, 'partition_dim', partition_dim)
os.remove(file)
return tensor

# step 2.1: collect all file
all_files = glob(os.path.join(weights_path, "*.safetensors"))
name2files = defaultdict(list)
for file in all_files:
(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel,
partition_dim) = decode_filename(os.path.basename(file).split(".safetensors")[0])
expert_id = -1
if ep_size > 0:
mcore_weight_name, expert_id = mcore_weight_name.split(".weight")
mcore_weight_name += ".weight"
name2files[mcore_weight_name].append((
file, # 0
tp_rank, # 1
int(expert_id), # 2
tp_size, # 3
tensor_model_parallel, # 4
partition_dim, # 5
))

# step 2.1: sorted and split for all rank
torch.distributed.barrier()
weight_names = sorted(list(name2files.keys()))
for w_name in weight_names[rank::world_size]:
w_files = sorted(name2files[w_name], key=lambda x: (x[2], x[1]))
if w_files[0][2] != -1:
# gather ep
assert len(w_files) == self.config.num_moe_experts * self.mpu.etp_size
for expert_id in range(self.config.num_moe_experts):
idx = expert_id * self.mpu.etp_size
# gather etp
params = []
for etp_idx in range(self.mpu.etp_size):
assert w_files[idx + etp_idx][2] == expert_id
params.append(load_file(w_files[idx + etp_idx]))
tmp_w_name = w_name + str(expert_id)
infer_params = self._weight_merge_across_tp(tmp_w_name, params, params[0])
for hf_name, hf_param in zip(*self._weight_to_hf_format(tmp_w_name, infer_params)):
self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path)
else:
# gather tp
if w_files[0][4] > 0:
assert len(w_files) == w_files[0][3]
params = [load_file(w_file) for w_file in w_files]
infer_params = self._weight_merge_across_tp(w_name, params, params[0])
else:
infer_params = load_file(w_files[0])
for hf_name, hf_param in zip(*self._weight_to_hf_format(w_name, infer_params)):
self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path)

# step 3: save the huggingface checkpoint
torch.distributed.barrier()
self.safetensor_io.save_hf_weight_merge(
weights_path,
rank,
Expand All @@ -305,18 +416,18 @@ def save_weights(
is_distributed = (
torch.distributed.is_available() and torch.distributed.is_initialized()
)

rank = torch.distributed.get_rank() if is_distributed else 0
if not os.path.exists(weights_path):
os.makedirs(weights_path, exist_ok=True)
per_tensor_generator = self.export_weights(models, distributed_filesystem)

if distributed_filesystem:
assert (
memory_efficient
), f"distributed_filesystem should use with memory_efficient"
assert is_distributed, f"distributed_filesystem should use in distributed"
return self._save_weights_fast(per_tensor_generator, weights_path)
return self._save_weights_fast(models, weights_path)

rank = torch.distributed.get_rank() if is_distributed else 0
per_tensor_generator = self.export_weights(models)

if rank != 0:
for _, _ in per_tensor_generator:
Expand Down Expand Up @@ -348,7 +459,7 @@ def set_extra_args(self, **kwargs):
self.config = self._build_config()

def export_weights(
self, models: list[torch.nn.Module], distributed_filesystem: bool = False
self, models: list[torch.nn.Module],
) -> Generator[tuple[str, torch.Tensor], None, None]:
assert (
len(self.export_weights_buff) == 0
Expand Down Expand Up @@ -412,18 +523,12 @@ def get_model_chunk_generator():
name = local_to_global_map[iter_name]
else:
name, param = None, None
if distributed_filesystem:
continue

if distributed_filesystem:
assert iter_pp_rank == self.mpu.pp_rank
broad_pp_param = param
else:
name = broadcast_str_from_megatron_pp(name)
broad_pp_param = broadcast_from_megatron_pp(param)
name = broadcast_str_from_megatron_pp(name)
broad_pp_param = broadcast_from_megatron_pp(param)

# EP
if ".mlp.experts.linear_fc" in name and self.mpu.ep_size > 1:
if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1:
num_experts = self.config.num_moe_experts
num_experts_per_rank = num_experts // self.mpu.ep_size
infer_params = [
Expand Down Expand Up @@ -501,6 +606,124 @@ def get_model_chunk_generator():

yield from zip(converted_names, converted_params)

def export_weights_without_gather(
self, models: list[torch.nn.Module],
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""
Export Weight Without Gather, Optim for distributed filesystem

Args:
name: MCore weight name

Returns:
Generator[tuple]: [mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size,
tensor_model_parallel, partition_dim, mcore_weight]
tp_size is 0: is not tp tensor
ep_size is 0: is not ep tensor
"""
models = [unwrap_model(model) for model in models]

def get_model_chunk_generator():
for model in models:
existing_keys = set()
for name, param in model.named_parameters():
existing_keys.add(name)
yield name, param

# note
# there is a bug in megatron GPTModel
# decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in state_dict().
# for now we patch it by adding those keys to extra_keys.
extra_keys = [
x
for x in model.state_dict().keys()
if "_extra_state" not in x
and "expert_bias" in x
and x not in existing_keys
]
for name in extra_keys:
yield name, model.state_dict()[name].to(torch.cuda.current_device())

weights_names = []
for vpp_rank, model in enumerate(models):
existing_keys = set()
for name, param in model.named_parameters():
existing_keys.add(name)
weights_names.append((self.mpu.pp_rank, vpp_rank, name))
extra_keys = [
x
for x in model.state_dict().keys()
if "_extra_state" not in x
and "expert_bias" in x
and x not in existing_keys
]
for name in extra_keys:
weights_names.append((self.mpu.pp_rank, vpp_rank, name))

model_chunk_generator = get_model_chunk_generator()
local_to_global_maps = [
self._weight_name_mapping_mcore_local_to_global(model, consider_ep=False)
for model in models
]

for iter_pp_rank, iter_vpp_rank, iter_name in weights_names:
local_to_global_map = local_to_global_maps[iter_vpp_rank]
assert iter_pp_rank == self.mpu.pp_rank
try:
name, param = next(model_chunk_generator)
except StopIteration:
name, param = None, None
name = local_to_global_map[iter_name]

assert iter_pp_rank == self.mpu.pp_rank

# EP
if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1:
num_experts = self.config.num_moe_experts
num_experts_per_rank = num_experts // self.mpu.ep_size

name_prefix, local_expert_id = name.split(".weight")
local_expert_id = int(local_expert_id)
global_expert_id = num_experts_per_rank * (self.mpu.ep_rank) + local_expert_id
global_expert_name = f"{name_prefix}.weight{global_expert_id}"

yield (
global_expert_name,
self.mpu.etp_rank,
self.mpu.etp_size,
self.mpu.ep_rank,
self.mpu.ep_size,
getattr(param, "tensor_model_parallel", None),
getattr(param, "partition_dim", None),
param,
)
continue

# TP
if (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel):
# allocate a new tensor with proper size
yield (
name,
self.mpu.tp_rank,
self.mpu.tp_size,
0,
0,
getattr(param, "tensor_model_parallel", None),
getattr(param, "partition_dim", None),
param,
)
else:
yield (
name,
0,
0,
0,
0,
getattr(param, "tensor_model_parallel", None),
getattr(param, "partition_dim", None),
param,
)

def _build_config(self):
"""
Build the configuration for the model.
Expand Down Expand Up @@ -920,13 +1143,16 @@ def _weight_merge_across_tp(
Returns:
torch.Tensor: Merged weight tensor
"""
if self.mpu.tp_size == 1:
assert len(mcore_weights) == 1
return mcore_weights[0]
if "mlp.experts.linear_fc" in mcore_weights_name:
assert len(mcore_weights) == self.mpu.etp_size
if self.mpu.etp_size == 1:
assert len(mcore_weights) == 1
return mcore_weights[0]
else:
assert len(mcore_weights) == self.mpu.tp_size
if self.mpu.tp_size == 1:
assert len(mcore_weights) == 1
return mcore_weights[0]
if (
"self_attention.linear_qkv." in mcore_weights_name
and "layer_norm" not in mcore_weights_name
Expand Down
2 changes: 1 addition & 1 deletion mbridge/core/safetensor_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def save_hf_weight(
)
return

def save_tmp_hf_weight(
def save_tmp_weight(
self,
hf_weight_name: str,
tensor: torch.tensor,
Expand Down