Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/wan2_1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ prompt: Summer beach vacation style, a white cat wearing sunglasses sits on a su

The code is tested in the following environments

| mindspore | ascend driver | firmware | cann tookit/kernel |
| :---: | :---: | :---: | :---: |
| 2.5.0 | 24.1.0 |7.35.23 | 8.0.RC3.beta1 |

| mindspore | ascend driver | firmware | cann tookit/kernel |
| :-------: | :-----------: | :---------: | :----------------: |
| 2.6.0 | 25.0.RC1.1 | 7.7.0.1.231 | 8.1.RC1 |
| 2.7.0 | 25.2.0 | 7.7.0.6.236 | 8.2.RC1 |

### Installation
Clone the repo:
Expand Down
2 changes: 2 additions & 0 deletions examples/wan2_1/wan/modules/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def construct(self, x: Tensor) -> Tensor:
value=v,
head_num=self.num_heads,
keep_prob=1.0 - p,
scalar_value=1 / math.sqrt(q.shape[-1]),
input_layout="BSND",
)
x = x.reshape(b, s, c)
Expand Down Expand Up @@ -228,6 +229,7 @@ def construct(self, x: Tensor) -> Tensor:
key=k,
value=v,
head_num=self.num_heads,
scalar_value=1 / math.sqrt(q.shape[-1]),
input_layout="BSND",
)
x = x.reshape(b, 1, c)
Expand Down
13 changes: 12 additions & 1 deletion examples/wan2_1/wan/modules/xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math

import mindspore as ms
import mindspore.mint as mint
import mindspore.nn as nn
Expand Down Expand Up @@ -46,7 +48,16 @@ def construct(self, x: Tensor, mask: Tensor) -> Tensor:
# compute attention
p = self.dropout.p if self.training else 0.0
# TODO: check mask
x = ops.flash_attention_score(q, k, v, self.num_heads, attn_mask=mask, keep_prob=1 - p)
x = ops.flash_attention_score(
q,
k,
v,
self.num_heads,
attn_mask=mask,
scalar_value=1 / math.sqrt(q.shape[-1]),
keep_prob=1 - p,
input_layout="BNSD",
)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)

# output
Expand Down
37 changes: 14 additions & 23 deletions examples/wan2_1/wan/utils/prompt_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from typing import Union

from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
from transformers import AutoTokenizer

import mindspore as ms
from mindspore import Tensor
from mindspore.communication import GlobalComm
from mindspore.nn.utils import no_init_parameters

from mindone.trainers.zero import prepare_network
from mindone.transformers import AutoProcessor
from mindone.transformers.models.qwen2 import Qwen2ForCausalLM
from mindone.transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from mindone.transformers.models.qwen2_vl.qwen_vl_utils import process_vision_info
Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(self, model_name=None, is_vl=False, qwen_zero3=False, **kwargs):
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
self.model_name, min_pixels=min_pixels, max_pixels=max_pixels, use_fast=True
self.model_name, min_pixels=min_pixels, max_pixels=max_pixels, use_fast=False
)
with no_init_parameters():
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
Expand All @@ -192,13 +193,10 @@ def __init__(self, model_name=None, is_vl=False, qwen_zero3=False, **kwargs):
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
with no_init_parameters():
# TODO: change to flash attention & use cache & do sampling
self.model = Qwen2ForCausalLM.from_pretrained(
self.model_name,
mindspore_dtype=ms.bfloat16,
attn_implementation="eager",
use_cache=False,
do_sample=False,
attn_implementation="flash_attention_2",
)
if qwen_zero3:
self.model = prepare_network(
Expand All @@ -215,13 +213,13 @@ def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors="np")
model_inputs = {k: self._to_int32(Tensor(v)) for k, v in model_inputs.items()}
for k, v in model_inputs.items():
model_inputs[k] = ms.tensor(v)

generated_ids = self.model.generate(**model_inputs, max_new_tokens=512).asnumpy()
# TODO: somehow the output is aready trimmed
# generated_ids = [
# output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
# ]
generated_ids = self.model.generate(**model_inputs, max_new_tokens=512, do_sample=False).asnumpy()
generated_ids = [
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

expanded_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return PromptOutput(
Expand Down Expand Up @@ -251,21 +249,14 @@ def extend_with_img(self, prompt, system_prompt, image: Union[Image.Image, str]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="ms"
)

inputs = {k: self._to_int32(Tensor(v.numpy())) for k, v in inputs.items()}

# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=512).asnumpy()
# TODO: somehow the output is aready trimmed
# generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=False).asnumpy()
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
expanded_prompt = self.processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return PromptOutput(
status=True,
Expand Down
6 changes: 3 additions & 3 deletions examples/wan2_2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ We are excited to introduce **Wan2.2**, a major upgrade to our foundational vide
## Run Wan2.2

#### Requirements
|mindspore | ascend driver | firmware | cann toolkit/kernel|
|--- | --- | --- | --- |
|2.7.0 | 24.1RC3 | 7.3.0.1.231 | 8.2.RC1 |
| mindspore | ascend driver | firmware | cann toolkit/kernel|
| :-------: | :-----------: | :---------: | :----------------: |
| 2.7.0 | 25.2.0 | 7.7.0.6.236 | 8.2.RC1 |


#### Installation
Expand Down
2 changes: 1 addition & 1 deletion examples/wan2_2/wan/utils/prompt_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, model_name=None, task=None, is_vl=False, **kwargs):

with nn.no_init_parameters():
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name, mindspore_dtype=ms.bfloat16, attn_implementation="eager"
self.model_name, mindspore_dtype=ms.bfloat16, attn_implementation="flash_attention_2"
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

Expand Down