Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
c9fc2db
feat(transformers): add Qwen2VLImageProcessorFast/Qwen2VLVideoProcessor
wcrzlh Oct 27, 2025
7290e26
feat(transformers): add Qwen2VLImageProcessorFast/Qwen2VLVideoProcessor
wcrzlh Oct 27, 2025
cf2d71c
feat(transformers): add Qwen2VLImageProcessorFast/Qwen2VLVideoProcessor
wcrzlh Oct 27, 2025
80d69f7
feat(transformers): add WhisperFeatureExtractor/qwen2vl videoprocesso…
wcrzlh Oct 27, 2025
7ca3e59
fix bugs
wcrzlh Oct 27, 2025
8228daf
feat(transformers): add autoprocessor for qwen2audio
wcrzlh Oct 28, 2025
b04d93c
pre-commit
wcrzlh Oct 28, 2025
d3e6689
feat(transformers): support qwen3-omni model
wcrzlh Oct 28, 2025
653c101
pre-commit
wcrzlh Oct 28, 2025
7c09e61
pre-commit
wcrzlh Oct 28, 2025
906a399
fix bugs
wcrzlh Oct 28, 2025
e40e718
fix bugs
wcrzlh Oct 28, 2025
302b13f
fix bugs
wcrzlh Oct 28, 2025
99e5b70
fix split ops bugs
wcrzlh Oct 28, 2025
5cfc0bb
fix pad_sequence bugs
wcrzlh Oct 28, 2025
509bf1b
fix audio padded_mask bugs/ supplement qwen_omni_utils
wcrzlh Oct 28, 2025
ea36b19
fix list += bug/mask_scatter bug
wcrzlh Oct 28, 2025
6470463
fix linspace bug
wcrzlh Oct 29, 2025
be04e46
fix bugs
wcrzlh Oct 29, 2025
41a1491
fix repeat bugs
wcrzlh Oct 29, 2025
f9571bd
fix view bugs
wcrzlh Oct 29, 2025
8d354f7
fix view bugs
wcrzlh Oct 29, 2025
1ae995c
fix arange bugs
wcrzlh Oct 29, 2025
b038159
fix arange bugs
wcrzlh Oct 29, 2025
7f5d9dd
fix arange bugs
wcrzlh Oct 29, 2025
52769e3
fix arange bugs
wcrzlh Oct 29, 2025
804ed81
fix arange bugs
wcrzlh Oct 29, 2025
4494189
fix arange bugs
wcrzlh Oct 29, 2025
ba8b103
fix construct wrapper bugs
wcrzlh Oct 29, 2025
63b9f9c
fix slice index bugs
wcrzlh Oct 29, 2025
85e5ec8
fix hidden_states return bugs
wcrzlh Oct 29, 2025
efbaa3f
fix hidden_states return bugs
wcrzlh Oct 29, 2025
308a5f6
fix hidden_states return bugs
wcrzlh Oct 29, 2025
6b10c90
fix mint.cat dtype bugs
wcrzlh Oct 29, 2025
1ebee09
fix tensor index bugs
wcrzlh Oct 29, 2025
bf5e2ff
fix scatter bugs
wcrzlh Oct 29, 2025
6a1b803
fix bugs
wcrzlh Oct 29, 2025
a17e8cc
fix bugs
wcrzlh Oct 29, 2025
9240ade
fix mint empty bugs
wcrzlh Oct 29, 2025
824d5dd
fix mint empty bugs
wcrzlh Oct 29, 2025
c66d943
fix mint empty bugs
wcrzlh Oct 29, 2025
180818c
fix or_mask/and_mask bugs
wcrzlh Oct 29, 2025
2629a72
fix np.prod bugs
wcrzlh Oct 30, 2025
265ea64
fix qwen_omni_utils bugs
wcrzlh Oct 30, 2025
56d2624
fix load weight time
wcrzlh Oct 30, 2025
1c1d47b
fix load weight time
wcrzlh Oct 30, 2025
dee9236
fix load weight time
wcrzlh Oct 30, 2025
95158bc
fix load weight time
wcrzlh Oct 30, 2025
82a6197
fix load weight time
wcrzlh Oct 30, 2025
b7a88d8
add qwen3 omni ut and examples
wcrzlh Nov 7, 2025
c704520
pre-commit
wcrzlh Nov 7, 2025
5662e36
rebase
wcrzlh Nov 7, 2025
f06a68a
reformat
wcrzlh Nov 14, 2025
06ce942
reformat
wcrzlh Nov 17, 2025
efbe2f5
supplement ut
wcrzlh Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions examples/transformers/qwen3_omni_moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Qwen3-Omni

## Introduction
The Qwen3-Omni-MOE model is a unified multiple modalities model proposed in Qwen3-Omni Technical Report from Qwen team, Alibaba Group.

The abstract from the technical report is the following:

