Skip to content

Commit

Permalink
3D Coronary artery centerline extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
BubblyYi committed Aug 7, 2020
0 parents commit 166925f
Show file tree
Hide file tree
Showing 48 changed files with 19,254 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Created by .ignore support plugin (hsz.mobi)
### Example user template template
### Example user template

# IntelliJ project files
.idea
*.iml
out
gen
5 changes: 5 additions & 0 deletions centerline_train_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -*- coding: UTF-8 -*-
# @Time : 06/08/2020 11:34
# @Author : BubblyYi
# @FileName: __init__.py.py
# @Software: PyCharm
68 changes: 68 additions & 0 deletions centerline_train_tools/centerline_train_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: UTF-8 -*-
# @Time : 14/05/2020 17:56
# @Author : BubblyYi
# @FileName: train_tools.py
# @Software: PyCharm
import sys
sys.path.append('..')
from models.centerline_net import CenterlineNet
from data_provider_argu import DataGenerater
from centerline_trainner import Trainer
import torch
def get_dataset(save_num = 0):
'''
:return: train set,val set
'''
train_data_info_path = "/Coronary-Artery-Tracking-via-3D-CNN-Classification/data_process_tools/patch_data/centerline_patch/train_save_d"+str(save_num)+"_train.csv"
train_pre_fix_path = "/data_process_tools/patch_data/"
train_flag = 'train'
train_transforms = None
target_transform = None
train_dataset = DataGenerater(train_data_info_path, train_pre_fix_path, 500, train_transforms, train_flag, target_transform)

val_data_info_path = "/Coronary-Artery-Tracking-via-3D-CNN-Classification/data_process_tools/patch_data/centerline_patch/train_save_d"+str(save_num)+"_val.csv"
val_pre_fix_path = "/data_process_tools/patch_data/"
val_flag = 'val'
test_valid_transforms = None
target_transform = None
val_dataset = DataGenerater(val_data_info_path, val_pre_fix_path, 500, test_valid_transforms, val_flag, target_transform)

return train_dataset, val_dataset


def cross_entropy(a, y):
epsilon = 1e-9
return torch.mean(torch.sum(-y * torch.log10(a + epsilon) - (1 - y) * torch.log10(1 - a + epsilon), dim=1))

if __name__ == '__main__':

# Here we use 8 fold cross validation, save_num means to use dataset0x as the validation set
save_num = 1
train_dataset, val_dataset = get_dataset(save_num)

curr_model_name = "centerline_net"
max_points = 500
model = CenterlineNet(n_classes = max_points)

batch_size = 64
num_workers = 16

criterion = cross_entropy
inital_lr = 0.001

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=inital_lr,weight_decay=0.001)

trainer = Trainer(batch_size,
num_workers,
train_dataset,
val_dataset,
model,
curr_model_name,
optimizer,
criterion,
max_points,
save_num = save_num,
start_epoch=0,
max_epoch=100,
initial_lr=inital_lr)
trainer.run_train()
244 changes: 244 additions & 0 deletions centerline_train_tools/centerline_trainner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# -*- coding: UTF-8 -*-
# @Time : 14/05/2020 16:48
# @Author : BubblyYi
# @FileName: trainner.py
# @Software: PyCharm

import os
import matplotlib
matplotlib.use('AGG')
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from time import time
import sys
from datetime import datetime
class Trainer(object):
def __init__(self, batch_size, num_workers, train_dataset, val_dataset, model, model_name, optimizer, criterion,max_points=500, save_num = 0,start_epoch=0, max_epoch=1000, initial_lr=0.01, checkpoint_path=None):

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.num_workers = num_workers
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.model = model
self.model_name = model_name
self.optimizer = optimizer
self.initial_lr = initial_lr
self.max_points = max_points
self.criterion = criterion
self.criterion_2 = torch.nn.MSELoss()
self.all_tr_loss = []
self.all_val_loss = []

self.all_tr_direction_loss = []
self.all_val_direction_loss = []

self.all_tr_radius_loss = []
self.all_val_radius_loss = []

