Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions example/0.load_model_and_generate_single_gpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Use Megatron-FSDP: python example/0.load_model_and_generate_single_gpu.py --model_path /path/to/model --use_megatron_fsdp

import argparse
import os

Expand Down Expand Up @@ -25,12 +27,26 @@ def init_distributed():
model_parallel_cuda_manual_seed(0)


def load_model(hf_model_path, trust_remote_code=False):
def load_model(hf_model_path, trust_remote_code=False, use_megatron_fsdp=False):
"""Load model"""
bridge = AutoBridge.from_pretrained(
hf_model_path, trust_remote_code=trust_remote_code
)
model = bridge.get_model()
if use_megatron_fsdp:
ddp_config = {
"use_distributed_optimizer": True,
"check_for_nan_in_grad": True,
"use_megatron_fsdp": True,
"data_parallel_sharding_strategy": "optim_grads_params",
}
model = bridge.get_model(
wrap_with_ddp=True,
use_megatron_fsdp=True,
ddp_config=ddp_config,
data_parallel_random_init=False,
)
else:
model = bridge.get_model()
bridge.load_weights(model, hf_model_path)
return model

Expand Down Expand Up @@ -105,13 +121,18 @@ def main():
parser.add_argument(
"--trust_remote_code", action="store_true", help="Trust remote code"
)
parser.add_argument(
"--use_megatron_fsdp",
action="store_true",
help="Use Megatron-FSDP",
)
args = parser.parse_args()

# Initialize distributed environment
init_distributed()

# Load model
model = load_model(args.model_path, args.trust_remote_code)
model = load_model(args.model_path, args.trust_remote_code, args.use_megatron_fsdp)
print(f"Model loaded: {args.model_path}")

# Generate text
Expand Down
24 changes: 23 additions & 1 deletion example/1.load_model_and_export_single_gpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Use Megatron-FSDP: python example/1.load_model_and_export_single_gpu.py --model_path /path/to/model --use_megatron_fsdp

import argparse
import os

Expand Down Expand Up @@ -34,6 +36,11 @@ def main():
parser.add_argument(
"--trust_remote_code", action="store_true", help="Trust remote code"
)
parser.add_argument(
"--use_megatron_fsdp",
action="store_true",
help="Use Megatron-FSDP",
)
args = parser.parse_args()

# Initialize distributed environment
Expand All @@ -44,7 +51,22 @@ def main():
bridge = AutoBridge.from_pretrained(
hf_model_path, trust_remote_code=args.trust_remote_code
)
model = bridge.get_model()
if args.use_megatron_fsdp:
ddp_config = {
"use_distributed_optimizer": True,
"check_for_nan_in_grad": True,
"use_megatron_fsdp": True,
"data_parallel_sharding_strategy": "optim_grads_params",
}
model = bridge.get_model(
wrap_with_ddp=True,
use_megatron_fsdp=True,
ddp_config=ddp_config,
data_parallel_random_init=False,
post_model_creation_callbacks=[],
)
else:
model = bridge.get_model()
bridge.load_weights(model, hf_model_path)
print(f"Model loaded: {args.model_path}")

Expand Down
23 changes: 22 additions & 1 deletion example/2.load_model_and_export_multiple_gpus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Example to use tp/pp/cp/vpp to test dense model
# torchrun --nproc_per_node=8 2.load_model_and_export_multiple_gpus.py --model_path /path/to/model
# Use Megatron-FSDP: torchrun --nproc_per_node=8 2.load_model_and_export_multiple_gpus.py --model_path /path/to/model --use_megatron_fsdp


import argparse
Expand Down Expand Up @@ -126,6 +127,11 @@ def main():
parser.add_argument(
"--trust_remote_code", action="store_true", help="Trust remote code"
)
parser.add_argument(
"--use_megatron_fsdp",
action="store_true",
help="Use Megatron-FSDP",
)
args = parser.parse_args()