*We present Qwen3-Omni, a single multimodal model that, for the first time, maintains state-of-the-art performance across text, image, audio, and video without any degradation relative to single-modal counterparts. Qwen3-Omni matches the performance of same-sized single-modal models within the Qwen series and excels particularly on audio tasks. Across 36 audio and audio-visual benchmarks, Qwen3-Omni achieves open-source SOTA on 32 benchmarks and overall SOTA on 22, outperforming strong closed-source models such as Gemini-2.5-Pro, Seed-ASR, and GPT-4o-Transcribe. Qwen3-Omni adopts a Thinker-Talker MoE architecture that unifies perception and generation across text, images, audio, and video, yielding fluent text and natural real-time speech. It supports text interaction in 119 languages, speech understanding in 19 languages, and speech generation in 10 languages. To reduce first-packet latency in streaming synthesis, Talker autoregressively predicts discrete speech codecs using a multi-codebook scheme. Leveraging the representational capacity of these codebooks, we replace computationally intensive block-wise diffusion with a lightweight causal ConvNet, enabling streaming from the first codec frame. In cold-start settings, Qwen3-Omni achieves a theoretical end-to-end first-packet latency of 234 ms. To further strengthen multimodal reasoning, we introduce a Thinking model that explicitly reasons over inputs from any modality. Since the research community currently lacks a general-purpose audio captioning model, we fine-tuned Qwen3-Omni-30B-A3B to obtain Qwen3-Omni-30B-A3B-Captioner, which produces detailed, low-hallucination captions for arbitrary audio inputs. Qwen3-Omni-30B-A3B, Qwen3-Omni-30B-A3B-Thinking, and Qwen3-Omni-30B-A3B-Captioner are publicly released under the Apache 2.0 license.

# Get Started

## Requirements:
| mindspore | ascend driver | firmware | cann tookit/kernel |
|-----------|----------------|----------------|--------------------|
| 2.7.0 | 24.1.RC3.b080 | 7.5.T11.0.B088 | 8.1.RC1 |

### Installation:
```
git clone https://github.com/mindspore-lab/mindone.git
cd mindone
pip install -e .

pip install transformers==4.57.1

cd examples/transformers/qwen3_omni_moe
```

## **Notice**
Note that adjusting `min_pixels` and `max_pixels` trades off between memory and accuracy. Please adjust min_pixel and max_pixel of processor if raising OOM error.

## Quick Start

Here is a usage example of Qwen3-Omni-30B-A3B-Instruct. you can use the following command:

```bash
# For Audio Understanding Task:
# If you want only return text, please set `return_audios=False`
msrun --worker_num=2 --local_worker_num=2 --master_port=8118 \
--log_dir=msrun_log --join=True --cluster_time_out=300 \
omni_understanding.py
```
Give it a try with various images, audios and prompts🤗🤗.

Omni Understanding Sample script:
`return_audio=False`could be set so that only text result would be returned.

```python
from functools import partial

import numpy as np
import soundfile as sf
from qwen_omni_utils import process_mm_info

import mindspore as ms
import mindspore.mint.distributed as dist
from mindspore.communication import GlobalComm

from mindone.trainers.zero import prepare_network
from mindone.transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor

# set up card communication
dist.init_process_group(backend="hccl")
ms.set_auto_parallel_context(parallel_mode="data_parallel")

MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
# MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Thinking"

model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
MODEL_PATH,
mindspore_dtype=ms.bfloat16,
attn_implementation="flash_attention_2",
)

# use zero3 parallel
shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
model = shard_fn(model)

min_pixels = 56 * 56
max_pixels = 14 * 14 * 768
processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)

conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cars.jpg"},
{"type": "audio", "audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav"},
{"type": "text", "text": "What can you see and hear? Answer in one short sentence."},
],
},
]

# Set whether to use audio in video
USE_AUDIO_IN_VIDEO = True

# Preparation for inference
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
inputs = processor(
text=text,
audio=audios,
images=images,
videos=videos,
return_tensors="np",
padding=True,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
)

for key, value in inputs.items():
if isinstance(value, np.ndarray):
inputs[key] = ms.tensor(value)
if inputs[key].dtype == ms.int64:
inputs[key] = inputs[key].to(ms.int32)
elif inputs[key].dtype != ms.int32:
inputs[key] = inputs[key].to(model.dtype)

# Inference: Generation of the output text and audio
text_ids, audio = model.generate(
**inputs,
speaker="Ethan",
thinker_return_dict_in_generate=True,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
return_audio=False,
talker_do_sample=False,
)

text = processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(text)
if audio is not None:
sf.write(
"output.wav",
audio.reshape(-1).asnumpy(),
samplerate=24000,
)

```

Text generation Outputs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about the audio output? maybe we can attach the audio output as well

Copy link
Contributor Author

@wcrzlh wcrzlh Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The audio output quality is good. It retells the text output and summarizes the audio.
Let me figure out how to attach audio output.

```
['The image displays four luxury cars-a Rolls-Royce, a Mercedes-Benz SUV, a Ferrari convertible and a Porsche 911-while the audio captures a person coughing.']
```

If `return_audio=True` is set, besides that above text generation results, a piece of audio that explains the image and audio would be generated.

