-
Notifications
You must be signed in to change notification settings - Fork 318
/
Copy pathtrain_transolver_darcy.py
177 lines (157 loc) · 6.78 KB
/
train_transolver_darcy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hydra
from omegaconf import DictConfig
from math import ceil
from torch.nn import MSELoss
from utils.testloss import TestLoss
from torch.optim import Adam, lr_scheduler
from physicsnemo.models.transolver import Transolver
from physicsnemo.datapipes.benchmarks.darcy import Darcy2D
from physicsnemo.distributed import DistributedManager
from physicsnemo.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
from physicsnemo.launch.logging import PythonLogger, LaunchLogger
from physicsnemo.launch.logging.mlflow import initialize_mlflow
from validator import GridValidator
@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml")
def darcy_trainer(cfg: DictConfig) -> None:
"""Training for the 2D Darcy flow benchmark problem."""
DistributedManager.initialize() # Only call this once in the entire script!
dist = DistributedManager() # call if required elsewhere
# initialize monitoring
log = PythonLogger(name="darcy_transolver")
log.file_logging()
initialize_mlflow(
experiment_name=f"Darcy_Transolver",
experiment_desc=f"training a Transformer-based PDE solver for the Darcy problem",
run_name=f"Darcy Transolver training",
run_desc=f"training Transolver for Darcy",
user_name="Haixu Wu, Huakun Luo, Haowen Wang",
mode="offline",
)
LaunchLogger.initialize(use_mlflow=True) # PhysicsNeMo launch logger
# define model, loss, optimiser, scheduler, data loader
model = Transolver(
space_dim=cfg.model.space_dim,
n_layers=cfg.model.n_layers,
n_hidden=cfg.model.n_hidden,
dropout=cfg.model.dropout,
n_head=cfg.model.n_head,
Time_Input=cfg.model.Time_Input,
act=cfg.model.act,
mlp_ratio=cfg.model.mlp_ratio,
fun_dim=cfg.model.fun_dim,
out_dim=cfg.model.out_dim,
slice_num=cfg.model.slice_num,
ref=cfg.model.ref,
unified_pos=cfg.model.unified_pos,
H=cfg.training.resolution,
W=cfg.training.resolution,
).to(dist.device)
loss_fun = TestLoss(size_average=False)
optimizer = Adam(model.parameters(), lr=cfg.scheduler.initial_lr)
scheduler = lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: cfg.scheduler.decay_rate**step
)
norm_vars = cfg.normaliser
normaliser = {
"permeability": (norm_vars.permeability.mean, norm_vars.permeability.std_dev),
"darcy": (norm_vars.darcy.mean, norm_vars.darcy.std_dev),
}
dataloader = Darcy2D(
resolution=cfg.training.resolution,
batch_size=cfg.training.batch_size,
normaliser=normaliser,
)
validator = GridValidator(loss_fun=TestLoss(size_average=False), norm=normaliser)
ckpt_args = {
"path": f"./checkpoints",
"optimizer": optimizer,
"scheduler": scheduler,
"models": model,
}
loaded_pseudo_epoch = load_checkpoint(device=dist.device, **ckpt_args)
# calculate steps per pseudo epoch
steps_per_pseudo_epoch = ceil(
cfg.training.pseudo_epoch_sample_size / cfg.training.batch_size
)
validation_iters = ceil(cfg.validation.sample_size / cfg.training.batch_size)
log_args = {
"name_space": "train",
"num_mini_batch": steps_per_pseudo_epoch,
"epoch_alert_freq": 1,
}
if cfg.training.pseudo_epoch_sample_size % cfg.training.batch_size != 0:
log.warning(
f"increased pseudo_epoch_sample_size to multiple of \
batch size: {steps_per_pseudo_epoch*cfg.training.batch_size}"
)
if cfg.validation.sample_size % cfg.training.batch_size != 0:
log.warning(
f"increased validation sample size to multiple of \
batch size: {validation_iters*cfg.training.batch_size}"
)
# define forward passes for training and inference
@StaticCaptureTraining(
model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False
)
def forward_train(invars, target):
pred = model(invars)
loss = loss_fun(pred, target)
return loss
@StaticCaptureEvaluateNoGrad(
model=model, logger=log, use_amp=False, use_graphs=False
)
def forward_eval(invars):
return model(invars)
if loaded_pseudo_epoch == 0:
log.success("Training started...")
else:
log.warning(f"Resuming training from pseudo epoch {loaded_pseudo_epoch+1}.")
for pseudo_epoch in range(
max(1, loaded_pseudo_epoch + 1), cfg.training.max_pseudo_epochs + 1
):
# Wrap epoch in launch logger for console / MLFlow logs
with LaunchLogger(**log_args, epoch=pseudo_epoch) as logger:
for _, batch in zip(range(steps_per_pseudo_epoch), dataloader):
loss = forward_train(batch["permeability"], batch["darcy"])
logger.log_minibatch({"loss": loss.detach()})
logger.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
# save checkpoint
if pseudo_epoch % cfg.training.rec_results_freq == 0:
save_checkpoint(**ckpt_args, epoch=pseudo_epoch)
# validation step
if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0:
with LaunchLogger("valid", epoch=pseudo_epoch) as logger:
total_loss = 0.0
for _, batch in zip(range(validation_iters), dataloader):
val_loss = validator.compare(
batch["permeability"],
batch["darcy"],
forward_eval(batch["permeability"]),
pseudo_epoch,
logger,
)
total_loss += val_loss
logger.log_epoch({"Validation error": total_loss / validation_iters})
# update learning rate
if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0:
scheduler.step()
save_checkpoint(**ckpt_args, epoch=cfg.training.max_pseudo_epochs)
log.success("Training completed *yay*")
if __name__ == "__main__":
darcy_trainer()