Skip to content

Commit 6e30e1e

Browse files
committed
reformat
1 parent f9fd201 commit 6e30e1e

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

examples/transformers/qwen3_omni_moe/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ The abstract from the technical report is the following:
1616

1717
### Installation:
1818
```
19-
# TODO modify path before mergeing
20-
git clone https://github.com/wcrzlh/mindone.git -b vllm_patch
19+
git clone https://github.com/mindspore-lab/mindone.git
2120
cd mindone
2221
pip install -e .
2322
@@ -122,7 +121,7 @@ text_ids, audio = model.generate(
122121
thinker_return_dict_in_generate=True,
123122
use_audio_in_video=USE_AUDIO_IN_VIDEO,
124123
return_audio=False,
125-
talker_do_sampe=False,
124+
talker_do_sample=False,
126125
)
127126

128127
text = processor.batch_decode(
@@ -148,5 +147,5 @@ If `return_audio=True` is set, besides that above text generation results, a pie
148147
## Inference Speed
149148
| model name | mindspore version | precision* | cards | Model part | attention type | tokens/s |
150149
|:------------------------------:|:-----------------:|:----------:|:-----:|:----------:|:--------------:|:----------:|
151-
| Qwen3-Omni-30B-A3B-Instruct | 2.7.0 | bf16 | 2 | Thinker | flash_attn | 0.36 |
150+
| Qwen3-Omni-30B-A3B-Instruct | 2.7.0 | bf16 | 2 | Thinker | flash_attn | 0.73 |
152151
| Qwen3-Omni-30B-A3B-Instruct | 2.7.0 | bf16 | 2 | Talker | flash_attn | 0.88 |

mindone/transformers/modeling_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,23 +170,23 @@ def _convert_state_dict(m, state_dict_pt, prefix=""):
170170
state_dict_ms = {}
171171
while state_dict_pt:
172172
name_pt, data_pt = state_dict_pt.popitem()
173-
# TODO For models contains a lot of paramters, going through state_dict and model at the same time
174-
# would cause performance decrease significantly. This part for aligning prefix would need to be optimized.
175-
# for name, param in m.parameters_and_names():
176-
# name_ms = param.name
177-
# length = len(prefix) + 1
178-
# if name_pt.startswith(prefix):
179-
# # When name_ms and name_pt match and name_pt has prefix, name_pt would be sliced
180-
# if name_ms.rsplit(".", 1)[0] == name_pt.rsplit(".", 1)[0][length:] or name_ms == name_pt[length:]:
181-
# name_pt = name_pt[length:]
182-
# elif not name_pt.startswith(prefix):
183-
# # When name_ms and name_pt match and name_ms has prefix, prefix would be added to name_pt
184-
# if name_pt.rsplit(".", 1)[0] == name_ms.rsplit(".", 1)[0][length:] or name_pt == name_ms[length:]:
185-
# name_pt = ".".join([prefix, name_pt])
186173
name_ms, data_mapping = pt2ms_mappings.get(name_pt, (name_pt, lambda x: x))
187174
data_ms = data_mapping(data_pt)
188175
if name_ms is not None:
189176
state_dict_ms[name_ms] = data_ms
177+
178+
length = len(prefix) + 1
179+
model_ckpt_key = m.state_dict().keys()
180+
for key in state_dict_ms.keys():
181+
# When model name and state dict name match and state dict name has prefix, state dict name would be sliced
182+
if key[length:] in model_ckpt_key:
183+
data_ms = state_dict_ms.pop(key)
184+
state_dict_ms[key[length:]] = data_ms
185+
# When model name and state dict name match and model name has prefix, prefix would be added to state dict name
186+
elif ".".join([prefix, key]) in model_ckpt_key:
187+
data_ms = state_dict_ms.pop(key)
188+
state_dict_ms[".".join([prefix, key])] = data_ms
189+
190190
return state_dict_ms
191191

192192

tests/transformers_tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
MODES = [1]
3535

3636
if transformers.__version__ >= "4.57.0":
37-
from transformers.models.qwen3_omni_moe import Qwen3OmniMoeTalkerConfig, Qwen3OmniMoeThinkerConfig
37+
from transformers.models.qwen3_omni_moe import Qwen3OmniMoeThinkerConfig
3838

3939
class Qwen3OmniModelTester:
4040
def __init__(

0 commit comments

Comments
 (0)