diff --git a/example/0.load_model_and_generate_single_gpu.py b/example/0.load_model_and_generate_single_gpu.py index ad9263c..0f7bd78 100644 --- a/example/0.load_model_and_generate_single_gpu.py +++ b/example/0.load_model_and_generate_single_gpu.py @@ -1,3 +1,5 @@ +# Use Megatron-FSDP: python example/0.load_model_and_generate_single_gpu.py --model_path /path/to/model --use_megatron_fsdp + import argparse import os @@ -25,12 +27,26 @@ def init_distributed(): model_parallel_cuda_manual_seed(0) -def load_model(hf_model_path, trust_remote_code=False): +def load_model(hf_model_path, trust_remote_code=False, use_megatron_fsdp=False): """Load model""" bridge = AutoBridge.from_pretrained( hf_model_path, trust_remote_code=trust_remote_code ) - model = bridge.get_model() + if use_megatron_fsdp: + ddp_config = { + "use_distributed_optimizer": True, + "check_for_nan_in_grad": True, + "use_megatron_fsdp": True, + "data_parallel_sharding_strategy": "optim_grads_params", + } + model = bridge.get_model( + wrap_with_ddp=True, + use_megatron_fsdp=True, + ddp_config=ddp_config, + data_parallel_random_init=False, + ) + else: + model = bridge.get_model() bridge.load_weights(model, hf_model_path) return model @@ -105,13 +121,18 @@ def main(): parser.add_argument( "--trust_remote_code", action="store_true", help="Trust remote code" ) + parser.add_argument( + "--use_megatron_fsdp", + action="store_true", + help="Use Megatron-FSDP", + ) args = parser.parse_args() # Initialize distributed environment init_distributed() # Load model - model = load_model(args.model_path, args.trust_remote_code) + model = load_model(args.model_path, args.trust_remote_code, args.use_megatron_fsdp) print(f"Model loaded: {args.model_path}") # Generate text diff --git a/example/1.load_model_and_export_single_gpu.py b/example/1.load_model_and_export_single_gpu.py index 8073c8e..e5a60c6 100644 --- a/example/1.load_model_and_export_single_gpu.py +++ b/example/1.load_model_and_export_single_gpu.py @@ -1,3 +1,5 @@ +# Use Megatron-FSDP: python example/1.load_model_and_export_single_gpu.py --model_path /path/to/model --use_megatron_fsdp + import argparse import os @@ -34,6 +36,11 @@ def main(): parser.add_argument( "--trust_remote_code", action="store_true", help="Trust remote code" ) + parser.add_argument( + "--use_megatron_fsdp", + action="store_true", + help="Use Megatron-FSDP", + ) args = parser.parse_args() # Initialize distributed environment @@ -44,7 +51,22 @@ def main(): bridge = AutoBridge.from_pretrained( hf_model_path, trust_remote_code=args.trust_remote_code ) - model = bridge.get_model() + if args.use_megatron_fsdp: + ddp_config = { + "use_distributed_optimizer": True, + "check_for_nan_in_grad": True, + "use_megatron_fsdp": True, + "data_parallel_sharding_strategy": "optim_grads_params", + } + model = bridge.get_model( + wrap_with_ddp=True, + use_megatron_fsdp=True, + ddp_config=ddp_config, + data_parallel_random_init=False, + post_model_creation_callbacks=[], + ) + else: + model = bridge.get_model() bridge.load_weights(model, hf_model_path) print(f"Model loaded: {args.model_path}") diff --git a/example/2.load_model_and_export_multiple_gpus.py b/example/2.load_model_and_export_multiple_gpus.py index 7100749..6d0c793 100644 --- a/example/2.load_model_and_export_multiple_gpus.py +++ b/example/2.load_model_and_export_multiple_gpus.py @@ -1,5 +1,6 @@ # Example to use tp/pp/cp/vpp to test dense model # torchrun --nproc_per_node=8 2.load_model_and_export_multiple_gpus.py --model_path /path/to/model +# Use Megatron-FSDP: torchrun --nproc_per_node=8 2.load_model_and_export_multiple_gpus.py --model_path /path/to/model --use_megatron_fsdp import argparse @@ -126,6 +127,11 @@ def main(): parser.add_argument( "--trust_remote_code", action="store_true", help="Trust remote code" ) + parser.add_argument( + "--use_megatron_fsdp", + action="store_true", + help="Use Megatron-FSDP", + ) args = parser.parse_args() # Initialize distributed environment @@ -142,7 +148,22 @@ def main(): hf_model_path = args.model_path print(f"rank{torch.distributed.get_rank()}: start loading model") bridge = AutoBridge.from_pretrained(hf_model_path) - model = bridge.get_model(post_model_creation_callbacks=[]) + if args.use_megatron_fsdp: + ddp_config = { + "use_distributed_optimizer": True, + "check_for_nan_in_grad": True, + "use_megatron_fsdp": True, + "data_parallel_sharding_strategy": "optim_grads_params", + } + model = bridge.get_model( + wrap_with_ddp=True, + use_megatron_fsdp=True, + ddp_config=ddp_config, + data_parallel_random_init=False, + post_model_creation_callbacks=[], + ) + else: + model = bridge.get_model(post_model_creation_callbacks=[]) print( f"rank{torch.distributed.get_rank()}: start loading weights from {hf_model_path}" ) diff --git a/example/3.launch_megatron_with_ray.py b/example/3.launch_megatron_with_ray.py index b1bcd40..b603746 100644 --- a/example/3.launch_megatron_with_ray.py +++ b/example/3.launch_megatron_with_ray.py @@ -101,6 +101,7 @@ def worker_fn( vpp: int, ep: int, etp: Optional[int], + use_megatron_fsdp: bool = False, ): """Worker that runs on a single GPU. @@ -114,7 +115,22 @@ def worker_fn( # 2. Load model & weights bridge = AutoBridge.from_pretrained(hf_model_path) - model = bridge.get_model(post_model_creation_callbacks=[]) + if use_megatron_fsdp: + ddp_config = { + "use_distributed_optimizer": True, + "check_for_nan_in_grad": True, + "use_megatron_fsdp": True, + "data_parallel_sharding_strategy": "optim_grads_params", + } + model = bridge.get_model( + wrap_with_ddp=True, + use_megatron_fsdp=True, + ddp_config=ddp_config, + data_parallel_random_init=False, + post_model_creation_callbacks=[], + ) + else: + model = bridge.get_model(post_model_creation_callbacks=[]) bridge.load_weights(model, hf_model_path) @@ -175,6 +191,11 @@ def main(): parser.add_argument( "--master_port", type=int, default=12355, help="NCCL master port" ) + parser.add_argument( + "--use_megatron_fsdp", + action="store_true", + help="Use Megatron-FSDP", + ) args = parser.parse_args() # Connect to the running Ray cluster @@ -203,6 +224,7 @@ def main(): args.vpp, args.ep, args.etp, + args.use_megatron_fsdp, ) ) rank += 1 diff --git a/example/4.launch_deepseekv3_with_ray.py b/example/4.launch_deepseekv3_with_ray.py index 59fe50f..ecc59e0 100644 --- a/example/4.launch_deepseekv3_with_ray.py +++ b/example/4.launch_deepseekv3_with_ray.py @@ -100,6 +100,7 @@ def worker_fn( etp: Optional[int], num_layers_in_first_pipeline_stage: Optional[int] = None, num_layers_in_last_pipeline_stage: Optional[int] = None, + use_megatron_fsdp: bool = False, ): """Worker that runs on a single GPU. @@ -118,7 +119,22 @@ def worker_fn( num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, ) # bridge.config.mtp_num_layers = 0 - model = bridge.get_model(post_model_creation_callbacks=[], wrap_with_ddp=False) + if use_megatron_fsdp: + ddp_config = { + "use_distributed_optimizer": True, + "check_for_nan_in_grad": True, + "use_megatron_fsdp": True, + "data_parallel_sharding_strategy": "optim_grads_params", + } + model = bridge.get_model( + wrap_with_ddp=True, + use_megatron_fsdp=True, + ddp_config=ddp_config, + data_parallel_random_init=False, + post_model_creation_callbacks=[], + ) + else: + model = bridge.get_model(post_model_creation_callbacks=[], wrap_with_ddp=False) # maintain router bias dtype for m in model: @@ -224,6 +240,11 @@ def main(): parser.add_argument( "--master_port", type=int, default=12355, help="NCCL master port" ) + parser.add_argument( + "--use_megatron_fsdp", + action="store_true", + help="Use Megatron-FSDP", + ) args = parser.parse_args() # Connect to the running Ray cluster @@ -252,6 +273,7 @@ def main(): args.etp, args.num_layers_in_first_pipeline_stage, args.num_layers_in_last_pipeline_stage, + args.use_megatron_fsdp, ) ) rank += 1 diff --git a/mbridge/core/bridge.py b/mbridge/core/bridge.py index 7a18420..3dc7c23 100644 --- a/mbridge/core/bridge.py +++ b/mbridge/core/bridge.py @@ -11,7 +11,7 @@ from transformers import AutoConfig from transformers.utils.hub import cached_file from safetensors import safe_open - +from torch.distributed._tensor import DTensor from .parallel_states import ParallelStates from .safetensor_io import SafeTensorIO from .util import ( @@ -19,6 +19,7 @@ broadcast_str_from_megatron_pp, get_model, unwrap_model, + get_module_and_param_from_name, ) @@ -60,6 +61,7 @@ def __init__( self.make_vocab_size_divisible_by = make_vocab_size_divisible_by self.vocab_size = None self.padded_vocab_size = None + self.use_megatron_fsdp = False # Some moe models require multiple weights to be combined into one, # such as qwen3vl. It will cache it into this buff until all weights are collected. @@ -73,6 +75,7 @@ def get_model( fp16: bool = False, bf16: bool = True, encoder_pipeline_model_parallel_size: int = 0, + use_megatron_fsdp: bool = False, use_torch_fsdp2: bool = False, use_custom_fsdp: bool = False, use_precision_aware_optimizer: bool = False, @@ -120,6 +123,7 @@ def get_model( # and self.mpu.vpp_size > 1 # ): # raise ValueError("tie_word_embeddings is not supported for VPP > 1") + self.use_megatron_fsdp = use_megatron_fsdp model = get_model( self._model_provider( post_model_creation_callbacks, @@ -131,6 +135,7 @@ def get_model( bf16=bf16, virtual_pipeline_model_parallel_size=self.mpu.vpp_size, encoder_pipeline_model_parallel_size=encoder_pipeline_model_parallel_size, + use_megatron_fsdp=use_megatron_fsdp, use_torch_fsdp2=use_torch_fsdp2, use_custom_fsdp=use_custom_fsdp, use_precision_aware_optimizer=use_precision_aware_optimizer, @@ -198,8 +203,9 @@ def load_weights( ) # import mcore weights + unwrapped_model = unwrap_model(model) for local_name, hf_names in local_to_hf_map.items(): - param = model.state_dict()[local_name] + param = unwrapped_model.state_dict()[local_name] # hf format to mcore format if set(to_load_from_disk) & set(hf_names): if not memory_efficient: @@ -218,7 +224,7 @@ def load_weights( # skip lm_head.weight when the model is a value model continue - param_to_load = torch.empty_like(param) + param_to_load = torch.empty(param.shape, device=param.device, dtype=param.dtype) if ".mlp.experts.linear_fc" in local_name: # split mcore weights across etp if self.mpu.etp_rank == 0: @@ -258,7 +264,14 @@ def load_weights( group=self.mpu.tp_group, ) # load + if isinstance(param, DTensor): + _, local_weights = get_module_and_param_from_name(unwrapped_model, local_name) + sliced_converted_weights = param_to_load.reshape(-1)[local_weights.megatron_fsdp_slice] + param._local_tensor.reshape(-1).copy_(sliced_converted_weights) + continue param.copy_(param_to_load) + if self.use_megatron_fsdp: + model.module.install_optimized_model_weights() def _save_weights_fast( self, @@ -527,7 +540,16 @@ def get_model_chunk_generator(): name, param = None, None name = broadcast_str_from_megatron_pp(name) - broad_pp_param = broadcast_from_megatron_pp(param) + broad_pp_param = None + if isinstance(param, DTensor): + from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import ( + gather_uneven_dtensor_to_full_tensor, + ) + _, local_weights = get_module_and_param_from_name(models, iter_name, iter_vpp_rank) + full_tensor = gather_uneven_dtensor_to_full_tensor(local_weights) + broad_pp_param = full_tensor.to_local() + else: + broad_pp_param = broadcast_from_megatron_pp(param) # EP if ".mlp.experts.linear_fc" in name and self.mpu.ep_size >= 1: @@ -574,7 +596,7 @@ def get_model_chunk_generator(): if len(converted_names) == 0: continue - yield from zip(converted_names, [p.detach() for p in converted_params]) + yield from zip(converted_names, [p.detach().to(self.dtype) for p in converted_params]) continue # TP @@ -606,7 +628,7 @@ def get_model_chunk_generator(): if len(converted_names) == 0: continue - yield from zip(converted_names, [p.detach() for p in converted_params]) + yield from zip(converted_names, [p.detach().to(self.dtype) for p in converted_params]) def export_weights_without_gather( self, models: list[torch.nn.Module], diff --git a/mbridge/core/util.py b/mbridge/core/util.py index 6f6fa61..5dbaa29 100644 --- a/mbridge/core/util.py +++ b/mbridge/core/util.py @@ -7,13 +7,14 @@ from functools import lru_cache import torch +from typing import List, Optional, Tuple from megatron.core import mpu from megatron.core import parallel_state as mpu from megatron.core import tensor_parallel from megatron.core.fp8_utils import correct_amax_history_if_needed from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.utils import ( StragglerDetector, check_param_hashes_across_dp_replicas, @@ -30,6 +31,7 @@ def get_model( bf16: bool = True, virtual_pipeline_model_parallel_size: int = None, encoder_pipeline_model_parallel_size: int = 0, + use_megatron_fsdp: bool = False, use_torch_fsdp2: bool = False, use_custom_fsdp: bool = False, use_precision_aware_optimizer: bool = False, @@ -164,9 +166,13 @@ def build_model(): correct_amax_history_if_needed(model) if wrap_with_ddp: - from megatron.core.distributed import DistributedDataParallelConfig - - if use_torch_fsdp2: + from megatron.core.distributed import DistributedDataParallelConfig, FullyShardedDataParallel + if use_megatron_fsdp: + from megatron.core.distributed import FullyShardedDataParallel + DP = FullyShardedDataParallel + if use_torch_fsdp2: + raise ValueError("Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported.") + elif use_torch_fsdp2: try: from megatron.core.distributed import ( TorchFullyShardedDataParallel as torch_FSDP, @@ -207,7 +213,7 @@ def build_model(): ddp_config = DistributedDataParallelConfig(**kwargs) - if not use_torch_fsdp2: + if not use_torch_fsdp2 and not use_megatron_fsdp: # In the custom FSDP and DDP use path, we need to initialize the bucket size. # If bucket_size is not provided as an input, use sane default. @@ -273,7 +279,11 @@ def build_model(): ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, custom_FSDP, Float16Module) except ImportError: ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, custom_FSDP, Float16Module) - +try: + from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp import MegatronFSDP + ALL_MODULE_WRAPPER_CLASSNAMES = ALL_MODULE_WRAPPER_CLASSNAMES + (MegatronFSDP,) +except ImportError: + pass def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): return_list = True @@ -782,3 +792,64 @@ def postprocess_packed_seqs( output_new[i, attention_mask[i]] = tmp[:s_len] return output_new + + +def get_module_and_param_from_name( + models: MegatronModule | List[MegatronModule], + param_name: str, + vp_stage: Optional[int] = None, +) -> Tuple[torch.nn.Module, torch.Tensor] | Tuple[torch.nn.Module, torch.Tensor, Tuple]: + """ + Get parameter from models in specific VP stage. + + Args: + models: List of Megatron model instances or a submodule + param_name: Local parameter name separated by dots, e.g., "transformer.layers.0.mlp.dense.bias" + vp_stage: Virtual pipeline stage index (None for single stage) + + Returns: + Tuple of (module, parameter) where module owns the parameter + """ + + if isinstance(models, list): + if vp_stage is None: + model = models[0] + else: + if vp_stage >= len(models): + raise ValueError(f"VP stage {vp_stage} out of range (max: {len(models) - 1})") + model = models[vp_stage] + else: + model = models + + module = unwrap_model(model) + splitted_name = param_name.split(".") + + # Try to find the parameter using the given parts + def get_param(parts): + parent_module = module + previous_module = module + + for i, part in enumerate(parts): + if not hasattr(parent_module, part): + return None + parent_module = getattr(parent_module, part) + if i < len(parts) - 1: + previous_module = getattr(previous_module, part) + + return previous_module, parent_module + + # First try the full parameter name (current behavior) + result = get_param(splitted_name) + if result is not None: + return result + + # If full name doesn't work, try suffixes of the parameter name + # This handles cases where models is a submodule but param_name is absolute + for start_idx in range(1, len(splitted_name)): + suffix_parts = splitted_name[start_idx:] + result = get_param(suffix_parts) + if result is not None: + return result + + # If no approach works, raise an error + raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}")