diff --git a/examples/emu3/README.md b/examples/emu3/README.md
index 1009a5d473..ad37e654a2 100644
--- a/examples/emu3/README.md
+++ b/examples/emu3/README.md
@@ -43,7 +43,8 @@ Image VQA:
### Requirements
|mindspore | ascend driver | firmware | cann tookit/kernel|
|--- | --- | --- | --- |
-|2.5.0 | 24.1RC2 | 7.3.0.1.231 | 8.0.RC3.beta1|
+|2.6.0 | 24.1.RC3 | 7.5.T11.0 | 8.1.RC1|
+|2.7.0 | 24.1.RC3 | 7.5.T11.0 | 8.2.RC1|
### Dependencies
@@ -69,12 +70,6 @@ pip install -r requirements.txt
-#### Weight conversion:
-
-For model **Emu3-VisionTokenizer**, there are some incompatible network layer variable names that cannot be automatically converted, we need to convert some weight names in advanved before loading the pre-trained weights:
-```
-python python convert_weights.py --safetensor_path ORIGINAL_MODEL.safetensors --target_safetensor_path model.safetensors
-```
## Inference
@@ -118,7 +113,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
- image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
@@ -220,7 +215,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained(
mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
- image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
@@ -279,7 +274,7 @@ model = Emu3VisionVQModel.from_pretrained(
mindspore_dtype=MS_DTYPE
).set_train(False)
model = auto_mixed_precision(
- model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)
@@ -366,66 +361,81 @@ DATA_DIR
Input an image or a clip of video frames, outout the reconstructed image(s).
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
+Experiments are tested on Ascend Atlas 800T A2 machines.
-| model name | precision* | cards | batch size| resolution | s/step | img/s |
-| --- | --- | --- | --- | --- | --- | --- |
-| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.65 | 0.38 |
-| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.96 | 1.04 |
+- mindspore 2.6.0
-*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
+|mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.42 | 0.41 |
+|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.95 | 4.21 |
+|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 3.06 | 0.33 |
+|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
-
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 graph mode.
+- mindspore 2.7.0
-| model name | precision* | cards | batch size| resolution | graph compile | s/step | img/s |
+|mode | model name | precision* | cards | batch size| resolution | s/step | img/s |
| --- | --- | --- | --- | --- | --- | --- | --- |
-| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 15s | 3.23 | 0.31 |
-| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 15s | 5.46 | 0.18 |
+|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.46 | 0.41 |
+|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 1.23 | 3.25 |
+|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.76 | 0.36 |
+|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
#### Text-to-Image Generation
Input a text prompt, output an image.
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
+Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
-|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | img/s |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- |
-| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 0.50 | 8193 | 2.27-e4 |
-| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 0.49 | 8193 | 2.50-e4 |
+- mindspore 2.6.0
-*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
+|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.68 | 8193 |
+| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.13 | 8193 |
+
+
+- mindspore 2.7.0
+
+|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.85 | 8193 |
+| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.33 | 8193 |
+
+*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16.
#### VQA
Input an image and a text prompt, output textual response.
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
+Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
-|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | response/s |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- |
-| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 0.29 | 131 | 0.03 |
-| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 0.24 | 92 | 0.05 |
+- mindspore 2.6.0
-*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16.
+|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 4.12 | 659 |
+| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 4.37 | 652 |
-### Training
+- mindspore 2.7.0
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode.
+|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 5.15 | 659 |
+| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 5.16 | 652 |
-| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
-| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 |
-| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 3.08 | 4993 | 0.32 |
+*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16.
-*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
+### Training
-
-Experiments are tested on ascend 910* with mindspore 2.5.0 graph mode.
+Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0*.
-| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
-| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
-| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 |
-| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 1.95 | 5000 | 0.51 |
+|mode | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.79 | 400 | 0.56 |
+| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.79 | 400 | 0.56 |
+|graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 34.11 | 400 | 0.03 |
+|graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 20.10 | 400 | 0.05 |
-*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
+*note: currently it supports training with mindspore 2.7.0 only.
+Used mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32.
diff --git a/examples/emu3/autoencode.py b/examples/emu3/autoencode.py
index 822f1c592d..1023705be7 100644
--- a/examples/emu3/autoencode.py
+++ b/examples/emu3/autoencode.py
@@ -10,7 +10,7 @@
from PIL import Image
import mindspore as ms
-from mindspore import Tensor, nn, ops
+from mindspore import Tensor, mint, ops
from mindone.diffusers.training_utils import pynative_no_grad as no_grad
from mindone.utils.amp import auto_mixed_precision
@@ -26,7 +26,7 @@
MS_DTYPE = ms.bfloat16 # float16 fail to reconstruct
model = Emu3VisionVQModel.from_pretrained(MODEL_HUB, use_safetensors=True, mindspore_dtype=MS_DTYPE).set_train(False)
model = auto_mixed_precision(
- model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
) # NOTE: nn.Conv3d used float16
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)
# Same as using AutoModel/AutoImageProcessor:
diff --git a/examples/emu3/convert_weights.py b/examples/emu3/convert_weights.py
deleted file mode 100644
index 43946cb4a3..0000000000
--- a/examples/emu3/convert_weights.py
+++ /dev/null
@@ -1,65 +0,0 @@
-"""
-A script to convert pytorch safetensors to mindspore compatible safetensors:
-
-Because some weights and variables in networks cannot be auto-converted, e.g. BatchNorm3d.bn2d.weight vs BatchNorm3d.gamma
-
-To run this script, you should have installed both pytorch and mindspore.
-
-Usage:
-python convert_model.py --safetensor_path pytorch_model.safetensors --target_safetensor_path model.safetensors
-
-The converted model `model.safetensors` will be saved in the same directory as this file belonging to.
-"""
-
-import argparse
-import os
-
-import torch
-from safetensors import safe_open
-from safetensors.torch import load_file, save_file
-
-
-def convert_safetensors(args):
- with safe_open(args.safetensor_path, framework="np") as f:
- metadata = f.metadata()
-
- weights_safetensors = load_file(args.safetensor_path)
- weights_ms_safetensors = {}
-
- # For BatchNorm3d:
- # turn torch key : X.time_res_stack.X.norm*.weight/bias/running_mean/running_var
- # to ms key : X.time_res_stack.X.norm*.bn2d.gamma/beta/moving_mean/moving_variance
- for key, value in weights_safetensors.items():
- if (".time_res_stack" in key) and (".norm" in key):
- origin_key = key
- if key.endswith("norm1.weight") or key.endswith("norm2.weight"):
- key = key.replace("weight", "bn2d.gamma")
- elif key.endswith("norm1.bias") or key.endswith("norm2.bias"):
- key = key.replace("bias", "bn2d.beta")
- elif key.endswith("norm1.running_mean") or key.endswith("norm2.running_mean"):
- key = key.replace("running_mean", "bn2d.moving_mean")
- elif key.endswith("norm1.running_var") or key.endswith("norm2.running_var"):
- key = key.replace("running_var", "bn2d.moving_variance")
- print(f"{origin_key} -> {key}")
-
- weights_ms_safetensors[key] = torch.from_numpy(value.numpy())
-
- save_file_dir = os.path.join(os.path.dirname(args.safetensor_path), args.target_safetensor_path)
- save_file(weights_ms_safetensors, save_file_dir, metadata=metadata)
- print(f"Safetensors is converted and saved as {save_file_dir}")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Emu2 model weight conversion")
- parser.add_argument(
- "--safetensor_path",
- type=str,
- help="path to Emu3 weight from torch (model.safetensors)",
- )
- parser.add_argument(
- "--target_safetensor_path", type=str, help="path to sdxl lora weight from mindone kohya (xxx.safetensors)"
- )
-
- args, _ = parser.parse_known_args()
- print("Converting...")
- convert_safetensors(args)
diff --git a/examples/emu3/emu3/tests/test_emu3_infer.py b/examples/emu3/emu3/tests/test_emu3_infer.py
index a4993a5e52..9e451bcaaf 100644
--- a/examples/emu3/emu3/tests/test_emu3_infer.py
+++ b/examples/emu3/emu3/tests/test_emu3_infer.py
@@ -8,7 +8,7 @@
from PIL import Image
import mindspore as ms
-from mindspore import nn
+from mindspore import mint
ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
# ms.set_context(mode = ms.GRAPH_MODE) # NOT SUPPORTED YET
@@ -81,7 +81,7 @@
VQ_HUB, use_safetensors=True, mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
- image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
print("*" * 100)
diff --git a/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py b/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py
index 7e9a7e3329..183d7de480 100644
--- a/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py
+++ b/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py
@@ -138,14 +138,14 @@ def __init__(
stride = (1, 1, 1)
kernel_size = (3, 3, 3)
- self.norm1 = nn.BatchNorm3d(in_channels)
+ self.norm1 = mint.nn.BatchNorm3d(in_channels)
self.conv1 = Emu3VisionVQCausalConv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
)
- self.norm2 = nn.BatchNorm3d(out_channels)
+ self.norm2 = mint.nn.BatchNorm3d(out_channels)
self.dropout = nn.Dropout(p=dropout)
self.conv2 = Emu3VisionVQCausalConv3d(
out_channels,
@@ -702,9 +702,9 @@ def _init_weights(self, module):
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
module.gamma.set_data(initializer(Constant(1), module.gamma.shape, module.gamma.dtype))
module.beta.set_data(initializer(Constant(0), module.beta.shape, module.beta.dtype))
- elif isinstance(module, nn.BatchNorm3d):
- module.bn2d.gamma.set_data(initializer(Constant(1), module.bn2d.gamma.shape, module.bn2d.gamma.dtype))
- module.bn2d.beta.set_data(initializer(Constant(0), module.bn2d.beta.shape, module.bn2d.beta.dtype))
+ elif isinstance(module, mint.nn.BatchNorm3d):
+ module.weight.set_data(initializer(Constant(1), module.weight.shape, module.weight.dtype))
+ module.bias.set_data(initializer(Constant(0), module.bias.shape, module.bias.dtype))
class Emu3VisionVQModel(Emu3VisionVQPretrainedModel):
diff --git a/examples/emu3/image_generation.py b/examples/emu3/image_generation.py
index 88a1bde60e..530488a340 100644
--- a/examples/emu3/image_generation.py
+++ b/examples/emu3/image_generation.py
@@ -7,7 +7,7 @@
from transformers.generation.configuration_utils import GenerationConfig
import mindspore as ms
-from mindspore import Tensor, nn, ops
+from mindspore import Tensor, mint, ops
from mindone.transformers.generation.logits_process import (
LogitsProcessorList,
@@ -44,7 +44,7 @@
False
)
image_tokenizer = auto_mixed_precision(
- image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
diff --git a/examples/emu3/multimodal_understanding.py b/examples/emu3/multimodal_understanding.py
index 6fd6c5e0db..e8885382cb 100644
--- a/examples/emu3/multimodal_understanding.py
+++ b/examples/emu3/multimodal_understanding.py
@@ -8,7 +8,7 @@
from transformers.generation.configuration_utils import GenerationConfig
import mindspore as ms
-from mindspore import Tensor, nn
+from mindspore import Tensor, mint
from mindone.utils.amp import auto_mixed_precision
@@ -38,7 +38,7 @@
False
)
image_tokenizer = auto_mixed_precision(
- image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d]
+ image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
print("Loaded all models, time elapsed: %.4fs" % (time.time() - start_time))
diff --git a/examples/emu3/requirements.txt b/examples/emu3/requirements.txt
index 4c6f3a0a18..1237f0993e 100644
--- a/examples/emu3/requirements.txt
+++ b/examples/emu3/requirements.txt
@@ -2,6 +2,6 @@ transformers>=4.44.0
tiktoken==0.6.0
pillow
omegaconf
-opencv-python
+opencv-python==4.9.0.80
ezcolorlog
mindcv==0.3.0
diff --git a/examples/emu3/scripts/t2i_sft_seq_parallel.sh b/examples/emu3/scripts/t2i_sft_seq_parallel.sh
index dab937e44f..3f3510b2c1 100644
--- a/examples/emu3/scripts/t2i_sft_seq_parallel.sh
+++ b/examples/emu3/scripts/t2i_sft_seq_parallel.sh
@@ -27,7 +27,7 @@ python emu3/train/train_seq_parallel.py \
--max_position_embeddings 4200 \
--trainable_hidden_layers 32 \
--output_dir ${LOG_DIR} \
- --num_train_epochs 5 \
+ --num_train_epochs 4 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1 \
diff --git a/examples/emu3/scripts/vqa_sft_seq_parallel.sh b/examples/emu3/scripts/vqa_sft_seq_parallel.sh
index 2b856b8ac8..ce80b05890 100644
--- a/examples/emu3/scripts/vqa_sft_seq_parallel.sh
+++ b/examples/emu3/scripts/vqa_sft_seq_parallel.sh
@@ -26,7 +26,7 @@ python emu3/train/train_seq_parallel.py \
--image_area 147456 \
--max_position_embeddings 2560 \
--output_dir ${LOG_DIR} \
- --num_train_epochs 50 \
+ --num_train_epochs 4 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1 \