Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 57 additions & 47 deletions examples/emu3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Image VQA:
### Requirements
|mindspore | ascend driver | firmware | cann tookit/kernel|
|--- | --- | --- | --- |
|2.5.0 | 24.1RC2 | 7.3.0.1.231 | 8.0.RC3.beta1|
|2.6.0 | 24.1.RC3 | 7.5.T11.0 | 8.1.RC1|
|2.7.0 | 24.1.RC3 | 7.5.T11.0 | 8.2.RC1|

### Dependencies

Expand All @@ -69,12 +70,6 @@ pip install -r requirements.txt

</details>

#### Weight conversion:

For model **Emu3-VisionTokenizer**, there are some incompatible network layer variable names that cannot be automatically converted, we need to convert some weight names in advanved before loading the pre-trained weights:
```
python python convert_weights.py --safetensor_path ORIGINAL_MODEL.safetensors --target_safetensor_path model.safetensors
```

## Inference

Expand Down Expand Up @@ -118,7 +113,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)

Expand Down Expand Up @@ -220,7 +215,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)

Expand Down Expand Up @@ -279,7 +274,7 @@ model = Emu3VisionVQModel.from_pretrained(
mindspore_dtype=MS_DTYPE
).set_train(False)
model = auto_mixed_precision(
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)

Expand Down Expand Up @@ -366,66 +361,81 @@ DATA_DIR

Input an image or a clip of video frames, outout the reconstructed image(s).
<br>
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
Experiments are tested on Ascend Atlas 800T A2 machines.

| model name | precision* | cards | batch size| resolution | s/step | img/s |
| --- | --- | --- | --- | --- | --- | --- |
| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.65 | 0.38 |
| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.96 | 1.04 |
- mindspore 2.6.0

*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
|mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
| --- | --- | --- | --- | --- | --- | --- | --- |
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.42 | 0.41 |
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.95 | 4.21 |
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 3.06 | 0.33 |
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |

<br>
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 graph mode.
- mindspore 2.7.0

| model name | precision* | cards | batch size| resolution | graph compile | s/step | img/s |
|mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 15s | 3.23 | 0.31 |
| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 15s | 5.46 | 0.18 |
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.46 | 0.41 |
|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 1.23 | 3.25 |
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.76 | 0.36 |
|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |

*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.

#### Text-to-Image Generation
Input a text prompt, output an image.
<br>
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.

|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | img/s |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 0.50 | 8193 | 2.27-e4 |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 0.49 | 8193 | 2.50-e4 |
- mindspore 2.6.0

*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.68 | 8193 |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.13 | 8193 |


- mindspore 2.7.0

|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.85 | 8193 |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.33 | 8193 |

*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16.

#### VQA
Input an image and a text prompt, output textual response.
<br>
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.

|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | response/s |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 0.29 | 131 | 0.03 |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 0.24 | 92 | 0.05 |
- mindspore 2.6.0

*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 4.12 | 659 |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 4.37 | 652 |

### Training
- mindspore 2.7.0

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 5.15 | 659 |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 5.16 | 652 |

| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 |
| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 3.08 | 4993 | 0.32 |
*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16.

*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
### Training

<br>
Experiments are tested on ascend 910* with mindspore 2.5.0 graph mode.
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0*.

| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 |
| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 1.95 | 5000 | 0.51 |
|mode | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.79 | 400 | 0.56 |
| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.79 | 400 | 0.56 |
|graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 34.11 | 400 | 0.03 |
|graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 20.10 | 400 | 0.05 |

*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
*note: currently it supports training with mindspore 2.7.0 only.
Used mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
4 changes: 2 additions & 2 deletions examples/emu3/autoencode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image

import mindspore as ms
from mindspore import Tensor, nn, ops
from mindspore import Tensor, mint, ops

from mindone.diffusers.training_utils import pynative_no_grad as no_grad
from mindone.utils.amp import auto_mixed_precision
Expand All @@ -26,7 +26,7 @@
MS_DTYPE = ms.bfloat16 # float16 fail to reconstruct
model = Emu3VisionVQModel.from_pretrained(MODEL_HUB, use_safetensors=True, mindspore_dtype=MS_DTYPE).set_train(False)
model = auto_mixed_precision(
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
) # NOTE: nn.Conv3d used float16
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)
# Same as using AutoModel/AutoImageProcessor:
Expand Down
65 changes: 0 additions & 65 deletions examples/emu3/convert_weights.py