# Initialize distributed environment
Expand All @@ -142,7 +148,22 @@ def main():
hf_model_path = args.model_path
print(f"rank{torch.distributed.get_rank()}: start loading model")
bridge = AutoBridge.from_pretrained(hf_model_path)
model = bridge.get_model(post_model_creation_callbacks=[])
if args.use_megatron_fsdp:
ddp_config = {
"use_distributed_optimizer": True,
"check_for_nan_in_grad": True,
"use_megatron_fsdp": True,
"data_parallel_sharding_strategy": "optim_grads_params",
}
model = bridge.get_model(
wrap_with_ddp=True,
use_megatron_fsdp=True,
ddp_config=ddp_config,
data_parallel_random_init=False,
post_model_creation_callbacks=[],
)
else:
model = bridge.get_model(post_model_creation_callbacks=[])
print(
f"rank{torch.distributed.get_rank()}: start loading weights from {hf_model_path}"
)
Expand Down
24 changes: 23 additions & 1 deletion example/3.launch_megatron_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def worker_fn(
vpp: int,
ep: int,
etp: Optional[int],
use_megatron_fsdp: bool = False,
):
"""Worker that runs on a single GPU.

Expand All @@ -114,7 +115,22 @@ def worker_fn(

# 2. Load model & weights
bridge = AutoBridge.from_pretrained(hf_model_path)
model = bridge.get_model(post_model_creation_callbacks=[])
if use_megatron_fsdp:
ddp_config = {
"use_distributed_optimizer": True,
"check_for_nan_in_grad": True,
"use_megatron_fsdp": True,
"data_parallel_sharding_strategy": "optim_grads_params",
}
model = bridge.get_model(
wrap_with_ddp=True,
use_megatron_fsdp=True,
ddp_config=ddp_config,
data_parallel_random_init=False,
post_model_creation_callbacks=[],
)
else:
model = bridge.get_model(post_model_creation_callbacks=[])

bridge.load_weights(model, hf_model_path)

Expand Down Expand Up @@ -175,6 +191,11 @@ def main():
parser.add_argument(
"--master_port", type=int, default=12355, help="NCCL master port"
)
parser.add_argument(
"--use_megatron_fsdp",
action="store_true",
help="Use Megatron-FSDP",
)
args = parser.parse_args()

# Connect to the running Ray cluster
Expand Down Expand Up @@ -203,6 +224,7 @@ def main():
args.vpp,
args.ep,
args.etp,
args.use_megatron_fsdp,
)
)
rank += 1
Expand Down
24 changes: 23 additions & 1 deletion example/4.launch_deepseekv3_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def worker_fn(
etp: Optional[int],
num_layers_in_first_pipeline_stage: Optional[int] = None,
num_layers_in_last_pipeline_stage: Optional[int] = None,
use_megatron_fsdp: bool = False,
):
"""Worker that runs on a single GPU.

Expand All @@ -118,7 +119,22 @@ def worker_fn(
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
)
# bridge.config.mtp_num_layers = 0
model = bridge.get_model(post_model_creation_callbacks=[], wrap_with_ddp=False)
if use_megatron_fsdp:
ddp_config = {
"use_distributed_optimizer": True,
"check_for_nan_in_grad": True,
"use_megatron_fsdp": True,
"data_parallel_sharding_strategy": "optim_grads_params",
}
model = bridge.get_model(
wrap_with_ddp=True,
use_megatron_fsdp=True,
ddp_config=ddp_config,
data_parallel_random_init=False,
post_model_creation_callbacks=[],
)
else:
model = bridge.get_model(post_model_creation_callbacks=[], wrap_with_ddp=False)

# maintain router bias dtype
for m in model:
Expand Down Expand Up @@ -224,6 +240,11 @@ def main():
parser.add_argument(
"--master_port", type=int, default=12355, help="NCCL master port"
)
parser.add_argument(
"--use_megatron_fsdp",
action="store_true",
help="Use Megatron-FSDP",
)
args = parser.parse_args()

# Connect to the running Ray cluster
Expand Down Expand Up @@ -252,6 +273,7 @@ def main():
args.etp,
args.num_layers_in_first_pipeline_stage,
args.num_layers_in_last_pipeline_stage,
args.use_megatron_fsdp,
)
)
rank += 1
Expand Down
34 changes: 28 additions & 6 deletions mbridge/core/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from transformers import AutoConfig
from transformers.utils.hub import cached_file
from safetensors import safe_open

