diff --git a/mbridge/core/bridge.py b/mbridge/core/bridge.py index 7a18420..70f1e3d 100644 --- a/mbridge/core/bridge.py +++ b/mbridge/core/bridge.py @@ -264,6 +264,7 @@ def _save_weights_fast( self, models: list, weights_path: str, + tensors_per_file: int = 500, ) -> None: if len(glob(os.path.join(weights_path, "*.safetensors"))) > 0: raise ValueError(f"The path:{weights_path} should not has safetensors files") @@ -280,7 +281,7 @@ def decode_filename(filename): return [mcore_weight_name] + [None if p == '' else int(p) for p in parts] per_tensor_generator = self.export_weights_without_gather(models) - # step 1: save the split_tp_ep file + # step 1: save the split_tp_ep file (batched) world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() ep_dp_group = mpu.get_expert_data_parallel_group() @@ -302,66 +303,115 @@ def decode_filename(filename): pp_save_rank = dp_rank * self.mpu.cp_size * self.mpu.tp_size + self.mpu.cp_rank * self.mpu.tp_size + self.mpu.tp_rank pp_save_cnt = 0 + # Batch buffer for step 1: accumulate tensors before writing + step1_buffer = {} + step1_file_idx = 0 + + def flush_step1_buffer(): + nonlocal step1_buffer, step1_file_idx + if not step1_buffer: + return + batch_filename = f"mcore_batch_{rank}_{step1_file_idx}.safetensors" + self.safetensor_io.save_batch_weights( + step1_buffer, + os.path.join(weights_path, batch_filename), + ) + step1_buffer = {} + step1_file_idx += 1 + for (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel, partition_dim, mcore_weight) in per_tensor_generator: assert "-" not in mcore_weight_name filename = encode_filename(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel, partition_dim) + should_save = False # save EP/ETP if ep_size > 0: if ep_save_cnt % ep_save_size == ep_save_rank: assert tp_size > 0 - self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path) + should_save = True ep_save_cnt += 1 - continue # save tp - if tp_size > 0: + elif tp_size > 0: if tp_save_cnt % tp_save_size == tp_save_rank: assert ep_size == 0 - self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path) + should_save = True tp_save_cnt += 1 - continue # save not tp and ep - if pp_save_cnt % pp_save_size == pp_save_rank: - assert ep_size == 0 and tp_size == 0 - self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path) - pp_save_cnt += 1 + else: + if pp_save_cnt % pp_save_size == pp_save_rank: + assert ep_size == 0 and tp_size == 0 + should_save = True + pp_save_cnt += 1 + + if should_save: + step1_buffer[filename] = mcore_weight.detach().cpu() + if len(step1_buffer) >= tensors_per_file: + flush_step1_buffer() + + flush_step1_buffer() torch.distributed.barrier() # step 2: merge tp/ep and convert to hf weight - def load_file(file_tuple): - file, _, _, _, tensor_model_parallel, partition_dim = file_tuple - with safe_open(file, framework="pt", device="cpu") as f: - assert len(f.keys()) == 1 - tensor = f.get_tensor(f.keys()[0]) - setattr(tensor, 'tensor_model_parallel', tensor_model_parallel) - setattr(tensor, 'partition_dim', partition_dim) - os.remove(file) - return tensor - - # step 2.1: collect all file - all_files = glob(os.path.join(weights_path, "*.safetensors")) + # step 2.1: build index from all batch files + all_batch_files = glob(os.path.join(weights_path, "mcore_batch_*.safetensors")) name2files = defaultdict(list) - for file in all_files: - (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel, - partition_dim) = decode_filename(os.path.basename(file).split(".safetensors")[0]) - expert_id = -1 - if ep_size > 0: - mcore_weight_name, expert_id = mcore_weight_name.split(".weight") - mcore_weight_name += ".weight" - name2files[mcore_weight_name].append(( - file, # 0 - tp_rank, # 1 - int(expert_id), # 2 - tp_size, # 3 - tensor_model_parallel, # 4 - partition_dim, # 5 - )) - - # step 2.1: sorted and split for all rank + # key -> batch_file_path + batch_key_index = {} + + for file in all_batch_files: + with safe_open(file, framework="pt", device="cpu") as f: + for key in f.keys(): + batch_key_index[key] = file + (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, + tensor_model_parallel, partition_dim) = decode_filename(key) + expert_id = -1 + if ep_size > 0: + mcore_weight_name, expert_id = mcore_weight_name.split(".weight") + mcore_weight_name += ".weight" + name2files[mcore_weight_name].append(( + key, # 0: key in batch file + tp_rank, # 1 + int(expert_id), # 2 + tp_size, # 3 + tensor_model_parallel, # 4 + partition_dim, # 5 + )) + + def load_tensor_from_batches(key, batch_key_index): + """Load a single tensor by key from batch files using pre-built index.""" + bf = batch_key_index[key] + with safe_open(bf, framework="pt", device="cpu") as f: + return f.get_tensor(key) + + def load_tensor_from_file(file_tuple): + key, _, _, _, tensor_model_parallel, partition_dim = file_tuple + tensor = load_tensor_from_batches(key, batch_key_index) + setattr(tensor, 'tensor_model_parallel', tensor_model_parallel) + setattr(tensor, 'partition_dim', partition_dim) + return tensor + + # step 2.2: sorted and split for all rank, output batched torch.distributed.barrier() weight_names = sorted(list(name2files.keys())) + + # Batch buffer for step 2 HF output + hf_buffer = {} + hf_file_idx = 0 + + def flush_hf_buffer(): + nonlocal hf_buffer, hf_file_idx + if not hf_buffer: + return + hf_filename = f"hf_batch_{rank}_{hf_file_idx}.safetensors" + self.safetensor_io.save_batch_weights( + hf_buffer, + os.path.join(weights_path, hf_filename), + ) + hf_buffer = {} + hf_file_idx += 1 + for w_name in weight_names[rank::world_size]: w_files = sorted(name2files[w_name], key=lambda x: (x[2], x[1])) if w_files[0][2] != -1: @@ -373,30 +423,57 @@ def load_file(file_tuple): params = [] for etp_idx in range(self.mpu.etp_size): assert w_files[idx + etp_idx][2] == expert_id - params.append(load_file(w_files[idx + etp_idx])) + params.append(load_tensor_from_file(w_files[idx + etp_idx])) tmp_w_name = w_name + str(expert_id) infer_params = self._weight_merge_across_tp(tmp_w_name, params, params[0]) for hf_name, hf_param in zip(*self._weight_to_hf_format(tmp_w_name, infer_params)): - self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path) + hf_buffer[hf_name] = hf_param.detach().cpu() + if len(hf_buffer) >= tensors_per_file: + flush_hf_buffer() else: # gather tp if w_files[0][4] is not None and w_files[0][4] > 0: assert len(w_files) == w_files[0][3] - params = [load_file(w_file) for w_file in w_files] + params = [load_tensor_from_file(w_file) for w_file in w_files] infer_params = self._weight_merge_across_tp(w_name, params, params[0]) else: - infer_params = load_file(w_files[0]) + infer_params = load_tensor_from_file(w_files[0]) for hf_name, hf_param in zip(*self._weight_to_hf_format(w_name, infer_params)): - self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path) + hf_buffer[hf_name] = hf_param.detach().cpu() + if len(hf_buffer) >= tensors_per_file: + flush_hf_buffer() - # step 3: save the huggingface checkpoint + flush_hf_buffer() + + # Delete step1 batch files (no longer needed) + torch.distributed.barrier() + if rank == 0: + for f in all_batch_files: + if os.path.exists(f): + try: + os.remove(f) + except OSError: + pass torch.distributed.barrier() - self.safetensor_io.save_hf_weight_merge( + + # step 3: merge batched HF files into final safetensors shards + torch.distributed.barrier() + self.safetensor_io.save_hf_weight_merge_from_batches( weights_path, rank, world_size, ) + # Clean up HF batch files + torch.distributed.barrier() + if rank == 0: + for f in glob(os.path.join(weights_path, "hf_batch_*.safetensors")): + if os.path.exists(f): + try: + os.remove(f) + except OSError: + pass + if 0 == rank: self.safetensor_io.save_index(weights_path) self.hf_config.save_pretrained(weights_path) diff --git a/mbridge/core/safetensor_io.py b/mbridge/core/safetensor_io.py index fa1a106..50df5e0 100644 --- a/mbridge/core/safetensor_io.py +++ b/mbridge/core/safetensor_io.py @@ -237,3 +237,73 @@ def save_index(self, new_hf_dir: str): else: warnings.warn("No index file found, saving index file failed") return + + def save_batch_weights( + self, + tensors: dict[str, torch.Tensor], + filepath: str, + ): + """ + Save multiple tensors into a single safetensors file. + This reduces IO overhead by batching multiple tensors together. + + Args: + tensors: Dict mapping tensor names to tensor data + filepath: Path to save the safetensors file + """ + if not tensors: + return + assert self.index, "index file is required for memory efficient saving" + save_file({k: v.cpu() if v.is_cuda else v for k, v in tensors.items()}, filepath) + + def save_hf_weight_merge_from_batches( + self, + new_hf_dir: str, + rank: int = 0, + world_size: int = 1, + batch_file_pattern: str = "hf_batch_*.safetensors", + ): + """ + Merge HF weights from batch files (each containing multiple tensors) + into final safetensors shards matching the original index layout. + + Args: + new_hf_dir: Directory containing batch files and where final files will be saved + rank: Current rank + world_size: Total number of ranks + batch_file_pattern: Glob pattern to find batch files + """ + assert self.index, "index file is required for memory efficient saving" + + filename_to_keys_map = defaultdict(set) + for key, filename in self.index.items(): + filename_to_keys_map[filename].add(key) + + # Build reverse index: key -> which batch file contains it + batch_files = glob(os.path.join(new_hf_dir, batch_file_pattern)) + key_to_batch_file = {} + for bf in batch_files: + with safe_open(bf, framework="pt", device="cpu") as f: + for key in f.keys(): + key_to_batch_file[key] = bf + + filename_list = sorted(list(filename_to_keys_map.keys())) + if world_size > 1: + num_files = len(filename_list) + num_files_rank = (num_files + world_size - 1) // world_size + begin_idx = min(num_files, rank * num_files_rank) + end_idx = min(num_files, (rank + 1) * num_files_rank) + filename_list = filename_list[begin_idx:end_idx] + + # For each final shard, collect tensors from batch files and write + for filename in filename_list: + keys_for_file = filename_to_keys_map[filename] + states = {} + old_keys_for_file, _ = self._mapping_weight_names_new2old(keys_for_file) + for old_key, key in zip(old_keys_for_file, keys_for_file): + bf = key_to_batch_file[old_key] + with safe_open(bf, framework="pt", device="cpu") as f: + states[key] = f.get_tensor(old_key) + save_file(states, os.path.join(new_hf_dir, filename)) + + return batch_files