diff --git a/mbridge/models/qwen3_vl/model.py b/mbridge/models/qwen3_vl/model.py index 0b38dc1..fb32076 100644 --- a/mbridge/models/qwen3_vl/model.py +++ b/mbridge/models/qwen3_vl/model.py @@ -394,6 +394,7 @@ def forward( tp_rank=mpu.get_tensor_model_parallel_rank(), cp_size=cp_size, cp_rank=mpu.get_context_parallel_rank(), + sequence_parallel=self.config.sequence_parallel, ) elif self.config.sequence_parallel: # THD and SP visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs( @@ -403,6 +404,7 @@ def forward( tp_rank=mpu.get_tensor_model_parallel_rank(), cp_size=1, cp_rank=0, + sequence_parallel=self.config.sequence_parallel, ) if position_ids is None: diff --git a/mbridge/models/qwen3_vl/utils.py b/mbridge/models/qwen3_vl/utils.py index c61de7e..ef19b40 100644 --- a/mbridge/models/qwen3_vl/utils.py +++ b/mbridge/models/qwen3_vl/utils.py @@ -177,7 +177,11 @@ def split_deepstack_embs( tp_rank: int = 0, cp_size: int = 1, cp_rank: int = 0, + sequence_parallel: bool = True, ): + if not sequence_parallel: + tp_size = 1 + tp_rank = 0 split_size = tp_size if cp_size > 1: split_size *= cp_size * 2