## Inference Speed
| model name | mindspore version | precision* | cards | Model part | attention type | tokens/s |
|:------------------------------:|:-----------------:|:----------:|:-----:|:----------:|:--------------:|:----------:|
| Qwen3-Omni-30B-A3B-Instruct | 2.7.0 | bf16 | 2 | Thinker | flash_attn | 0.73 |
| Qwen3-Omni-30B-A3B-Instruct | 2.7.0 | bf16 | 2 | Talker | flash_attn | 0.88 |
115 changes: 115 additions & 0 deletions examples/transformers/qwen3_omni_moe/omni_understanding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse
from functools import partial

import numpy as np
import soundfile as sf
from qwen_omni_utils import process_mm_info

import mindspore as ms
import mindspore.mint.distributed as dist
from mindspore.communication import GlobalComm

from mindone.trainers.zero import prepare_network
from mindone.transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor


def generate(args):
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
args.model_name,
mindspore_dtype=ms.bfloat16,
attn_implementation="flash_attention_2",
)

# use zero3 parallel
shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
model.thinker = shard_fn(model.thinker)
model.talker = shard_fn(model.talker)

min_pixels = 56 * 56
max_pixels = 14 * 14 * 768
processor = Qwen3OmniMoeProcessor.from_pretrained(args.model_name, min_pixels=min_pixels, max_pixels=max_pixels)

conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": args.image},
{"type": "audio", "audio": args.audio},
{"type": "text", "text": args.prompt},
],
},
]

# Set whether to use audio in video
USE_AUDIO_IN_VIDEO = True

# Preparation for inference
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
inputs = processor(
text=text,
audio=audios,
images=images,
videos=videos,
return_tensors="np",
padding=True,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
)

for key, value in inputs.items():
if isinstance(value, np.ndarray):
inputs[key] = ms.tensor(value)
if inputs[key].dtype == ms.int64:
inputs[key] = inputs[key].to(ms.int32)
elif inputs[key].dtype != ms.int32:
inputs[key] = inputs[key].to(model.dtype)

# Inference: Generation of the output text and audio
text_ids, audio = model.generate(
**inputs,
speaker="Ethan",
thinker_return_dict_in_generate=True,
use_audio_in_video=USE_AUDIO_IN_VIDEO,
talker_do_sample=False,
)

text = processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
print(text)
if audio is not None:
sf.write(
"output.wav",
audio.reshape(-1).asnumpy(),
samplerate=24000,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen3OmniMoE demo.")

parser.add_argument("--prompt", type=str, default="What can you see and hear? Answer in one short sentence.")
parser.add_argument(
"--image",
type=str,
default="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cars.jpg",
)
parser.add_argument(
"--audio",
type=str,
default="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav",
)
parser.add_argument(
"--model_name", type=str, default="Qwen/Qwen3-Omni-30B-A3B-Instruct", help="Path to the pre-trained model."
)

# Parse the arguments
args = parser.parse_args()

# set up card communication
dist.init_process_group(backend="hccl")
ms.set_auto_parallel_context(parallel_mode="data_parallel")

generate(args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .audio_process import process_audio_info
from .vision_process import extract_vision_info, fetch_image, fetch_video, process_vision_info, smart_resize


def process_mm_info(conversations, use_audio_in_video, return_video_kwargs=False):
audios = process_audio_info(conversations, use_audio_in_video)
vision = process_vision_info(conversations, return_video_kwargs=return_video_kwargs)
return (audios,) + vision
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import audioread
import av
import librosa
import numpy as np


def _check_if_video_has_audio(video_path):
container = av.open(video_path)
audio_streams = [stream for stream in container.streams if stream.type == "audio"]
if not audio_streams:
return False
return True


def process_audio_info(conversations: list[dict], use_audio_in_video: bool):
audios = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if not isinstance(message["content"], list):
continue
for ele in message["content"]:
if ele["type"] == "audio":
if "audio" in ele:
path = ele["audio"]
if path.startswith("http://") or path.startswith("https://"):
audios.append(librosa.load(audioread.ffdec.FFmpegAudioFile(path), sr=16000)[0])
elif isinstance(path, np.ndarray):
if path.ndim > 1:
raise ValueError("Support only mono audio")
audios.append(path)
elif path.startswith("file://"):
audios.append(librosa.load(path[len("file://") :], sr=16000)[0])
else:
audios.append(librosa.load(path, sr=16000)[0])
else:
raise ValueError("Unknown audio {}".format(ele))
if use_audio_in_video and ele["type"] == "video":
if "video" in ele:
path = ele["video"]
assert _check_if_video_has_audio(
path
), "Video must has audio track when use_audio_in_video=True"
if path.startswith("http://") or path.startswith("https://"):
audios.append(librosa.load(audioread.ffdec.FFmpegAudioFile(path), sr=16000)[0])
elif path.startswith("file://"):
audios.append(librosa.load(path[len("file://") :], sr=16000)[0])
else:
audios.append(librosa.load(path, sr=16000)[0])
else:
raise ValueError("Unknown video {}".format(ele))
if len(audios) == 0:
audios = None
return audios
Loading