Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 123 additions & 46 deletions mbridge/core/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def _save_weights_fast(
self,
models: list,
weights_path: str,
tensors_per_file: int = 500,
) -> None:
if len(glob(os.path.join(weights_path, "*.safetensors"))) > 0:
raise ValueError(f"The path:{weights_path} should not has safetensors files")
Expand All @@ -280,7 +281,7 @@ def decode_filename(filename):
return [mcore_weight_name] + [None if p == '' else int(p) for p in parts]

per_tensor_generator = self.export_weights_without_gather(models)
# step 1: save the split_tp_ep file
# step 1: save the split_tp_ep file (batched)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
ep_dp_group = mpu.get_expert_data_parallel_group()
Expand All @@ -302,66 +303,115 @@ def decode_filename(filename):
pp_save_rank = dp_rank * self.mpu.cp_size * self.mpu.tp_size + self.mpu.cp_rank * self.mpu.tp_size + self.mpu.tp_rank
pp_save_cnt = 0

# Batch buffer for step 1: accumulate tensors before writing
step1_buffer = {}
step1_file_idx = 0

def flush_step1_buffer():
nonlocal step1_buffer, step1_file_idx
if not step1_buffer:
return
batch_filename = f"mcore_batch_{rank}_{step1_file_idx}.safetensors"
self.safetensor_io.save_batch_weights(
step1_buffer,
os.path.join(weights_path, batch_filename),
)
step1_buffer = {}
step1_file_idx += 1

for (mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel,
partition_dim, mcore_weight) in per_tensor_generator:
assert "-" not in mcore_weight_name
filename = encode_filename(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size,
tensor_model_parallel, partition_dim)
should_save = False
# save EP/ETP
if ep_size > 0:
if ep_save_cnt % ep_save_size == ep_save_rank:
assert tp_size > 0
self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path)
should_save = True
ep_save_cnt += 1
continue
# save tp
if tp_size > 0:
elif tp_size > 0:
if tp_save_cnt % tp_save_size == tp_save_rank:
assert ep_size == 0
self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path)
should_save = True
tp_save_cnt += 1
continue
# save not tp and ep
if pp_save_cnt % pp_save_size == pp_save_rank:
assert ep_size == 0 and tp_size == 0
self.safetensor_io.save_tmp_weight(filename, mcore_weight, weights_path)
pp_save_cnt += 1
else:
if pp_save_cnt % pp_save_size == pp_save_rank:
assert ep_size == 0 and tp_size == 0
should_save = True
pp_save_cnt += 1

if should_save:
step1_buffer[filename] = mcore_weight.detach().cpu()
if len(step1_buffer) >= tensors_per_file:
flush_step1_buffer()

flush_step1_buffer()

torch.distributed.barrier()

# step 2: merge tp/ep and convert to hf weight
def load_file(file_tuple):
file, _, _, _, tensor_model_parallel, partition_dim = file_tuple
with safe_open(file, framework="pt", device="cpu") as f:
assert len(f.keys()) == 1
tensor = f.get_tensor(f.keys()[0])
setattr(tensor, 'tensor_model_parallel', tensor_model_parallel)
setattr(tensor, 'partition_dim', partition_dim)
os.remove(file)
return tensor

# step 2.1: collect all file
all_files = glob(os.path.join(weights_path, "*.safetensors"))
# step 2.1: build index from all batch files
all_batch_files = glob(os.path.join(weights_path, "mcore_batch_*.safetensors"))
name2files = defaultdict(list)
for file in all_files:
(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size, tensor_model_parallel,
partition_dim) = decode_filename(os.path.basename(file).split(".safetensors")[0])
expert_id = -1
if ep_size > 0:
mcore_weight_name, expert_id = mcore_weight_name.split(".weight")
mcore_weight_name += ".weight"
name2files[mcore_weight_name].append((
file, # 0
tp_rank, # 1
int(expert_id), # 2
tp_size, # 3
tensor_model_parallel, # 4
partition_dim, # 5
))

# step 2.1: sorted and split for all rank
# key -> batch_file_path
batch_key_index = {}

for file in all_batch_files:
with safe_open(file, framework="pt", device="cpu") as f:
for key in f.keys():
batch_key_index[key] = file
(mcore_weight_name, tp_rank, tp_size, ep_rank, ep_size,
tensor_model_parallel, partition_dim) = decode_filename(key)
expert_id = -1
if ep_size > 0:
mcore_weight_name, expert_id = mcore_weight_name.split(".weight")
mcore_weight_name += ".weight"
name2files[mcore_weight_name].append((
key, # 0: key in batch file
tp_rank, # 1
int(expert_id), # 2
tp_size, # 3
tensor_model_parallel, # 4
partition_dim, # 5
))

def load_tensor_from_batches(key, batch_key_index):
"""Load a single tensor by key from batch files using pre-built index."""
bf = batch_key_index[key]
with safe_open(bf, framework="pt", device="cpu") as f:
return f.get_tensor(key)

def load_tensor_from_file(file_tuple):
key, _, _, _, tensor_model_parallel, partition_dim = file_tuple
tensor = load_tensor_from_batches(key, batch_key_index)
setattr(tensor, 'tensor_model_parallel', tensor_model_parallel)
setattr(tensor, 'partition_dim', partition_dim)
return tensor

# step 2.2: sorted and split for all rank, output batched
torch.distributed.barrier()
weight_names = sorted(list(name2files.keys()))

# Batch buffer for step 2 HF output
hf_buffer = {}
hf_file_idx = 0

