Skip to content

Commit 4db5550

Browse files
authored
Update example Emu3 performance in mindspore 2.6.0 and 2.7.0 (#1417)
* use mint.nn.BatchNorm3d; update inference perf for ms2.6.0&ms2.7.0 * update training perf for ms2.7.0
1 parent 2c574e1 commit 4db5550

File tree

10 files changed

+73
-128
lines changed

10 files changed

+73
-128
lines changed

examples/emu3/README.md

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ Image VQA:
4343
### Requirements
4444
|mindspore | ascend driver | firmware | cann tookit/kernel|
4545
|--- | --- | --- | --- |
46-
|2.5.0 | 24.1RC2 | 7.3.0.1.231 | 8.0.RC3.beta1|
46+
|2.6.0 | 24.1.RC3 | 7.5.T11.0 | 8.1.RC1|
47+
|2.7.0 | 24.1.RC3 | 7.5.T11.0 | 8.2.RC1|
4748

4849
### Dependencies
4950

@@ -69,12 +70,6 @@ pip install -r requirements.txt
6970

7071
</details>
7172

72-
#### Weight conversion:
73-
74-
For model **Emu3-VisionTokenizer**, there are some incompatible network layer variable names that cannot be automatically converted, we need to convert some weight names in advanved before loading the pre-trained weights:
75-
```
76-
python python convert_weights.py --safetensor_path ORIGINAL_MODEL.safetensors --target_safetensor_path model.safetensors
77-
```
7873

7974
## Inference
8075

@@ -118,7 +113,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
118113
mindspore_dtype=VQ_DTYPE
119114
).set_train(False)
120115
image_tokenizer = auto_mixed_precision(
121-
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
116+
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
122117
)
123118
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
124119

@@ -220,7 +215,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
220215
mindspore_dtype=VQ_DTYPE
221216
).set_train(False)
222217
image_tokenizer = auto_mixed_precision(
223-
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
218+
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
224219
)
225220
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
226221

@@ -279,7 +274,7 @@ model = Emu3VisionVQModel.from_pretrained(
279274
mindspore_dtype=MS_DTYPE
280275
).set_train(False)
281276
model = auto_mixed_precision(
282-
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
277+
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
283278
)
284279
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)
285280

@@ -366,66 +361,81 @@ DATA_DIR
366361

367362
Input an image or a clip of video frames, outout the reconstructed image(s).
368363
<br>
369-
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
364+
Experiments are tested on Ascend Atlas 800T A2 machines.
370365

371-
| model name | precision* | cards | batch size| resolution | s/step | img/s |
372-
| --- | --- | --- | --- | --- | --- | --- |
373-
| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.65 | 0.38 |
374-
| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.96 | 1.04 |
366+
- mindspore 2.6.0
375367

376-
*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
368+
|mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
369+
| --- | --- | --- | --- | --- | --- | --- | --- |
370+
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.42 | 0.41 |
371+
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.95 | 4.21 |
372+
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 3.06 | 0.33 |
373+
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
377374

378-
<br>
379-
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 graph mode.
375+
- mindspore 2.7.0
380376

381-
| model name | precision* | cards | batch size| resolution | graph compile | s/step | img/s |
377+
|mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
382378
| --- | --- | --- | --- | --- | --- | --- | --- |
383-
| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 15s | 3.23 | 0.31 |
384-
| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 15s | 5.46 | 0.18 |
379+
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.46 | 0.41 |
380+
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 1.23 | 3.25 |
381+
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.76 | 0.36 |
382+
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
385383

386384
*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
387385

388386
#### Text-to-Image Generation
389387
Input a text prompt, output an image.
390388
<br>
391-
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
389+
Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
392390

393-
|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | img/s |
394-
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
395-
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 0.50 | 8193 | 2.27-e4 |
396-
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 0.49 | 8193 | 2.50-e4 |
391+
- mindspore 2.6.0
397392

398-
*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
393+
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
394+
| --- | --- | --- | --- | --- | --- | --- | --- |
395+
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.68 | 8193 |
396+
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.13 | 8193 |
397+
398+
399+
- mindspore 2.7.0
400+
401+
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
402+
| --- | --- | --- | --- | --- | --- | --- | --- |
403+
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.85 | 8193 |
404+
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.33 | 8193 |
405+
406+
*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16.
399407

400408
#### VQA
401409
Input an image and a text prompt, output textual response.
402410
<br>
403-
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
411+
Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
404412

405-
|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | response/s |
406-
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
407-
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 0.29 | 131 | 0.03 |
408-
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 0.24 | 92 | 0.05 |
413+
- mindspore 2.6.0
409414

410-
*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
415+
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
416+
| --- | --- | --- | --- | --- | --- | --- | --- |
417+
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 4.12 | 659 |
418+
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 4.37 | 652 |
411419

412-
### Training
420+
- mindspore 2.7.0
413421

414-
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
422+
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
423+
| --- | --- | --- | --- | --- | --- | --- | --- |
424+
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 5.15 | 659 |
425+
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 5.16 | 652 |
415426

416-
| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
417-
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
418-
| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 |
419-
| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 3.08 | 4993 | 0.32 |
427+
*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16.
420428

421-
*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
429+
### Training
422430

423-
<br>
424-
Experiments are tested on ascend 910* with mindspore 2.5.0 graph mode.
431+
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0*.
425432

426-
| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
427-
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
428-
| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 |
429-
| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 1.95 | 5000 | 0.51 |
433+
|mode | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
434+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
435+
| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.79 | 400 | 0.56 |
436+
| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.79 | 400 | 0.56 |
437+
|graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 34.11 | 400 | 0.03 |
438+
|graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 20.10 | 400 | 0.05 |
430439

