diff --git a/example/qwen3_5/hf_fwd_moe.py b/example/qwen3_5/hf_fwd_moe.py new file mode 100644 index 0000000..b8cf0a4 --- /dev/null +++ b/example/qwen3_5/hf_fwd_moe.py @@ -0,0 +1,50 @@ +import argparse + +import torch + +try: + from transformers import Qwen3_5MoeForConditionalGeneration +except: + print(f"your install the tranformers>=5.2.0 or install from source") + +from example.qwen3_5.load_model_and_forward import get_sample_for_forward + +if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser(description="Load model and generate text") + parser.add_argument( + "--model_path", type=str, required=True, help="HuggingFace model path" + ) + parser.add_argument( + "--sample_type", + type=str, + default="image", + choices=["image", "video", "mix"], + help="sample type", + ) + args = parser.parse_args() + + # default: Load the model on the available device(s) + torch.set_grad_enabled(False) + model = Qwen3_5MoeForConditionalGeneration.from_pretrained( + args.model_path, + dtype="auto", + device_map="cuda:0", + ) + + for pname, params in model.named_parameters(): + print(f"Model weight {pname=} {params.shape} {params.dtype} {params.sum()}") + + # Preparation for inference + inputs = get_sample_for_forward(args.model_path, args.sample_type) + + for k in inputs: + inputs[k] = inputs[k].cuda() + + # Inference: Generation of the output + hf_output = model.forward(**inputs) + + print(hf_output.logits.shape, hf_output.logits.device, hf_output.logits.dtype) + torch.save(hf_output.logits.cpu(), "qwen3_5_save/hf_qwen3_5.pt") + + print(f"hf Done") diff --git a/example/qwen3_5/load_model_and_forward.py b/example/qwen3_5/load_model_and_forward.py new file mode 100644 index 0000000..1f6cba0 --- /dev/null +++ b/example/qwen3_5/load_model_and_forward.py @@ -0,0 +1,455 @@ +# Example to use tp/pp/cp/vpp to test dense model +# torchrun --nproc_per_node=8 example/qwen3vl/load_model_and_forward.py --model_path /path/to/model + +import argparse +import os + +import requests + +try: + from transformers import Qwen3VLProcessor +except: + print(f"your install the tranformers>=5.2.0 or install from source") + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state as mpu +from megatron.core.models.gpt.gpt_model import ModelType +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.tensor_parallel.mappings import ( + gather_from_tensor_model_parallel_region, +) +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +from mbridge import AutoBridge + + +def download_img(filename): + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + try: + response = requests.get(image_url, stream=True) + response.raise_for_status() + + with open(filename, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + except requests.exceptions.RequestException as e: + print(f"downlaod fail: {e}") + raise e + + +def download_video(filename): + video_url = ( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4" + ) + try: + response = requests.get(video_url, stream=True) + response.raise_for_status() + + with open(filename, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + except requests.exceptions.RequestException as e: + print(f"downlaod fail: {e}") + raise e + + +def get_image_sample_for_forward(hf_model_path): + processor = Qwen3VLProcessor.from_pretrained(hf_model_path) + # text = "Please describe this picture completely and in detail, including the details, characters, scenes, etc." + text = "Describe this image in shortly." + filename = "../australia.jpg" + if True: + if not os.path.exists(filename): + download_img(filename) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": filename}, + {"type": "text", "text": text}, + ], + } + ] + else: + text = "Given the accelerating trajectory of artificial intelligence, where do you foresee the most critical point of divergence between a future in which AI acts as a fundamentally benevolent, symbiotic partner in elevating human consciousness, collective intelligence, and our capacity to solve existential challenges, and a future where it inadvertently becomes an insidious, alienating force that amplifies societal biases, erodes human agency, and creates a new, opaque class structure based on access to and control of cognitive capital—and what specific, measurable factors in our current approach to AI development, governance, and education will be the primary determinants in steering us toward one outcome over the other?" + messages = [{"role": "user", "content": [{"type": "text", "text": text}]}] + + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + max_pixels=256 * 28 * 28, + ) + inputs.pop("token_type_ids", None) + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) + + return inputs + + +def get_video_sample_for_forward(hf_model_path): + processor = Qwen3VLProcessor.from_pretrained(hf_model_path) + video_file = "../space_woaudio.mp4" + if not os.path.exists(video_file): + download_video(video_file) + + processor = Qwen3VLProcessor.from_pretrained(hf_model_path) + # Messages containing a video url(or a local path) and a text query + messages = [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": video_file, + }, + {"type": "text", "text": "Describe those videos in shortly."}, + ], + } + ] + # Preparation for inference + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + inputs.pop("token_type_ids", None) + assert "pixel_values" not in inputs + inputs["pixel_values_videos"] = inputs["pixel_values_videos"].to(torch.bfloat16) + + return inputs + + +def get_mix_sample_for_forward(hf_model_path): + processor = Qwen3VLProcessor.from_pretrained(hf_model_path) + video_file = "../space_woaudio.mp4" + if not os.path.exists(video_file): + download_video(video_file) + image_file = "../australia.jpg" + if not os.path.exists(image_file): + download_img(image_file) + + processor = Qwen3VLProcessor.from_pretrained(hf_model_path) + # Messages containing a video url(or a local path) and a text query + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image_file, + }, + { + "type": "video", + "video": video_file, + }, + { + "type": "text", + "text": "Describe those videos and images in shortly and respectively", + }, + ], + } + ] + # Preparation for inference + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + add_vision_id=True, # have better to add this + ) + inputs.pop("token_type_ids", None) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) + inputs["pixel_values_videos"] = inputs["pixel_values_videos"].to(torch.bfloat16) + + return inputs + + +def get_sample_for_forward(hf_model_path, sample_type="image"): + if sample_type == "image": + return get_image_sample_for_forward(hf_model_path) + elif sample_type == "video": + return get_video_sample_for_forward(hf_model_path) + elif sample_type == "mix": + return get_mix_sample_for_forward(hf_model_path) + else: + assert False + + +def gather_output_from_cp(input_: torch.Tensor, seq_dim, cp_size, cp_group): + assert seq_dim in [0, 1] and input_.dim() > seq_dim + # Split local_logits into two parts + input_ = input_.view( + *input_.shape[0:seq_dim], + 2, + input_.shape[seq_dim] // 2, + *input_.shape[(seq_dim + 1) :], + ) + + gathered_logits = [torch.zeros_like(input_) for _ in range(cp_size)] + torch.distributed.all_gather(gathered_logits, input_, group=cp_group) + + reorded_logits = [None for _ in range(2 * cp_size)] + if seq_dim == 1: + for rank in range(cp_size): + reorded_logits[rank] = gathered_logits[rank][:, 0] + reorded_logits[2 * cp_size - rank - 1] = gathered_logits[rank][:, 1] + elif seq_dim == 0: + for rank in range(cp_size): + reorded_logits[rank] = gathered_logits[rank][0] + reorded_logits[2 * cp_size - rank - 1] = gathered_logits[rank][1] + else: + assert False + gathered_logits = torch.cat(reorded_logits, dim=seq_dim) + + return gathered_logits + + +# hf logits vs megatron logits +def cos_similarity(a, b): + print(f"a {a.shape} b {b.shape}") + a = a.float() + # a = a / a.norm(dim=-1, keepdim=True) + a = torch.exp(a - a.max(dim=-1, keepdim=True)[0]) + a = a / a.norm(dim=-1, keepdim=True) + """ + a = (a - a.mean(dim=-1, keepdim=True)) + a = a / a.norm(dim=-1, keepdim=True) + """ + b = b.float() + # b = b / b.norm(dim=-1, keepdim=True) + b = torch.exp(b - b.max(dim=-1, keepdim=True)[0]) + b = b / b.norm(dim=-1, keepdim=True) + """ + b = (b - b.mean(dim=-1, keepdim=True)) + b = b / b.norm(dim=-1, keepdim=True) + """ + sim = (a * b).sum(dim=-1) + print( + f"hf vs megatron cos_similarity min: {sim.min()}; max: {sim.max()}; mean: {sim.mean()}" + ) + + +def init_distributed(tp=2, pp=1, cp=1, vpp=1, ep=1, etp=None): + """Initialize distributed environment""" + torch.distributed.init_process_group("nccl") + torch.cuda.set_device(torch.distributed.get_rank() % 8) + if pp <= 1: + vpp = None + mpu.initialize_model_parallel( + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + virtual_pipeline_model_parallel_size=vpp, + context_parallel_size=cp, + expert_model_parallel_size=ep, + expert_tensor_parallel_size=etp, + ) + model_parallel_cuda_manual_seed(0) + + +def get_args(): + # Parse command line arguments + parser = argparse.ArgumentParser(description="Load model and generate text") + parser.add_argument( + "--model_path", type=str, required=True, help="HuggingFace model path" + ) + parser.add_argument("--tp", type=int, default=2, help="Tensor model parallel size") + parser.add_argument( + "--pp", type=int, default=1, help="Pipeline model parallel size" + ) + parser.add_argument("--cp", type=int, default=1, help="Context parallel size") + parser.add_argument( + "--vpp", type=int, default=None, help="Virtual pipeline model parallel size" + ) + parser.add_argument("--ep", type=int, default=1, help="Expert model parallel size") + parser.add_argument( + "--etp", type=int, default=None, help="Expert tensor parallel size" + ) + parser.add_argument("--check_export", action="store_true", help="Trust remote code") + + parser.add_argument( + "--sample_type", + type=str, + default="image", + choices=["image", "video", "mix"], + help="sample type", + ) + parser.add_argument("--save_path", type=str, default=None, help="Save path") + + args = parser.parse_args() + return args + + +def mcore_fwd_fn(data_iterator, model): + sample = next(data_iterator) + + output_tensor = model( + input_ids=sample["input_ids"].cuda(), + position_ids=None, + attention_mask=None, + pixel_values=( + sample["pixel_values"].cuda() if "pixel_values" in sample else None + ), + image_grid_thw=( + sample["image_grid_thw"].cuda() if "image_grid_thw" in sample else None + ), + pixel_values_videos=( + sample["pixel_values_videos"].cuda() + if "pixel_values_videos" in sample + else None + ), + video_grid_thw=( + sample["video_grid_thw"].cuda() if "video_grid_thw" in sample else None + ), + ) + if isinstance(output_tensor, tuple): + output_tensor = output_tensor[0] + assert isinstance(output_tensor, torch.Tensor) + + def loss_fn(output_tensor, non_loss_data=True): + loss = output_tensor.mean() + return loss, { + "loss": loss.detach(), + "logits": output_tensor.detach(), + } + + return output_tensor, loss_fn + + +def main(): + args = get_args() + print(f"{args=}") + + # Initialize distributed environment + init_distributed( + tp=args.tp, + pp=args.pp, + cp=args.cp, + vpp=args.vpp, + ep=args.ep, + etp=args.etp, + ) + + # Load megatron model + hf_model_path = args.model_path + print(f"rank{torch.distributed.get_rank()}: start loading model ...") + bridge = AutoBridge.from_pretrained(hf_model_path) + bridge.config.sequence_parallel = True + if args.pp > 1: + num_layer = bridge.hf_config.text_config.num_hidden_layers + first_last_layer = num_layer - (num_layer + args.pp - 1) // args.pp * ( + args.pp - 2 + ) + assert first_last_layer > 1 + bridge.set_extra_args( + num_layers_in_first_pipeline_stage=first_last_layer // 2, + num_layers_in_last_pipeline_stage=(first_last_layer + 1) // 2, + ) + model = bridge.get_model(model_type=ModelType.encoder_or_decoder) + assert len(model) == 1 + bridge.load_weights(model, hf_model_path, memory_efficient=False) + for pname, params in model[0].named_parameters(): + if torch.distributed.get_rank() == 0: + print(f"Model weight {pname=} {params.shape} {params.dtype} {params.sum()}") + + # # check the export + if args.check_export: + print( + f"rank{torch.distributed.get_rank()}: end load weight, start check export ..." + ) + keys = bridge.safetensor_io.load_hf_weight_names() + loaded_keys = set() + # export weights + for k, v in bridge.export_weights(model): + gt = bridge.safetensor_io.load_one_hf_weight(k).cuda() + assert v.shape == gt.shape, f"mismatch of {k} {v.shape=} {gt.shape=}" + gt_dtype = gt.dtype + v_dtype = v.dtype + if v_dtype != gt_dtype: + gt = gt.to(v_dtype) + print(f"dtype_check {k} {v_dtype=} {gt_dtype=}") + + if torch.equal(v, gt): + print(f"weight equal {k=}") + elif torch.allclose(v, gt, atol=1e-3, rtol=1e-3): + print(f"weight close {k=} {v.sum()} {gt.sum()} {v=} {gt=}") + else: + assert False, f"weight_mismatch of {k=} {v.sum()} {gt.sum()} {v=} {gt=}" + + loaded_keys.add(k) + assert len(bridge.export_weights_buff) == 0 + + missing_keys = set(keys) - loaded_keys + missing_keys = sorted(list(missing_keys)) + print(f"missing keys: {missing_keys} {len(missing_keys)==0}") + + print(f"rank{torch.distributed.get_rank()}: end load weight, start forward ...") + + sample = get_sample_for_forward(hf_model_path, args.sample_type) + real_seq_length = sample["input_ids"].shape[-1] + torch.distributed.barrier() + seq_length_factor = args.tp + if args.cp > 1: + seq_length_factor *= args.cp * 2 + with torch.no_grad(): + fwd_bwd_function = get_forward_backward_func() + + seq_length = real_seq_length + if real_seq_length % seq_length_factor != 0: + seq_length = ( + (real_seq_length + seq_length_factor - 1) + // seq_length_factor + * seq_length_factor + ) + sample["input_ids"] = F.pad( + sample["input_ids"], + (0, seq_length - real_seq_length, 0, 0), + value=0, + ) + + mcore_output = fwd_bwd_function( + forward_step_func=mcore_fwd_fn, + data_iterator=iter([sample]), + model=model, + num_microbatches=1, + forward_only=True, + seq_length=seq_length, + decoder_seq_length=seq_length, + micro_batch_size=1, + ) + + if mpu.is_pipeline_last_stage(): + megatron_output = mcore_output[0]["logits"] + if mpu.get_context_parallel_world_size() > 1: + megatron_output = gather_output_from_cp( + megatron_output, + 1, + mpu.get_context_parallel_world_size(), + mpu.get_context_parallel_group(), + ) + if mpu.get_tensor_model_parallel_world_size() > 1: + megatron_output = gather_from_tensor_model_parallel_region( + megatron_output + ) + + megatron_output = megatron_output[:, :real_seq_length, :] + + torch.save( + megatron_output, + f"qwen3_5_save/mlm_tp{args.tp}_pp{args.pp}_cp{args.cp}_ep{args.ep}.pt", + ) + + print(f"Finish Done") + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/example/qwen3_5/load_model_and_inference.py b/example/qwen3_5/load_model_and_inference.py new file mode 100644 index 0000000..1003c3a --- /dev/null +++ b/example/qwen3_5/load_model_and_inference.py @@ -0,0 +1,202 @@ +# Example to use tp/pp/cp/vpp to test dense model +# torchrun --nproc_per_node=8 example/qwen3vl/load_model_and_forward.py --model_path /path/to/model + +import argparse +from typing import Any + +from tqdm import trange + +try: + from transformers import Qwen3VLProcessor +except: + print(f"your install the tranformers>=5.2.0 or install from source") + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state as mpu +from megatron.core.models.gpt.gpt_model import ModelType +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.tensor_parallel.mappings import ( + gather_from_tensor_model_parallel_region, +) +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +from example.qwen3_5.load_model_and_forward import ( + gather_output_from_cp, + get_args, + get_sample_for_forward, + mcore_fwd_fn, +) +from mbridge import AutoBridge +from mbridge.core.util import unwrap_model + + +def init_distributed(tp=2, pp=1, cp=1, vpp=1, ep=1, etp=None): + """Initialize distributed environment""" + torch.distributed.init_process_group("nccl") + torch.cuda.set_device(torch.distributed.get_rank() % 8) + if pp <= 1: + vpp = None + mpu.initialize_model_parallel( + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + virtual_pipeline_model_parallel_size=vpp, + context_parallel_size=cp, + expert_model_parallel_size=ep, + expert_tensor_parallel_size=etp, + ) + model_parallel_cuda_manual_seed(0) + + +def broadcast_object_within_pp(obj: Any) -> Any: + group = mpu.get_pipeline_model_parallel_group() + + if torch.distributed.get_world_size(group) > 1: + obj_list = [obj] + torch.distributed.broadcast_object_list( + obj_list, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=group, + ) + return obj_list[0] + else: + return obj + + +def main(): + args = get_args() + print(f"{args=}") + + # Initialize distributed environment + init_distributed( + tp=args.tp, + pp=args.pp, + cp=args.cp, + vpp=args.vpp, + ep=args.ep, + etp=args.etp, + ) + + # Load megatron model + hf_model_path = args.model_path + print(f"rank{torch.distributed.get_rank()}: start loading model ...") + bridge = AutoBridge.from_pretrained(hf_model_path) + if args.pp > 1: + num_layer = bridge.hf_config.text_config.num_hidden_layers + first_last_layer = num_layer - (num_layer + args.pp - 1) // args.pp * ( + args.pp - 2 + ) + assert first_last_layer > 1 + bridge.set_extra_args( + num_layers_in_first_pipeline_stage=first_last_layer // 2, + num_layers_in_last_pipeline_stage=(first_last_layer + 1) // 2, + ) + bridge.config.sequence_parallel = True + model = bridge.get_model(model_type=ModelType.encoder_or_decoder) + assert len(model) == 1 + bridge.load_weights(model, hf_model_path, memory_efficient=True) + + if args.save_path is not None: + save_weights_kwargs = {} + import inspect + + save_func_sig = inspect.signature(bridge.save_weights) + if "distributed_filesystem" in save_func_sig.parameters: + save_weights_kwargs["distributed_filesystem"] = True + unwrapped_model = unwrap_model(model) + bridge.save_weights( + unwrapped_model, + args.save_path, + memory_efficient=True, + **save_weights_kwargs, + ) + + print(f"rank{torch.distributed.get_rank()}: end load weight, start forward ...") + + eos_token_id = bridge.hf_config.text_config.eos_token_id + sample = get_sample_for_forward(hf_model_path, args.sample_type) + input_ids = sample["input_ids"].tolist() + generated_tokens = [] + max_new_tokens = 200 + torch.distributed.barrier() + seq_length_factor = args.tp + if args.cp > 1: + seq_length_factor *= args.cp * 2 + with torch.no_grad(): + fwd_bwd_function = get_forward_backward_func() + + for i in trange( + max_new_tokens, disable=(mpu.get_tensor_model_parallel_rank() == 0) + ): + real_seq_length = sample["input_ids"].shape[-1] + seq_length = real_seq_length + if real_seq_length % seq_length_factor != 0: + seq_length = ( + (real_seq_length + seq_length_factor - 1) + // seq_length_factor + * seq_length_factor + ) + sample["input_ids"] = F.pad( + sample["input_ids"], + (0, seq_length - real_seq_length, 0, 0), + value=0, + ) + + mcore_output = fwd_bwd_function( + forward_step_func=mcore_fwd_fn, + data_iterator=iter([sample]), + model=model, + num_microbatches=1, + forward_only=True, + seq_length=seq_length, + decoder_seq_length=seq_length, + micro_batch_size=1, + ) + + next_token = -1 + if mpu.is_pipeline_last_stage(): + megatron_output = mcore_output[0]["logits"] + if mpu.get_context_parallel_world_size() > 1: + megatron_output = gather_output_from_cp( + megatron_output, + 1, + mpu.get_context_parallel_world_size(), + mpu.get_context_parallel_group(), + ) + if mpu.get_tensor_model_parallel_world_size() > 1: + megatron_output = gather_from_tensor_model_parallel_region( + megatron_output + ) + + megatron_output = megatron_output[:, :real_seq_length, :] + next_token = megatron_output[:, -1, :].argmax(dim=-1)[0].item() + if ( + torch.distributed.get_rank() + == torch.distributed.get_world_size() - 1 + ): + print(f"{i=} {next_token=}") + + next_token = broadcast_object_within_pp(next_token) + generated_tokens.append(next_token) + input_ids[0].append(next_token) + sample["input_ids"] = torch.tensor( + input_ids, device=torch.cuda.current_device() + ) + if next_token == eos_token_id: + break + + if torch.distributed.get_rank() == 0: + print(f"{generated_tokens=}") + processor = Qwen3VLProcessor.from_pretrained(hf_model_path) + output_text = processor.batch_decode( + [generated_tokens], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + print(f"{output_text=}") + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/example/qwen3_5/mpirun.sh b/example/qwen3_5/mpirun.sh new file mode 100644 index 0000000..c2020d0 --- /dev/null +++ b/example/qwen3_5/mpirun.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +cat /etc/mpi/hostfile > /root/hostfile +sed -i 's/slots=8/slots=1/g' /root/hostfile + +export _MASTER_ADDR=${__POD_IP__:-localhost} + +mpirun -v --allow-run-as-root \ + --bind-to none --map-by slot --hostfile /root/hostfile \ + --mca btl_tcp_if_include bond1 --mca oob_tcp_if_include bond1 --mca routed direct \ + -x PATH -x LIBRARY_PATH -x LD_LIBRARY_PATH -x _MASTER_ADDR \ + bash example/qwen3_5/run_load_and_fwd.sh diff --git a/example/qwen3_5/offline_calcul_cos.py b/example/qwen3_5/offline_calcul_cos.py new file mode 100644 index 0000000..c7a3bfb --- /dev/null +++ b/example/qwen3_5/offline_calcul_cos.py @@ -0,0 +1,47 @@ +import torch + + +def cos_similarity(a, b): + print(f"a {a.shape} b {b.shape}") + a = a.to(b.device) + a = a.float() + # a = a / a.norm(dim=-1, keepdim=True) + a = torch.exp(a) + a = a / a.norm(dim=-1, keepdim=True) + """ + a = (a - a.mean(dim=-1, keepdim=True)) + a = a / a.norm(dim=-1, keepdim=True) + """ + b = b.float() + # b = b / b.norm(dim=-1, keepdim=True) + b = torch.exp(b) + b = b / b.norm(dim=-1, keepdim=True) + """ + b = (b - b.mean(dim=-1, keepdim=True)) + b = b / b.norm(dim=-1, keepdim=True) + """ + sim = (a * b).sum(dim=-1) + print( + f"hf vs megatron cos_similarity min: {sim.min()}; max: {sim.max()}; mean: {sim.mean()}" + ) + + +path1 = "qwen3_5_save/hf_qwen3_5.pt" + +path2_list = [ + # "qwen3_5_save/mlm_tp1_pp1_cp1_ep1.pt", + # "qwen3_5_save/mlm_tp2_pp1_cp1_ep1.pt", + # "qwen3_5_save/mlm_tp2_pp1_cp1_ep4.pt", + # "qwen3_5_save/mlm_tp2_pp1_cp2_ep4.pt", + "qwen3_5_save/mlm_tp2_pp2_cp2_ep4.pt", + "qwen3_5_save/mlm_tp4_pp2_cp1_ep4.pt", + "qwen3_5_save/mlm_tp2_pp2_cp1_ep2.pt", +] + +a = torch.load(path1) +for path2 in path2_list: + print(f"load from {path1=} {path2=}") + b = torch.load(path2) + + cos = cos_similarity(a, b) + print(f"{cos=} {a.sum()} {b.sum()} {a.dtype} {b.dtype}") diff --git a/example/qwen3_5/run_load_and_fwd.sh b/example/qwen3_5/run_load_and_fwd.sh new file mode 100644 index 0000000..7d00d3d --- /dev/null +++ b/example/qwen3_5/run_load_and_fwd.sh @@ -0,0 +1,56 @@ +#!/bin/bash +ps -ef | grep python | awk '{print $2}' | xargs -I {} kill -9 {} +sleep 1 + +MLM_PATH="../3rdparty/Megatron-LM/" + +export PYTHONPATH=$PWD:$MLM_PATH:$PYTHONPATH +echo "PYTHONPATH ${PYTHONPATH}" +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export HF_DATASETS_OFFLINE=1 +export GLOO_SOCKET_IFNAME=bond1 +export NCCL_SOCKET_IFNAME=bond1 + +readonly GPUS_PER_NODE=8 +readonly NODE_RANK="${OMPI_COMM_WORLD_RANK:-0}" +readonly NNODES="${OMPI_COMM_WORLD_SIZE:-1}" +readonly WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +readonly MASTER_PORT=65535 +export MASTER_ADDR="${_MASTER_ADDR:-localhost}" + +readonly TP_SIZE=2 +readonly PP_SIZE=1 +readonly CP_SIZE=1 +readonly EP_SIZE=16 + +echo "INFO +__POD_IP__ $__POD_IP__ +NODE_RANK $NODE_RANK +NNODES $NNODES +TP_SIZE $TP_SIZE +PP_SIZE $PP_SIZE +CP_SIZE $CP_SIZE +EP_SIZE $EP_SIZE +" + +# torchrun distributed args +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ +" + +SAMPLE_TYPE="image" + +torchrun $DISTRIBUTED_ARGS \ + example/qwen3_5/load_model_and_inference.py \ + --tp $TP_SIZE \ + --pp $PP_SIZE \ + --ep $EP_SIZE \ + --etp 1 \ + --cp $CP_SIZE \ + --model_path hf-hub/Qwen/Qwen3.5-397B-A17B/ \ + --sample_type $SAMPLE_TYPE \ + diff --git a/example/qwen3_5/run_test.sh b/example/qwen3_5/run_test.sh new file mode 100644 index 0000000..a751128 --- /dev/null +++ b/example/qwen3_5/run_test.sh @@ -0,0 +1,60 @@ +#!/bin/bash +ps -ef | grep python | awk '{print $2}' | xargs -I {} kill -9 {} +sleep 1 + +MLM_PATH="../3rdparty/Megatron-LM/" + +export PYTHONPATH=$PWD:$MLM_PATH:$PYTHONPATH +echo "PYTHONPATH ${PYTHONPATH}" +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export HF_DATASETS_OFFLINE=1 +export GLOO_SOCKET_IFNAME=bond1 +export NCCL_SOCKET_IFNAME=bond1 + +readonly GPUS_PER_NODE=8 +readonly NODE_RANK="${OMPI_COMM_WORLD_RANK:-0}" +readonly NNODES="${OMPI_COMM_WORLD_SIZE:-1}" +readonly WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +readonly MASTER_PORT=65535 +export MASTER_ADDR="${_MASTER_ADDR:-localhost}" + +readonly TP_SIZE=2 +readonly PP_SIZE=2 +readonly CP_SIZE=1 +readonly EP_SIZE=2 + +echo "INFO +__POD_IP__ $__POD_IP__ +NODE_RANK $NODE_RANK +NNODES $NNODES +TP_SIZE $TP_SIZE +PP_SIZE $PP_SIZE +CP_SIZE $CP_SIZE +EP_SIZE $EP_SIZE +" + +# torchrun distributed args +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ +" + + +SAMPLE_TYPE="image" + +# python example/qwen3_5/hf_fwd_moe.py \ +# --model_path hf-hub/Qwen/Qwen3.5_small/ \ +# --sample_type $SAMPLE_TYPE + +torchrun $DISTRIBUTED_ARGS example/qwen3_5/load_model_and_forward.py \ + --tp $TP_SIZE \ + --pp $PP_SIZE \ + --ep $EP_SIZE \ + --etp 1 \ + --cp $CP_SIZE \ + --model_path hf-hub/Qwen/Qwen3.5_small/ \ + --sample_type $SAMPLE_TYPE \ + --check_export diff --git a/mbridge/core/safetensor_io.py b/mbridge/core/safetensor_io.py index fa1a106..d20cd4a 100644 --- a/mbridge/core/safetensor_io.py +++ b/mbridge/core/safetensor_io.py @@ -14,7 +14,7 @@ class SafeTensorIO: def __init__(self, hf_dir: str): index_file = os.path.join(hf_dir, "model.safetensors.index.json") - config = AutoConfig.from_pretrained(hf_dir) + config = AutoConfig.from_pretrained(hf_dir, trust_remote_code=True) self.index = {} self.origin_index = {} diff --git a/mbridge/models/__init__.py b/mbridge/models/__init__.py index e195cef..48b0a50 100644 --- a/mbridge/models/__init__.py +++ b/mbridge/models/__init__.py @@ -30,3 +30,7 @@ from .gemma3 import Gemma3Bridge from .internvl3 import InternVL3Bridge from .qwen3_vl import Qwen3VLBridge, Qwen3VLBridge + +from contextlib import suppress +with suppress(ImportError): + from .qwen3_5 import Qwen3_5VlBridge, Qwen3_5MoeVlBridge diff --git a/mbridge/models/qwen3_5/__init__.py b/mbridge/models/qwen3_5/__init__.py new file mode 100644 index 0000000..aab10e1 --- /dev/null +++ b/mbridge/models/qwen3_5/__init__.py @@ -0,0 +1 @@ +from mbridge.models.qwen3_5.qwen3_5_vl_bridge import Qwen3_5MoeVlBridge, Qwen3_5VlBridge diff --git a/mbridge/models/qwen3_5/attention.py b/mbridge/models/qwen3_5/attention.py new file mode 100644 index 0000000..0f519de --- /dev/null +++ b/mbridge/models/qwen3_5/attention.py @@ -0,0 +1,378 @@ +from typing import Optional, Tuple, Union + +import torch +from megatron.core.transformer.attention import * +from torch import Tensor + +from mbridge.models.qwen3_vl.rope_utils import apply_rotary_pos_emb_absolute + + +class Qwen3_5VLSelfAttention(SelfAttention): + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ) -> tuple[Tensor, Tensor]: + """ + Perform a forward pass through the attention module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary + embedding tensor(s). + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) Attention output and bias. + + """ + # Check if we need to skip RoPE + # no_rope is 0-indexed array and self.layer_number is 1-indexed + no_rope = ( + self.config.no_rope_freq[self.layer_number - 1] + if self.config.no_rope_freq + else False + ) + if no_rope: + rotary_pos_emb = None + + inference_context = deprecate_inference_params( + inference_context, inference_params + ) + + if inference_context and inference_context.is_dynamic_batching(): + assert HAVE_FA3 or is_fa_min_version( + "2.7.3" + ), "flash attn verion v2.7.3 and above is required for dynamic batching." + + # hidden_states: [sq, b, h] + is_inference_mode = inference_context is not None and not self.training + # is_using_flash_decode - True is we are using the static inference engine with flash decode + is_using_flash_decode = is_inference_mode and self.config.flash_decode + # is_using_flashinfer_rope - True if we are using the dynamic inference engine + # with flashinfer fused rope + is_using_flashinfer_rope = is_inference_mode and ( + not inference_context.is_static_batching() + and inference_context.use_flashinfer_fused_rope + ) + if is_using_flash_decode or is_using_flashinfer_rope: + # flash decode and flash-infer fused rope use rotary_pos_cos and rotary_pos_sin + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + nvtx_range_push(suffix="qkv") + split_qkv = (self.attention_type == "cross") or not all( + [ + not self.config.test_mode, + self.config.fused_single_qkv_rope, + inference_context is None, + packed_seq_params is None, + ( + rotary_pos_emb is not None + and rotary_pos_emb[0] is not None + and rotary_pos_emb[1] is not None + ), + not self.config.flash_decode, + HAVE_FUSED_QKV_ROPE, + self.q_layernorm is None or isinstance(self.q_layernorm, IdentityOp), + self.k_layernorm is None or isinstance(self.k_layernorm, IdentityOp), + ] + ) + output_gate = self.config.attention_output_gate + # Check if fused_single_qkv_rope is requested but either unavailable or not + # supported for the current use case. + if self.attention_type != "cross": + assert not ( + self.config.fused_single_qkv_rope and split_qkv + ), "fused_single_qkv_rope requested but not available/supported for the config." + if output_gate: + assert ( + split_qkv + ), "output_gate is not supported for unsplit mixed_qkv tensor." + + with off_interface( + self.offload_qkv_linear, hidden_states, "qkv_linear" + ) as hidden_states: + qkv_output = self.get_query_key_value_tensors( + hidden_states, + key_value_states, + split_qkv=split_qkv, + output_gate=self.config.attention_output_gate, + ) + if self.offload_qkv_linear: + # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. + qkv_output = off_interface.group_commit( + qkv_output, name="qkv_linear", forced_released_tensors=[] + ) + attn_mask_type = self.attn_mask_type + block_table = None + gate = None + if split_qkv: + if self.config.attention_output_gate: + query, key, value, gate = qkv_output + else: + query, key, value = qkv_output + mixed_qkv = qkv_split_arg_list = None + else: + assert ( + not self.config.attention_output_gate + ), "attention_output_gate is not supported for unsplit mixed_qkv tensor." + mixed_qkv, qkv_split_arg_list = qkv_output + nvtx_range_pop(suffix="qkv") + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + in_decode_mode = ( + inference_context is not None + and inference_context.is_decode_only() + and not self.training + ) + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + nvtx_range_push(suffix="adjust_key_value") + if in_decode_mode and self.config.flash_decode: + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = ( + inference_context.key_value_memory_dict[self.layer_number] + ) + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + rotary_interleaved=self.config.rotary_interleaved, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + if ( + in_decode_mode + and self.config.cuda_graph_impl == "local" + and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope + and inference_context.is_static_batching() + ): + raise ValueError(f"CUDA graphs must use flash decode with static batching!") + + if split_qkv: + query, key, value, rotary_pos_emb, attn_mask_type, block_table = ( + self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + rotary_pos_cos_sin, + sequence_len_offset, + ) + ) + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + nvtx_range_pop(suffix="adjust_key_value") + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + nvtx_range_push(suffix="rotary_pos_emb") + if rotary_pos_emb is not None and ( + not self.config.flash_decode or inference_context is None + ): + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if split_qkv: + if q_pos_emb is not None: + # TODO VIJAY: simplify + if ( + inference_context is None + or inference_context.is_static_batching() + ): + query = apply_rotary_pos_emb_absolute( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + mscale=_yarn_get_concentration_factor_from_config( + self.config + ), + cp_group=self.pg_collection.cp, + ) + else: + query = inference_context.apply_rotary_emb_query( + query, + q_pos_emb, + self.config, + cu_seqlens_q, + self.pg_collection.cp, + ) + if k_pos_emb is not None: + key = apply_rotary_pos_emb_absolute( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + mscale=_yarn_get_concentration_factor_from_config(self.config), + cp_group=self.pg_collection.cp, + ) + else: + raise ValueError( + "fused_qkv_rotary_pos_emb is not supported for unsplit mixed_qkv tensor." + ) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + nvtx_range_pop(suffix="rotary_pos_emb") + + # ================================== + # core attention computation + # ================================== + + nvtx_range_push(suffix="core_attention") + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + # Static batching attention kernel. + with off_interface( + self.offload_core_attention and self.training, query, "core_attn" + ) as query: + core_attn_out = apply_module(self.core_attention)( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + else: + # Dynamic batching attention kernel. + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, kv_lengths, max_seqlen_k = ( + inference_context.cu_kv_lengths() + ) + + core_attn_out = self.flash_decode_and_prefill( + q, + k, + v, + max_seqlen_q, + max_seqlen_k, + cu_query_lengths, + cu_kv_lengths, + kv_lengths, + block_table, + inference_context.is_decode_only(), + ) + core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") + + # Clear the outputs for padding tokens when using quantization scales + # to avoid corrupting amax calculations + if is_using_quantization_scales(self.config): + core_attn_out[inference_context.padding_slice] = 0.0 + + if self.offload_core_attention and self.training: + core_attn_out = off_interface.group_commit( + core_attn_out, + name="core_attn", + forced_released_tensors=[query, key, value], + ) + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + nvtx_range_pop(suffix="core_attention") + + # Output gate + if gate is not None: + nvtx_range_push(suffix="output_gate") + core_attn_out = self._apply_output_gate(core_attn_out, gate) + nvtx_range_pop(suffix="output_gate") + + # ================= + # Output. [sq, b, h] + # ================= + nvtx_range_push(suffix="linear_proj") + with off_interface( + self.offload_attn_proj, core_attn_out, "attn_proj" + ) as core_attn_out: + output, bias = self.linear_proj(core_attn_out) + if self.offload_attn_proj: + output = off_interface.group_commit( + output, name="attn_proj", forced_released_tensors=[core_attn_out] + ) + nvtx_range_pop(suffix="linear_proj") + + return output, bias diff --git a/mbridge/models/qwen3_5/base_bridge.py b/mbridge/models/qwen3_5/base_bridge.py new file mode 100644 index 0000000..99da78c --- /dev/null +++ b/mbridge/models/qwen3_5/base_bridge.py @@ -0,0 +1,850 @@ +import inspect +import logging +from copy import deepcopy +from typing import Callable, Optional + +import torch + +from mbridge.core import VLMBridge +from mbridge.core.util import unwrap_model +from mbridge.models.qwen3_5.qwen3_5_safetensor import Qwen3_5SafeTensorIO + + +class Qwen3_5VlBaseBridge(VLMBridge): + + def _get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + """ + Gets the transformer layer specification. + + Creates and returns a specification for the transformer layers based on + the current configuration. + + Returns: + TransformerLayerSpec: Specification for transformer layers + + Raises: + AssertionError: If normalization is not RMSNorm + """ + assert ( + self.config.normalization == "RMSNorm" + ), "only RMSNorm is supported for now" + + try: + from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, + is_linear_attention_variant, + ) + except ImportError: + is_linear_attention_variant = None + get_transformer_block_with_experimental_attention_variant_spec = None + + if is_linear_attention_variant is not None and is_linear_attention_variant( + getattr(self.config, "experimental_attention_variant", None) + ): + # check if get_transformer_block_with_experimental_attention_variant_spec has vp_stage parameter + sig = inspect.signature( + get_transformer_block_with_experimental_attention_variant_spec + ) + self.has_vp_stage = ( + "vp_stage" in sig.parameters + ) # for mcore 0.12 compatibility + extra_args = {} + if self.has_vp_stage: + extra_args["vp_stage"] = vp_stage + + # Use experimental attention variant spec for linear attention (e.g., gated_delta_net) + transformer_layer_spec = ( + get_transformer_block_with_experimental_attention_variant_spec( + self.config, + **extra_args, + ) + ) + else: + raise ImportError( + "experimental_attention_variant is not supported, please megatron-lm dev branch" + ) + + return transformer_layer_spec + + def _adjust_mapping_for_shared_weights(self): + if getattr(self.hf_config.text_config, "tie_word_embeddings", False): + self._DIRECT_MAPPING["language_model.output_layer.weight"] = ( + "model.language_model.embed_tokens.weight" + ) + + def _get_hf_shared_weight_keys(self): + if getattr(self.hf_config.text_config, "tie_word_embeddings", False): + return ["model.language_model.embed_tokens.weight"] + return [] + + def _get_mcore_config_by_name(self, mcore_weights_name: str): + return self.config + + def _get_safetensor_io(self, weights_path: str): + # TODO: MTP layers are not handled yet + return Qwen3_5SafeTensorIO( + self._get_actual_hf_path(weights_path), ignore_mtp=True + ) + + def _weight_name_mapping_mcore_local_to_global( + self, model: torch.nn.Module, consider_ep: bool = True + ) -> dict[str, str]: + # vpp + local_layer_to_global_layer = {} + model = unwrap_model(model) + if hasattr(model, "language_model") and hasattr( + model.language_model, "decoder" + ): + for idx, layer in enumerate(model.language_model.decoder.layers): + local_layer_to_global_layer[idx] = layer.layer_number - 1 + all_param_names = [ + k for k in model.state_dict().keys() if "_extra_state" not in k + ] + ret = {} + for param_name in all_param_names: + keyword = "language_model.decoder.layers." + if keyword in param_name: + layer_idx = int(param_name.split(keyword)[1].split(".")[0]) + global_layer_idx = local_layer_to_global_layer[layer_idx] + ret[param_name] = param_name.replace( + f"layers.{layer_idx}.", f"layers.{global_layer_idx}." + ) + else: + ret[param_name] = param_name + + # ep + if self.mpu.ep_size > 1 and consider_ep: + num_experts = self.config.num_moe_experts + num_experts_per_rank = num_experts // self.mpu.ep_size + local_expert_to_global_expert = { + i: i + num_experts_per_rank * self.mpu.ep_rank + for i in range(num_experts_per_rank) + } + for k in ret.keys(): + v = ret[k] + if ".mlp.experts.linear_fc" in v: + name_prefix, local_expert_id = v.split(".weight") + global_expert_idx = local_expert_to_global_expert[ + int(local_expert_id) + ] + ret[k] = f"{name_prefix}.weight{global_expert_idx}" + + return ret + + def _weight_name_mapping_attention(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = split_name[3] + split_name[3] = "{layer_number}" + key = ".".join(split_name) + convert_names = [] + mapping_names = self._ATTENTION_MAPPING[key] + convert_names.extend( + [x.format(layer_number=layer_number) for x in mapping_names] + ) + if len(convert_names) == 0: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names + + def _weight_name_mapping_mlp(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = split_name[3] + split_name[3] = "{layer_number}" + key = ".".join(split_name) + convert_names = [] + mapping_names = self._MLP_MAPPING[key] + convert_names.extend( + [x.format(layer_number=layer_number) for x in mapping_names] + ) + if len(convert_names) == 0: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names + + def _weight_name_mapping_other(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = split_name[3] + split_name[3] = "{layer_number}" + key = ".".join(split_name) + convert_names = [] + mapping_names = self._OTHER_MAPPING[key] + convert_names.extend( + [x.format(layer_number=layer_number) for x in mapping_names] + ) + if len(convert_names) == 0: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names + + def _weight_name_mapping_visual(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = split_name[2] + split_name[2] = "{layer_number}" + key = ".".join(split_name) + convert_names = [] + mapping_names = self._VISUAL_MAPPING[key] + convert_names.extend( + [x.format(layer_number=layer_number) for x in mapping_names] + ) + if len(convert_names) == 0: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names + + def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str]: + """ + Map MCore weight names to Hugging Face weight names. + + Args: + mcore_weights_name: MCore weight name + + Returns: + list: Corresponding Hugging Face weight names + """ + assert ( + "_extra_state" not in mcore_weights_name + ), "extra_state should not be loaded" + + if mcore_weights_name in self._DIRECT_MAPPING: + return [self._DIRECT_MAPPING[mcore_weights_name]] + + if "vision_model" in mcore_weights_name: + return self._weight_name_mapping_visual(mcore_weights_name) + + if ".self_attention." in mcore_weights_name: + return self._weight_name_mapping_attention(mcore_weights_name) + elif "mlp" in mcore_weights_name: + return self._weight_name_mapping_mlp(mcore_weights_name) + else: + raise NotImplementedError( + f"Unsupported parameter name: {mcore_weights_name} {self._DIRECT_MAPPING}" + ) + + def _weight_to_hf_format( + self, mcore_weights_name: str, mcore_weights: torch.Tensor + ) -> tuple[list[str], list[torch.Tensor]]: + """ + Export MCore weights to Hugging Face format. + + Takes MCore weight names and tensor, outputs Hugging Face weight names and tensors. + Due to MCore's runtime optimizations involving weight merging, output can be a list. + + Args: + mcore_weights_name: MCore weight name + mcore_weights: MCore weight tensor + + Returns: + tuple: (hf_names, hf_weights) - lists of Hugging Face weight names and tensors + + Raises: + NotImplementedError: If the parameter name is unsupported + """ + hf_names = self._weight_name_mapping_mcore_to_hf(mcore_weights_name) + + self_attn_output_gate = getattr(self.config, "attention_output_gate", False) + + if len(hf_names) == 1: + # pad embeding and output layer + if self.make_vocab_size_divisible_by is not None and ( + "embedding.word_embeddings.weight" in mcore_weights_name + or "output_layer.weight" in mcore_weights_name + ): + assert mcore_weights.shape[0] == self.padded_vocab_size + assert self.vocab_size is not None + + return [hf_names[0]], [mcore_weights[: self.vocab_size]] + + # moe + if ".mlp.experts.linear_fc" in mcore_weights_name: + # get export index + experts_key = hf_names[0] + experts_idx = int(mcore_weights_name.split(".weight")[-1]) + + if experts_key not in self.export_weights_buff: + self.export_weights_buff[experts_key] = {} + assert experts_idx not in self.export_weights_buff[experts_key] + self.export_weights_buff[experts_key][experts_idx] = mcore_weights + + if ( + len(self.export_weights_buff[experts_key]) + < self.config.num_moe_experts + ): + return [], [] + + mcore_weights_list = [] + for idx in range(self.config.num_moe_experts): + mcore_weights_list.append( + self.export_weights_buff[experts_key].pop(idx) + ) + self.export_weights_buff.pop(experts_key) + return [hf_names[0]], [torch.stack(mcore_weights_list)] + elif "self_attention.out_norm.weight" in mcore_weights_name: + return [hf_names[0]], [mcore_weights + 1] + + return [hf_names[0]], [mcore_weights] + + if ( + "self_attention.linear_qkv." in mcore_weights_name + and "layer_norm" not in mcore_weights_name + ): + assert "vision_model" not in mcore_weights_name + # split qkv + assert len(hf_names) == 3 + # split qkv + num_key_value_heads = self.hf_config.text_config.num_key_value_heads + hidden_dim = self.hf_config.text_config.hidden_size + num_attention_heads = self.hf_config.text_config.num_attention_heads + + head_dim = getattr( + self.hf_config.text_config, + "head_dim", + hidden_dim // num_attention_heads, + ) + out_shape = ( + [num_key_value_heads, -1, hidden_dim] + if ".bias" not in mcore_weights_name + else [num_key_value_heads, -1] + ) + qkv = mcore_weights.view(*out_shape) + q_len = head_dim * num_attention_heads // num_key_value_heads + k_len = head_dim + v_len = head_dim + single_out_shape = ( + [-1, hidden_dim] if ".bias" not in mcore_weights_name else [-1] + ) + q = qkv[:, :q_len].reshape(*single_out_shape) + g = None + if self_attn_output_gate: + g = qkv[:, q_len : q_len + q_len].reshape(*single_out_shape) + q_len += q_len + + k = qkv[:, q_len : q_len + k_len].reshape(*single_out_shape) + v = qkv[:, q_len + k_len :].reshape(*single_out_shape) + + if self_attn_output_gate: + _out_shape = ( + [num_attention_heads, -1, hidden_dim] + if ".bias" not in mcore_weights_name + else [num_attention_heads, -1] + ) + q = q.view(_out_shape) + g = g.view(_out_shape) + + q = torch.cat([q, g], dim=1).view(*single_out_shape).contiguous() + + return hf_names, [q, k, v] + + elif "vision_model" not in mcore_weights_name and ( + "linear_fc1.weight" in mcore_weights_name + or "linear_fc1.bias" in mcore_weights_name + ): + # split gate_proj and up_proj + assert len(hf_names) == 2 + gate, up = mcore_weights.chunk(2) + return hf_names, [gate, up] + + elif "self_attention.in_proj.weight" in mcore_weights_name: + assert len(hf_names) == 4 + hidden_size = self.hf_config.text_config.hidden_size + linear_num_key_heads = self.hf_config.text_config.linear_num_key_heads + linear_key_head_dim = self.hf_config.text_config.linear_key_head_dim + linear_num_value_heads = self.hf_config.text_config.linear_num_value_heads + linear_value_head_dim = self.hf_config.text_config.linear_value_head_dim + + k_dim = linear_num_key_heads * linear_key_head_dim + v_dim = linear_num_value_heads * linear_value_head_dim + split_shape = [ + k_dim, + k_dim, + v_dim, + v_dim, + linear_num_value_heads, + linear_num_value_heads, + ] + weight_lst = mcore_weights.split(split_shape, dim=0) + # weight_lst: [wq, wk, wv, wz, wb, wa] + assert len(weight_lst) == 6 + wq, wk, wv, wz, wb, wa = weight_lst + + # qk_shape = [linear_num_key_heads, linear_key_head_dim, -1] + # vz_shape = [linear_num_key_heads, v_dim // linear_num_key_heads, -1] + # ba_shape = [linear_num_key_heads, linear_num_value_heads // linear_num_key_heads, -1] + wq = wq.view([-1, hidden_size]) + wk = wk.view([-1, hidden_size]) + wv = wv.view([-1, hidden_size]) + + wz = wz.view([-1, hidden_size]).contiguous() + wb = wb.view([-1, hidden_size]).contiguous() + wa = wa.view([-1, hidden_size]).contiguous() + + qkv_weight = torch.cat([wq, wk, wv], dim=0).contiguous() + + return hf_names, [qkv_weight, wz, wb, wa] + + raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") + + def _weight_to_mcore_format( + self, mcore_weights_name: str, hf_weights: list[torch.Tensor] + ) -> torch.Tensor: + """ + Import Hugging Face weights to MCore format. + + Takes Hugging Face weight names and tensors, outputs MCore weight tensor. + Due to MCore's runtime optimizations involving weight merging, input is a list. + + Args: + mcore_weights_name: MCore weight name + hf_weights: List of Hugging Face weight tensors + + Returns: + torch.Tensor: MCore weight tensor + + Raises: + NotImplementedError: If the parameter name is unsupported + """ + self_attn_output_gate = getattr(self.config, "attention_output_gate", False) + + if len(hf_weights) == 1: + # vision model + hf_names = self._weight_name_mapping_mcore_to_hf(mcore_weights_name) + + # pad embeding and output layer + if self.make_vocab_size_divisible_by is not None and ( + "embedding.word_embeddings.weight" in mcore_weights_name + or "output_layer.weight" in mcore_weights_name + ): + assert hf_weights[0].shape[0] == self.vocab_size + assert self.padded_vocab_size is not None + + embed_dim = hf_weights[0].shape[1] + extra_zeros = torch.zeros( + (self.padded_vocab_size - self.vocab_size, embed_dim), + device=hf_weights[0].device, + dtype=hf_weights[0].dtype, + ) + return torch.cat((hf_weights[0], extra_zeros), dim=0) + + # moe + if ".mlp.experts.linear_fc" in mcore_weights_name: + # get export index + local_experts_idx = int(mcore_weights_name.split(".weight")[-1]) + num_experts = self.config.num_moe_experts + num_experts_per_rank = num_experts // self.mpu.ep_size + experts_idx = ( + local_experts_idx + num_experts_per_rank * self.mpu.ep_rank + ) + return hf_weights[0][experts_idx].clone().contiguous() + # return hf_weights[0][experts_idx].T.clone().contiguous() + elif "self_attention.out_norm.weight" in mcore_weights_name: + return hf_weights[0] - 1 + + return hf_weights[0] + + if ( + "self_attention.linear_qkv." in mcore_weights_name + and "layer_norm" not in mcore_weights_name + ): + # merge qkv + assert len(hf_weights) == 3 + num_key_value_heads = self.hf_config.text_config.num_key_value_heads + hidden_dim = self.hf_config.text_config.hidden_size + num_attention_heads = self.hf_config.text_config.num_attention_heads + if "vision_model" in mcore_weights_name: + num_attention_heads = self.hf_config.text_config.vision_config.num_heads + num_key_value_heads = self.hf_config.text_config.vision_config.num_heads + head_dim = getattr( + self.hf_config.text_config, + "head_dim", + hidden_dim // num_attention_heads, + ) + group_dim = head_dim * num_attention_heads // num_key_value_heads + q, k, v = hf_weights + # q k v might be tp split + if self_attn_output_gate: + real_num_key_value_heads = q.shape[0] // 2 // group_dim + + combined_w = q.reshape((num_attention_heads, 2 * head_dim, -1)) + q_w = combined_w.narrow(1, 0, head_dim).reshape( + (num_attention_heads * head_dim, -1) + ) + g_w = combined_w.narrow(1, head_dim, head_dim).reshape( + (num_attention_heads * head_dim, -1) + ) + + q = q_w.view( + [ + real_num_key_value_heads, + group_dim, + -1, + ] + ) + g = g_w.view( + [ + real_num_key_value_heads, + group_dim, + -1, + ] + ) + else: + real_num_key_value_heads = q.shape[0] // group_dim + q = q.view( + [ + real_num_key_value_heads, + group_dim, + -1, + ] + ) + + k = k.view([real_num_key_value_heads, head_dim, -1]) + v = v.view([real_num_key_value_heads, head_dim, -1]) + out_shape = [-1, hidden_dim] if ".bias" not in mcore_weights_name else [-1] + + if self_attn_output_gate: + qkv = torch.cat([q, g, k, v], dim=1).view(*out_shape).contiguous() + else: + qkv = torch.cat([q, k, v], dim=1).view(*out_shape).contiguous() + return qkv + elif "vision_model" not in mcore_weights_name and ( + "linear_fc1.weight" in mcore_weights_name + or "linear_fc1.bias" in mcore_weights_name + ): + # merge gate_proj and up_proj + assert len(hf_weights) == 2 + gate, up = hf_weights + return torch.cat([gate, up], dim=0) + elif "self_attention.in_proj.weight" in mcore_weights_name: + assert len(hf_weights) == 4 + hidden_size = self.hf_config.text_config.hidden_size + in_proj_qkv, in_proj_z, in_proj_b, in_proj_a = hf_weights + linear_num_key_heads = self.hf_config.text_config.linear_num_key_heads + linear_key_head_dim = self.hf_config.text_config.linear_key_head_dim + linear_num_value_heads = self.hf_config.text_config.linear_num_value_heads + linear_value_head_dim = self.hf_config.text_config.linear_value_head_dim + key_dim = linear_key_head_dim * linear_num_key_heads + value_dim = linear_value_head_dim * linear_num_value_heads + + split_shape = [ + key_dim, + key_dim, + value_dim, + ] + wq, wk, wv = in_proj_qkv.split(split_shape, dim=0) + + # qkv_dim_per_partition = 2 * linear_key_head_dim + value_dim // linear_num_key_heads + # in_proj_qkv_ = in_proj_qkv.reshape((linear_num_key_heads, qkv_dim_per_partition, -1)) + # wq = in_proj_qkv_.narrow(1, 0, linear_key_head_dim).reshape(key_dim, -1) + # wk = in_proj_qkv_.narrow(1, linear_key_head_dim, linear_key_head_dim).reshape(key_dim, -1) + # wv = in_proj_qkv_.narrow(1, 2 * linear_key_head_dim, + # value_dim // linear_num_key_heads).reshape(value_dim, -1) + + wz = in_proj_z.reshape(value_dim, -1) + wb = in_proj_b.reshape((linear_num_value_heads, -1)) + wa = in_proj_a.reshape((linear_num_value_heads, -1)) + return torch.cat([wq, wk, wv, wz, wb, wa], dim=0) + else: + logging.warning( + f"Unhandled weights {mcore_weights_name}: count={len(hf_weights)} shapes={[hf_w.shape for hf_w in hf_weights]}" + ) + raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") + + def _weight_merge_across_tp( + self, + mcore_weights_name: str, + mcore_weights: list[torch.Tensor], + param: torch.Tensor, + ) -> torch.Tensor: + """ + Merge weights across tensor parallel ranks. + In mcore format + + Args: + mcore_weights_name: MCore weight name + mcore_weights: List of MCore weight tensors from different TP ranks + param: Parameter tensor + + Returns: + torch.Tensor: Merged weight tensor + """ + if "mlp.experts.linear_fc" in mcore_weights_name: + assert len(mcore_weights) == self.mpu.etp_size + if self.mpu.etp_size == 1: + assert len(mcore_weights) == 1 + return mcore_weights[0] + else: + assert len(mcore_weights) == self.mpu.tp_size + if self.mpu.tp_size == 1: + assert len(mcore_weights) == 1 + return mcore_weights[0] + if ( + "self_attention.linear_qkv." in mcore_weights_name + and "layer_norm" not in mcore_weights_name + ): + return torch.cat(mcore_weights, dim=0) + elif "vision_model" not in mcore_weights_name and ( + "linear_fc1.weight" in mcore_weights_name + or "linear_fc1.bias" in mcore_weights_name + ): + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + if not mcore_config.gated_linear_unit: + return torch.cat(mcore_weights, dim=0) + + # if the tensor is gate and proj + gate_lst = [] + up_lst = [] + for mcore_weight in mcore_weights: + gate, up = mcore_weight.chunk(2) + gate_lst.append(gate) + up_lst.append(up) + gate = torch.cat(gate_lst, dim=0) + up = torch.cat(up_lst, dim=0) + ret = torch.cat((gate, up), dim=0) + + elif "mlp.experts.linear_fc2.weight" in mcore_weights_name: # moe + ret = torch.cat(mcore_weights, dim=1) + elif ( + "self_attention.linear_kv_down_proj.weight" in mcore_weights_name + or "self_attention.linear_q_down_proj.weight" in mcore_weights_name + ): + # self_attention.linear_kv_down_proj.weight and self_attention.linear_q_down_proj.weight are copied + return mcore_weights[0] + elif "self_attention.in_proj.weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + tp_size = len(mcore_weights) + k_dim = mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + v_dim = ( + mcore_config.linear_num_value_heads * mcore_config.linear_value_head_dim + ) + split_shape = [ + k_dim // tp_size, + k_dim // tp_size, + v_dim // tp_size, + v_dim // tp_size, + mcore_config.linear_num_value_heads // tp_size, + mcore_config.linear_num_value_heads // tp_size, + ] + # split_shape for [wq, wk, wv, wz, wb, wa] + ret = self._split_weight_by_size_and_merge_across_tp( + mcore_weights, split_shape + ) + elif "self_attention.conv1d" in mcore_weights_name: + if "weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + tp_size = len(mcore_weights) + k_dim = ( + mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + ) + v_dim = ( + mcore_config.linear_num_value_heads + * mcore_config.linear_value_head_dim + ) + split_shape = [ + k_dim // tp_size, + k_dim // tp_size, + v_dim // tp_size, + ] + # split_shape for [X, B, C] + ret = self._split_weight_by_size_and_merge_across_tp( + mcore_weights, split_shape + ) + else: + raise NotImplementedError(f"{mcore_weights_name} not supported yet") + else: + assert ( + hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel + ) + ret = torch.cat(mcore_weights, dim=param.partition_dim) + + return ret + + def _weight_split_across_tp( + self, + mcore_weights_name: str, + mcore_weights: torch.Tensor, + param: torch.Tensor, + tp_split_size: int, + ) -> list[torch.Tensor]: + """ + Split weight tensor across tensor parallel ranks. + + Args: + mcore_weights_name: MCore weight name + mcore_weights: MCore weight tensor + param: Parameter tensor + + Returns: + list: List of weight tensors split for each TP rank + """ + if tp_split_size == 1: + return [mcore_weights] + + if ( + "self_attention.linear_qkv." in mcore_weights_name + and "layer_norm" not in mcore_weights_name + ): + return mcore_weights.chunk(tp_split_size) + elif "vision_model" not in mcore_weights_name and ( + "linear_fc1.weight" in mcore_weights_name + or "linear_fc1.bias" in mcore_weights_name + ): + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + if not mcore_config.gated_linear_unit: + return mcore_weights.chunk(tp_split_size) + + gate, up = mcore_weights.chunk(2) + gates = gate.chunk(tp_split_size) + ups = up.chunk(tp_split_size) + ret = [torch.cat([g, u], dim=0) for g, u in zip(gates, ups)] + elif "mlp.experts.linear_fc2.weight" in mcore_weights_name: # moe + ret = mcore_weights.chunk(tp_split_size, dim=1) + elif "self_attention.in_proj.weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + k_dim = mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + v_dim = ( + mcore_config.linear_num_value_heads * mcore_config.linear_value_head_dim + ) + split_shape = [ + k_dim, + k_dim, + v_dim, + v_dim, + mcore_config.linear_num_value_heads, + mcore_config.linear_num_value_heads, + ] + split_w_lst = mcore_weights.split(split_shape, dim=0) + # split_w_lst: [wq, wk, wv, wz, wb, wa] + assert len(split_w_lst) == 6, f"split_shape {split_shape} not supported" + weight_list = [] + for weight in split_w_lst: + weight_list.append(weight.chunk(tp_split_size)) + ret = [ + torch.cat( + [wq_slice, wk_slice, wv_slice, wz_slice, wb_slice, wa_slice], dim=0 + ) + for wq_slice, wk_slice, wv_slice, wz_slice, wb_slice, wa_slice in zip( + *weight_list + ) + ] + elif "self_attention.conv1d" in mcore_weights_name: + if "weight" in mcore_weights_name: + mcore_config = self._get_mcore_config_by_name(mcore_weights_name) + k_dim = ( + mcore_config.linear_num_key_heads * mcore_config.linear_key_head_dim + ) + v_dim = ( + mcore_config.linear_num_value_heads + * mcore_config.linear_value_head_dim + ) + split_shape = [ + k_dim, + k_dim, + v_dim, + ] + split_w_lst = mcore_weights.split(split_shape, dim=0) + # split_w_lst: [X, B, C] + assert len(split_w_lst) == 3, f"split_shape {split_shape} not supported" + weight_list = [] + for weight in split_w_lst: + weight_list.append(weight.chunk(tp_split_size)) + ret = [ + torch.cat([x_slice, b_slice, c_slice], dim=0) + for x_slice, b_slice, c_slice in zip(*weight_list) + ] + else: + raise NotImplementedError(f"{mcore_weights_name} not supported yet") + else: + if param.shape == mcore_weights.shape: + return [mcore_weights for _ in range(tp_split_size)] + assert len(param.shape) == len(mcore_weights.shape) + for partition_dim, (s1, s2) in enumerate( + zip(param.shape, mcore_weights.shape) + ): + if s1 != s2: + break + + ret = mcore_weights.chunk(tp_split_size, dim=partition_dim) + return ret + + def _split_weight_by_size_and_merge_across_tp( + self, + mcore_weights: list[torch.Tensor], + split_shape: list[int], + ) -> torch.Tensor: + """ + First split weight by splist_shape and then merge across tensor parallel ranks + + use for linear attn in_proj and linear attn conv1d layer weight + """ + tp_size = len(mcore_weights) + + weight_lst = [[] for _ in range(len(split_shape))] + for mcore_weight in mcore_weights: + split_w_lst = mcore_weight.split(split_shape, dim=0) + assert len(split_w_lst) == len(weight_lst) + for wi, split_w in enumerate(split_w_lst): + weight_lst[wi].append(split_w) + for weight in weight_lst: + assert len(weight) == tp_size + ret = torch.cat([torch.cat(w_split, dim=0) for w_split in weight_lst], dim=0) + return ret + + def _model_provider( + self, post_model_creation_callbacks: list[Callable[[torch.nn.Module], None]] + ): + """ + Creates and returns a model provider function. + + The returned function creates a GPTModel with the specified configuration + when called with pre_process and post_process parameters. + + Args: + post_model_creation_callbacks: List of callbacks to be called after model creation + + Returns: + function: A provider function that creates and returns a GPTModel instance + """ + from mbridge.models.qwen3_5.model import Qwen3_5VLModel + + share_embeddings_and_output_weights = getattr( + self.hf_config, "tie_word_embeddings", False + ) + + def provider( + pre_process, + post_process, + add_decoder=True, + add_encoder=True, + vp_stage: Optional[int] = None, + ): + transformer_layer_spec = self._get_transformer_layer_spec(vp_stage) + + model = Qwen3_5VLModel( + language_transformer_config=self.config, + language_transformer_layer_spec=transformer_layer_spec, + language_mtp_block_spec=None, + language_vocab_size=self.hf_config.text_config.vocab_size, + language_max_sequence_length=self.hf_config.text_config.max_position_embeddings, + hf_config=self.hf_config, + hf_vision_cls=self.HfVisionClass, + language_rotary_base=self.hf_config.text_config.rope_scaling.get( + "rope_theta", 10000000 + ), + position_embedding_type="mrope", + pre_process=pre_process, + post_process=post_process, + add_decoder=add_decoder, + add_encoder=add_encoder, + parallel_output=True, + language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, + image_token_id=self.hf_config.image_token_id, + video_token_id=self.hf_config.video_token_id, + vision_start_token_id=self.hf_config.vision_start_token_id, + ) + + for callback in post_model_creation_callbacks: + callback( + model, + pre_process=pre_process, + post_process=post_process, + config=self.config, + hf_config=self.hf_config, + ) + + return model + + return provider diff --git a/mbridge/models/qwen3_5/model.py b/mbridge/models/qwen3_5/model.py new file mode 100644 index 0000000..e7b950f --- /dev/null +++ b/mbridge/models/qwen3_5/model.py @@ -0,0 +1,369 @@ +import logging +from typing import Optional + +import torch +from megatron.core import InferenceParams, mpu, tensor_parallel +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from transformers import AutoConfig + +from mbridge.core.util import ( + AllGatherVisionEmbeddings, + collapse_thw, + get_vision_cp_data, + preprocess_packed_seqs, + qwen3vl_cp_split, + split_data_cp_rank, +) +from mbridge.models.qwen3_5.attention import Qwen3_5VLSelfAttention, SelfAttention +from mbridge.models.qwen3_5.transformer_config import Qwen3_5VLTransformerConfig +from mbridge.models.qwen3_5.utils import reorganize_inputs +from mbridge.models.qwen3_vl.rope_utils import ( + Qwen3VLMultimodalRotaryEmbedding, + get_rope_index, +) + + +class Qwen3_5GPTModel(GPTModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # rebuild rope + self.rotary_pos_emb = Qwen3VLMultimodalRotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=kwargs["rotary_percent"], + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=kwargs.get( + "seq_len_interpolation_factor", None + ), + rotary_base=kwargs.get("rotary_base", 10000), + ) + + +class Qwen3_5VLModel(MegatronModule): + """Qwen3_5VLModel model""" + + def __init__( + self, + language_transformer_config: Qwen3_5VLTransformerConfig, + language_transformer_layer_spec: ModuleSpec, + language_mtp_block_spec: ModuleSpec, + language_vocab_size: int, + language_max_sequence_length: int, + hf_config: AutoConfig, + hf_vision_cls: type, + parallel_output: bool = True, + language_rotary_percent: float = 1.0, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + language_rotary_base: int = 10000, + position_embedding_type: str = "mrope", + fp16_lm_cross_entropy: bool = False, + language_share_embeddings_and_output_weights: bool = False, + image_token_id: int = 151655, + video_token_id: int = 151656, + vision_start_token_id: int = 151652, + rope_scaling: bool = False, + ) -> None: + super().__init__(config=language_transformer_config) + + for layer_spec in language_transformer_layer_spec.layer_specs: + # only replace SelfAttention + if isinstance(layer_spec.submodules.self_attention.module, SelfAttention): + layer_spec.submodules.self_attention.module = Qwen3_5VLSelfAttention + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.hf_config = hf_config + self.encoder_hidden_state = None + self.vision_model = None + self.language_model = None + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.spatial_merge_size = self.hf_config.vision_config.spatial_merge_size + self.square_merge_size = self.spatial_merge_size**2 + + if self.pre_process: + self.vision_model = hf_vision_cls._from_config(hf_config.vision_config) + self._cast_rotary_emb_to_fp32(self.vision_model) + self._mark_vision_params_sequence_parallel(self.vision_model) + + self.language_model = Qwen3_5GPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_vocab_size, + max_sequence_length=language_max_sequence_length, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + parallel_output=parallel_output, + share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, + position_embedding_type=position_embedding_type, + rotary_percent=language_rotary_percent, + rotary_base=language_rotary_base, + rope_scaling=rope_scaling, + mtp_block_spec=language_mtp_block_spec, + scatter_embedding_sequence_parallel=False, + ) + + @staticmethod + def _cast_rotary_emb_to_fp32(module: torch.nn.Module): + """Force all RotaryEmbedding inv_freq buffers to stay at original float32 precision.""" + for submodule in module.modules(): + if hasattr(submodule, "inv_freq") and submodule.inv_freq is not None: + # Save the original float32 inv_freq (this runs BEFORE Float16Module) + submodule._inv_freq_fp32_original = ( + submodule.inv_freq.detach().clone().float() + ) + + def _hook(mod, args): + if hasattr(mod, "_inv_freq_fp32_original"): + # Restore inv_freq from the saved fp32_copied + mod.inv_freq = mod._inv_freq_fp32_original.to( + device=mod.inv_freq.device + ) + + submodule.register_forward_pre_hook(_hook) + + def _mark_vision_params_sequence_parallel(self, module: torch.nn.Module): + """Mark all vision model parameters with sequence_parallel=True.""" + for param in module.parameters(): + setattr(param, "sequence_parallel", self.config.sequence_parallel) + + @property + def share_embeddings_and_output_weights(self): + return self.language_model.share_embeddings_and_output_weights + + @property + def decoder(self): + return self.language_model.decoder + + def shared_embedding_or_output_weight(self): + return self.language_model.shared_embedding_or_output_weight() + + def set_input_tensor(self, input_tensor) -> None: + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1" + + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze( + self, + freeze_language_model: bool, + freeze_vision_model: bool, + freeze_vision_projection: bool, + ): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False for the module's parameters. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. + """ + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_model is not None: + modules.append(self.vision_model.merger) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + if freeze_vision_model and not freeze_vision_projection: + if self.vision_model is not None: + for param in self.vision_model.merger.parameters(): + param.requires_grad = True + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + inference_params: Optional[BaseInferenceContext] = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + # can set at dataset + image_input_mask: torch.Tensor = None, + video_input_mask: torch.Tensor = None, + cp_img_num: list[int] = None, + images_padded: list[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert inference_context is None, "not support inference yet" + + vision_grid_thw = None + vision_data = None + vision_mask = None + # TODO: this approach may need rethinking + cp_size = mpu.get_context_parallel_world_size() + + if self.pre_process: + # can reorganize_inputs at dataset + vision_data, vision_grid_thw, vision_mask = reorganize_inputs( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + square_merge_size=self.square_merge_size, + ) + + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + if cp_size > 1: + if cp_img_num is None: + assert images_padded is None + vision_data, vision_grid_thw, cp_img_num, images_padded = ( + qwen3vl_cp_split( + cp_size, + vision_data, + vision_grid_thw, + ) + ) + vision_data, vision_grid_thw, seqlen_on_cp_ranks = ( + get_vision_cp_data( + vision_data, + vision_grid_thw, + self.square_merge_size, + cp_img_num, + images_padded, + ) + ) + vision_grid_thw = collapse_thw(vision_grid_thw) + if vision_data.shape[0] > 0: + vision_embeds = self.vision_model( + hidden_states=vision_data, + grid_thw=vision_grid_thw, + ).pooler_output + # Encodes images into continuous embeddings that can be forwarded to the language model. + split_sizes = ( + vision_grid_thw.prod(-1) // self.spatial_merge_size**2 + ).tolist() + vision_embeds = torch.split(vision_embeds, split_sizes) + vision_embeds = torch.cat(vision_embeds, dim=0) + else: + vision_embeds = torch.zeros( + (0, self.language_model.config.hidden_size), + device=vision_data.device, + dtype=torch.bfloat16, + ) + if cp_size > 1: + vision_embeds = AllGatherVisionEmbeddings.apply( + vision_embeds, + seqlen_on_cp_ranks, + ) + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ).clone() # [text_seq_len, b, h_language] + + if vision_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + combined_embeddings[vision_mask] = vision_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + if ( + combined_embeddings is not None + and cp_size > 1 + and packed_seq_params is None + ): + combined_embeddings = split_data_cp_rank( + combined_embeddings, cp_size, 0 + ) + if packed_seq_params is not None: + input_ids_thd, _ = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=True + ) + vision_mask_thd = (input_ids_thd == self.image_token_id) | ( + input_ids_thd == self.video_token_id + ) + + vision_mask = vision_mask_thd + combined_embeddings_thd = ( + preprocess_packed_seqs( + combined_embeddings.transpose(0, 1).contiguous(), + attention_mask, + pre_process=True, + )[0] + .transpose(0, 1) + .contiguous() + ) + combined_embeddings = combined_embeddings_thd + + if self.config.sequence_parallel: + combined_embeddings = ( + tensor_parallel.scatter_to_sequence_parallel_region( + combined_embeddings + ) + ) + combined_embeddings = combined_embeddings.contiguous() + + else: + combined_embeddings = None + + if position_ids is None: + # BSHD + position_ids, _ = get_rope_index( + self.config.spatial_merge_size, + self.image_token_id, + self.video_token_id, + self.vision_start_token_id, + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) # [3*b*s] + if packed_seq_params is not None: + # convert position_ids to THD format + position_ids = ( + preprocess_packed_seqs( + position_ids.permute(1, 2, 0), attention_mask, pre_process=True + )[0] + .permute(2, 0, 1) + .contiguous() + ) + attention_mask = None + self.language_model.rotary_pos_emb.is_thd_format = True + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + inference_context=inference_context, + packed_seq_params=packed_seq_params, # currently always None + runtime_gather_output=runtime_gather_output, + inference_params=inference_params, # currently always None + **(extra_block_kwargs or {}), + **kwargs, + ) + + return output diff --git a/mbridge/models/qwen3_5/qwen3_5_safetensor.py b/mbridge/models/qwen3_5/qwen3_5_safetensor.py new file mode 100644 index 0000000..e33b5f0 --- /dev/null +++ b/mbridge/models/qwen3_5/qwen3_5_safetensor.py @@ -0,0 +1,52 @@ +import json +import os +import warnings +from collections import defaultdict +from glob import glob +from typing import Generator + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig + +from mbridge.core.safetensor_io import SafeTensorIO + + +class Qwen3_5SafeTensorIO(SafeTensorIO): + def __init__(self, hf_dir: str, ignore_mtp: bool = False): + index_file = os.path.join(hf_dir, "model.safetensors.index.json") + config = AutoConfig.from_pretrained(hf_dir, trust_remote_code=True) + + self.index = {} + self.origin_index = {} + if os.path.exists(index_file): + with open(index_file, "r") as f: + origin_index = json.load(f) + + filtered_index = {} + for key, value in origin_index["weight_map"].items(): + if ignore_mtp and "mtp" in key: + continue + filtered_index[key] = value + origin_index["weight_map"] = filtered_index + + self.index = origin_index["weight_map"] + if getattr(config, "tie_word_embeddings", False) or getattr( + getattr(config, "text_config", None), "tie_word_embeddings", False + ): + if "lm_head.weight" in self.index.keys(): + self.index.pop("lm_head.weight") + self.origin_index = origin_index + else: + src_files = glob(os.path.join(hf_dir, "*.safetensors")) + if len(src_files) == 1: + for file in src_files: + with safe_open(file, framework="pt", device="cpu") as f: + filename = os.path.basename(file) + for key in f.keys(): + if ignore_mtp and "mtp" in key: + continue + self.index[key] = filename + + self.hf_dir = hf_dir diff --git a/mbridge/models/qwen3_5/qwen3_5_vl_bridge.py b/mbridge/models/qwen3_5/qwen3_5_vl_bridge.py new file mode 100644 index 0000000..ad7eee1 --- /dev/null +++ b/mbridge/models/qwen3_5/qwen3_5_vl_bridge.py @@ -0,0 +1,362 @@ +import torch + +from mbridge.core import register_model +from mbridge.models.qwen3_5.base_bridge import Qwen3_5VlBaseBridge +from mbridge.models.qwen3_5.transformer_config import Qwen3_5VLTransformerConfig + +_QWEN3p5VIT_DIRECT_MAPPING = { + "vision_model.patch_embed.proj.weight": "model.visual.patch_embed.proj.weight", + "vision_model.patch_embed.proj.bias": "model.visual.patch_embed.proj.bias", + "vision_model.pos_embed.weight": "model.visual.pos_embed.weight", + "vision_model.merger.norm.weight": "model.visual.merger.norm.weight", + "vision_model.merger.norm.bias": "model.visual.merger.norm.bias", + "vision_model.merger.linear_fc1.weight": "model.visual.merger.linear_fc1.weight", + "vision_model.merger.linear_fc1.bias": "model.visual.merger.linear_fc1.bias", + "vision_model.merger.linear_fc2.weight": "model.visual.merger.linear_fc2.weight", + "vision_model.merger.linear_fc2.bias": "model.visual.merger.linear_fc2.bias", +} + +_QWEN3p5_VISUAL_MAPPING = { + # visual attn + "vision_model.blocks.{layer_number}.attn.proj.weight": [ + "model.visual.blocks.{layer_number}.attn.proj.weight", + ], + "vision_model.blocks.{layer_number}.attn.proj.bias": [ + "model.visual.blocks.{layer_number}.attn.proj.bias", + ], + "vision_model.blocks.{layer_number}.attn.qkv.bias": [ + "model.visual.blocks.{layer_number}.attn.qkv.bias", + ], + "vision_model.blocks.{layer_number}.attn.qkv.weight": [ + "model.visual.blocks.{layer_number}.attn.qkv.weight", + ], + # visual mlp + "vision_model.blocks.{layer_number}.mlp.linear_fc1.weight": [ + "model.visual.blocks.{layer_number}.mlp.linear_fc1.weight", + ], + "vision_model.blocks.{layer_number}.mlp.linear_fc1.bias": [ + "model.visual.blocks.{layer_number}.mlp.linear_fc1.bias", + ], + "vision_model.blocks.{layer_number}.mlp.linear_fc2.weight": [ + "model.visual.blocks.{layer_number}.mlp.linear_fc2.weight", + ], + "vision_model.blocks.{layer_number}.mlp.linear_fc2.bias": [ + "model.visual.blocks.{layer_number}.mlp.linear_fc2.bias", + ], + # visual norm + "vision_model.blocks.{layer_number}.norm1.weight": [ + "model.visual.blocks.{layer_number}.norm1.weight", + ], + "vision_model.blocks.{layer_number}.norm1.bias": [ + "model.visual.blocks.{layer_number}.norm1.bias", + ], + "vision_model.blocks.{layer_number}.norm2.weight": [ + "model.visual.blocks.{layer_number}.norm2.weight", + ], + "vision_model.blocks.{layer_number}.norm2.bias": [ + "model.visual.blocks.{layer_number}.norm2.bias", + ], +} + +_QWEN3p5TEXT_DIRECT_MAPPING = { + "language_model.embedding.word_embeddings.weight": "model.language_model.embed_tokens.weight", + "language_model.decoder.final_layernorm.weight": "model.language_model.norm.weight", + "language_model.output_layer.weight": "lm_head.weight", +} + +_QWEN3p5TEXT_ATTENTION_MAPPING = { + "language_model.decoder.layers.{layer_number}.self_attention.linear_proj.weight": [ + "model.language_model.layers.{layer_number}.self_attn.o_proj.weight", + ], + "language_model.decoder.layers.{layer_number}.self_attention.linear_qkv.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.input_layernorm.weight", + ], + "language_model.decoder.layers.{layer_number}.self_attention.q_layernorm.weight": [ + "model.language_model.layers.{layer_number}.self_attn.q_norm.weight", + ], + "language_model.decoder.layers.{layer_number}.self_attention.k_layernorm.weight": [ + "model.language_model.layers.{layer_number}.self_attn.k_norm.weight", + ], + "language_model.decoder.layers.{layer_number}.self_attention.linear_qkv.weight": [ + "model.language_model.layers.{layer_number}.self_attn.q_proj.weight", + "model.language_model.layers.{layer_number}.self_attn.k_proj.weight", + "model.language_model.layers.{layer_number}.self_attn.v_proj.weight", + ], + "language_model.decoder.layers.{layer_number}.self_attention.linear_qkv.bias": [ + "model.language_model.layers.{layer_number}.self_attn.q_proj.bias", + "model.language_model.layers.{layer_number}.self_attn.k_proj.bias", + "model.language_model.layers.{layer_number}.self_attn.v_proj.bias", + ], + # linear attention + "language_model.decoder.layers.{layer_number}.self_attention.dt_bias": [ + "model.language_model.layers.{layer_number}.linear_attn.dt_bias" + ], + "language_model.decoder.layers.{layer_number}.self_attention.A_log": [ + "model.language_model.layers.{layer_number}.linear_attn.A_log" + ], + "language_model.decoder.layers.{layer_number}.self_attention.in_proj.weight": [ + "model.language_model.layers.{layer_number}.linear_attn.in_proj_qkv.weight", + "model.language_model.layers.{layer_number}.linear_attn.in_proj_z.weight", + "model.language_model.layers.{layer_number}.linear_attn.in_proj_b.weight", + "model.language_model.layers.{layer_number}.linear_attn.in_proj_a.weight", + ], + "language_model.decoder.layers.{layer_number}.self_attention.conv1d.weight": [ + "model.language_model.layers.{layer_number}.linear_attn.conv1d.weight" + ], + "language_model.decoder.layers.{layer_number}.self_attention.out_norm.weight": [ + "model.language_model.layers.{layer_number}.linear_attn.norm.weight" + ], + "language_model.decoder.layers.{layer_number}.self_attention.out_proj.weight": [ + "model.language_model.layers.{layer_number}.linear_attn.out_proj.weight" + ], + "language_model.decoder.layers.{layer_number}.self_attention.in_proj.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.input_layernorm.weight" + ], +} + +_QWEN3p5TEXT_MLP_MAPPING = { + "language_model.decoder.layers.{layer_number}.mlp.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.up_proj.weight", + ], + "language_model.decoder.layers.{layer_number}.mlp.linear_fc1.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.post_attention_layernorm.weight", + ], + "language_model.decoder.layers.{layer_number}.mlp.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.down_proj.weight", + ], +} + +_QWEN3p5TEXT_MOE_MLP_MAPPING = { + "language_model.decoder.layers.{layer_number}.pre_mlp_layernorm.weight": [ + "model.language_model.layers.{layer_number}.post_attention_layernorm.weight", + ], + "language_model.decoder.layers.{layer_number}.mlp.router.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate.weight", + ], + "language_model.decoder.layers.{layer_number}.mlp.experts.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.experts.gate_up_proj", + ], + "language_model.decoder.layers.{layer_number}.mlp.experts.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.experts.down_proj", + ], + # shared expert + "language_model.decoder.layers.{layer_number}.mlp.shared_experts.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.shared_expert.up_proj.weight", + ], + "language_model.decoder.layers.{layer_number}.mlp.shared_experts.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.shared_expert.down_proj.weight" + ], + "language_model.decoder.layers.{layer_number}.mlp.shared_experts.gate_weight": [ + "model.language_model.layers.{layer_number}.mlp.shared_expert_gate.weight" + ], +} + + +@register_model("qwen3_5") +class Qwen3_5VlBridge(Qwen3_5VlBaseBridge): + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5VisionModel + except: + Qwen3_5VisionModel = None + HfVisionClass: type = Qwen3_5VisionModel + TransformerConfigClass = Qwen3_5VLTransformerConfig + + _CONFIG_MAPPING = { + "num_layers": "num_hidden_layers", + "hidden_size": "hidden_size", + "num_attention_heads": "num_attention_heads", + "num_query_groups": "num_key_value_heads", + "ffn_hidden_size": "intermediate_size", + "attention_dropout": "attention_dropout", + "layernorm_epsilon": "rms_norm_eps", + "hidden_dropout": ("hidden_dropout", 0.0), + "kv_channels": ("head_dim", None), + } + + _DIRECT_MAPPING = { + **_QWEN3p5VIT_DIRECT_MAPPING, + **_QWEN3p5TEXT_DIRECT_MAPPING, + } + _ATTENTION_MAPPING = { + **_QWEN3p5TEXT_ATTENTION_MAPPING, + } + _MLP_MAPPING = { + **_QWEN3p5TEXT_MLP_MAPPING, + } + _OTHER_MAPPING = {} + _VISUAL_MAPPING = { + **_QWEN3p5_VISUAL_MAPPING, + } + + def _build_config(self): + return self._build_base_config( + text_config_key="text_config", + layernorm_epsilon=self.hf_config.text_config.rms_norm_eps, + use_cpu_initialization=False, + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + masked_softmax_fusion=False, + deallocate_pipeline_outputs=True, + async_tensor_model_parallel_allreduce=True, + distribute_saved_activations=False, + cp_comm_type="p2p", + # Qwen3.5 specific + qk_layernorm=True, + layernorm_zero_centered_gamma=True, + attention_output_gate=True, + kv_channels=self.hf_config.text_config.head_dim, + experimental_attention_variant="gated_delta_net", + linear_attention_freq=self.hf_config.text_config.full_attention_interval, + linear_conv_kernel_dim=self.hf_config.text_config.linear_conv_kernel_dim, + linear_key_head_dim=self.hf_config.text_config.linear_key_head_dim, + linear_value_head_dim=self.hf_config.text_config.linear_value_head_dim, + linear_num_key_heads=self.hf_config.text_config.linear_num_key_heads, + linear_num_value_heads=self.hf_config.text_config.linear_num_value_heads, + rotary_percent=self.hf_config.text_config.rope_scaling.get( + "partial_rotary_factor", 0.25 + ), + rotary_interleaved=self.hf_config.text_config.rope_scaling.get( + "mrope_interleaved", True + ), + mrope_section=self.hf_config.text_config.rope_scaling.get( + "mrope_section", + [11, 11, 10], + ), + patch_size=self.hf_config.vision_config.patch_size, + temporal_patch_size=self.hf_config.vision_config.temporal_patch_size, + in_channels=self.hf_config.vision_config.in_channels, + spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, + num_position_embeddings=self.hf_config.vision_config.num_position_embeddings, + out_hidden_size=self.hf_config.vision_config.out_hidden_size, + apply_rotary_pos_emb_in_fp32=True, + ) + + +@register_model("qwen3_5_moe") +class Qwen3_5MoeVlBridge(Qwen3_5VlBaseBridge): + try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeVisionModel, + ) + except: + Qwen3_5MoeVisionModel = None + HfVisionClass: type = Qwen3_5MoeVisionModel + + TransformerConfigClass = Qwen3_5VLTransformerConfig + + _CONFIG_MAPPING = { + "num_layers": "num_hidden_layers", + "hidden_size": "hidden_size", + "num_attention_heads": "num_attention_heads", + "num_query_groups": "num_key_value_heads", + "ffn_hidden_size": "moe_intermediate_size", + "attention_dropout": "attention_dropout", + "layernorm_epsilon": "rms_norm_eps", + "hidden_dropout": ("hidden_dropout", 0.0), + "kv_channels": ("head_dim", None), + } + + _DIRECT_MAPPING = { + **_QWEN3p5VIT_DIRECT_MAPPING, + **_QWEN3p5TEXT_DIRECT_MAPPING, + } + _ATTENTION_MAPPING = { + **_QWEN3p5TEXT_ATTENTION_MAPPING, + } + _MLP_MAPPING = { + **_QWEN3p5TEXT_MOE_MLP_MAPPING, + } + _OTHER_MAPPING = {} + _VISUAL_MAPPING = { + **_QWEN3p5_VISUAL_MAPPING, + } + + def _weight_name_mapping_mlp(self, name: str) -> list[str]: + if ( + name.startswith("vision_model.") + or ".pre_mlp_layernorm.weight" in name + or ".mlp.router.weight" in name + or ".shared_experts" in name + ): + return super()._weight_name_mapping_mlp(name) + + assert ".mlp.experts.linear_fc" in name, f"{name=}" + split_name = name.split(".") + layer_number = split_name[3] + split_name[3] = "{layer_number}" + key = ".".join(split_name) + key = key.split(".weight")[0] + ".weight" + convert_names = [] + mapping_names = self._MLP_MAPPING[key] + convert_names.extend( + [x.format(layer_number=layer_number) for x in mapping_names] + ) + if len(convert_names) == 0: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names + + def _build_config(self): + return self._build_base_config( + text_config_key="text_config", + layernorm_epsilon=self.hf_config.text_config.rms_norm_eps, + use_cpu_initialization=False, + # MoE specific + moe_ffn_hidden_size=self.hf_config.text_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=self.hf_config.text_config.num_experts_per_tok, + num_moe_experts=self.hf_config.text_config.num_experts, + moe_aux_loss_coeff=self.hf_config.text_config.router_aux_loss_coef, + moe_token_dispatcher_type="alltoall", + moe_permute_fusion=True, + moe_router_dtype="fp32", + moe_router_load_balancing_type="none", # default None for RL + moe_shared_expert_overlap=True, + moe_grouped_gemm=True, + moe_router_score_function="softmax", + moe_shared_expert_intermediate_size=self.hf_config.text_config.shared_expert_intermediate_size, + moe_shared_expert_gate=self.hf_config.text_config.shared_expert_intermediate_size + > 0, + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + masked_softmax_fusion=False, + deallocate_pipeline_outputs=True, + async_tensor_model_parallel_allreduce=True, + distribute_saved_activations=False, + cp_comm_type="p2p", + # Qwen3.5 specific + moe_router_pre_softmax=False, + qk_layernorm=True, + layernorm_zero_centered_gamma=True, + attention_output_gate=True, + kv_channels=self.hf_config.text_config.head_dim, + experimental_attention_variant="gated_delta_net", + linear_attention_freq=self.hf_config.text_config.full_attention_interval, + linear_conv_kernel_dim=self.hf_config.text_config.linear_conv_kernel_dim, + linear_key_head_dim=self.hf_config.text_config.linear_key_head_dim, + linear_value_head_dim=self.hf_config.text_config.linear_value_head_dim, + linear_num_key_heads=self.hf_config.text_config.linear_num_key_heads, + linear_num_value_heads=self.hf_config.text_config.linear_num_value_heads, + rotary_percent=self.hf_config.text_config.rope_scaling.get( + "partial_rotary_factor", 0.25 + ), + rotary_interleaved=self.hf_config.text_config.rope_scaling.get( + "mrope_interleaved", True + ), + mrope_section=self.hf_config.text_config.rope_scaling.get( + "mrope_section", + [11, 11, 10], + ), + patch_size=self.hf_config.vision_config.patch_size, + temporal_patch_size=self.hf_config.vision_config.temporal_patch_size, + in_channels=self.hf_config.vision_config.in_channels, + spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, + num_position_embeddings=self.hf_config.vision_config.num_position_embeddings, + out_hidden_size=self.hf_config.vision_config.out_hidden_size, + apply_rotary_pos_emb_in_fp32=True, + ) diff --git a/mbridge/models/qwen3_5/rope_utils.py b/mbridge/models/qwen3_5/rope_utils.py new file mode 100644 index 0000000..59eb08a --- /dev/null +++ b/mbridge/models/qwen3_5/rope_utils.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import List, Optional + +import torch +from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rope_utils import ( + _apply_rotary_pos_emb_bshd, + get_pos_emb_on_this_cp_rank, +) +from torch import Tensor, nn + +from mbridge.models.qwen3_vl.transformer_config import Qwen3VLTransformerConfig + +# Prefer fused RoPE from Apex as we need the `transpose_output_memory` argument for the bshd trick. +# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2469. +try: + # pylint: disable=unused-import + from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb +except ImportError: + fused_apply_rotary_pos_emb = None + + +logger = logging.getLogger(__name__) + + +def apply_rotary_pos_emb_thd_absolute( + t: Tensor, + cu_seqlens: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, + cp_group: torch.distributed.ProcessGroup = None, +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + return _apply_rotary_pos_emb_bshd( + t[:, None], + freqs, + rotary_interleaved=rotary_interleaved, + multi_latent_attention=multi_latent_attention, + mscale=mscale, + ).squeeze(1) + + +def apply_rotary_pos_emb_absolute( + t: Tensor, + freqs: Tensor, + config: Qwen3VLTransformerConfig, + cu_seqlens: Optional[Tensor] = None, + mscale: float = 1.0, + cp_group: torch.distributed.ProcessGroup = None, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + bshd (conventional) / thd (packed seq) format + + In Qwen3-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] + """ + assert not config.apply_rope_fusion + orig_t_dtype = t.dtype + if config.apply_rotary_pos_emb_in_fp32: + t = t.float() + + if cu_seqlens is None: + result = _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + else: + result = apply_rotary_pos_emb_thd_absolute( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + cp_group=cp_group, + ) + + if config.apply_rotary_pos_emb_in_fp32: + result = result.to(orig_t_dtype) + + return result diff --git a/mbridge/models/qwen3_5/transformer_config.py b/mbridge/models/qwen3_5/transformer_config.py new file mode 100644 index 0000000..2c36a6b --- /dev/null +++ b/mbridge/models/qwen3_5/transformer_config.py @@ -0,0 +1,21 @@ +from copy import deepcopy +from dataclasses import dataclass, field +from functools import partial +from typing import List + +import torch +import torch.nn.functional as F +from megatron.core.transformer import TransformerConfig + + +@dataclass +class Qwen3_5VLTransformerConfig(TransformerConfig): + patch_size: int = 14 + in_channels: int = 3 + spatial_merge_size: int = 2 + temporal_patch_size: int = 2 + num_position_embeddings: int = 2304 + out_hidden_size: int = 2304 + apply_rotary_pos_emb_in_fp32: bool = False + rotary_percent: float = 1.0 + rotary_base: float = 10000 diff --git a/mbridge/models/qwen3_5/utils.py b/mbridge/models/qwen3_5/utils.py new file mode 100644 index 0000000..69bfcbf --- /dev/null +++ b/mbridge/models/qwen3_5/utils.py @@ -0,0 +1,91 @@ +import torch + +from mbridge.models.qwen3_vl.utils import find_vision_id_index + + +def reorganize_inputs( + input_ids: torch.Tensor, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + image_input_mask: torch.Tensor = None, + video_input_mask: torch.Tensor = None, + image_token_id: int = 151655, + video_token_id: int = 151656, + square_merge_size: int = 4, +): + if pixel_values is None: + if video_input_mask is None and pixel_values_videos is not None: + video_input_mask = (input_ids == video_token_id).contiguous() + return pixel_values_videos, video_grid_thw, video_input_mask + + if pixel_values_videos is None: + if image_input_mask is None and pixel_values is not None: + image_input_mask = (input_ids == image_token_id).contiguous() + return pixel_values, image_grid_thw, image_input_mask + + image_thw_cpu = image_grid_thw.tolist() + video_thw_cpu = video_grid_thw.tolist() + vision_indexs = find_vision_id_index( + input_ids.view(-1), image_token_id, video_token_id + ) + len_split = sum([thw[0] for thw in image_thw_cpu]) + len_split += sum([thw[0] for thw in video_thw_cpu]) + assert len_split == len(vision_indexs) + + vision_values = [] + vision_grid_thw = [] + idx = 0 + video_idx = 0 + image_idx = 0 + video_seqlen = 0 + image_seqlen = 0 + while idx < len(vision_indexs): + start, end, token_id = vision_indexs[idx] + if token_id == image_token_id: + seqlen = 0 + thw = image_thw_cpu[image_idx] + for i in range(thw[0]): + start, end, token_id = vision_indexs[idx + i] + assert token_id == image_token_id + seqlen += (end - start) * square_merge_size + assert seqlen == thw[0] * thw[1] * thw[2] + vision_values.append(pixel_values[image_seqlen : (image_seqlen + seqlen)]) + vision_grid_thw.append(thw) + + image_idx += 1 + idx += thw[0] + image_seqlen += seqlen + elif token_id == video_token_id: + seqlen = 0 + thw = video_thw_cpu[video_idx] + for i in range(thw[0]): + start, end, token_id = vision_indexs[idx + i] + assert token_id == video_token_id + seqlen += (end - start) * square_merge_size + assert seqlen == thw[0] * thw[1] * thw[2] + vision_values.append( + pixel_values_videos[video_seqlen : (video_seqlen + seqlen)] + ) + vision_grid_thw.append(thw) + + video_idx += 1 + idx += thw[0] + video_seqlen += seqlen + else: + assert False, f"should not have {token_id=}" + + if video_input_mask is None: + video_input_mask = input_ids == video_token_id + + if image_input_mask is None: + image_input_mask = input_ids == image_token_id + + vision_values = torch.cat(vision_values) + vision_grid_thw = torch.tensor( + vision_grid_thw, device=image_grid_thw.device, dtype=image_grid_thw.dtype + ) + vision_input_mask = video_input_mask | image_input_mask + + return vision_values, vision_grid_thw, vision_input_mask diff --git a/mbridge/models/qwen3_vl/model.py b/mbridge/models/qwen3_vl/model.py index fb32076..4a75c17 100644 --- a/mbridge/models/qwen3_vl/model.py +++ b/mbridge/models/qwen3_vl/model.py @@ -340,8 +340,9 @@ def forward( input_ids_thd, _ = preprocess_packed_seqs( input_ids, attention_mask, pre_process=True ) - vision_mask_thd = (input_ids_thd == self.image_token_id) \ - | (input_ids_thd == self.video_token_id) + vision_mask_thd = (input_ids_thd == self.image_token_id) | ( + input_ids_thd == self.video_token_id + ) if deepstack_feature_lists is not None: tmp_embeddings = torch.zeros_like( combined_embeddings.transpose(0, 1) diff --git a/mbridge/models/qwen3_vl/rope_utils.py b/mbridge/models/qwen3_vl/rope_utils.py index 7a812bb..ed3c503 100644 --- a/mbridge/models/qwen3_vl/rope_utils.py +++ b/mbridge/models/qwen3_vl/rope_utils.py @@ -72,7 +72,7 @@ def __init__( if rotary_percent < 1.0: dim = int(dim * rotary_percent) self.rotary_interleaved = rotary_interleaved - assert not self.rotary_interleaved, "only support qwen3vl" + # assert not self.rotary_interleaved, "only support qwen3vl" self.seq_len_interpolation_factor = seq_len_interpolation_factor self.inv_freq = 1.0 / ( @@ -103,7 +103,9 @@ def apply_interleaved_mrope(self, freqs, mrope_section): freqs_t[..., idx] = freqs[dim, ..., idx] return freqs_t - def forward(self, position_ids: torch.Tensor, mrope_section: List[int], **kwargs) -> Tensor: + def forward( + self, position_ids: torch.Tensor, mrope_section: List[int], **kwargs + ) -> Tensor: """Forward pass of multimodal RoPE embedding. Args: diff --git a/mbridge/utils/post_creation_callbacks.py b/mbridge/utils/post_creation_callbacks.py index 3c3f9a4..239cf8b 100644 --- a/mbridge/utils/post_creation_callbacks.py +++ b/mbridge/utils/post_creation_callbacks.py @@ -15,7 +15,7 @@ def freeze_moe_router(model, pre_process, post_process, config, hf_config): if hasattr(layer.mlp, "router"): if hasattr(layer.mlp.router, "weight"): layer.mlp.router.weight.requires_grad = False - if hasattr(layer.mlp.router, "bias"): + if hasattr(layer.mlp.router, "bias") and layer.mlp.router.bias is not None: layer.mlp.router.bias.requires_grad = False if hasattr(layer.mlp, "shared_experts"): if hasattr(layer.mlp.shared_experts, "gate_weight"):