def flush_hf_buffer():
nonlocal hf_buffer, hf_file_idx
if not hf_buffer:
return
hf_filename = f"hf_batch_{rank}_{hf_file_idx}.safetensors"
self.safetensor_io.save_batch_weights(
hf_buffer,
os.path.join(weights_path, hf_filename),
)
hf_buffer = {}
hf_file_idx += 1

for w_name in weight_names[rank::world_size]:
w_files = sorted(name2files[w_name], key=lambda x: (x[2], x[1]))
if w_files[0][2] != -1:
Expand All @@ -373,30 +423,57 @@ def load_file(file_tuple):
params = []
for etp_idx in range(self.mpu.etp_size):
assert w_files[idx + etp_idx][2] == expert_id
params.append(load_file(w_files[idx + etp_idx]))
params.append(load_tensor_from_file(w_files[idx + etp_idx]))
tmp_w_name = w_name + str(expert_id)
infer_params = self._weight_merge_across_tp(tmp_w_name, params, params[0])
for hf_name, hf_param in zip(*self._weight_to_hf_format(tmp_w_name, infer_params)):
self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path)
hf_buffer[hf_name] = hf_param.detach().cpu()
if len(hf_buffer) >= tensors_per_file:
flush_hf_buffer()
else:
# gather tp
if w_files[0][4] is not None and w_files[0][4] > 0:
assert len(w_files) == w_files[0][3]
params = [load_file(w_file) for w_file in w_files]
params = [load_tensor_from_file(w_file) for w_file in w_files]
infer_params = self._weight_merge_across_tp(w_name, params, params[0])
else:
infer_params = load_file(w_files[0])
infer_params = load_tensor_from_file(w_files[0])
for hf_name, hf_param in zip(*self._weight_to_hf_format(w_name, infer_params)):
self.safetensor_io.save_tmp_weight(hf_name, hf_param, weights_path)
hf_buffer[hf_name] = hf_param.detach().cpu()
if len(hf_buffer) >= tensors_per_file:
flush_hf_buffer()

# step 3: save the huggingface checkpoint
flush_hf_buffer()

# Delete step1 batch files (no longer needed)
torch.distributed.barrier()
if rank == 0:
for f in all_batch_files:
if os.path.exists(f):
try:
os.remove(f)
except OSError:
pass
torch.distributed.barrier()
self.safetensor_io.save_hf_weight_merge(

# step 3: merge batched HF files into final safetensors shards
torch.distributed.barrier()
self.safetensor_io.save_hf_weight_merge_from_batches(
weights_path,
rank,
world_size,
)

# Clean up HF batch files
torch.distributed.barrier()
if rank == 0:
for f in glob(os.path.join(weights_path, "hf_batch_*.safetensors")):
if os.path.exists(f):
try:
os.remove(f)
except OSError:
pass

if 0 == rank:
self.safetensor_io.save_index(weights_path)
self.hf_config.save_pretrained(weights_path)
Expand Down
70 changes: 70 additions & 0 deletions mbridge/core/safetensor_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,73 @@ def save_index(self, new_hf_dir: str):
else:
warnings.warn("No index file found, saving index file failed")
return

def save_batch_weights(
self,
tensors: dict[str, torch.Tensor],
filepath: str,
):
"""
Save multiple tensors into a single safetensors file.
This reduces IO overhead by batching multiple tensors together.

Args:
tensors: Dict mapping tensor names to tensor data
filepath: Path to save the safetensors file
"""
if not tensors:
return
assert self.index, "index file is required for memory efficient saving"
save_file({k: v.cpu() if v.is_cuda else v for k, v in tensors.items()}, filepath)

def save_hf_weight_merge_from_batches(
self,
new_hf_dir: str,
rank: int = 0,
world_size: int = 1,
batch_file_pattern: str = "hf_batch_*.safetensors",
):
"""
Merge HF weights from batch files (each containing multiple tensors)
into final safetensors shards matching the original index layout.

Args:
new_hf_dir: Directory containing batch files and where final files will be saved
rank: Current rank
world_size: Total number of ranks
batch_file_pattern: Glob pattern to find batch files
"""
assert self.index, "index file is required for memory efficient saving"

filename_to_keys_map = defaultdict(set)
for key, filename in self.index.items():
filename_to_keys_map[filename].add(key)

# Build reverse index: key -> which batch file contains it
batch_files = glob(os.path.join(new_hf_dir, batch_file_pattern))
key_to_batch_file = {}
for bf in batch_files:
with safe_open(bf, framework="pt", device="cpu") as f:
for key in f.keys():
key_to_batch_file[key] = bf

filename_list = sorted(list(filename_to_keys_map.keys()))
if world_size > 1:
num_files = len(filename_list)
num_files_rank = (num_files + world_size - 1) // world_size
begin_idx = min(num_files, rank * num_files_rank)
end_idx = min(num_files, (rank + 1) * num_files_rank)
filename_list = filename_list[begin_idx:end_idx]

# For each final shard, collect tensors from batch files and write
for filename in filename_list:
keys_for_file = filename_to_keys_map[filename]
states = {}
old_keys_for_file, _ = self._mapping_weight_names_new2old(keys_for_file)
for old_key, key in zip(old_keys_for_file, keys_for_file):
bf = key_to_batch_file[old_key]
with safe_open(bf, framework="pt", device="cpu") as f:
states[key] = f.get_tensor(old_key)
save_file(states, os.path.join(new_hf_dir, filename))

return batch_files
Loading