-
Notifications
You must be signed in to change notification settings - Fork 307
/
Copy pathtrain.py
537 lines (488 loc) · 20.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, time, psutil, hydra, torch
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import wandb
from hydra.core.hydra_config import HydraConfig
from physicsnemo import Module
from physicsnemo.models.diffusion import UNet, EDMPrecondSuperResolution
from physicsnemo.distributed import DistributedManager
from physicsnemo.metrics.diffusion import RegressionLoss, ResidualLoss, RegressionLossCE
from physicsnemo.utils.patching import RandomPatching2D
from physicsnemo.launch.logging.wandb import initialize_wandb
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
from physicsnemo.launch.utils import (
load_checkpoint,
save_checkpoint,
get_checkpoint_dir,
)
from datasets.dataset import init_train_valid_datasets_from_config, register_dataset
from helpers.train_helpers import (
set_patch_shape,
set_seed,
configure_cuda_for_consistent_precision,
compute_num_accumulation_rounds,
handle_and_clip_gradients,
is_time_for_periodic_task,
)
def checkpoint_list(path, suffix=".mdlus"):
"""Helper function to return sorted list, in ascending order, of checkpoints in a path"""
checkpoints = []
for file in os.listdir(path):
if file.endswith(suffix):
# Split the filename and extract the index
try:
index = int(file.split(".")[-2])
checkpoints.append((index, file))
except ValueError:
continue
# Sort by index and return filenames
checkpoints.sort(key=lambda x: x[0])
return [file for _, file in checkpoints]
# Train the CorrDiff model using the configurations in "conf/config_training.yaml"
@hydra.main(version_base="1.2", config_path="conf", config_name="config_training")
def main(cfg: DictConfig) -> None:
# Initialize distributed environment for training
DistributedManager.initialize()
dist = DistributedManager()
# Initialize loggers
if dist.rank == 0:
writer = SummaryWriter(log_dir="tensorboard")
logger = PythonLogger("main") # General python logger
logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger
initialize_wandb(
project="Modulus-Launch",
entity="Modulus",
name=f"CorrDiff-Training-{HydraConfig.get().job.name}",
group="CorrDiff-DDP-Group",
mode=cfg.wandb.mode,
config=OmegaConf.to_container(cfg),
results_dir=cfg.wandb.results_dir,
)
# Resolve and parse configs
OmegaConf.resolve(cfg)
dataset_cfg = OmegaConf.to_container(cfg.dataset) # TODO needs better handling
# Register custom dataset if specified in config
register_dataset(cfg.dataset.type)
logger0.info(f"Using dataset: {cfg.dataset.type}")
if hasattr(cfg, "validation"):
train_test_split = True
validation_dataset_cfg = OmegaConf.to_container(cfg.validation)
else:
train_test_split = False
validation_dataset_cfg = None
fp_optimizations = cfg.training.perf.fp_optimizations
songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level
fp16 = fp_optimizations == "fp16"
enable_amp = fp_optimizations.startswith("amp")
amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
logger.info(f"Saving the outputs in {os.getcwd()}")
checkpoint_dir = get_checkpoint_dir(
str(cfg.training.io.get("checkpoint_dir", ".")), cfg.model.name
)
if cfg.training.hp.batch_size_per_gpu == "auto":
cfg.training.hp.batch_size_per_gpu = (
cfg.training.hp.total_batch_size // dist.world_size
)
# Set seeds and configure CUDA and cuDNN settings to ensure consistent precision
set_seed(dist.rank)
configure_cuda_for_consistent_precision()
# Instantiate the dataset
data_loader_kwargs = {
"pin_memory": True,
"num_workers": cfg.training.perf.dataloader_workers,
"prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None,
}
(
dataset,
dataset_iterator,
validation_dataset,
validation_dataset_iterator,
) = init_train_valid_datasets_from_config(
dataset_cfg,
data_loader_kwargs,
batch_size=cfg.training.hp.batch_size_per_gpu,
seed=0,
validation_dataset_cfg=validation_dataset_cfg,
train_test_split=train_test_split,
)
# Parse image configuration & update model args
dataset_channels = len(dataset.input_channels())
img_in_channels = dataset_channels
img_shape = dataset.image_shape()
img_out_channels = len(dataset.output_channels())
if cfg.model.hr_mean_conditioning:
img_in_channels += img_out_channels
if cfg.model.name == "lt_aware_ce_regression":
prob_channels = dataset.get_prob_channel_index()
else:
prob_channels = None
# Parse the patch shape
if (
cfg.model.name == "patched_diffusion"
or cfg.model.name == "lt_aware_patched_diffusion"
):
patch_shape_x = cfg.training.hp.patch_shape_x
patch_shape_y = cfg.training.hp.patch_shape_y
else:
patch_shape_x = None
patch_shape_y = None
patch_shape = (patch_shape_y, patch_shape_x)
use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
if use_patching:
# Utility to perform patches extraction and batching
patching = RandomPatching2D(
img_shape=img_shape,
patch_shape=patch_shape,
patch_num=getattr(cfg.training.hp, "patch_num", 1),
)
logger0.info("Patch-based training enabled")
else:
patching = None
logger0.info("Patch-based training disabled")
# interpolate global channel if patch-based model is used
if use_patching:
img_in_channels += dataset_channels
# Instantiate the model and move to device.
model_args = { # default parameters for all networks
"img_out_channels": img_out_channels,
"img_resolution": list(img_shape),
"use_fp16": fp16,
"checkpoint_level": songunet_checkpoint_level,
}
if cfg.model.name == "lt_aware_ce_regression":
model_args["prob_channels"] = prob_channels
if hasattr(cfg.model, "model_args"): # override defaults from config file
model_args.update(OmegaConf.to_container(cfg.model.model_args))
if cfg.model.name == "regression":
model = UNet(
img_in_channels=img_in_channels + model_args["N_grid_channels"],
**model_args,
)
elif cfg.model.name == "lt_aware_ce_regression":
model = UNet(
img_in_channels=img_in_channels
+ model_args["N_grid_channels"]
+ model_args["lead_time_channels"],
**model_args,
)
elif cfg.model.name == "lt_aware_patched_diffusion":
model = EDMPrecondSuperResolution(
img_in_channels=img_in_channels
+ model_args["N_grid_channels"]
+ model_args["lead_time_channels"],
**model_args,
)
elif cfg.model.name == "diffusion":
model = EDMPrecondSuperResolution(
img_in_channels=img_in_channels + model_args["N_grid_channels"],
**model_args,
)
elif cfg.model.name == "patched_diffusion":
model = EDMPrecondSuperResolution(
img_in_channels=img_in_channels + model_args["N_grid_channels"],
**model_args,
)
else:
raise ValueError(f"Invalid model: {cfg.model.name}")
model.train().requires_grad_(True).to(dist.device)
# Check if regression model is used with patching
if (
cfg.model.name in ["regression", "lt_aware_ce_regression"]
and patching is not None
):
raise ValueError(
f"Regression model ({cfg.model.name}) cannot be used with patch-based training. "
)
# Enable distributed data parallel if applicable
if dist.world_size > 1:
model = DistributedDataParallel(
model,
device_ids=[dist.local_rank],
broadcast_buffers=True,
output_device=dist.device,
find_unused_parameters=dist.find_unused_parameters,
)
if cfg.wandb.watch_model and dist.rank == 0:
wandb.watch(model)
# Load the regression checkpoint if applicable
if (
hasattr(cfg.training.io, "regression_checkpoint_path")
and cfg.training.io.regression_checkpoint_path is not None
):
regression_checkpoint_path = to_absolute_path(
cfg.training.io.regression_checkpoint_path
)
if not os.path.exists(regression_checkpoint_path):
raise FileNotFoundError(
f"Expected this regression checkpoint but not found: {regression_checkpoint_path}"
)
regression_net = Module.from_checkpoint(regression_checkpoint_path)
regression_net.eval().requires_grad_(False).to(dist.device)
logger0.success("Loaded the pre-trained regression model")
# Instantiate the loss function
if cfg.model.name in (
"diffusion",
"patched_diffusion",
"lt_aware_patched_diffusion",
):
loss_fn = ResidualLoss(
regression_net=regression_net,
hr_mean_conditioning=cfg.model.hr_mean_conditioning,
)
elif cfg.model.name == "regression":
loss_fn = RegressionLoss()
elif cfg.model.name == "lt_aware_ce_regression":
loss_fn = RegressionLossCE(prob_channels=prob_channels)
# Instantiate the optimizer
optimizer = torch.optim.Adam(
params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8
)
# Record the current time to measure the duration of subsequent operations.
start_time = time.time()
# Compute the number of required gradient accumulation rounds
# It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size
batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds(
cfg.training.hp.total_batch_size,
cfg.training.hp.batch_size_per_gpu,
dist.world_size,
)
batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu
logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds")
## Resume training from previous checkpoints if exists
if dist.world_size > 1:
torch.distributed.barrier()
try:
cur_nimg = load_checkpoint(
path=checkpoint_dir,
models=model,
optimizer=optimizer,
device=dist.device,
)
except:
cur_nimg = 0
############################################################################
# MAIN TRAINING LOOP #
############################################################################
logger0.info(f"Training for {cfg.training.hp.training_duration} images...")
done = False
# init variables to monitor running mean of average loss since last periodic
average_loss_running_mean = 0
n_average_loss_running_mean = 1
while not done:
tick_start_nimg = cur_nimg
tick_start_time = time.time()
# Compute & accumulate gradients
optimizer.zero_grad(set_to_none=True)
loss_accum = 0
for _ in range(num_accumulation_rounds):
img_clean, img_lr, *lead_time_label = next(dataset_iterator)
img_clean = img_clean.to(dist.device).to(torch.float32).contiguous()
img_lr = img_lr.to(dist.device).to(torch.float32).contiguous()
loss_fn_kwargs = {
"net": model,
"img_clean": img_clean,
"img_lr": img_lr,
"augment_pipe": None,
}
# Sample new random patches for this iteration and add patching to
# loss arguments
if patching is not None:
patching.reset_patch_indices()
loss_fn_kwargs.update({"patching": patching})
if lead_time_label:
lead_time_label = lead_time_label[0].to(dist.device).contiguous()
loss_fn_kwargs.update({"lead_time_label": lead_time_label})
else:
lead_time_label = None
with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp):
loss = loss_fn(**loss_fn_kwargs)
loss = loss.sum() / batch_size_per_gpu
loss_accum += loss / num_accumulation_rounds
loss.backward()
loss_sum = torch.tensor([loss_accum], device=dist.device)
if dist.world_size > 1:
torch.distributed.barrier()
torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM)
average_loss = (loss_sum / dist.world_size).cpu().item()
# update running mean of average loss since last periodic task
average_loss_running_mean += (
average_loss - average_loss_running_mean
) / n_average_loss_running_mean
n_average_loss_running_mean += 1
if dist.rank == 0:
writer.add_scalar("training_loss", average_loss, cur_nimg)
writer.add_scalar(
"training_loss_running_mean", average_loss_running_mean, cur_nimg
)
wandb.log(
{
"training_loss": average_loss,
"training_loss_running_mean": average_loss_running_mean,
}
)
ptt = is_time_for_periodic_task(
cur_nimg,
cfg.training.io.print_progress_freq,
done,
cfg.training.hp.total_batch_size,
dist.rank,
rank_0_only=True,
)
if ptt:
# reset running mean of average loss
average_loss_running_mean = 0
n_average_loss_running_mean = 1
# Update weights.
lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate
for g in optimizer.param_groups:
if lr_rampup > 0:
g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1)
if cur_nimg >= lr_rampup:
g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6)
current_lr = g["lr"]
if dist.rank == 0:
writer.add_scalar("learning_rate", current_lr, cur_nimg)
handle_and_clip_gradients(
model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold
)
optimizer.step()
cur_nimg += cfg.training.hp.total_batch_size
done = cur_nimg >= cfg.training.hp.training_duration
# Validation
if validation_dataset_iterator is not None:
valid_loss_accum = 0
if is_time_for_periodic_task(
cur_nimg,
cfg.training.io.validation_freq,
done,
cfg.training.hp.total_batch_size,
dist.rank,
):
with torch.no_grad():
for _ in range(cfg.training.io.validation_steps):
img_clean_valid, img_lr_valid, *lead_time_label_valid = next(
validation_dataset_iterator
)
img_clean_valid = (
img_clean_valid.to(dist.device)
.to(torch.float32)
.contiguous()
)
img_lr_valid = (
img_lr_valid.to(dist.device).to(torch.float32).contiguous()
)
loss_valid_kwargs = {
"net": model,
"img_clean": img_clean_valid,
"img_lr": img_lr_valid,
"augment_pipe": None,
}
if lead_time_label_valid:
lead_time_label_valid = (
lead_time_label_valid[0].to(dist.device).contiguous()
)
loss_valid_kwargs.update(
{"lead_time_label": lead_time_label_valid}
)
loss_valid = loss_fn(**loss_valid_kwargs)
loss_valid = (
(loss_valid.sum() / batch_size_per_gpu).cpu().item()
)
valid_loss_accum += (
loss_valid / cfg.training.io.validation_steps
)
valid_loss_sum = torch.tensor(
[valid_loss_accum], device=dist.device
)
if dist.world_size > 1:
torch.distributed.barrier()
torch.distributed.all_reduce(
valid_loss_sum, op=torch.distributed.ReduceOp.SUM
)
average_valid_loss = valid_loss_sum / dist.world_size
if dist.rank == 0:
writer.add_scalar(
"validation_loss", average_valid_loss, cur_nimg
)
wandb.log(
{
"validation_loss": average_valid_loss,
}
)
if is_time_for_periodic_task(
cur_nimg,
cfg.training.io.print_progress_freq,
done,
cfg.training.hp.total_batch_size,
dist.rank,
rank_0_only=True,
):
# Print stats if we crossed the printing threshold with this batch
tick_end_time = time.time()
fields = []
fields += [f"samples {cur_nimg:<9.1f}"]
fields += [f"training_loss {average_loss:<7.2f}"]
fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"]
fields += [f"learning_rate {current_lr:<7.8f}"]
fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"]
fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"]
fields += [
f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}"
]
fields += [
f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
]
fields += [
f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}"
]
fields += [
f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}"
]
logger0.info(" ".join(fields))
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Save checkpoints
if dist.world_size > 1:
torch.distributed.barrier()
if is_time_for_periodic_task(
cur_nimg,
cfg.training.io.save_checkpoint_freq,
done,
cfg.training.hp.total_batch_size,
dist.rank,
rank_0_only=True,
):
save_checkpoint(
path=checkpoint_dir,
models=model,
optimizer=optimizer,
epoch=cur_nimg,
)
# Retain only the recent n checkpoints, if desired
if cfg.training.io.save_n_recent_checkpoints > 0:
for suffix in [".mdlus", ".pt"]:
ckpts = checkpoint_list(checkpoint_dir, suffix=suffix)
while len(ckpts) > cfg.training.io.save_n_recent_checkpoints:
os.remove(os.path.join(checkpoint_dir, ckpts[0]))
ckpts = ckpts[1:]
# Done.
logger0.info("Training Completed.")
if __name__ == "__main__":
main()