diff --git a/cfg.py b/cfg.py index fe01a22b..1840c6ed 100644 --- a/cfg.py +++ b/cfg.py @@ -3,6 +3,7 @@ def parse_args(): parser = argparse.ArgumentParser() + parser.add_argument('-seed', type=int, default=24, help='random seed') parser.add_argument('-net', type=str, default='sam', help='net type') parser.add_argument('-baseline', type=str, default='unet', help='baseline net type') parser.add_argument('-encoder', type=str, default='default', help='encoder type') diff --git a/dataset/__init__.py b/dataset/__init__.py index 2ab92b9f..8ca08d80 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -24,6 +24,7 @@ def get_dataloader(args): transform_train = transforms.Compose([ transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(), + transforms.Lambda(lambda x: x * 255) ]) transform_train_seg = transforms.Compose([ @@ -34,6 +35,7 @@ def get_dataloader(args): transform_test = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), + transforms.Lambda(lambda x: x * 255) ]) transform_test_seg = transforms.Compose([ diff --git a/function.py b/function.py index 03197fcc..f6f054d6 100644 --- a/function.py +++ b/function.py @@ -45,7 +45,6 @@ GPUdevice = torch.device('cuda', args.gpu_device) pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2 criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) -seed = torch.randint(1,11,(args.b,7)) torch.backends.cudnn.benchmark = True loss_function = DiceCELoss(to_onehot_y=True, softmax=True) @@ -102,7 +101,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader, imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs) masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks) - showp = pt + showp = pt[..., [1, 0]] mask_type = torch.float32 ind += 1 @@ -115,7 +114,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader, coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice) labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice) if(len(point_labels.shape)==1): # only one point prompt - coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :] + coords_torch, labels_torch, showp = coords_torch.unsqueeze(1), labels_torch.unsqueeze(1), showp.unsqueeze(1) pt = (coords_torch, labels_torch) '''init''' @@ -145,6 +144,8 @@ def train_sam(args, net: nn.Module, optimizer, train_loader, for n, value in net.image_encoder.named_parameters(): value.requires_grad = True + origin_imgs = imgs.clone() + imgs = net.preprocess(imgs) imge= net.image_encoder(imgs) with torch.no_grad(): if args.net == 'sam' or args.net == 'mobile_sam': @@ -191,7 +192,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader, ) # Resize to the ordered output size - pred = F.interpolate(pred,size=(args.out_size,args.out_size)) + pred = F.interpolate(pred,size=(args.out_size,args.out_size), mode="bilinear", align_corners=False) loss = lossfunc(pred, masks) @@ -215,11 +216,11 @@ def train_sam(args, net: nn.Module, optimizer, train_loader, namecat = 'Train' for na in name[:2]: namecat = namecat + na.split('/')[-1].split('.')[0] + '+' - vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) + vis_image(origin_imgs/255,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) pbar.update() - return loss + return epoch_loss / len(train_loader) def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): # eval mode @@ -227,6 +228,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): mask_type = torch.float32 n_val = len(val_loader) # the number of batch + dataset_size = len(val_loader.dataset) ave_res, mix_res = (0,0,0,0), (0,)*args.multimask_output*2 rater_res = [(0,0,0,0) for _ in range(6)] tot = 0 @@ -244,6 +246,8 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): for ind, pack in enumerate(val_loader): imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice) masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice) + + cur_bsz = imgsw.shape[0] # for k,v in pack['image_meta_dict'].items(): # print(k) if 'pt' not in pack or args.thd: @@ -279,7 +283,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs) masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks) - showp = pt + showp = pt[..., [1, 0]] mask_type = torch.float32 ind += 1 @@ -292,7 +296,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice) labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice) if(len(point_labels.shape)==1): # only one point prompt - coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :] + coords_torch, labels_torch, showp = coords_torch.unsqueeze(1), labels_torch.unsqueeze(1), showp.unsqueeze(1) pt = (coords_torch, labels_torch) '''init''' @@ -303,6 +307,8 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): '''test''' with torch.no_grad(): + origin_imgs = imgs.clone() + imgs = net.preprocess(imgs) imge= net.image_encoder(imgs) if args.net == 'sam' or args.net == 'mobile_sam': se, de = net.prompt_encoder( @@ -348,8 +354,8 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): ) # Resize to the ordered output size - pred = F.interpolate(pred,size=(args.out_size,args.out_size)) - tot += lossfunc(pred, masks) + pred = F.interpolate(pred,size=(args.out_size,args.out_size), mode="bilinear", align_corners=False) + tot += lossfunc(pred, masks) * cur_bsz '''vis images''' if ind % args.vis == 0: @@ -359,10 +365,11 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): ]: img_name = na.split('/')[-1].split('.')[0] namecat = namecat + img_name + '+' - vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) + vis_image(origin_imgs/255,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) temp = eval_seg(pred, masks, threshold) + temp = tuple([number * cur_bsz for number in temp]) mix_res = tuple([sum(a) for a in zip(mix_res, temp)]) pbar.update() @@ -370,7 +377,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): if args.evl_chunk: n_val = n_val * (imgsw.size(-1) // evl_ch) - return tot/ n_val , tuple([a/n_val for a in mix_res]) + return tot/dataset_size, tuple([a / dataset_size for a in mix_res]) def transform_prompt(coord,label,h,w): coord = coord.transpose(0,1) diff --git a/train.py b/train.py index f752e24e..87c6bb72 100644 --- a/train.py +++ b/train.py @@ -39,6 +39,9 @@ def main(): args = cfg.parse_args() + seed = args.seed + set_seed(seed) + GPUdevice = torch.device('cuda', args.gpu_device) net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) @@ -96,7 +99,7 @@ def main(): for epoch in range(settings.EPOCH): - if epoch and epoch < 5: + if epoch < 5: if args.dataset != 'REFUGE': tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') diff --git a/utils.py b/utils.py index 1d7f3a9c..b5dce226 100644 --- a/utils.py +++ b/utils.py @@ -76,6 +76,17 @@ # optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) '''end''' +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ['PYTHONHASHSEED'] = str(seed) + def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True): """ return given network """ @@ -963,6 +974,9 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N if reverse == True: pred_masks = 1 - pred_masks gt_masks = 1 - gt_masks + else: + pred_masks = pred_masks.clone() + gt_masks = gt_masks.clone() if c == 2: # for REFUGE multi mask output pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) @@ -993,10 +1007,10 @@ def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = N else: ps = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int) # gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev))) - for p in ps: - gt_masks[i,0,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.5 - gt_masks[i,1,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.1 - gt_masks[i,2,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.4 + for p in ps[i]: + gt_masks[i,0,p[0]-5:p[0]+5,p[1]-5:p[1]+5] = 0.5 + gt_masks[i,1,p[0]-5:p[0]+5,p[1]-5:p[1]+5] = 0.1 + gt_masks[i,2,p[0]-5:p[0]+5,p[1]-5:p[1]+5] = 0.4 if boxes is not None: for i in range(b): # the next line causes: ValueError: Tensor uint8 expected, got torch.float32 @@ -1021,6 +1035,7 @@ def eval_seg(pred,true_mask_p,threshold): pred: [b,2,h,w] ''' b, c, h, w = pred.size() + pred = F.sigmoid(pred) if c == 2: iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0 for th in threshold: @@ -1160,6 +1175,7 @@ def random_click(mask, point_labels = 1): point_labels = max_label # max agreement position indices = np.argwhere(mask == max_label) + indices = indices[:, ::-1].copy() return point_labels, indices[np.random.randint(len(indices))]