diff --git a/.github/workflows/cpu_unit_tests.yml b/.github/workflows/cpu_unit_tests.yml
index c0145d6d8a1..62e63cbe613 100644
--- a/.github/workflows/cpu_unit_tests.yml
+++ b/.github/workflows/cpu_unit_tests.yml
@@ -95,7 +95,7 @@ jobs:
run: |
pip3 install -r requirements-test.txt
pip3 install --no-deps -e .
- pip3 install --upgrade "transformers<5.0.0"
+ pip3 install --upgrade "transformers>=5.0.0"
- name: Download datasets
run: |
python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k
diff --git a/examples/grpo_trainer/run_qwen3_5-35b-megatron.sh b/examples/grpo_trainer/run_qwen3_5-35b-megatron.sh
new file mode 100644
index 00000000000..43563c20c59
--- /dev/null
+++ b/examples/grpo_trainer/run_qwen3_5-35b-megatron.sh
@@ -0,0 +1,166 @@
+#!/usr/bin/env bash
+# Qwen3.5-35B-A3B MoE GRPO RL with Megatron (single node, 8 GPUs, geo3k dataset)
+#
+# notes on vllm:
+# by 20260225, the latest vllm nightly does not support qwen3.5 rollout, to use this script, you need to
+# 1. wait until vllm supports qwen3.5 officially, and build a verl docker with that version of vllm
+# 2. self build a verl docker image with vllm from source code with qwen3.5 support (main branch 20260225 is OK)
+# I succeeded in running this script with the main branch of vllm on 20260225, yet there are still some minor issues
+# the vllm qwen3.5 during initialization, need to be fixed. Also, the cuda_graph is somehow not working, need to be
+# fixed, either by verl team with supoorts to vllm0.16, or by vllm team.
+# Requirements:
+# - 8 GPUs (80GB each, e.g. 1x8 H100/H200)
+# - Additional packages on top of the base image:
+# pip install --upgrade transformers
+# pip install flash-linear-attention
+# pip install -U git+https://github.com/ISEEKYAN/mbridge.git
+# - Megatron-LM dev branch with Qwen3.5 GDN support
+#
+# Qwen3.5 architecture notes:
+# Qwen3.5 uses Gated Delta Net (GDN) linear attention which currently does
+# NOT support packed sequences (THD format) in Megatron-LM. Therefore:
+# - model.use_remove_padding=False (deprecated option, will be removed in the future forces bshd compute format)
+# - actor.megatron.use_remove_padding=False (forces bshd compute format)
+# - actor.use_dynamic_bsz=False (required for bshd mode)
+#
+# Once Megatron-LM adds THD support for Qwen3.5 GDN, use_remove_padding
+# can be set to True for better performance.
+#
+# Tested parallelism config (8 GPUs / 1 node):
+# TP=2 PP=1 CP=1 EP=8 ETP=1 GEN_TP=8
+#
+
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export VLLM_USE_V1=1
+export VLLM_ALLREDUCE_USE_SYMM_MEM=0
+
+set -xeuo pipefail
+
+########################### Quick Config ###########################
+
+TP=${TP:-2}
+PP=${PP:-1}
+CP=${CP:-1}
+EP=${EP:-8}
+ETP=${ETP:-1}
+GEN_TP=${GEN_TP:-8}
+
+ALL_OFFLOAD=${ALL_OFFLOAD:-True}
+
+rollout_name="vllm"
+project_name='verl_grpo_qwen3_5_35b_geo3k'
+exp_name='qwen3_5_35b_megatron'
+adv_estimator=grpo
+
+HF_MODEL_PATH=${HF_MODEL_PATH:-"Qwen3.5-35B-A3B"}
+train_path=${train_path:-$HOME/data/geo3k/train.parquet}
+test_path=${test_path:-$HOME/data/geo3k/test.parquet}
+
+########################### Parameter Arrays ###########################
+
+DATA=(
+ data.train_files=${train_path}
+ data.val_files=${test_path}
+ data.train_batch_size=32
+ data.max_prompt_length=1024
+ data.max_response_length=2048
+ data.truncation='error'
+ data.filter_overlong_prompts=True
+)
+
+MODEL=(
+ actor_rollout_ref.model.path=${HF_MODEL_PATH}
+ actor_rollout_ref.model.trust_remote_code=True
+ actor_rollout_ref.model.use_remove_padding=False
+)
+
+ACTOR=(
+ actor_rollout_ref.actor.optim.lr=1e-6
+ actor_rollout_ref.actor.ppo_mini_batch_size=32
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096
+ actor_rollout_ref.actor.use_dynamic_bsz=False
+ actor_rollout_ref.actor.use_kl_loss=True
+ actor_rollout_ref.actor.kl_loss_coef=0.01
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl
+ actor_rollout_ref.actor.entropy_coeff=0
+ actor_rollout_ref.actor.megatron.use_mbridge=True
+ actor_rollout_ref.actor.megatron.vanilla_mbridge=True
+ actor_rollout_ref.actor.megatron.use_remove_padding=False
+ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP}
+ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP}
+ actor_rollout_ref.actor.megatron.context_parallel_size=${CP}
+ actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
+ actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
+ actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD}
+ actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD}
+ actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD}
+ actor_rollout_ref.actor.megatron.dtype=bfloat16
+ ++actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto
+ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
+ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+ +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
+ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.01
+ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_z_loss_coeff=0.001
+ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1
+ +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
+ +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
+ +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
+)
+
+ROLLOUT=(
+ actor_rollout_ref.rollout.name=${rollout_name}
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${GEN_TP}
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6
+ actor_rollout_ref.rollout.n=5
+ actor_rollout_ref.rollout.mode=async
+ actor_rollout_ref.rollout.enforce_eager=True
+ actor_rollout_ref.rollout.dtype=bfloat16
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096
+)
+
+REF=(
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096
+ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP}
+ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP}
+ actor_rollout_ref.ref.megatron.context_parallel_size=${CP}
+ actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP}
+ actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP}
+ actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD}
+)
+
+ALGORITHM=(
+ algorithm.adv_estimator=${adv_estimator}
+ algorithm.use_kl_in_reward=False
+)
+
+TRAINER=(
+ trainer.critic_warmup=0
+ trainer.logger='["console","wandb"]'
+ trainer.project_name=${project_name}
+ trainer.experiment_name=${exp_name}
+ trainer.n_gpus_per_node=8
+ trainer.nnodes=1
+ trainer.save_freq=20
+ trainer.val_before_train=False
+ trainer.test_freq=5
+ trainer.total_epochs=15
+)
+
+########################### Launch ###########################
+
+python3 -m verl.trainer.main_ppo \
+ --config-path=config \
+ --config-name='ppo_megatron_trainer.yaml' \
+ "${DATA[@]}" \
+ "${ALGORITHM[@]}" \
+ "${MODEL[@]}" \
+ "${ROLLOUT[@]}" \
+ "${ACTOR[@]}" \
+ "${REF[@]}" \
+ "${TRAINER[@]}" \
+ "$@"
diff --git a/examples/sft/gsm8k/run_qwen3_5_megatron.sh b/examples/sft/gsm8k/run_qwen3_5_megatron.sh
new file mode 100644
index 00000000000..102392e65c8
--- /dev/null
+++ b/examples/sft/gsm8k/run_qwen3_5_megatron.sh
@@ -0,0 +1,142 @@
+#!/usr/bin/env bash
+# Qwen3.5-397B-A17B SFT with Megatron backend + mbridge
+#
+# Requirements:
+# - 128+ GPUs (80GB each, e.g. 16x8 H100/H200)
+# - Docker: verlai/verl:vllm015 (or equivalent)
+# - Additional packages on top of the base image:
+# pip install --upgrade transformers
+# pip install flash-linear-attention
+# pip install -U git+https://github.com/ISEEKYAN/mbridge.git
+# - Megatron-LM dev branch with Qwen3.5 GDN support
+#
+# Qwen3.5 architecture notes:
+# Qwen3.5 uses Gated Delta Net (GDN) linear attention which currently does
+# NOT support packed sequences (THD format) in Megatron-LM. Therefore:
+# - engine.use_remove_padding=False (forces bshd compute format)
+# - data.use_dynamic_bsz=False (required for bshd mode)
+#
+# Once https://github.com/NVIDIA/Megatron-LM/pull/2644 is merged, THD
+# format will be supported and engine.use_remove_padding can be set to True
+# for better performance.
+#
+# Tested parallelism config (128 GPUs / 16 nodes):
+# TP=2 PP=4 EP=32 CP=1
+
+set -xeuo pipefail
+
+# ============================================================
+# Distributed
+# ============================================================
+NUM_GPUS=${NUM_GPUS:-8}
+MASTER_ADDR=${MASTER_ADDR:-localhost}
+MASTER_PORT=${MASTER_PORT:-29500}
+NNODES=${NNODES:-16}
+NODE_RANK=${NODE_RANK:-0}
+
+# ============================================================
+# Data
+# ============================================================
+DATASET_DIR=${DATASET_DIR:-~/dataset}
+TRAIN_FILES=${TRAIN_FILES:-${DATASET_DIR}/train.parquet}
+
+# ============================================================
+# Model
+# ============================================================
+MODEL_PATH=${MODEL_PATH:-Qwen/Qwen3.5-397B-A17B}
+
+# ============================================================
+# Parallelism
+# ============================================================
+TP_SIZE=${TP_SIZE:-2}
+PP_SIZE=${PP_SIZE:-4}
+VPP_SIZE=${VPP_SIZE:-null}
+CP_SIZE=${CP_SIZE:-1}
+EP_SIZE=${EP_SIZE:-32}
+ETP_SIZE=${ETP_SIZE:-1}
+
+# ============================================================
+# Training
+# ============================================================
+TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-128}
+MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2}
+MAX_LENGTH=${MAX_LENGTH:-2048}
+LR=${LR:-2e-5}
+MIN_LR=${MIN_LR:-2e-6}
+DTYPE=${DTYPE:-bfloat16}
+
+BACKEND=megatron
+RESUME_MODE=${RESUME_MODE:-disable}
+
+project_name=verl_sft_qwen3_5
+exp_name=qwen3_5-${BACKEND}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-ep${EP_SIZE}
+ckpts_home=${ckpts_home:-~/verl/checkpoints/${project_name}/${exp_name}}
+mkdir -p "${ckpts_home}"
+
+# ============================================================
+# Engine config
+# ============================================================
+# Key Qwen3.5 settings:
+# engine.use_remove_padding=False - GDN requires bshd format (no THD)
+# engine.vanilla_mbridge=True - use mbridge (not megatron-bridge)
+ENGINE_CONFIG="\
+ engine=${BACKEND} \
+ optim=${BACKEND} \
+ optim.lr=${LR} \
+ optim.min_lr=${MIN_LR} \
+ optim.lr_warmup_steps=10 \
+ optim.weight_decay=0.1 \
+ optim.betas='[0.9,0.95]' \
+ optim.clip_grad=1.0 \
+ optim.lr_warmup_init=0 \
+ optim.lr_decay_style=cosine \
+ +optim.override_optimizer_config.optimizer_offload_fraction=1 \
+ +optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
+ +optim.override_optimizer_config.use_precision_aware_optimizer=True \
+ +optim.override_optimizer_config.optimizer_cpu_offload=True \
+ engine.tensor_model_parallel_size=${TP_SIZE} \
+ engine.pipeline_model_parallel_size=${PP_SIZE} \
+ engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
+ engine.context_parallel_size=${CP_SIZE} \
+ engine.expert_model_parallel_size=${EP_SIZE} \
+ engine.expert_tensor_parallel_size=${ETP_SIZE} \
+ engine.use_mbridge=True \
+ engine.vanilla_mbridge=True \
+ engine.dtype=${DTYPE} \
+ engine.use_remove_padding=False \
+ engine.override_transformer_config.attention_backend=auto \
+ +engine.override_transformer_config.recompute_method=uniform \
+ +engine.override_transformer_config.recompute_granularity=full \
+ +engine.override_transformer_config.recompute_num_layers=1"
+
+# ============================================================
+# Launch
+# ============================================================
+torchrun \
+ --nproc_per_node=${NUM_GPUS} \
+ --nnodes=${NNODES} \
+ --node_rank=${NODE_RANK} \
+ --master_addr=${MASTER_ADDR} \
+ --master_port=${MASTER_PORT} \
+ -m verl.trainer.sft_trainer \
+ data.train_files="${TRAIN_FILES}" \
+ data.train_batch_size=${TRAIN_BATCH_SIZE} \
+ data.micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \
+ data.max_length=${MAX_LENGTH} \
+ data.pad_mode=no_padding \
+ data.truncation=error \
+ data.use_dynamic_bsz=False \
+ data.max_token_len_per_gpu=${MAX_LENGTH} \
+ data.messages_key=messages \
+ model.path=${MODEL_PATH} \
+ model.use_remove_padding=False \
+ model.trust_remote_code=True \
+ ${ENGINE_CONFIG} \
+ trainer.test_freq=-1 \
+ trainer.save_freq=500 \
+ trainer.logger="['console']" \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.total_epochs=1 \
+ trainer.default_local_dir="${ckpts_home}" \
+ trainer.resume_mode=${RESUME_MODE}
diff --git a/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py
index a55417ce839..02381b972d2 100644
--- a/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py
+++ b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py
@@ -26,9 +26,9 @@
from tensordict import TensorDict
from torch.utils.data import DistributedSampler
from torchdata.stateful_dataloader import StatefulDataLoader
-from transformers import AutoProcessor, AutoTokenizer
from transformers.utils import get_json_schema
+from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.dataset_utils import DatasetPadMode, SFTTensorCollator
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.model import extract_multi_modal_inputs
@@ -37,31 +37,29 @@
@pytest.mark.parametrize(
- "model_path",
+ "model_path, ignore_input_ids_mismatch",
[
- f"{custom_model_prefix}/Qwen/Qwen2.5-0.5B",
- f"{custom_model_prefix}/Qwen/Qwen2.5-Coder-7B-Instruct",
- f"{custom_model_prefix}/Qwen/Qwen3-30B-A3B-Instruct-2507",
- # "Qwen/Qwen3-30B-A3B-Thinking-2507" # Thinking series models add tags to last turn.
+ ("{custom_model_prefix}/Qwen/Qwen2.5-0.5B", False),
+ ("{custom_model_prefix}/Qwen/Qwen3-0.6B", True),
+ ("{custom_model_prefix}/Qwen/Qwen3.5-0.8B", False),
],
)
-@pytest.mark.parametrize("enable_thinking", [False, True])
-def test_multiturn_sft_dataset(model_path: str, enable_thinking: bool):
- print(f"Starting test... model_path={model_path}, enable_thinking={enable_thinking}")
+def test_multiturn_sft_dataset(model_path: str, ignore_input_ids_mismatch: bool):
+ print(f"Starting test... model_path={model_path}, ignore_input_ids_mismatch={ignore_input_ids_mismatch}")
# Create a temporary parquet file with test data
test_data = {
"messages": [
[
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."},
- {"role": "user", "content": "And what is 4+4?"},
+ {"role": "tool", "content": "And what is 4+4?"},
{"role": "assistant", "content": "4+4 equals 8."},
],
[
- {"role": "system", "content": "You are a powerful assistant."},
+ # {"role": "system", "content": "You are a powerful assistant."},
{"role": "user", "content": "Tell me a joke."},
{"role": "assistant", "content": "Why did the chicken cross the road?"},
- {"role": "user", "content": "Why?"},
+ {"role": "tool", "content": "Why?"},
{"role": "assistant", "content": "To get to the other side!"},
],
]
@@ -76,14 +74,16 @@ def test_multiturn_sft_dataset(model_path: str, enable_thinking: bool):
df.to_parquet(test_file)
# Initialize tokenizer and dataset
- tokenizer = AutoTokenizer.from_pretrained(model_path)
+ tokenizer = hf_tokenizer(model_path)
+ # processor = hf_processor(model_path)
+ processor = None
config = {
"max_length": 512,
"truncation": "error",
"multiturn": {"messages_key": "messages"},
- "apply_chat_template_kwargs": {"enable_thinking": enable_thinking},
+ "ignore_input_ids_mismatch": ignore_input_ids_mismatch,
}
- dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
+ dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, processor=processor, config=config)
# Test 1: Dataset Length
assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}"
@@ -189,8 +189,15 @@ def test_multiturn_sft_dataset(model_path: str, enable_thinking: bool):
)
# Test 10: Verify padding behavior
- padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}}
- small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config)
+ padding_config = {
+ "max_length": 1024,
+ "truncation": "error",
+ "multiturn": {"messages_key": "messages"},
+ "ignore_input_ids_mismatch": ignore_input_ids_mismatch,
+ }
+ small_dataset = MultiTurnSFTDataset(
+ parquet_files=test_file, tokenizer=tokenizer, processor=processor, config=padding_config
+ )
padded_item = small_dataset[0]
# Get actual sequence length (before padding)
@@ -209,8 +216,9 @@ def test_multiturn_sft_dataset(model_path: str, enable_thinking: bool):
"truncation": "error",
"multiturn": {"messages_key": "messages"},
"pad_mode": "no_padding",
+ "ignore_input_ids_mismatch": ignore_input_ids_mismatch,
}
- dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
+ dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, processor=processor, config=config)
item0 = dataset[0]
@@ -286,7 +294,7 @@ def vlm_data_file():
"content": "Let's generate a zoom-in image.",
"tool_calls": [
{
- "function": {"arguments": '{"bbox_2d": "[0, 1, 2, 4]"}', "name": "image_zoom_in_tool"},
+ "function": {"arguments": {"bbox_2d": "[0, 1, 2, 4]"}, "name": "image_zoom_in_tool"},
"type": "function",
}
],
@@ -331,13 +339,19 @@ def serialize_image(img):
return test_file
-def test_multiturn_sft_vlm_dataset_on_cpu(vlm_data_file):
+@pytest.mark.parametrize(
+ "model_path",
+ [
+ "{custom_model_prefix}/Qwen/Qwen3-VL-2B-Instruct",
+ "{custom_model_prefix}/Qwen/Qwen3.5-0.8B",
+ ],
+)
+def test_multiturn_sft_vlm_dataset_on_cpu(model_path, vlm_data_file):
df = pd.read_parquet(vlm_data_file)
- model_path = f"{custom_model_prefix}/Qwen/Qwen3-VL-2B-Instruct"
- tokenizer = AutoTokenizer.from_pretrained(model_path)
- processor = AutoProcessor.from_pretrained(model_path)
- config = {"max_length": 512, "pad_mode": "no_padding", "truncation": "error", "messages_key": "messages"}
- dataset = MultiTurnSFTDataset(parquet_files=vlm_data_file, tokenizer=tokenizer, config=config, processor=processor)
+ tokenizer = hf_tokenizer(model_path)
+ processor = hf_processor(model_path)
+ config = {"max_length": 1024, "pad_mode": "no_padding", "truncation": "error", "messages_key": "messages"}
+ dataset = MultiTurnSFTDataset(parquet_files=vlm_data_file, tokenizer=tokenizer, processor=processor, config=config)
assert dataset.pad_mode == DatasetPadMode.NO_PADDING
for i in range(len(dataset)):
@@ -387,13 +401,19 @@ def test_multiturn_sft_vlm_dataset_on_cpu(vlm_data_file):
assert image_grid_thw is None, "image_grid_thw should be None when no image is provided"
-def test_multiturn_sft_vlm_dataloader_on_cpu(vlm_data_file):
+@pytest.mark.parametrize(
+ "model_path",
+ [
+ "{custom_model_prefix}/Qwen/Qwen3-VL-2B-Instruct",
+ "{custom_model_prefix}/Qwen/Qwen3.5-0.8B",
+ ],
+)
+def test_multiturn_sft_vlm_dataloader_on_cpu(model_path, vlm_data_file):
df = pd.read_parquet(vlm_data_file)
- model_path = f"{custom_model_prefix}/Qwen/Qwen3-VL-2B-Instruct"
- tokenizer = AutoTokenizer.from_pretrained(model_path)
- processor = AutoProcessor.from_pretrained(model_path)
- config = {"max_length": 512, "pad_mode": "no_padding", "truncation": "error", "messages_key": "messages"}
- dataset = MultiTurnSFTDataset(parquet_files=vlm_data_file, tokenizer=tokenizer, config=config, processor=processor)
+ tokenizer = hf_tokenizer(model_path)
+ processor = hf_processor(model_path)
+ config = {"max_length": 1024, "pad_mode": "no_padding", "truncation": "error", "messages_key": "messages"}
+ dataset = MultiTurnSFTDataset(parquet_files=vlm_data_file, tokenizer=tokenizer, processor=processor, config=config)
assert dataset.pad_mode == DatasetPadMode.NO_PADDING
collate_fn = SFTTensorCollator(DatasetPadMode.NO_PADDING)
diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py
index f52ead64570..7c11f2ca6cf 100644
--- a/verl/experimental/agent_loop/agent_loop.py
+++ b/verl/experimental/agent_loop/agent_loop.py
@@ -35,7 +35,7 @@
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.protocol import DataProto
from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
-from verl.utils.chat_template import initialize_system_prompt
+from verl.utils.chat_template import apply_chat_template_single_turn, initialize_system_prompt
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class
from verl.utils.model import compute_position_id_with_mask
@@ -274,7 +274,8 @@ async def apply_chat_template(
if self.processor is not None:
raw_prompt = await self.loop.run_in_executor(
None,
- lambda: self.processor.apply_chat_template(
+ lambda: apply_chat_template_single_turn(
+ self.processor,
messages,
tools=tools,
add_generation_prompt=True,
@@ -302,7 +303,8 @@ async def apply_chat_template(
else:
tokenized_prompt = await self.loop.run_in_executor(
None,
- lambda: self.tokenizer.apply_chat_template(
+ lambda: apply_chat_template_single_turn(
+ self.tokenizer,
messages,
tools=tools,
add_generation_prompt=True,
diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py
index c649a2fc3fd..cd330d53d61 100644
--- a/verl/experimental/agent_loop/tool_agent_loop.py
+++ b/verl/experimental/agent_loop/tool_agent_loop.py
@@ -249,7 +249,7 @@ async def _handle_generating_state(
return AgentState.TERMINATED
# Extract tool calls
- _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids)
+ _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids, self.tool_schemas)
# Handle interaction if needed
if self.interaction_config_file:
diff --git a/verl/experimental/agent_loop/tool_parser.py b/verl/experimental/agent_loop/tool_parser.py
index 67ad75e2bb8..b035bd16115 100644
--- a/verl/experimental/agent_loop/tool_parser.py
+++ b/verl/experimental/agent_loop/tool_parser.py
@@ -15,10 +15,12 @@
import logging
import os
from abc import ABC, abstractmethod
+from typing import Any, Optional
import regex
from pydantic import BaseModel
+from verl.tools.schemas import OpenAIFunctionToolSchema
from verl.utils.ray_utils import get_event_loop
from verl.utils.rollout_trace import rollout_trace_op
@@ -46,11 +48,14 @@ def __init__(self, tokenizer) -> None:
self.tokenizer = tokenizer
@abstractmethod
- async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
+ async def extract_tool_calls(
+ self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
+ ) -> tuple[str, list[FunctionCall]]:
"""Extract tool calls from the responses.
Args:
responses_ids (List[int]): The ids of the responses.
+ tools (List[OpenAIFunctionToolSchema], optional): OpenAI function tool schema.
Returns:
Tuple[str, List[FunctionCall]]: Content and extracted tool calls.
@@ -84,7 +89,9 @@ def __init__(self, tokenizer) -> None:
self.tool_call_regex = regex.compile(r"(.*?)", regex.DOTALL)
@rollout_trace_op
- async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
+ async def extract_tool_calls(
+ self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
+ ) -> tuple[str, list[FunctionCall]]:
loop = get_event_loop()
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
@@ -131,7 +138,9 @@ def __init__(self, tokenizer) -> None:
)
@rollout_trace_op
- async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
+ async def extract_tool_calls(
+ self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
+ ) -> tuple[str, list[FunctionCall]]:
loop = get_event_loop()
# We need to keep special tokens for gpt-oss model for better tool call extraction.
text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False))
@@ -159,3 +168,174 @@ async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[
content = regex.sub(self.tool_call_pattern, "", text)
return content, function_calls
+
+
+@ToolParser.register("qwen3_coder")
+class Qwen3XMLToolParser(ToolParser):
+ """
+ Tool parser for qwen3_coder/qwen3.5 model.
+ Adapted from https://huggingface.co/Qwen/Qwen3-Coder-30B-A3B-Instruct/blob/main/qwen3coder_tool_parser.py
+
+ Args:
+ tokenizer: The tokenizer to use.
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+ self.tool_call_prefix: str = "(.*?)", regex.DOTALL)
+ self.tool_call_regex = regex.compile(r"(.*?)|(.*?)$", regex.DOTALL)
+ self.tool_call_function_regex = regex.compile(r"|| FunctionCall:
+ def get_arguments_config(func_name: str) -> dict:
+ for config in tools:
+ if config.type == "function" and config.function.name == func_name:
+ properties = config.function.parameters.properties
+ return {k: v.model_dump() for k, v in properties.items()}
+ logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
+ return {}
+
+ def convert_param_value(param_value: str, param_name: str, param_config: dict, func_name: str) -> Any:
+ # Handle null value for any type
+ if param_value.lower() == "null":
+ return None
+
+ if param_name not in param_config:
+ if param_config != {}:
+ logger.warning(
+ f"Parsed parameter '{param_name}' is not defined in the tool "
+ f"parameters for tool '{func_name}', directly returning the string value."
+ )
+ return param_value
+
+ if isinstance(param_config[param_name], dict) and "type" in param_config[param_name]:
+ param_type = str(param_config[param_name]["type"]).strip().lower()
+ else:
+ param_type = "string"
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
+ return param_value
+ elif (
+ param_type.startswith("int")
+ or param_type.startswith("uint")
+ or param_type.startswith("long")
+ or param_type.startswith("short")
+ or param_type.startswith("unsigned")
+ ):
+ try:
+ param_value = int(param_value)
+ except Exception:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
+ f"'{func_name}', degenerating to string."
+ )
+ return param_value
+ elif param_type.startswith("num") or param_type.startswith("float"):
+ try:
+ float_param_value = float(param_value)
+ param_value = (
+ float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value)
+ )
+ except Exception:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
+ f"'{func_name}', degenerating to string."
+ )
+ return param_value
+ elif param_type in ["boolean", "bool", "binary"]:
+ param_value = param_value.lower()
+ if param_value not in ["true", "false"]:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a "
+ f"boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
+ )
+ return param_value == "true"
+ else:
+ if param_type == "object" or param_type.startswith("dict"):
+ try:
+ param_value = json.loads(param_value)
+ return param_value
+ except Exception:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid "
+ f"JSON object in tool '{func_name}', will try other methods to parse it."
+ )
+ try:
+ param_value = eval(param_value)
+ except Exception:
+ logger.warning(
+ f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted "
+ f"via Python `eval()` in tool '{func_name}', degenerating to string."
+ )
+ return param_value
+
+ # Extract function name
+ end_index = function_call_str.index(">")
+ function_name = function_call_str[:end_index]
+ param_config = get_arguments_config(function_name)
+ parameters = function_call_str[end_index + 1 :]
+ param_dict = {}
+ for match in self.tool_call_parameter_regex.findall(parameters):
+ match_text = match[0] if match[0] else match[1]
+ idx = match_text.index(">")
+ param_name = match_text[:idx]
+ param_value = str(match_text[idx + 1 :])
+ # Remove prefix and trailing \n
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ param_dict[param_name] = convert_param_value(param_value, param_name, param_config, function_name)
+ return FunctionCall(name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False))
+
+ def _get_function_calls(self, model_output: str) -> list[str]:
+ # Find all tool calls
+ matched_ranges = self.tool_call_regex.findall(model_output)
+ raw_tool_calls = [match[0] if match[0] else match[1] for match in matched_ranges]
+
+ # Back-off strategy if no tool_call tags found
+ if len(raw_tool_calls) == 0:
+ raw_tool_calls = [model_output]
+
+ raw_function_calls = []
+ for tool_call in raw_tool_calls:
+ raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
+
+ function_calls = [match[0] if match[0] else match[1] for match in raw_function_calls]
+ return function_calls
+
+ @rollout_trace_op
+ async def extract_tool_calls(
+ self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
+ ) -> tuple[str, list[FunctionCall]]:
+ loop = get_event_loop()
+ text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
+ if self.tool_call_start_token not in text:
+ return text, []
+
+ try:
+ function_calls = self._get_function_calls(text)
+ if len(function_calls) == 0:
+ return text, []
+
+ tool_calls = [
+ self._parse_xml_function_call(function_call_str, tools) for function_call_str in function_calls
+ ]
+
+ # Extract content before tool calls
+ content_index = text.find(self.tool_call_start_token)
+ content_index = content_index if content_index >= 0 else text.find(self.tool_call_prefix)
+ content = text[:content_index] # .rstrip()
+
+ return content, tool_calls
+ except Exception as e:
+ logger.exception(f"Error in extracting tool call from response: {e}")
+ return text, []
diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py
index fd160fa86c9..9955107f304 100644
--- a/verl/models/mcore/model_forward.py
+++ b/verl/models/mcore/model_forward.py
@@ -122,22 +122,33 @@ def model_forward(
When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length,
so it is recommended to disable dynamic batch size and set batch size to 1
"""
- assert not vision_model, "vision model does not support bshd format"
assert fp8 is None, "fp8 is not supported for bshd format yet"
batch_size, sequence_length = attention_mask.shape[:2]
+ position_ids_for_preprocess = (
+ torch.arange(sequence_length, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
+ if vision_model
+ else position_ids
+ )
+ pre_process_for_bshd = True if vision_model else pre_process
new_input_ids, new_attention_mask, new_position_ids = preprocess_bshd(
- input_ids, attention_mask, position_ids, sequence_parallel=sp, pre_process=pre_process
+ input_ids,
+ attention_mask,
+ position_ids_for_preprocess,
+ sequence_parallel=sp,
+ pre_process=pre_process_for_bshd,
)
output_orig = model(
input_ids=new_input_ids,
- position_ids=new_position_ids,
+ position_ids=None if vision_model else new_position_ids,
attention_mask=new_attention_mask,
**model_kwargs,
)
if post_process and logits_processor is not None:
args = {
- k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=sp, pre_process=True)[0]
+ k: preprocess_bshd(
+ v, attention_mask, position_ids_for_preprocess, sequence_parallel=sp, pre_process=True
+ )[0]
for k, v in logits_processor_args.items()
}
output_dict = logits_processor(output_orig, **args)
@@ -258,7 +269,7 @@ def gptmodel_forward_no_padding(
output_orig = model(
input_ids=input_ids_bshd,
attention_mask=attention_mask_bshd,
- position_ids=position_ids_bshd,
+ position_ids=None if vision_model else position_ids_bshd,
**model_kwargs,
)
if post_process and logits_processor is not None:
diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py
index b1b5c03406b..5776b13fcf4 100644
--- a/verl/models/mcore/registry.py
+++ b/verl/models/mcore/registry.py
@@ -30,6 +30,8 @@ class SupportedVLM(Enum):
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
QWEN3_VL = "Qwen3VLForConditionalGeneration"
+ QWEN3_5_MOE_VL = "Qwen3_5MoeForConditionalGeneration"
+ QWEN3_5_VL = "Qwen3_5ForConditionalGeneration"
supported_vlm = [member.value for member in SupportedVLM]
diff --git a/verl/utils/chat_template.py b/verl/utils/chat_template.py
index e5f8d3e9d1d..5dcef56afa1 100644
--- a/verl/utils/chat_template.py
+++ b/verl/utils/chat_template.py
@@ -2,6 +2,8 @@
import logging
import os
+from transformers import PreTrainedTokenizerBase, ProcessorMixin
+
from verl.utils.tokenizer import normalize_token_ids
logger = logging.getLogger(__name__)
@@ -46,3 +48,79 @@ def extract_system_prompt_and_generation(tokenizer):
generate_prompt = token3[len(token1) :]
return system_prompt, generate_prompt
+
+
+def apply_chat_template_single_turn(
+ processor: PreTrainedTokenizerBase | ProcessorMixin,
+ messages: list[dict],
+ *,
+ tokenize: bool = True,
+ add_generation_prompt: bool = True,
+ tools=None,
+ return_dict: bool = False,
+ return_mm_token_type_ids: bool = False,
+ **kwargs,
+) -> list[int] | str:
+ """apply_chat_template to a single turn's messages.
+
+ Args:
+ processor: tokenizer or processor.
+ messages: list[dict], single turn messages.
+ tokenize: bool, whether to tokenize the output.
+ add_generation_prompt: bool, whether to add generation prompt.
+ tools: list[dict], tools schema.
+ return_dict: bool, whether to return a dict.
+ return_mm_token_type_ids: bool, whether to return multimodal token type ids.
+ **kwargs: additional arguments for apply_chat_template.
+
+ Returns:
+ list[int] | str: tokenized ids or text string.
+ """
+ assert isinstance(messages, list) and len(messages) == 1, f"messages must be a single turn, got {messages}"
+ try:
+ return processor.apply_chat_template(
+ messages,
+ tokenize=tokenize,
+ add_generation_prompt=add_generation_prompt,
+ tools=tools,
+ return_dict=return_dict,
+ return_mm_token_type_ids=return_mm_token_type_ids,
+ **kwargs,
+ )
+ except Exception:
+ # Qwen3.5 apply_chat_template needs messages with at least one user message
+ dummy_user_message = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
+ dummy_user_prefix = processor.apply_chat_template(
+ dummy_user_message,
+ tokenize=tokenize,
+ add_generation_prompt=False,
+ tools=tools,
+ return_dict=return_dict,
+ return_mm_token_type_ids=return_mm_token_type_ids,
+ **kwargs,
+ )
+ output = processor.apply_chat_template(
+ dummy_user_message + messages,
+ tokenize=tokenize,
+ add_generation_prompt=add_generation_prompt,
+ tools=tools,
+ return_dict=return_dict,
+ return_mm_token_type_ids=return_mm_token_type_ids,
+ **kwargs,
+ )
+
+ if not tokenize: # tokenize=False
+ return output[len(dummy_user_prefix) :]
+ elif not return_dict: # tokenize=True and return_dict=False
+ if isinstance(output[0], list): # transformers>=5
+ assert len(output) == 1, "output must be a list[int] or list[list[int]]"
+ dummy_user_prefix = dummy_user_prefix[0]
+ output = output[0]
+ return output[len(dummy_user_prefix) :]
+ else: # tokenize=True and return_dict=True and return_tensors="pt"
+ dummy_user_prefix = dict(dummy_user_prefix)
+ output = dict(output)
+ prefix_len = dummy_user_prefix["input_ids"].shape[1]
+ output["input_ids"] = output["input_ids"][:, prefix_len:]
+ output["attention_mask"] = output["attention_mask"][:, prefix_len:]
+ return output
diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py
index 081d1dcfafa..8f3ae1e166a 100644
--- a/verl/utils/dataset/multiturn_sft_dataset.py
+++ b/verl/utils/dataset/multiturn_sft_dataset.py
@@ -32,7 +32,7 @@
from verl.models.transformers.qwen2_vl import get_rope_index
from verl.utils import hf_tokenizer
-from verl.utils.chat_template import extract_system_prompt_and_generation
+from verl.utils.chat_template import apply_chat_template_single_turn, extract_system_prompt_and_generation
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.dataset.vision_utils import process_image, process_video
from verl.utils.fs import copy_local_path_from_hdfs
@@ -208,8 +208,9 @@ def _process_single_message(
if enable_thinking is not None:
apply_chat_template_kwargs["enable_thinking"] = enable_thinking
- inputs = processor.apply_chat_template(
- [message],
+ inputs = apply_chat_template_single_turn(
+ processor,
+ messages=[message],
tools=tools,
add_generation_prompt=False,
tokenize=True,
@@ -254,14 +255,16 @@ def _build_messages(self, example: dict):
image_offset, video_offset = 0, 0
for message in messages:
- if self.image_key not in example and self.video_key not in example:
- continue
- assert self.processor is not None, "processor is needed to process image and video"
-
content = message["content"]
if not isinstance(content, str):
continue
+ if self.image_key not in example and self.video_key not in example:
+ if self.processor is not None:
+ message["content"] = [{"type": "text", "text": content}]
+ continue
+ assert self.processor is not None, "processor is needed to process image and video"
+
content_list = []
segments = re.split("(|