Skip to content

Commit 8a16cf6

Browse files
committed
add data example, train script
1 parent 3e677cb commit 8a16cf6

22 files changed

+1775
-16
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ It will output the preprocessed image, generated 6-view images and CCMs and a 3D
6161
**Tips:** (1) If the result is unsatisfatory, please check whether the input image is correctly pre-processed into a grey background. Otherwise the results will be unpredictable.
6262
(2) Different from the [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM), this official implementation uses UV texture instead of vertex color. It has better texture than the online demo but longer generating time owing to the UV texturing.
6363

64+
## train
65+
We provide training script for multivew generation and their data requirements see `launch_train.sh`.
66+
67+
6468
## Todo List
6569
- [x] Release inference code.
6670
- [x] Release pretrained models.

acce.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
compute_environment: LOCAL_MACHINE
2+
deepspeed_config:
3+
deepspeed_multinode_launcher: standard
4+
offload_optimizer_device: none
5+
offload_param_device: none
6+
zero3_init_flag: false
7+
zero_stage: 2
8+
distributed_type: DEEPSPEED
9+
mixed_precision: fp16
10+
num_processes: 8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
config:
2+
# others
3+
seed: 1234
4+
num_frames: 7
5+
mode: pixel
6+
offset_noise: true
7+
# model related
8+
models:
9+
config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
10+
resume: release_models/sd-v2.1-base-4view-ipmv.pt
11+
# sampler related
12+
sampler:
13+
target: libs.sample.ImageDreamDiffusion
14+
params:
15+
mode: pixel
16+
num_frames: 7
17+
camera_views: [1, 2, 3, 4, 5, 0, 0]
18+
ref_position: 6
19+
random_background: false
20+
offset_noise: true
21+
resize_rate: 1.0
22+
23+
# config datasets
24+
train_data:
25+
target: libs.data.DataRelativeStroke
26+
params:
27+
base_dir: train_examples
28+
caption_csv: train_examples/caption.csv
29+
image_size: 256
30+
repeat: 1
31+
camera_views: [1, 2, 3, 4, 5, 0, 0]
32+
ref_indexs: [0, 1, 3, 4, 5, 2]
33+
ref_position: 6
34+
split: train
35+
num_frames: 7
36+
random_background: true
37+
resize_rate: 0.95
38+
stroke_p: 0.5
39+
eval_size: 100
40+
resize_range:
41+
- 0.5
42+
- 1.0
43+
eval_data:
44+
target: libs.data.DataRelativeStroke
45+
params:
46+
base_dir: train_examples
47+
caption_csv: train_examples/caption.csv
48+
image_size: 256
49+
repeat: 1
50+
camera_views: [1, 2, 3, 4, 5, 0, 0] # camera views are relative views
51+
ref_indexs: [0, 1, 3, 4, 5, 2]
52+
ref_position: 6
53+
split: eval
54+
num_frames: 7
55+
random_background: true
56+
resize_rate: 0.95
57+
stroke_p: 0.5
58+
eval_size: 100
59+
resize_range:
60+
- 0.5
61+
- 1.0
62+
63+
in_the_wild_images:
64+
target: libs.data.InTheWildImages
65+
params:
66+
base_dirs:
67+
- examples
68+
69+
# optimizer related
70+
optimizer:
71+
lr: 5e-5
72+
gradient_accumulation_steps: 12
73+
74+
# wandb related parameters
75+
project: CRM
76+
wandb_run_name: CRM-pixel
77+
wandb_mode: offline
78+
79+
80+
# training hyperparmeters
81+
batch_size: 16
82+
dataloader:
83+
num_workers: 10
84+
shuffle: true
85+
drop_last: true
86+
87+
save_interval: 600000
88+
log_interval: 5000
89+
eval_interval: 300000
90+
max_step: 10000000

configs/stage2-v2-snr_train.yaml

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
config:
2+
# others
3+
seed: 1234
4+
num_frames: 6
5+
mode: pixel
6+
offset_noise: true
7+
gd_type: xyz
8+
# model related
9+
models:
10+
config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
11+
resume: release_models/ImageDream/sd-v2.1-base-4view-ipmv.pt
12+
resume_unet: null
13+
14+
# eval related
15+
sampler:
16+
target: libs.sample.ImageDreamDiffusionStage2
17+
params:
18+
mode: pixel
19+
num_frames: 6
20+
camera_views: [1, 2, 3, 4, 5, 0]
21+
ref_position: null
22+
random_background: false
23+
offset_noise: true
24+
resize_rate: 1.0
25+
26+
# config datasets
27+
train_data:
28+
target: libs.data.DataHQCRelative
29+
params:
30+
xyz_base: train_examples
31+
base_dir: train_examples
32+
caption_csv: train_examples/caption.csv
33+
image_size: 256
34+
repeat: 1
35+
camera_views: [1, 2, 3, 4, 5, 0]
36+
ref_indexs: [0, 1, 3, 4]
37+
ref_position: null
38+
split: train
39+
num_frames: 6
40+
random_background: true
41+
resize_rate: 0.95
42+
eval_data:
43+
target: libs.data.DataHQCRelative
44+
params:
45+
xyz_base: train_examples
46+
base_dir: train_examples
47+
caption_csv: train_examples/caption.csv
48+
image_size: 256
49+
repeat: 1
50+
camera_views: [1, 2, 3, 4, 5, 0] # when pixel mode, last image will be coverd by ref image
51+
ref_indexs: [0, 1, 3, 4]
52+
ref_position: null
53+
split: eval
54+
num_frames: 6
55+
random_background: true
56+
resize_rate: 0.95
57+
58+
# optimizer related
59+
optimizer:
60+
lr: 5e-5
61+
gradient_accumulation_steps: 12
62+
63+
# wandb related parameters
64+
project: CRM
65+
wandb_run_name: CRM-xyz
66+
wandb_mode: offline
67+
68+
69+
# training hyperparmeters
70+
batch_size: 16
71+
dataloader:
72+
num_workers: 10
73+
shuffle: true
74+
drop_last: true
75+
76+
save_interval: 400000
77+
log_interval: 5000
78+
eval_interval: 50000
79+
max_step: 100000000

launch_train.sh

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
2+
3+
# set default values for the environment variables
4+
export OMP_NUM_THREADS=8
5+
if [ -z "$ADDR" ]
6+
then
7+
export ADDR=127.0.0.1
8+
fi
9+
10+
if [ -z "$WORLD_SIZE" ]
11+
then
12+
export WORLD_SIZE=1
13+
fi
14+
15+
if [ -z "$RANK" ]
16+
then
17+
export RANK=0
18+
fi
19+
20+
if [ -z "$MASTER_PORT" ]
21+
then
22+
export MASTER_PORT=29501
23+
fi
24+
25+
export WANDB_MODE=offline
26+
accelerate_args="--config_file acce.yaml --num_machines $WORLD_SIZE \
27+
--machine_rank $RANK --num_processes 1 \
28+
--main_process_port $MASTER_PORT \
29+
--main_process_ip $ADDR"
30+
echo $accelerate_args
31+
32+
# train stage 1
33+
accelerate launch $accelerate_args train.py --config configs/nf7_v3_SNR_rd_size_stroke_train.yaml \
34+
config.batch_size=1 \
35+
config.eval_interval=100
36+
37+
38+
# train stage 2
39+
# accelerate launch $accelerate_args train_stage2.py --config configs/stage2-v2-snr_train.yaml \
40+
# config.batch_size=1 \
41+
# config.eval_interval=100

0 commit comments

Comments
 (0)