@@ -43,7 +43,8 @@ Image VQA:
4343### Requirements
4444| mindspore | ascend driver | firmware | cann tookit/kernel|
4545| --- | --- | --- | --- |
46- | 2.5.0 | 24.1RC2 | 7.3.0.1.231 | 8.0.RC3.beta1|
46+ | 2.6.0 | 24.1.RC3 | 7.5.T11.0 | 8.1.RC1|
47+ | 2.7.0 | 24.1.RC3 | 7.5.T11.0 | 8.2.RC1|
4748
4849### Dependencies
4950
@@ -69,12 +70,6 @@ pip install -r requirements.txt
6970
7071</details >
7172
72- #### Weight conversion:
73-
74- For model ** Emu3-VisionTokenizer** , there are some incompatible network layer variable names that cannot be automatically converted, we need to convert some weight names in advanved before loading the pre-trained weights:
75- ```
76- python python convert_weights.py --safetensor_path ORIGINAL_MODEL.safetensors --target_safetensor_path model.safetensors
77- ```
7873
7974## Inference
8075
@@ -118,7 +113,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
118113 mindspore_dtype = VQ_DTYPE
119114).set_train(False )
120115image_tokenizer = auto_mixed_precision(
121- image_tokenizer, amp_level = " O2" , dtype = VQ_DTYPE , custom_fp32_cells = [nn.BatchNorm3d]
116+ image_tokenizer, amp_level = " O2" , dtype = VQ_DTYPE , custom_fp32_cells = [mint. nn.BatchNorm3d]
122117)
123118processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
124119
@@ -220,7 +215,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
220215 mindspore_dtype = VQ_DTYPE
221216).set_train(False )
222217image_tokenizer = auto_mixed_precision(
223- image_tokenizer, amp_level = " O2" , dtype = VQ_DTYPE , custom_fp32_cells = [nn.BatchNorm3d]
218+ image_tokenizer, amp_level = " O2" , dtype = VQ_DTYPE , custom_fp32_cells = [mint. nn.BatchNorm3d]
224219)
225220processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
226221
@@ -279,7 +274,7 @@ model = Emu3VisionVQModel.from_pretrained(
279274 mindspore_dtype = MS_DTYPE
280275 ).set_train(False )
281276model = auto_mixed_precision(
282- model, amp_level = " O2" , dtype = MS_DTYPE , custom_fp32_cells = [nn.BatchNorm3d]
277+ model, amp_level = " O2" , dtype = MS_DTYPE , custom_fp32_cells = [mint. nn.BatchNorm3d]
283278)
284279processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB )
285280
@@ -366,66 +361,81 @@ DATA_DIR
366361
367362Input an image or a clip of video frames, outout the reconstructed image(s).
368363<br >
369- Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode .
364+ Experiments are tested on Ascend Atlas 800T A2 machines.
370365
371- | model name | precision* | cards | batch size| resolution | s/step | img/s |
372- | --- | --- | --- | --- | --- | --- | --- |
373- | Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.65 | 0.38 |
374- | Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.96 | 1.04 |
366+ - mindspore 2.6.0
375367
376- * note: mixed precision, ` BatchNorm3d ` uses fp32, ` Conv3d ` and ` Flash Attention ` use fp16.
368+ | mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
369+ | --- | --- | --- | --- | --- | --- | --- | --- |
370+ | pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.42 | 0.41 |
371+ | pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.95 | 4.21 |
372+ | graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 3.06 | 0.33 |
373+ | graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
377374
378- <br >
379- Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 graph mode.
375+ - mindspore 2.7.0
380376
381- | model name | precision* | cards | batch size| resolution | graph compile | s/step | img/s |
377+ | mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
382378| --- | --- | --- | --- | --- | --- | --- | --- |
383- | Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 15s | 3.23 | 0.31 |
384- | Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 15s | 5.46 | 0.18 |
379+ | pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.46 | 0.41 |
380+ | pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 1.23 | 3.25 |
381+ | graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.76 | 0.36 |
382+ | graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
385383
386384* note: mixed precision, ` BatchNorm3d ` uses fp32, ` Conv3d ` and ` Flash Attention ` use fp16.
387385
388386#### Text-to-Image Generation
389387Input a text prompt, output an image.
390388<br >
391- Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
389+ Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
392390
393- | model name | precision* | cards | batch size| resolution | flash attn | s/step | step | img/s |
394- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
395- | Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 0.50 | 8193 | 2.27-e4 |
396- | Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 0.49 | 8193 | 2.50-e4 |
391+ - mindspore 2.6.0
397392
398- * note: mixed precision, ` BatchNorm3d ` uses fp32, ` Conv3d ` and ` Flash Attention ` use fp16.
393+ | model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
394+ | --- | --- | --- | --- | --- | --- | --- | --- |
395+ | Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.68 | 8193 |
396+ | Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.13 | 8193 |
397+
398+
399+ - mindspore 2.7.0
400+
401+ | model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
402+ | --- | --- | --- | --- | --- | --- | --- | --- |
403+ | Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.85 | 8193 |
404+ | Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.33 | 8193 |
405+
406+ * note: mixed precision, ` BatchNorm3d ` and ` Emu3RMSNorm ` use fp32, ` Conv3d ` and ` Flash Attention ` use fp16.
399407
400408#### VQA
401409Input an image and a text prompt, output textual response.
402410<br >
403- Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
411+ Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
404412
405- | model name | precision* | cards | batch size| resolution | flash attn | s/step | step | response/s |
406- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
407- | Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 0.29 | 131 | 0.03 |
408- | Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 0.24 | 92 | 0.05 |
413+ - mindspore 2.6.0
409414
410- * note: mixed precision, ` BatchNorm3d ` uses fp32, ` Conv3d ` and ` Flash Attention ` use fp16.
415+ | model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
416+ | --- | --- | --- | --- | --- | --- | --- | --- |
417+ | Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 4.12 | 659 |
418+ | Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 4.37 | 652 |
411419
412- ### Training
420+ - mindspore 2.7.0
413421
414- Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
422+ | model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
423+ | --- | --- | --- | --- | --- | --- | --- | --- |
424+ | Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 5.15 | 659 |
425+ | Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 5.16 | 652 |
415426
416- | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu | flash attn | sequence parallel | s/step | step | sample/s |
417- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
418- | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 |
419- | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 3.08 | 4993 | 0.32 |
427+ * note: mixed precision, ` BatchNorm3d ` and ` Emu3RMSNorm ` use fp32, ` Conv3d ` and ` Flash Attention ` use fp16.
420428
421- * note: mixed precision, ` BatchNorm3d ` and ` Emu3RMSNorm ` use fp32.
429+ ### Training
422430
423- <br >
424- Experiments are tested on ascend 910* with mindspore 2.5.0 graph mode.
431+ Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0* .
425432
426- | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu | flash attn | sequence parallel | s/step | step | sample/s |
427- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
428- | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 |
429- | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 1.95 | 5000 | 0.51 |
433+ | mode | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu | flash attn | sequence parallel | s/step | step | sample/s |
434+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
435+ | pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.79 | 400 | 0.56 |
436+ | pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.79 | 400 | 0.56 |
437+ | graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 34.11 | 400 | 0.03 |
438+ | graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 20.10 | 400 | 0.05 |
430439
431- * note: mixed precision, ` BatchNorm3d ` and ` Emu3RMSNorm ` use fp32.
440+ * note: currently it supports training with mindspore 2.7.0 only.
441+ Used mixed precision, ` BatchNorm3d ` and ` Emu3RMSNorm ` use fp32.
0 commit comments