self.all_tr_err = []
self.all_val_err = []

self.best_test_loss = 2**31
self.log_file = None

self.start_epoch = start_epoch
self.max_epoch = max_epoch

self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
self.checkpoint_path = checkpoint_path
self.output_folder = "logs"
if not os.path.exists(self.output_folder):
os.makedirs(self.output_folder)
self.save_num = save_num

def train_step(self, epoch):
self.print_to_log_file("\nEpoch: ", epoch + 1)
self.model.train()
training_loss = 0.
train_loss = 0.
train_direction_loss = 0.
train_radius_loss = 0.
correct = 0
total = 0
punishment_factor = 15
for idx, (inputs, labels, r) in enumerate(self.train_loader):
inputs, labels, r = inputs.to(self.device), labels.to(self.device), r.to(self.device)
outputs = self.model(inputs)
outputs = outputs.view((len(labels),self.max_points+1))
outputs_1 = outputs[:,:len(outputs[0])-1]
outputs_2 = outputs[:,-1]
outputs_1 = torch.nn.functional.softmax(outputs_1,1)
loss_1 = self.criterion(outputs_1.float(), labels.float())
loss_2 = self.criterion_2(outputs_2.float(), r.float())
train_direction_loss+=loss_1.item()
train_radius_loss+=loss_2.item()
loss = loss_1+punishment_factor*loss_2
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
training_loss += loss.item()
train_loss += loss.item()
total += labels.size(0)

print_str = "Train Loss:{:.5f} Direction Train Loss:{:.5f} Radius Train Loss:{:.5f}".format(training_loss / len(self.train_loader),
train_direction_loss / len(self.train_loader),
train_radius_loss / len(self.train_loader))

self.print_to_log_file(print_str)

return train_loss / len(self.train_loader), 1. - correct / total, train_direction_loss/ len(self.train_loader),train_radius_loss/ len(self.train_loader)

# 验证模型
def val_step(self, epoch):
self.model.eval()
test_loss = 0.
val_direction_loss = 0.
val_radius_loss = 0.
correct = 0
total = 0
punishment_factor = 15
if True:
for idx, (inputs, labels, r) in enumerate(self.val_loader):
inputs, labels, r = inputs.to(self.device), labels.to(self.device), r.to(self.device)
outputs = self.model(inputs)

outputs = outputs.view((len(labels), self.max_points+1))
outputs_1 = outputs[:, :len(outputs[0]) - 1]
outputs_2 = outputs[:, -1]
outputs_1 = torch.nn.functional.softmax(outputs_1,1)
loss_1 = self.criterion(outputs_1.float(), labels.float())
loss_2 = self.criterion_2(outputs_2.float(), r.float())
val_direction_loss+=loss_1.item()
val_radius_loss+=loss_2.item()
loss = loss_1+punishment_factor*loss_2
test_loss += loss.item()
total += labels.size(0)
print_str = "Val Loss:{:.5f} Direction Val Loss:{:.5f} Radius Val Loss:{:.5f}".format(test_loss / len(self.val_loader),
val_direction_loss / len(self.val_loader),
val_radius_loss / len(self.val_loader))

self.print_to_log_file(print_str)
print("test loss",test_loss/len(self.val_loader))
print("best test loss", self.best_test_loss)
if (test_loss/len(self.val_loader)) < self.best_test_loss:
print("saving models")
self.best_test_loss = test_loss/len(self.val_loader)
save_fold = "../checkpoint/classification_checkpoints"
if not os.path.exists(save_fold):
os.makedirs(save_fold)
model_save_path = save_fold+"/"+ self.model_name + "_model_s"+str(self.save_num)+".pkl"
self.save_best_checkpoint(model_save_path, test_loss, epoch)
print_str = "Saving parameters to " + model_save_path
self.print_to_log_file(print_str)

return test_loss / len(self.val_loader), 1. - correct / total, val_direction_loss/ len(self.val_loader), val_radius_loss/ len(self.val_loader)

def poly_lr(self, epoch, max_epochs, initial_lr, exponent=0.9):
return initial_lr * (1 - epoch / max_epochs) ** exponent