This file was deleted.

4 changes: 2 additions & 2 deletions examples/emu3/emu3/tests/test_emu3_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image

import mindspore as ms
from mindspore import nn
from mindspore import mint

ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
# ms.set_context(mode = ms.GRAPH_MODE) # NOT SUPPORTED YET
Expand Down Expand Up @@ -81,7 +81,7 @@
VQ_HUB, use_safetensors=True, mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
print("*" * 100)
Expand Down
10 changes: 5 additions & 5 deletions examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def __init__(
stride = (1, 1, 1)
kernel_size = (3, 3, 3)

self.norm1 = nn.BatchNorm3d(in_channels)
self.norm1 = mint.nn.BatchNorm3d(in_channels)
self.conv1 = Emu3VisionVQCausalConv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
)
self.norm2 = nn.BatchNorm3d(out_channels)
self.norm2 = mint.nn.BatchNorm3d(out_channels)
self.dropout = nn.Dropout(p=dropout)
self.conv2 = Emu3VisionVQCausalConv3d(
out_channels,
Expand Down Expand Up @@ -702,9 +702,9 @@ def _init_weights(self, module):
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
module.gamma.set_data(initializer(Constant(1), module.gamma.shape, module.gamma.dtype))
module.beta.set_data(initializer(Constant(0), module.beta.shape, module.beta.dtype))
elif isinstance(module, nn.BatchNorm3d):
module.bn2d.gamma.set_data(initializer(Constant(1), module.bn2d.gamma.shape, module.bn2d.gamma.dtype))
module.bn2d.beta.set_data(initializer(Constant(0), module.bn2d.beta.shape, module.bn2d.beta.dtype))
elif isinstance(module, mint.nn.BatchNorm3d):
module.weight.set_data(initializer(Constant(1), module.weight.shape, module.weight.dtype))
module.bias.set_data(initializer(Constant(0), module.bias.shape, module.bias.dtype))


class Emu3VisionVQModel(Emu3VisionVQPretrainedModel):
Expand Down
4 changes: 2 additions & 2 deletions examples/emu3/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers.generation.configuration_utils import GenerationConfig

import mindspore as ms
from mindspore import Tensor, nn, ops
from mindspore import Tensor, mint, ops

from mindone.transformers.generation.logits_process import (
LogitsProcessorList,
Expand Down Expand Up @@ -44,7 +44,7 @@
False
)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)

Expand Down
4 changes: 2 additions & 2 deletions examples/emu3/multimodal_understanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.generation.configuration_utils import GenerationConfig

import mindspore as ms
from mindspore import Tensor, nn
from mindspore import Tensor, mint

from mindone.utils.amp import auto_mixed_precision

Expand Down Expand Up @@ -38,7 +38,7 @@
False
)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
print("Loaded all models, time elapsed: %.4fs" % (time.time() - start_time))
Expand Down
2 changes: 1 addition & 1 deletion examples/emu3/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ transformers>=4.44.0
tiktoken==0.6.0
pillow
omegaconf
opencv-python
opencv-python==4.9.0.80
ezcolorlog
mindcv==0.3.0
2 changes: 1 addition & 1 deletion examples/emu3/scripts/t2i_sft_seq_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ python emu3/train/train_seq_parallel.py \
--max_position_embeddings 4200 \
--trainable_hidden_layers 32 \
--output_dir ${LOG_DIR} \
--num_train_epochs 5 \
--num_train_epochs 4 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1 \
Expand Down
2 changes: 1 addition & 1 deletion examples/emu3/scripts/vqa_sft_seq_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python emu3/train/train_seq_parallel.py \
--image_area 147456 \
--max_position_embeddings 2560 \
--output_dir ${LOG_DIR} \
--num_train_epochs 50 \
--num_train_epochs 4 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1 \
Expand Down