Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-seed', type=int, default=24, help='random seed')
parser.add_argument('-net', type=str, default='sam', help='net type')
parser.add_argument('-baseline', type=str, default='unet', help='baseline net type')
parser.add_argument('-encoder', type=str, default='default', help='encoder type')
Expand Down
2 changes: 2 additions & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_dataloader(args):
transform_train = transforms.Compose([
transforms.Resize((args.image_size,args.image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])

transform_train_seg = transforms.Compose([
Expand All @@ -34,6 +35,7 @@ def get_dataloader(args):
transform_test = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])

transform_test_seg = transforms.Compose([
Expand Down
31 changes: 19 additions & 12 deletions function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1,11,(args.b,7))

torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
Expand Down Expand Up @@ -102,7 +101,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,

imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
showp = pt[..., [1, 0]]

mask_type = torch.float32
ind += 1
Expand All @@ -115,7 +114,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
if(len(point_labels.shape)==1): # only one point prompt
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :]
coords_torch, labels_torch, showp = coords_torch.unsqueeze(1), labels_torch.unsqueeze(1), showp.unsqueeze(1)
pt = (coords_torch, labels_torch)

'''init'''
Expand Down Expand Up @@ -145,6 +144,8 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
for n, value in net.image_encoder.named_parameters():
value.requires_grad = True

origin_imgs = imgs.clone()
imgs = net.preprocess(imgs)
imge= net.image_encoder(imgs)
with torch.no_grad():
if args.net == 'sam' or args.net == 'mobile_sam':
Expand Down Expand Up @@ -191,7 +192,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
)

# Resize to the ordered output size
pred = F.interpolate(pred,size=(args.out_size,args.out_size))
pred = F.interpolate(pred,size=(args.out_size,args.out_size), mode="bilinear", align_corners=False)

loss = lossfunc(pred, masks)

Expand All @@ -215,18 +216,19 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
namecat = 'Train'
for na in name[:2]:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image(origin_imgs/255,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)

pbar.update()

return loss
return epoch_loss / len(train_loader)

def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# eval mode
net.eval()

mask_type = torch.float32
n_val = len(val_loader) # the number of batch
dataset_size = len(val_loader.dataset)
ave_res, mix_res = (0,0,0,0), (0,)*args.multimask_output*2
rater_res = [(0,0,0,0) for _ in range(6)]
tot = 0
Expand All @@ -244,6 +246,8 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice)

cur_bsz = imgsw.shape[0]
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack or args.thd:
Expand Down Expand Up @@ -279,7 +283,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)

showp = pt
showp = pt[..., [1, 0]]

mask_type = torch.float32
ind += 1
Expand All @@ -292,7 +296,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
if(len(point_labels.shape)==1): # only one point prompt
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :]
coords_torch, labels_torch, showp = coords_torch.unsqueeze(1), labels_torch.unsqueeze(1), showp.unsqueeze(1)
pt = (coords_torch, labels_torch)

'''init'''
Expand All @@ -303,6 +307,8 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):

'''test'''
with torch.no_grad():
origin_imgs = imgs.clone()
imgs = net.preprocess(imgs)
imge= net.image_encoder(imgs)
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
Expand Down Expand Up @@ -348,8 +354,8 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
)

# Resize to the ordered output size
pred = F.interpolate(pred,size=(args.out_size,args.out_size))
tot += lossfunc(pred, masks)
pred = F.interpolate(pred,size=(args.out_size,args.out_size), mode="bilinear", align_corners=False)
tot += lossfunc(pred, masks) * cur_bsz

'''vis images'''
if ind % args.vis == 0:
Expand All @@ -359,18 +365,19 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
]:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image(origin_imgs/255,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)


temp = eval_seg(pred, masks, threshold)
temp = tuple([number * cur_bsz for number in temp])
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])

pbar.update()

if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)

return tot/ n_val , tuple([a/n_val for a in mix_res])
return tot/dataset_size, tuple([a / dataset_size for a in mix_res])

def transform_prompt(coord,label,h,w):
coord = coord.transpose(0,1)
Expand Down
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def main():

args = cfg.parse_args()

seed = args.seed
set_seed(seed)

GPUdevice = torch.device('cuda', args.gpu_device)

net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
Expand Down Expand Up @@ -96,7 +99,7 @@ def main():

for epoch in range(settings.EPOCH):

if epoch and epoch < 5:
if epoch < 5:
if args.dataset != 'REFUGE':
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
Expand Down
24 changes: 20 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@
# optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999))
'''end'''

def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)

def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True):
""" return given network
"""
Expand Down Expand Up @@ -963,6 +974,9 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N
if reverse == True:
pred_masks = 1 - pred_masks
gt_masks = 1 - gt_masks
else:
pred_masks = pred_masks.clone()
gt_masks = gt_masks.clone()
if c == 2: # for REFUGE multi mask output
pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
Expand Down Expand Up @@ -993,10 +1007,10 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N
else:
ps = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int)
# gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev)))
for p in ps:
gt_masks[i,0,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.5
gt_masks[i,1,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.1
gt_masks[i,2,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.4
for p in ps[i]:
gt_masks[i,0,p[0]-5:p[0]+5,p[1]-5:p[1]+5] = 0.5
gt_masks[i,1,p[0]-5:p[0]+5,p[1]-5:p[1]+5] = 0.1
gt_masks[i,2,p[0]-5:p[0]+5,p[1]-5:p[1]+5] = 0.4
if boxes is not None:
for i in range(b):
# the next line causes: ValueError: Tensor uint8 expected, got torch.float32
Expand All @@ -1021,6 +1035,7 @@ def eval_seg(pred,true_mask_p,threshold):
pred: [b,2,h,w]
'''
b, c, h, w = pred.size()
pred = F.sigmoid(pred)
if c == 2:
iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0
for th in threshold:
Expand Down Expand Up @@ -1160,6 +1175,7 @@ def random_click(mask, point_labels = 1):
point_labels = max_label
# max agreement position
indices = np.argwhere(mask == max_label)
indices = indices[:, ::-1].copy()
return point_labels, indices[np.random.randint(len(indices))]


Expand Down