def lr_decay(self, epoch, max_epochs, initial_lr):
for params in self.optimizer.param_groups:
params['lr'] = self.poly_lr(epoch, max_epochs, initial_lr, exponent=1.5)
lr = params['lr']
print_str = "Learning rate adjusted to {}".format(lr)
self.print_to_log_file(print_str)

def plot_progress(self, epoch):

x_epoch = list(range(len(self.all_tr_loss)))
plt.plot(x_epoch, self.all_tr_direction_loss, color="b", linestyle="--", marker="*", label='train')
plt.plot(x_epoch, self.all_val_direction_loss, color="r", linestyle="--", marker="*", label='val')
plt.legend()
plt.rcParams['savefig.dpi'] = 300 #图片像素"
plt.rcParams['figure.dpi'] = 300 #分辨率"
plt.savefig("Direction_loss_"+str(self.save_num)+".jpg")
plt.close()

plt.plot(x_epoch, self.all_tr_radius_loss, color="b", linestyle="--", marker="*", label='train')
plt.plot(x_epoch, self.all_val_radius_loss, color="r", linestyle="--", marker="*", label='val')
plt.legend()
plt.rcParams['savefig.dpi'] = 300 # 图片像素"
plt.rcParams['figure.dpi'] = 300 # 分辨率"
plt.savefig("Radius_loss_s"+str(self.save_num)+".jpg")
plt.close()

plt.plot(x_epoch, self.all_tr_loss, color="b", linestyle="--", marker="*", label='train')
plt.plot(x_epoch, self.all_val_loss, color="r", linestyle="--", marker="*", label='val')
plt.legend()
plt.rcParams['savefig.dpi'] = 300 # 图片像素"
plt.rcParams['figure.dpi'] = 300 # 分辨率"
plt.savefig("Total_loss"+str(self.save_num)+".jpg")
plt.close()


def save_best_checkpoint(self, model_save_path, acc, epoch):
checkpoint = {
'net_dict': self.model.state_dict(),
'acc': acc,
'epoch': epoch,
'optimizer_state_dict': self.optimizer.state_dict(),
'batch_size': self.batch_size,
'train_loss': self.all_tr_loss,
'train_err': self.all_tr_err,
'val_loss': self.all_val_loss,
'val_err': self.all_val_err,
'initial_lr': self.initial_lr
}
torch.save(checkpoint, model_save_path)

def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):

timestamp = time()
dt_object = datetime.fromtimestamp(timestamp)

if add_timestamp:
args = ("%s:" % dt_object, *args)

if self.log_file is None:
if not os.path.isdir(self.output_folder):
os.mkdir(self.output_folder)
timestamp = datetime.now()
self.log_file = os.path.join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
(timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
timestamp.second))
with open(self.log_file, 'w') as f:
f.write("Starting... \n")
successful = False
max_attempts = 5
ctr = 0
while not successful and ctr < max_attempts:
try:
with open(self.log_file, 'a+') as f:
for a in args:
f.write(str(a))
f.write(" ")
f.write("\n")
successful = True
except IOError:
print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
ctr += 1
if also_print_to_console:
print(*args)

def run_train(self):
print("Start training")
self.model.to(self.device)
for epoch in range(self.start_epoch, self.max_epoch):
train_loss, train_err,train_d_loss,train_r_loss = self.train_step(epoch)
val_loss, val_err,val_d_loss,val_r_loss = self.val_step(epoch)
self.all_tr_loss.append(train_loss)
self.all_tr_err.append(train_err)
self.all_val_loss.append(val_loss)
self.all_val_err.append(val_err)
self.all_tr_direction_loss.append(train_d_loss)
self.all_tr_radius_loss.append(train_r_loss)
self.all_val_direction_loss.append(val_d_loss)
self.all_val_radius_loss.append(val_r_loss)
self.plot_progress(epoch)
self.lr_decay(epoch, self.max_epoch, self.initial_lr)



Loading

0 comments on commit 166925f

Please sign in to comment.