From 44f5a1d4e378e9286f974d3ca29ac2bf388840e2 Mon Sep 17 00:00:00 2001 From: Yan Bai Date: Mon, 23 Feb 2026 22:47:51 -0800 Subject: [PATCH 01/11] add qwen3.5 megatron sft example --- examples/sft/gsm8k/run_qwen3_5_megatron.sh | 143 ++++++++++++++++++++ verl/models/mcore/model_forward.py | 2 +- verl/utils/dataset/multiturn_sft_dataset.py | 63 +++++++-- verl/utils/megatron_utils.py | 11 +- verl/utils/model.py | 13 +- verl/utils/tensordict_utils.py | 41 +++++- verl/workers/engine_workers.py | 11 +- verl/workers/fsdp_workers.py | 11 +- 8 files changed, 264 insertions(+), 31 deletions(-) create mode 100644 examples/sft/gsm8k/run_qwen3_5_megatron.sh diff --git a/examples/sft/gsm8k/run_qwen3_5_megatron.sh b/examples/sft/gsm8k/run_qwen3_5_megatron.sh new file mode 100644 index 00000000000..43ab11ec213 --- /dev/null +++ b/examples/sft/gsm8k/run_qwen3_5_megatron.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +# Qwen3.5-397B-A17B SFT with Megatron backend + mbridge +# +# Requirements: +# - 128+ GPUs (80GB each, e.g. 16x8 H100/H200) +# - Docker: verlai/verl:vllm015 (or equivalent) +# - Additional packages on top of the base image: +# pip install --upgrade transformers +# pip install flash-linear-attention +# pip install -U git+https://github.com/ISEEKYAN/mbridge.git +# - Megatron-LM dev branch with Qwen3.5 GDN support +# +# Qwen3.5 architecture notes: +# Qwen3.5 uses Gated Delta Net (GDN) linear attention which currently does +# NOT support packed sequences (THD format) in Megatron-LM. Therefore: +# - engine.use_remove_padding=False (forces bshd compute format) +# - model.use_remove_padding=True (keeps NestedTensor in data pipeline) +# - data.use_dynamic_bsz=False (required for bshd mode) +# +# Once https://github.com/NVIDIA/Megatron-LM/pull/2644 is merged, THD +# format will be supported and engine.use_remove_padding can be set to True +# for better performance. +# +# Tested parallelism config (128 GPUs / 16 nodes): +# TP=2 PP=4 EP=32 CP=1 + +set -xeuo pipefail + +# ============================================================ +# Distributed +# ============================================================ +NUM_GPUS=${NUM_GPUS:-8} +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-29500} +NNODES=${NNODES:-16} +NODE_RANK=${NODE_RANK:-0} + +# ============================================================ +# Data +# ============================================================ +DATASET_DIR=${DATASET_DIR:-~/dataset} +TRAIN_FILES=${TRAIN_FILES:-${DATASET_DIR}/train.parquet} + +# ============================================================ +# Model +# ============================================================ +MODEL_PATH=${MODEL_PATH:-Qwen/Qwen3.5-397B-A17B} + +# ============================================================ +# Parallelism +# ============================================================ +TP_SIZE=${TP_SIZE:-2} +PP_SIZE=${PP_SIZE:-4} +VPP_SIZE=${VPP_SIZE:-null} +CP_SIZE=${CP_SIZE:-1} +EP_SIZE=${EP_SIZE:-32} +ETP_SIZE=${ETP_SIZE:-1} + +# ============================================================ +# Training +# ============================================================ +TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-128} +MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2} +MAX_LENGTH=${MAX_LENGTH:-2048} +LR=${LR:-2e-5} +MIN_LR=${MIN_LR:-2e-6} +DTYPE=${DTYPE:-bfloat16} + +BACKEND=megatron +RESUME_MODE=${RESUME_MODE:-disable} + +project_name=verl_sft_qwen3_5 +exp_name=qwen3_5-${BACKEND}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-ep${EP_SIZE} +ckpts_home=${ckpts_home:-~/verl/checkpoints/${project_name}/${exp_name}} +mkdir -p "${ckpts_home}" + +# ============================================================ +# Engine config +# ============================================================ +# Key Qwen3.5 settings: +# engine.use_remove_padding=False - GDN requires bshd format (no THD) +# engine.vanilla_mbridge=True - use mbridge (not megatron-bridge) +ENGINE_CONFIG="\ + engine=${BACKEND} \ + optim=${BACKEND} \ + optim.lr=${LR} \ + optim.min_lr=${MIN_LR} \ + optim.lr_warmup_steps=10 \ + optim.weight_decay=0.1 \ + optim.betas='[0.9,0.95]' \ + optim.clip_grad=1.0 \ + optim.lr_warmup_init=0 \ + optim.lr_decay_style=cosine \ + +optim.override_optimizer_config.optimizer_offload_fraction=1 \ + +optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +optim.override_optimizer_config.optimizer_cpu_offload=True \ + engine.tensor_model_parallel_size=${TP_SIZE} \ + engine.pipeline_model_parallel_size=${PP_SIZE} \ + engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.expert_model_parallel_size=${EP_SIZE} \ + engine.expert_tensor_parallel_size=${ETP_SIZE} \ + engine.use_mbridge=True \ + engine.vanilla_mbridge=True \ + engine.dtype=${DTYPE} \ + engine.use_remove_padding=False \ + engine.override_transformer_config.attention_backend=auto \ + +engine.override_transformer_config.recompute_method=uniform \ + +engine.override_transformer_config.recompute_granularity=full \ + +engine.override_transformer_config.recompute_num_layers=1" + +# ============================================================ +# Launch +# ============================================================ +torchrun \ + --nproc_per_node=${NUM_GPUS} \ + --nnodes=${NNODES} \ + --node_rank=${NODE_RANK} \ + --master_addr=${MASTER_ADDR} \ + --master_port=${MASTER_PORT} \ + -m verl.trainer.sft_trainer \ + data.train_files="${TRAIN_FILES}" \ + data.train_batch_size=${TRAIN_BATCH_SIZE} \ + data.micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \ + data.max_length=${MAX_LENGTH} \ + data.pad_mode=no_padding \ + data.truncation=error \ + data.use_dynamic_bsz=False \ + data.max_token_len_per_gpu=${MAX_LENGTH} \ + data.messages_key=messages \ + model.path=${MODEL_PATH} \ + model.use_remove_padding=True \ + model.trust_remote_code=True \ + ${ENGINE_CONFIG} \ + trainer.test_freq=-1 \ + trainer.save_freq=500 \ + trainer.logger="['console']" \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${ckpts_home}" \ + trainer.resume_mode=${RESUME_MODE} diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index fd160fa86c9..4c82a56f09f 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -258,7 +258,7 @@ def gptmodel_forward_no_padding( output_orig = model( input_ids=input_ids_bshd, attention_mask=attention_mask_bshd, - position_ids=position_ids_bshd, + position_ids=None if vision_model else position_ids_bshd, **model_kwargs, ) if post_process and logits_processor is not None: diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 9da33228e21..e8fb4b65e37 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -22,6 +22,7 @@ from functools import wraps from typing import Any, Optional +import jinja2 import numpy as np import pandas as pd import torch @@ -215,20 +216,52 @@ def _process_single_message( Returns: Tuple of (input_ids, loss_mask, attention_mask, dict[str, torch.Tensor]) """ - processor = self.processor if self.processor is not None else self.tokenizer + has_visual_content = isinstance(message.get("content"), list) and any( + isinstance(c, dict) and c.get("type") in ("image", "video") for c in message["content"] + ) + processor = self.processor if self.processor is not None and has_visual_content else self.tokenizer apply_chat_template_kwargs = {**self.apply_chat_template_kwargs} if enable_thinking is not None: apply_chat_template_kwargs["enable_thinking"] = enable_thinking - inputs = processor.apply_chat_template( - [message], - tools=tools, - add_generation_prompt=False, - tokenize=True, - return_dict=True, - return_tensors="pt", - **apply_chat_template_kwargs, - ) + try: + inputs = processor.apply_chat_template( + [message], + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + **apply_chat_template_kwargs, + ) + except (jinja2.exceptions.TemplateError, Exception) as e: + if "No user query" not in str(e): + raise + # Chat templates that require a user message (e.g. Qwen3.5) fail + # when tokenising a single non-user message. Fallback: tokenise the + # conversation up to this turn and subtract the prefix. + inputs_full = processor.apply_chat_template( + full_message[: index + 1], + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + **apply_chat_template_kwargs, + ) + prefix_len = 0 + if index > 0: + inputs_prev = processor.apply_chat_template( + full_message[:index], + tools=tools if index == 1 else None, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + **apply_chat_template_kwargs, + ) + prefix_len = inputs_prev["input_ids"].shape[-1] + inputs = {k: v[..., prefix_len:] for k, v in inputs_full.items()} inputs = dict(inputs) input_ids = inputs.pop("input_ids")[0] @@ -266,14 +299,16 @@ def _build_messages(self, example: dict): image_offset, video_offset = 0, 0 for message in messages: - if self.image_key not in example and self.video_key not in example: - continue - assert self.processor is not None, "processor is needed to process image and video" - content = message["content"] if not isinstance(content, str): continue + if self.image_key not in example and self.video_key not in example: + if self.processor is not None: + message["content"] = [{"type": "text", "text": content}] + continue + assert self.processor is not None, "processor is needed to process image and video" + content_list = [] segments = re.split("(|