431-
*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
440+
*note: currently it supports training with mindspore 2.7.0 only.
441+
Used mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.

examples/emu3/autoencode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from PIL import Image
1111

1212
import mindspore as ms
13-
from mindspore import Tensor, nn, ops
13+
from mindspore import Tensor, mint, ops
1414

1515
from mindone.diffusers.training_utils import pynative_no_grad as no_grad
1616
from mindone.utils.amp import auto_mixed_precision
@@ -26,7 +26,7 @@
2626
MS_DTYPE = ms.bfloat16 # float16 fail to reconstruct
2727
model = Emu3VisionVQModel.from_pretrained(MODEL_HUB, use_safetensors=True, mindspore_dtype=MS_DTYPE).set_train(False)
2828
model = auto_mixed_precision(
29-
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
29+
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
3030
) # NOTE: nn.Conv3d used float16
3131
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)
3232
# Same as using AutoModel/AutoImageProcessor:

examples/emu3/convert_weights.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

examples/emu3/emu3/tests/test_emu3_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PIL import Image
99

1010
import mindspore as ms
11-
from mindspore import nn
11+
from mindspore import mint
1212

1313
ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
1414
# ms.set_context(mode = ms.GRAPH_MODE) # NOT SUPPORTED YET
@@ -81,7 +81,7 @@
8181
VQ_HUB, use_safetensors=True, mindspore_dtype=VQ_DTYPE
8282
).set_train(False)
8383
image_tokenizer = auto_mixed_precision(
84-
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
84+
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
8585
)
8686
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
8787
print("*" * 100)

examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def __init__(
138138
stride = (1, 1, 1)
139139
kernel_size = (3, 3, 3)
140140

141-
self.norm1 = nn.BatchNorm3d(in_channels)
141+
self.norm1 = mint.nn.BatchNorm3d(in_channels)
142142
self.conv1 = Emu3VisionVQCausalConv3d(
143143
in_channels,
144144
out_channels,
145145
kernel_size=kernel_size,
146146
stride=stride,
147147
)
148-
self.norm2 = nn.BatchNorm3d(out_channels)
148+
self.norm2 = mint.nn.BatchNorm3d(out_channels)
149149
self.dropout = nn.Dropout(p=dropout)
150150
self.conv2 = Emu3VisionVQCausalConv3d(
151151
out_channels,
@@ -702,9 +702,9 @@ def _init_weights(self, module):
702702
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
703703
module.gamma.set_data(initializer(Constant(1), module.gamma.shape, module.gamma.dtype))
704704
module.beta.set_data(initializer(Constant(0), module.beta.shape, module.beta.dtype))
705-
elif isinstance(module, nn.BatchNorm3d):
706-
module.bn2d.gamma.set_data(initializer(Constant(1), module.bn2d.gamma.shape, module.bn2d.gamma.dtype))
707-
module.bn2d.beta.set_data(initializer(Constant(0), module.bn2d.beta.shape, module.bn2d.beta.dtype))
705+
elif isinstance(module, mint.nn.BatchNorm3d):
706+
module.weight.set_data(initializer(Constant(1), module.weight.shape, module.weight.dtype))
707+
module.bias.set_data(initializer(Constant(0), module.bias.shape, module.bias.dtype))
708708

709709

710710
class Emu3VisionVQModel(Emu3VisionVQPretrainedModel):

examples/emu3/image_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from transformers.generation.configuration_utils import GenerationConfig
88

99
import mindspore as ms
10-
from mindspore import Tensor, nn, ops
10+
from mindspore import Tensor, mint, ops
1111

1212
from mindone.transformers.generation.logits_process import (
1313
LogitsProcessorList,
@@ -44,7 +44,7 @@
4444
False
4545
)
4646
image_tokenizer = auto_mixed_precision(
47-
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
47+
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
4848
)
4949
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
5050

examples/emu3/multimodal_understanding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers.generation.configuration_utils import GenerationConfig
99

1010
import mindspore as ms
11-
from mindspore import Tensor, nn
11+
from mindspore import Tensor, mint
1212

1313
from mindone.utils.amp import auto_mixed_precision
1414

@@ -38,7 +38,7 @@
3838
False
3939
)
4040
image_tokenizer = auto_mixed_precision(
41-
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
41+
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
4242
)
4343
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
4444
print("Loaded all models, time elapsed: %.4fs" % (time.time() - start_time))

examples/emu3/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ transformers>=4.44.0
22
tiktoken==0.6.0
33
pillow
44
omegaconf
5-
opencv-python
5+
opencv-python==4.9.0.80
66
ezcolorlog
77
mindcv==0.3.0

examples/emu3/scripts/t2i_sft_seq_parallel.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ python emu3/train/train_seq_parallel.py \
2727
--max_position_embeddings 4200 \
2828
--trainable_hidden_layers 32 \
2929
--output_dir ${LOG_DIR} \
30-
--num_train_epochs 5 \
30+
--num_train_epochs 4 \
3131
--per_device_train_batch_size 1 \
3232
--gradient_accumulation_steps 1 \
3333
--save_steps 1 \

examples/emu3/scripts/vqa_sft_seq_parallel.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ python emu3/train/train_seq_parallel.py \
2626
--image_area 147456 \
2727
--max_position_embeddings 2560 \
2828
--output_dir ${LOG_DIR} \
29-
--num_train_epochs 50 \
29+
--num_train_epochs 4 \
3030
--per_device_train_batch_size 1 \
3131
--gradient_accumulation_steps 1 \
3232
--save_steps 1 \

0 commit comments

Comments
 (0)