Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwinUNETR/Pretrain: Fix resume mechanism #120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions SwinUNETR/Pretrain/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,18 @@ def train(args, global_step, train_loader, val_best, scaler):
writer.add_image("Validation/x1_aug", img_list[1], global_step, dataformats="HW")
writer.add_image("Validation/x1_recon", img_list[2], global_step, dataformats="HW")

checkpoint = {
"global_step": global_step,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"val_best": val_best,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()

if val_loss_recon < val_best:
val_best = val_loss_recon
checkpoint = {
"global_step": global_step,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
save_ckp(checkpoint, logdir + "/model_bestValRMSE.pt")
print(
"Model was saved ! Best Recon. Val Loss: {:.4f}, Recon. Val Loss: {:.4f}".format(
Expand All @@ -110,6 +115,8 @@ def train(args, global_step, train_loader, val_best, scaler):
val_best, val_loss_recon
)
)

save_ckp(checkpoint, logdir + "/last.pt")
return global_step, loss, val_best

def validation(args, test_loader):
Expand Down Expand Up @@ -234,13 +241,6 @@ def validation(args, test_loader):
elif args.opt == "sgd":
optimizer = optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay)

if args.resume:
model_pth = args.resume
model_dict = torch.load(model_pth)
model.load_state_dict(model_dict["state_dict"])
model.epoch = model_dict["epoch"]
model.optimizer = model_dict["optimizer"]

if args.lrdecay:
if args.lr_schedule == "warmup_cosine":
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=args.num_steps)
Expand All @@ -252,21 +252,45 @@ def lambdas(epoch):

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas)

if args.amp:
scaler = GradScaler(1024)
else:
scaler = None

global_step = 0
best_val = 1e8
if args.resume:
model_pth = args.resume
model_dict = torch.load(model_pth)
model.load_state_dict({k[7:]: v for k, v in model_dict["state_dict"].items()})
optimizer.load_state_dict(model_dict["optimizer"])
global_step = model_dict["global_step"]
if "scaler" in model_dict:
scaler.load_state_dict(model_dict["scaler"])
if "val_best" in model_dict:
best_val = model_dict["val_best"]
if "scheduler" in model_dict:
scheduler.load_state_dict(model_dict["scheduler"])
else:
scheduler.last_epoch = global_step - 1

loss_function = Loss(args.batch_size * args.sw_batch_size, args)
if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[args.local_rank])
train_loader, test_loader = get_loader(args)

global_step = 0
best_val = 1e8
if args.amp:
scaler = GradScaler()
else:
scaler = None
while global_step < args.num_steps:
global_step, loss, best_val = train(args, global_step, train_loader, best_val, scaler)
checkpoint = {"epoch": args.epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
checkpoint = {
"global_step": global_step,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"val_best": best_val
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()

if args.distributed:
if dist.get_rank() == 0:
Expand Down