-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_internevo.py
184 lines (137 loc) · 6.14 KB
/
train_internevo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch
import torch.nn.functional as F
import torch.utils.data as Data
from internlm.core.context import global_context as gpc
from internlm.initialize import initialize_distributed_env
from internlm.utils.common import parse_args, get_current_device, enable_pytorch_expandable_segments
from flux.util import (configs, load_ae, load_clip,
load_flow_model, load_t5)
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from dataset import loader
from einops import rearrange
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from internlm.core.context import (
IS_REPLICA_ZERO_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
)
from internlm.train.pipeline import initialize_optimizer, initialize_parallel_communicator
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from datetime import datetime
from dataset import DummyClsDataset
from internlm.checkpoint.checkpoint_manager import CheckpointManager
from internlm.data.train_state import get_train_state
logger = get_logger(__file__)
def get_models(model_cfg: dict, device, is_schnell: bool):
name = model_cfg.model_name
t5 = load_t5(
tokenizer_path=model_cfg.t5_tokenizer,
model_path=model_cfg.t5_ckpt,
device=device, max_length=256 if is_schnell else 512)
clip = load_clip(
tokenizer_path=model_cfg.clip_tokenizer,
model_path=model_cfg.clip_ckpt,
device=device)
model = load_flow_model(name, device=device).to(device)
vae = load_ae(
name,
ckpt_path=model_cfg.vae_ckpt,
).to(device)
for name, p in model.named_parameters():
if not hasattr(p, IS_WEIGHT_ZERO_PARALLEL):
setattr(p, IS_REPLICA_ZERO_PARALLEL, True)
return model, vae, t5, clip
def main(args):
# obtain the data config
data_cfg = gpc.config.flux.data
# obtain the model config
model_cfg = gpc.config.flux.model
if model_cfg.weight_dtype == "bfloat16":
weight_dtype = torch.bfloat16
is_schnell = model_cfg.model_name == "flux-schnell"
device = get_current_device()
dit, vae, t5, clip = get_models(model_cfg=model_cfg, device=device, is_schnell=is_schnell)
isp_communicator = initialize_parallel_communicator(dit)
vae.requires_grad_(False)
t5.requires_grad_(False)
clip.requires_grad_(False)
dit = dit.to(torch.bfloat16)
dit.train()
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(dit, isp_communicator)
# train_dataset = DummyClsDataset([3, 256, 256])
# sampler = DistributedSampler(
# train_dataset,
# num_replicas=1, # important
# rank=0, # important
# shuffle=True,
# seed=global_seed
# )
# train_dataloader = Data.DataLoader(dataset=train_dataset,
# batch_size=1,
# shuffle=False,
# sampler=sampler,
# num_workers=4,
# pin_memory=True,
# drop_last=True)
train_dataloader = loader(train_batch_size=data_cfg.batch_size, num_workers=data_cfg.num_workers, img_dir=data_cfg.train_folder, img_size=data_cfg.img_size)
train_iter = iter(train_dataloader)
with open(args.config, "r") as f:
config_lines = f.readlines()
# initialize the checkpoint manager
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=dit,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_dl=train_dataloader,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
)
train_state = get_train_state(train_dataloader)
for step in range(0, data_cfg.total_steps):
gpc.step = step
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_dataloader)
batch = next(train_iter)
img, prompts = batch[0], batch[1]
with torch.no_grad():
x_1 = vae.encode(img.to(device))
inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts)
x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
bs = img.shape[0]
t = torch.sigmoid(torch.randn((bs,), device=device))
x_0 = torch.randn_like(x_1).to(device)
x_t = (1 - t) * x_1 + t * x_0
guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype)
# Predict the noise residual and compute loss
model_pred = dit(img=x_t.to(weight_dtype),
img_ids=inp['img_ids'].to(weight_dtype),
txt=inp['txt'].to(weight_dtype),
txt_ids=inp['txt_ids'].to(weight_dtype),
y=inp['vec'].to(weight_dtype),
timesteps=t.to(weight_dtype),
guidance=guidance_vec.to(weight_dtype),)
loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")
# Backpropagate
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if gpc.is_rank_for_log():
logger.info(f"{datetime.now()}: step = {step}, loss = {loss}")
if ckpt_manager.try_save_checkpoint(train_state):
ckpt_manager.wait_async_upload_finish()
if __name__ == "__main__":
args = parse_args()
# Initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# Run the main function with parsed arguments
main(args)