diff --git a/mbridge/core/bridge.py b/mbridge/core/bridge.py index 7f9ace1..6ad5503 100644 --- a/mbridge/core/bridge.py +++ b/mbridge/core/bridge.py @@ -2,12 +2,15 @@ import os from abc import ABC from typing import Callable, Generator +from glob import glob +from collections import defaultdict import torch -from megatron.core import parallel_state +from megatron.core import parallel_state as mpu from megatron.core.models.gpt.gpt_model import ModelType from transformers import AutoConfig from transformers.utils.hub import cached_file +from safetensors import safe_open from .parallel_states import ParallelStates from .safetensor_io import SafeTensorIO @@ -259,26 +262,134 @@ def load_weights( def _save_weights_fast( self, - per_tensor_generator: Generator[tuple[str, torch.Tensor], None, None], + models: list, weights_path: str, ) -> None: - dp_rank = parallel_state.get_data_parallel_rank() - save_rank = dp_rank * self.mpu.cp_size + self.mpu.cp_rank - dp_cp_size = self.mpu.cp_size * parallel_state.get_data_parallel_world_size() - - # save dense weight will faster - param_cnt = 0 - for hf_weight_name, tensor in per_tensor_generator: - # only tp_rank 0 should write - if param_cnt % dp_cp_size == save_rank and self.mpu.tp_rank == 0: - self.safetensor_io.save_tmp_hf_weight( - hf_weight_name, tensor, weights_path - ) - param_cnt += 1 - torch.distributed.barrier() + if len(glob(os.path.join(weights_path, "*.safetensors"))) > 0: + raise ValueError(f"The path:{weights_path} should not has safetensors files") - rank = torch.distributed.get_rank() + def encode_filename(mcore_weight_name, *values): + return mcore_weight_name + '--' + '--'.join( + str(int(v)) if v is not None else '' for v in values) + + def decode_filename(filename): + parts = filename.split('--') + mcore_weight_name = parts[0] + parts = parts[1:] + return [mcore_weight_name] + [None if p == '' else int(p) for p in parts] + + per_tensor_generator = self.export_weights_without_gather(models) + # step 1: save the split_tp_ep file world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + ep_dp_group = mpu.get_expert_data_parallel_group() + ep_save_size = torch.distributed.get_world_size(ep_dp_group) + if self.config.num_moe_experts: + assert ep_save_size == world_size // (self.mpu.ep_size * self.mpu.etp_size * + self.mpu.pp_size) + ep_save_rank = torch.distributed.get_rank(ep_dp_group) + ep_save_cnt = 0 + + dp_cp_group = mpu.get_data_parallel_group(True) + tp_save_size = torch.distributed.get_world_size(dp_cp_group) + tp_save_rank = torch.distributed.get_rank(dp_cp_group) + assert tp_save_size == world_size // (self.mpu.pp_size * self.mpu.tp_size) + tp_save_cnt = 0 + + pp_save_size = world_size // self.mpu.pp_size + dp_rank = mpu.get_data_parallel_rank() + pp_save_rank = dp_rank * self.mpu.cp_size * self.mpu.tp_size + self.mpu.cp_rank * self.mpu.tp_size + self.mpu.tp_rank + pp_save_cnt = 0 + + for (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel, + partition_dim, mcore_weight) in per_tensor_generator: + assert "-" not in mcore_weight_name + filename = encode_filename(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, + tensor_model_parallel, partition_dim) + # save EP/ETP + if ep_size > 0: + if ep_save_cnt % ep_save_size == ep_save_rank: + assert tp_size > 0 + self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path) + ep_save_cnt += 1 + continue + # save tp + if tp_size > 0: + if tp_save_cnt % tp_save_size == tp_save_rank: + assert ep_size == 0 + self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path) + tp_save_cnt += 1 + continue + # save not tp and ep + if pp_save_cnt % pp_save_size == pp_save_rank: + assert ep_size == 0 and tp_size == 0 + self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path) + pp_save_cnt += 1 + + torch.distributed.barrier() + + # step 2: merge tp/ep and convert to hf weight + def load_file(file_tuple): + file, _, _, _, tensor_model_parallel, partition_dim = file_tuple + with safe_open(file, framework="pt", device="cpu") as f: + assert len(f.keys()) == 1 + tensor = f.get_tensor(f.keys()[0]) + setattr(tensor, 'tensor_model_parallel', tensor_model_parallel) + setattr(tensor, 'partition_dim', partition_dim) + os.remove(file) + return tensor + + # step 2.1: collect all file + all_files = glob(os.path.join(weights_path, "*.safetensors")) + name2files = defaultdict(list) + for file in all_files: + (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel, + partition_dim) = decode_filename(os.path.basename(file).split(".safetensors")[0]) + expert_id = -1 + if ep_size > 0: + mcore_weight_name, expert_id = mcore_weight_name.split(".weight") + mcore_weight_name += ".weight" + name2files[mcore_weight_name].append(( + file, # 0 + tp_rank, # 1 + int(expert_id), # 2 + tp_size, # 3 + tensor_model_parallel, # 4 + partition_dim, # 5 + )) + + # step 2.1: sorted and split for all rank + torch.distributed.barrier() + weight_names = sorted(list(name2files.keys())) + for w_name in weight_names[rank::world_size]: + w_files = sorted(name2files[w_name], key=lambda x: (x[2], x[1])) + if w_files[0][2] != -1: + # gather ep + assert len(w_files) == self.config.num_moe_experts * self.mpu.etp_size + for expert_id in range(self.config.num_moe_experts): + idx = expert_id * self.mpu.etp_size + # gather etp + params = [] + for etp_idx in range(self.mpu.etp_size): + assert w_files[idx + etp_idx][2] == expert_id + params.append(load_file(w_files[idx + etp_idx])) + tmp_w_name = w_name + str(expert_id) + infer_params = self._weight_merge_across_tp(tmp_w_name, params, params[0]) + for hf_name, hf_param in zip(*self._weight_to_hf_format(tmp_w_name, infer_params)): + self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path) + else: + # gather tp + if w_files[0][4] > 0: + assert len(w_files) == w_files[0][3] + params = [load_file(w_file) for w_file in w_files] + infer_params = self._weight_merge_across_tp(w_name, params, params[0]) + else: + infer_params = load_file(w_files[0]) + for hf_name, hf_param in zip(*self._weight_to_hf_format(w_name, infer_params)): + self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path) + + # step 3: save the huggingface checkpoint + torch.distributed.barrier() self.safetensor_io.save_hf_weight_merge( weights_path, rank, @@ -305,18 +416,18 @@ def save_weights( is_distributed = ( torch.distributed.is_available() and torch.distributed.is_initialized() ) - - rank = torch.distributed.get_rank() if is_distributed else 0 if not os.path.exists(weights_path): os.makedirs(weights_path, exist_ok=True) - per_tensor_generator = self.export_weights(models, distributed_filesystem) if distributed_filesystem: assert ( memory_efficient ), f"distributed_filesystem should use with memory_efficient" assert is_distributed, f"distributed_filesystem should use in distributed" - return self._save_weights_fast(per_tensor_generator, weights_path) + return self._save_weights_fast(models, weights_path) + + rank = torch.distributed.get_rank() if is_distributed else 0 + per_tensor_generator = self.export_weights(models) if rank != 0: for _, _ in per_tensor_generator: @@ -348,7 +459,7 @@ def set_extra_args(self, **kwargs): self.config = self._build_config() def export_weights( - self, models: list[torch.nn.Module], distributed_filesystem: bool = False + self, models: list[torch.nn.Module], ) -> Generator[tuple[str, torch.Tensor], None, None]: assert ( len(self.export_weights_buff) == 0 @@ -412,18 +523,12 @@ def get_model_chunk_generator(): name = local_to_global_map[iter_name] else: name, param = None, None - if distributed_filesystem: - continue - if distributed_filesystem: - assert iter_pp_rank == self.mpu.pp_rank - broad_pp_param = param - else: - name = broadcast_str_from_megatron_pp(name) - broad_pp_param = broadcast_from_megatron_pp(param) + name = broadcast_str_from_megatron_pp(name) + broad_pp_param = broadcast_from_megatron_pp(param) # EP - if ".mlp.experts.linear_fc" in name and self.mpu.ep_size > 1: + if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1: num_experts = self.config.num_moe_experts num_experts_per_rank = num_experts // self.mpu.ep_size infer_params = [ @@ -501,6 +606,124 @@ def get_model_chunk_generator(): yield from zip(converted_names, converted_params) + def export_weights_without_gather( + self, models: list[torch.nn.Module], + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Export Weight Without Gather, Optim for distributed filesystem + + Args: + name: MCore weight name + + Returns: + Generator[tuple]: [mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, + tensor_model_parallel, partition_dim, mcore_weight] + tp_size is 0: is not tp tensor + ep_size is 0: is not ep tensor + """ + models = [unwrap_model(model) for model in models] + + def get_model_chunk_generator(): + for model in models: + existing_keys = set() + for name, param in model.named_parameters(): + existing_keys.add(name) + yield name, param + + # note + # there is a bug in megatron GPTModel + # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in state_dict(). + # for now we patch it by adding those keys to extra_keys. + extra_keys = [ + x + for x in model.state_dict().keys() + if "_extra_state" not in x + and "expert_bias" in x + and x not in existing_keys + ] + for name in extra_keys: + yield name, model.state_dict()[name].to(torch.cuda.current_device()) + + weights_names = [] + for vpp_rank, model in enumerate(models): + existing_keys = set() + for name, param in model.named_parameters(): + existing_keys.add(name) + weights_names.append((self.mpu.pp_rank, vpp_rank, name)) + extra_keys = [ + x + for x in model.state_dict().keys() + if "_extra_state" not in x + and "expert_bias" in x + and x not in existing_keys + ] + for name in extra_keys: + weights_names.append((self.mpu.pp_rank, vpp_rank, name)) + + model_chunk_generator = get_model_chunk_generator() + local_to_global_maps = [ + self._weight_name_mapping_mcore_local_to_global(model, consider_ep=False) + for model in models + ] + + for iter_pp_rank, iter_vpp_rank, iter_name in weights_names: + local_to_global_map = local_to_global_maps[iter_vpp_rank] + assert iter_pp_rank == self.mpu.pp_rank + try: + name, param = next(model_chunk_generator) + except StopIteration: + name, param = None, None + name = local_to_global_map[iter_name] + + assert iter_pp_rank == self.mpu.pp_rank + + # EP + if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1: + num_experts = self.config.num_moe_experts + num_experts_per_rank = num_experts // self.mpu.ep_size + + name_prefix, local_expert_id = name.split(".weight") + local_expert_id = int(local_expert_id) + global_expert_id = num_experts_per_rank * (self.mpu.ep_rank) + local_expert_id + global_expert_name = f"{name_prefix}.weight{global_expert_id}" + + yield ( + global_expert_name, + self.mpu.etp_rank, + self.mpu.etp_size, + self.mpu.ep_rank, + self.mpu.ep_size, + getattr(param, "tensor_model_parallel", None), + getattr(param, "partition_dim", None), + param, + ) + continue + + # TP + if (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel): + # allocate a new tensor with proper size + yield ( + name, + self.mpu.tp_rank, + self.mpu.tp_size, + 0, + 0, + getattr(param, "tensor_model_parallel", None), + getattr(param, "partition_dim", None), + param, + ) + else: + yield ( + name, + 0, + 0, + 0, + 0, + getattr(param, "tensor_model_parallel", None), + getattr(param, "partition_dim", None), + param, + ) + def _build_config(self): """ Build the configuration for the model. @@ -920,13 +1143,16 @@ def _weight_merge_across_tp( Returns: torch.Tensor: Merged weight tensor """ - if self.mpu.tp_size == 1: - assert len(mcore_weights) == 1 - return mcore_weights[0] if "mlp.experts.linear_fc" in mcore_weights_name: assert len(mcore_weights) == self.mpu.etp_size + if self.mpu.etp_size == 1: + assert len(mcore_weights) == 1 + return mcore_weights[0] else: assert len(mcore_weights) == self.mpu.tp_size + if self.mpu.tp_size == 1: + assert len(mcore_weights) == 1 + return mcore_weights[0] if ( "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name diff --git a/mbridge/core/safetensor_io.py b/mbridge/core/safetensor_io.py index 87ad96f..fa1a106 100644 --- a/mbridge/core/safetensor_io.py +++ b/mbridge/core/safetensor_io.py @@ -160,7 +160,7 @@ def save_hf_weight( ) return - def save_tmp_hf_weight( + def save_tmp_weight( self, hf_weight_name: str, tensor: torch.tensor,