From 21a1f8b9fc88f21ea100abf0438ce4669bd357c1 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Thu, 16 Oct 2025 18:32:11 +0800 Subject: [PATCH 1/2] use mint.nn.BatchNorm3d; update inference perf for ms2.6.0&ms2.7.0 --- examples/emu3/README.md | 103 ++++++++++-------- examples/emu3/autoencode.py | 4 +- examples/emu3/convert_weights.py | 65 ----------- examples/emu3/emu3/tests/test_emu3_infer.py | 4 +- .../emu3/tokenizer/modeling_emu3visionvq.py | 10 +- examples/emu3/image_generation.py | 4 +- examples/emu3/multimodal_understanding.py | 4 +- examples/emu3/requirements.txt | 2 +- 8 files changed, 71 insertions(+), 125 deletions(-) delete mode 100644 examples/emu3/convert_weights.py diff --git a/examples/emu3/README.md b/examples/emu3/README.md index 1009a5d473..537958e0f7 100644 --- a/examples/emu3/README.md +++ b/examples/emu3/README.md @@ -43,7 +43,8 @@ Image VQA: ### Requirements |mindspore | ascend driver | firmware | cann tookit/kernel| |--- | --- | --- | --- | -|2.5.0 | 24.1RC2 | 7.3.0.1.231 | 8.0.RC3.beta1| +|2.6.0 | 24.1.RC3 | 7.5.T11.0 | 8.1.RC1| +|2.7.0 | 24.1.RC3 | 7.5.T11.0 | 8.2.RC1| ### Dependencies @@ -69,12 +70,6 @@ pip install -r requirements.txt -#### Weight conversion: - -For model **Emu3-VisionTokenizer**, there are some incompatible network layer variable names that cannot be automatically converted, we need to convert some weight names in advanved before loading the pre-trained weights: -``` -python python convert_weights.py --safetensor_path ORIGINAL_MODEL.safetensors --target_safetensor_path model.safetensors -``` ## Inference @@ -118,7 +113,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained( mindspore_dtype=VQ_DTYPE ).set_train(False) image_tokenizer = auto_mixed_precision( - image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) @@ -220,7 +215,7 @@ image_tokenizer = Emu3VisionVQModel.from_pretrained( mindspore_dtype=VQ_DTYPE ).set_train(False) image_tokenizer = auto_mixed_precision( - image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) @@ -279,7 +274,7 @@ model = Emu3VisionVQModel.from_pretrained( mindspore_dtype=MS_DTYPE ).set_train(False) model = auto_mixed_precision( - model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB) @@ -366,66 +361,82 @@ DATA_DIR Input an image or a clip of video frames, outout the reconstructed image(s).
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode. +Experiments are tested on Ascend Atlas 800T A2 machines. -| model name | precision* | cards | batch size| resolution | s/step | img/s | -| --- | --- | --- | --- | --- | --- | --- | -| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.65 | 0.38 | -| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.96 | 1.04 | +- mindspore 2.6.0 -*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16. +|mode | model name | precision* | cards | batch size| resolution | s/step | img/s | +| --- | --- | --- | --- | --- | --- | --- | --- | +|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.42 | 0.41 | +|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.95 | 4.21 | +|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 3.06 | 0.33 | +|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 | -
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 graph mode. +- mindspore 2.7.0 -| model name | precision* | cards | batch size| resolution | graph compile | s/step | img/s | +|mode | model name | precision* | cards | batch size| resolution | s/step | img/s | | --- | --- | --- | --- | --- | --- | --- | --- | -| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 15s | 3.23 | 0.31 | -| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 15s | 5.46 | 0.18 | +|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.46 | 0.41 | +|pynative| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 1.23 | 3.25 | +|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.76 | 0.36 | +|graph| Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 | *note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16. #### Text-to-Image Generation Input a text prompt, output an image.
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode. +Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode. -|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | img/s | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 0.50 | 8193 | 2.27-e4 | -| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 0.49 | 8193 | 2.50-e4 | +- mindspore 2.6.0 + +|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.68 | 8193 | +| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.13 | 8193 | -*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16. + +- mindspore 2.7.0 + +|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.85 | 8193 | +| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.33 | 8193 | + +*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16. #### VQA Input an image and a text prompt, output textual response.
-Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode. +Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode. -|model name | precision* | cards | batch size| resolution | flash attn | s/step | step | response/s | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 0.29 | 131 | 0.03 | -| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 0.24 | 92 | 0.05 | +- mindspore 2.6.0 -*note: mixed precision, `BatchNorm3d` uses fp32, `Conv3d` and `Flash Attention` use fp16. +|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 4.12 | 659 | +| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 4.37 | 652 | -### Training +- mindspore 2.7.0 -Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode. +|model name | precision* | cards | batch size| resolution | flash attn | tokens/s | step | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 5.15 | 659 | +| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 5.16 | 652 | -| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 | -| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 3.08 | 4993 | 0.32 | +*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32, `Conv3d` and `Flash Attention` use fp16. -*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32. +### Training -
-Experiments are tested on ascend 910* with mindspore 2.5.0 graph mode. +Experiments are tested on Ascend Atlas 800T A2 machines. + +- mindspore 2.5.0 -| stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 | -| stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 8 shards | 1.95 | 5000 | 0.51 | +|mode | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 | +| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 3.08 | 4993 | 0.32 | +|graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 | +|graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.95 | 5000 | 0.51 | *note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32. diff --git a/examples/emu3/autoencode.py b/examples/emu3/autoencode.py index 822f1c592d..1023705be7 100644 --- a/examples/emu3/autoencode.py +++ b/examples/emu3/autoencode.py @@ -10,7 +10,7 @@ from PIL import Image import mindspore as ms -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, ops from mindone.diffusers.training_utils import pynative_no_grad as no_grad from mindone.utils.amp import auto_mixed_precision @@ -26,7 +26,7 @@ MS_DTYPE = ms.bfloat16 # float16 fail to reconstruct model = Emu3VisionVQModel.from_pretrained(MODEL_HUB, use_safetensors=True, mindspore_dtype=MS_DTYPE).set_train(False) model = auto_mixed_precision( - model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) # NOTE: nn.Conv3d used float16 processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB) # Same as using AutoModel/AutoImageProcessor: diff --git a/examples/emu3/convert_weights.py b/examples/emu3/convert_weights.py deleted file mode 100644 index 43946cb4a3..0000000000 --- a/examples/emu3/convert_weights.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -A script to convert pytorch safetensors to mindspore compatible safetensors: - -Because some weights and variables in networks cannot be auto-converted, e.g. BatchNorm3d.bn2d.weight vs BatchNorm3d.gamma - -To run this script, you should have installed both pytorch and mindspore. - -Usage: -python convert_model.py --safetensor_path pytorch_model.safetensors --target_safetensor_path model.safetensors - -The converted model `model.safetensors` will be saved in the same directory as this file belonging to. -""" - -import argparse -import os - -import torch -from safetensors import safe_open -from safetensors.torch import load_file, save_file - - -def convert_safetensors(args): - with safe_open(args.safetensor_path, framework="np") as f: - metadata = f.metadata() - - weights_safetensors = load_file(args.safetensor_path) - weights_ms_safetensors = {} - - # For BatchNorm3d: - # turn torch key : X.time_res_stack.X.norm*.weight/bias/running_mean/running_var - # to ms key : X.time_res_stack.X.norm*.bn2d.gamma/beta/moving_mean/moving_variance - for key, value in weights_safetensors.items(): - if (".time_res_stack" in key) and (".norm" in key): - origin_key = key - if key.endswith("norm1.weight") or key.endswith("norm2.weight"): - key = key.replace("weight", "bn2d.gamma") - elif key.endswith("norm1.bias") or key.endswith("norm2.bias"): - key = key.replace("bias", "bn2d.beta") - elif key.endswith("norm1.running_mean") or key.endswith("norm2.running_mean"): - key = key.replace("running_mean", "bn2d.moving_mean") - elif key.endswith("norm1.running_var") or key.endswith("norm2.running_var"): - key = key.replace("running_var", "bn2d.moving_variance") - print(f"{origin_key} -> {key}") - - weights_ms_safetensors[key] = torch.from_numpy(value.numpy()) - - save_file_dir = os.path.join(os.path.dirname(args.safetensor_path), args.target_safetensor_path) - save_file(weights_ms_safetensors, save_file_dir, metadata=metadata) - print(f"Safetensors is converted and saved as {save_file_dir}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Emu2 model weight conversion") - parser.add_argument( - "--safetensor_path", - type=str, - help="path to Emu3 weight from torch (model.safetensors)", - ) - parser.add_argument( - "--target_safetensor_path", type=str, help="path to sdxl lora weight from mindone kohya (xxx.safetensors)" - ) - - args, _ = parser.parse_known_args() - print("Converting...") - convert_safetensors(args) diff --git a/examples/emu3/emu3/tests/test_emu3_infer.py b/examples/emu3/emu3/tests/test_emu3_infer.py index a4993a5e52..9e451bcaaf 100644 --- a/examples/emu3/emu3/tests/test_emu3_infer.py +++ b/examples/emu3/emu3/tests/test_emu3_infer.py @@ -8,7 +8,7 @@ from PIL import Image import mindspore as ms -from mindspore import nn +from mindspore import mint ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True) # ms.set_context(mode = ms.GRAPH_MODE) # NOT SUPPORTED YET @@ -81,7 +81,7 @@ VQ_HUB, use_safetensors=True, mindspore_dtype=VQ_DTYPE ).set_train(False) image_tokenizer = auto_mixed_precision( - image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) print("*" * 100) diff --git a/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py b/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py index 7e9a7e3329..183d7de480 100644 --- a/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py +++ b/examples/emu3/emu3/tokenizer/modeling_emu3visionvq.py @@ -138,14 +138,14 @@ def __init__( stride = (1, 1, 1) kernel_size = (3, 3, 3) - self.norm1 = nn.BatchNorm3d(in_channels) + self.norm1 = mint.nn.BatchNorm3d(in_channels) self.conv1 = Emu3VisionVQCausalConv3d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, ) - self.norm2 = nn.BatchNorm3d(out_channels) + self.norm2 = mint.nn.BatchNorm3d(out_channels) self.dropout = nn.Dropout(p=dropout) self.conv2 = Emu3VisionVQCausalConv3d( out_channels, @@ -702,9 +702,9 @@ def _init_weights(self, module): elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): module.gamma.set_data(initializer(Constant(1), module.gamma.shape, module.gamma.dtype)) module.beta.set_data(initializer(Constant(0), module.beta.shape, module.beta.dtype)) - elif isinstance(module, nn.BatchNorm3d): - module.bn2d.gamma.set_data(initializer(Constant(1), module.bn2d.gamma.shape, module.bn2d.gamma.dtype)) - module.bn2d.beta.set_data(initializer(Constant(0), module.bn2d.beta.shape, module.bn2d.beta.dtype)) + elif isinstance(module, mint.nn.BatchNorm3d): + module.weight.set_data(initializer(Constant(1), module.weight.shape, module.weight.dtype)) + module.bias.set_data(initializer(Constant(0), module.bias.shape, module.bias.dtype)) class Emu3VisionVQModel(Emu3VisionVQPretrainedModel): diff --git a/examples/emu3/image_generation.py b/examples/emu3/image_generation.py index 88a1bde60e..530488a340 100644 --- a/examples/emu3/image_generation.py +++ b/examples/emu3/image_generation.py @@ -7,7 +7,7 @@ from transformers.generation.configuration_utils import GenerationConfig import mindspore as ms -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, ops from mindone.transformers.generation.logits_process import ( LogitsProcessorList, @@ -44,7 +44,7 @@ False ) image_tokenizer = auto_mixed_precision( - image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) diff --git a/examples/emu3/multimodal_understanding.py b/examples/emu3/multimodal_understanding.py index 6fd6c5e0db..e8885382cb 100644 --- a/examples/emu3/multimodal_understanding.py +++ b/examples/emu3/multimodal_understanding.py @@ -8,7 +8,7 @@ from transformers.generation.configuration_utils import GenerationConfig import mindspore as ms -from mindspore import Tensor, nn +from mindspore import Tensor, mint from mindone.utils.amp import auto_mixed_precision @@ -38,7 +38,7 @@ False ) image_tokenizer = auto_mixed_precision( - image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[nn.BatchNorm3d] + image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d] ) processor = Emu3Processor(image_processor, image_tokenizer, tokenizer) print("Loaded all models, time elapsed: %.4fs" % (time.time() - start_time)) diff --git a/examples/emu3/requirements.txt b/examples/emu3/requirements.txt index 4c6f3a0a18..1237f0993e 100644 --- a/examples/emu3/requirements.txt +++ b/examples/emu3/requirements.txt @@ -2,6 +2,6 @@ transformers>=4.44.0 tiktoken==0.6.0 pillow omegaconf -opencv-python +opencv-python==4.9.0.80 ezcolorlog mindcv==0.3.0 From abae16504c58644d4dc98d99c8fe62a32f7e808a Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Wed, 5 Nov 2025 16:41:21 +0800 Subject: [PATCH 2/2] update training perf for ms2.7.0 --- examples/emu3/README.md | 15 +++++++-------- examples/emu3/scripts/t2i_sft_seq_parallel.sh | 2 +- examples/emu3/scripts/vqa_sft_seq_parallel.sh | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/emu3/README.md b/examples/emu3/README.md index 537958e0f7..ad37e654a2 100644 --- a/examples/emu3/README.md +++ b/examples/emu3/README.md @@ -428,15 +428,14 @@ Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode. ### Training -Experiments are tested on Ascend Atlas 800T A2 machines. - -- mindspore 2.5.0 +Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0*. |mode | stage | pre-trained model | precision* | cards | batch size| resolution | max token | init lr | recompute | zero stage | grad accu |flash attn | sequence parallel | s/step | step | sample/s | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 2.61 | 4996 | 0.38 | -| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 3.08 | 4993 | 0.32 | -|graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.93 | 4993 | 0.52 | -|graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.95 | 5000 | 0.51 | +| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.79 | 400 | 0.56 | +| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.79 | 400 | 0.56 | +|graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 34.11 | 400 | 0.03 | +|graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 20.10 | 400 | 0.05 | -*note: mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32. +*note: currently it supports training with mindspore 2.7.0 only. +Used mixed precision, `BatchNorm3d` and `Emu3RMSNorm` use fp32. diff --git a/examples/emu3/scripts/t2i_sft_seq_parallel.sh b/examples/emu3/scripts/t2i_sft_seq_parallel.sh index dab937e44f..3f3510b2c1 100644 --- a/examples/emu3/scripts/t2i_sft_seq_parallel.sh +++ b/examples/emu3/scripts/t2i_sft_seq_parallel.sh @@ -27,7 +27,7 @@ python emu3/train/train_seq_parallel.py \ --max_position_embeddings 4200 \ --trainable_hidden_layers 32 \ --output_dir ${LOG_DIR} \ - --num_train_epochs 5 \ + --num_train_epochs 4 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --save_steps 1 \ diff --git a/examples/emu3/scripts/vqa_sft_seq_parallel.sh b/examples/emu3/scripts/vqa_sft_seq_parallel.sh index 2b856b8ac8..ce80b05890 100644 --- a/examples/emu3/scripts/vqa_sft_seq_parallel.sh +++ b/examples/emu3/scripts/vqa_sft_seq_parallel.sh @@ -26,7 +26,7 @@ python emu3/train/train_seq_parallel.py \ --image_area 147456 \ --max_position_embeddings 2560 \ --output_dir ${LOG_DIR} \ - --num_train_epochs 50 \ + --num_train_epochs 4 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --save_steps 1 \