|
| 1 | +import torch |
| 2 | + |
| 3 | +from torchvision import transforms |
| 4 | +from Dataset import data_folder as fd |
| 5 | + |
| 6 | + |
| 7 | +def load_data(args,crop_height =112, crop_width = 112,height = 112,width = 112,num_workers=4,dataset_mode="vehicle_logo"): |
| 8 | + train_loader = None |
| 9 | + test_loader = None |
| 10 | + val_loader = None |
| 11 | + |
| 12 | + if dataset_mode == "vehicle_logo": |
| 13 | + # Data transforms |
| 14 | + mean = [0.5071, 0.4867, 0.4408] |
| 15 | + stdv = [0.2675, 0.2565, 0.2761] |
| 16 | + transform_train = transforms.Compose([ |
| 17 | + transforms.RandomCrop((crop_height,crop_width)), |
| 18 | + transforms.RandomHorizontalFlip(), |
| 19 | + transforms.ToTensor(), |
| 20 | + transforms.Normalize(mean=mean, std=stdv), |
| 21 | + ]) |
| 22 | + |
| 23 | + #center cropping will be disabled in the testing phase |
| 24 | + transform_test = transforms.Compose([ |
| 25 | + # transforms.CenterCrop((crop_height, crop_width)), |
| 26 | + transforms.ToTensor(), |
| 27 | + transforms.Normalize(mean=mean, std=stdv), |
| 28 | + ]) |
| 29 | + |
| 30 | + print(args.train_dir) |
| 31 | + trainDataset = fd.ImageFolder(root = args.train_dir, transform = transform_train,height = height,width = width) |
| 32 | + train_loader = torch.utils.data.DataLoader( |
| 33 | + trainDataset, |
| 34 | + batch_size=args.batch_size, |
| 35 | + shuffle=True, |
| 36 | + num_workers=num_workers |
| 37 | + ) |
| 38 | + |
| 39 | + testDataset = fd.ImageFolder(root=args.test_dir, transform=transform_test,height = height,width = width) |
| 40 | + test_loader = torch.utils.data.DataLoader( |
| 41 | + testDataset, |
| 42 | + batch_size=args.batch_size, |
| 43 | + shuffle=False, |
| 44 | + num_workers=num_workers |
| 45 | + ) |
| 46 | + |
| 47 | + valDataset = fd.ImageFolder(root=args.val_dir, transform=transform_test,height = height,width = width) |
| 48 | + val_loader = torch.utils.data.DataLoader( |
| 49 | + valDataset, |
| 50 | + batch_size=args.batch_size, |
| 51 | + shuffle=False, |
| 52 | + num_workers=num_workers |
| 53 | + ) |
| 54 | + |
| 55 | + elif args.dataset_mode == "CCML_vehicle_logo": |
| 56 | + # Data transforms |
| 57 | + mean = [0.5071, 0.4867, 0.4408] |
| 58 | + stdv = [0.2675, 0.2565, 0.2761] |
| 59 | + transform_train = transforms.Compose([ |
| 60 | + transforms.ToTensor(), |
| 61 | + transforms.Normalize(mean=mean, std=stdv), |
| 62 | + ]) |
| 63 | + |
| 64 | + #center cropping will be disabled in the testing phase |
| 65 | + transform_test = transforms.Compose([ |
| 66 | + # transforms.CenterCrop((crop_height, crop_width)), |
| 67 | + transforms.ToTensor(), |
| 68 | + transforms.Normalize(mean=mean, std=stdv), |
| 69 | + ]) |
| 70 | + |
| 71 | + mask_tran = transforms.Compose([ |
| 72 | + transforms.ToTensor(), |
| 73 | + ]) |
| 74 | + |
| 75 | + print(args.train_dir) |
| 76 | + trainDataset = fd.CCML_Train_ImageFolder(root = args.train_dir,mask_path =args.mask_path, |
| 77 | + train_type=True, transform = transform_train,mask_transform=mask_tran, |
| 78 | + crop_height = crop_height,crop_width=crop_width,height = height ,width = width ) |
| 79 | + train_loader = torch.utils.data.DataLoader( |
| 80 | + trainDataset, |
| 81 | + batch_size=args.batch_size, |
| 82 | + shuffle=True, |
| 83 | + num_workers=num_workers |
| 84 | + ) |
| 85 | + #center cropping will be disabled in the testing phase |
| 86 | + testDataset = fd.CCML_Test_ImageFolder(root=args.test_dir, mask_path=args.mask_path, transform=transform_test, mask_transform=mask_tran, |
| 87 | + crop_height=height, crop_width=width, height=height,width=width) |
| 88 | + |
| 89 | + test_loader = torch.utils.data.DataLoader( |
| 90 | + testDataset, |
| 91 | + batch_size=args.batch_size, |
| 92 | + shuffle=False, |
| 93 | + num_workers=num_workers |
| 94 | + ) |
| 95 | + #center cropping will be disabled in the validation phase |
| 96 | + valDataset = fd.CCML_Train_ImageFolder(root=args.val_dir,mask_path =args.mask_path, |
| 97 | + train_type=False, transform=transform_test, |
| 98 | + crop_height = height,crop_width=width,height = height ,width =width ) |
| 99 | + val_loader = torch.utils.data.DataLoader( |
| 100 | + valDataset, |
| 101 | + batch_size=args.batch_size, |
| 102 | + shuffle=False, |
| 103 | + num_workers=num_workers |
| 104 | + ) |
| 105 | + |
| 106 | + return train_loader, test_loader,val_loader |
0 commit comments