from torch.distributed._tensor import DTensor
from .parallel_states import ParallelStates
from .safetensor_io import SafeTensorIO
from .util import (
broadcast_from_megatron_pp,
broadcast_str_from_megatron_pp,
get_model,
unwrap_model,
get_module_and_param_from_name,
)


Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
self.make_vocab_size_divisible_by = make_vocab_size_divisible_by
self.vocab_size = None
self.padded_vocab_size = None
self.use_megatron_fsdp = False

# Some moe models require multiple weights to be combined into one,
# such as qwen3vl. It will cache it into this buff until all weights are collected.
Expand All @@ -73,6 +75,7 @@ def get_model(
fp16: bool = False,
bf16: bool = True,
encoder_pipeline_model_parallel_size: int = 0,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
use_custom_fsdp: bool = False,
use_precision_aware_optimizer: bool = False,
Expand Down Expand Up @@ -120,6 +123,7 @@ def get_model(
# and self.mpu.vpp_size > 1
# ):
# raise ValueError("tie_word_embeddings is not supported for VPP > 1")
self.use_megatron_fsdp = use_megatron_fsdp
model = get_model(
self._model_provider(
post_model_creation_callbacks,
Expand All @@ -131,6 +135,7 @@ def get_model(
bf16=bf16,
virtual_pipeline_model_parallel_size=self.mpu.vpp_size,
encoder_pipeline_model_parallel_size=encoder_pipeline_model_parallel_size,
use_megatron_fsdp=use_megatron_fsdp,
use_torch_fsdp2=use_torch_fsdp2,
use_custom_fsdp=use_custom_fsdp,
use_precision_aware_optimizer=use_precision_aware_optimizer,
Expand Down Expand Up @@ -198,8 +203,9 @@ def load_weights(
)

# import mcore weights
unwrapped_model = unwrap_model(model)
for local_name, hf_names in local_to_hf_map.items():
param = model.state_dict()[local_name]
param = unwrapped_model.state_dict()[local_name]
# hf format to mcore format
if set(to_load_from_disk) & set(hf_names):
if not memory_efficient:
Expand All @@ -218,7 +224,7 @@ def load_weights(
# skip lm_head.weight when the model is a value model
continue

param_to_load = torch.empty_like(param)
param_to_load = torch.empty(param.shape, device=param.device, dtype=param.dtype)
if ".mlp.experts.linear_fc" in local_name:
# split mcore weights across etp
if self.mpu.etp_rank == 0:
Expand Down Expand Up @@ -258,7 +264,14 @@ def load_weights(
group=self.mpu.tp_group,
)
# load
if isinstance(param, DTensor):
_, local_weights = get_module_and_param_from_name(unwrapped_model, local_name)
sliced_converted_weights = param_to_load.reshape(-1)[local_weights.megatron_fsdp_slice]
param._local_tensor.reshape(-1).copy_(sliced_converted_weights)
continue
param.copy_(param_to_load)
if self.use_megatron_fsdp:
model.module.install_optimized_model_weights()

def _save_weights_fast(
self,
Expand Down Expand Up @@ -527,7 +540,16 @@ def get_model_chunk_generator():
name, param = None, None

name = broadcast_str_from_megatron_pp(name)
broad_pp_param = broadcast_from_megatron_pp(param)
broad_pp_param = None
if isinstance(param, DTensor):
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import (
gather_uneven_dtensor_to_full_tensor,
)
_, local_weights = get_module_and_param_from_name(models, iter_name, iter_vpp_rank)
full_tensor = gather_uneven_dtensor_to_full_tensor(local_weights)
broad_pp_param = full_tensor.to_local()
else:
broad_pp_param = broadcast_from_megatron_pp(param)

# EP
if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1:
Expand Down Expand Up @@ -574,7 +596,7 @@ def get_model_chunk_generator():
if len(converted_names) == 0:
continue

yield from zip(converted_names, [p.detach() for p in converted_params])
yield from zip(converted_names, [p.detach().to(self.dtype) for p in converted_params])
continue

# TP
Expand Down Expand Up @@ -606,7 +628,7 @@ def get_model_chunk_generator():
if len(converted_names) == 0:
continue

yield from zip(converted_names, [p.detach() for p in converted_params])
yield from zip(converted_names, [p.detach().to(self.dtype) for p in converted_params])

def export_weights_without_gather(
self, models: list[torch.nn.Module],
Expand Down
Loading