Skip to content

Commit b5b10ba

Browse files
authored
Merge pull request #13 from buptlihang/master
add coco
2 parents e0ac39c + 61639b8 commit b5b10ba

File tree

1 file changed

+30
-59
lines changed

1 file changed

+30
-59
lines changed

data/dataloader.py

+30-59
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,33 @@ def init_map_mask(img, bboxes, shrink_ratio, thresh_min, thresh_max, tags=None):
235235
threshold_map = threshold_map * (thresh_max - thresh_min) + thresh_min
236236
return shrink_map, shrink_mask, threshold_map, threshold_mask
237237

238+
def augmentation(img, bboxes, shrink_ratio, thresh_min, thresh_max, img_size, tags=None):
239+
img = random_scale(img, img_size[0])
240+
241+
shrink_map, shrink_mask, threshold_map, threshold_mask = init_map_mask(img, bboxes, shrink_ratio, thresh_min, thresh_max, tags)
242+
243+
imgs = [img, shrink_map, shrink_mask, threshold_map, threshold_mask]
244+
245+
imgs = random_horizontal_flip(imgs)
246+
imgs = random_rotate(imgs)
247+
imgs = random_crop(imgs, img_size)
248+
249+
img, shrink_map, shrink_mask, threshold_map, threshold_mask = imgs[0], imgs[1], imgs[2], imgs[3], imgs[4]
250+
251+
img = Image.fromarray(img)
252+
img = img.convert('RGB')
253+
img = transforms.ColorJitter(brightness = 32.0 / 255, saturation = 0.5)(img)
254+
255+
img = transforms.ToTensor()(img)
256+
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
257+
258+
shrink_map = torch.from_numpy(shrink_map).float()
259+
shrink_mask = torch.from_numpy(shrink_mask).float()
260+
threshold_map = torch.from_numpy(threshold_map).float()
261+
threshold_mask = torch.from_numpy(threshold_mask).float()
262+
263+
return img, [shrink_map, shrink_mask, threshold_map, threshold_mask]
264+
238265
class DBTrainCocoDataset(CocoCVContoursDataset):
239266
def __init__(self, deepvac_config, sample_path_prefix, target_path, img_size):
240267
super(DBTrainCocoDataset, self).__init__(deepvac_config, sample_path_prefix, target_path)
@@ -253,55 +280,26 @@ def __getitem__(self, index):
253280
bbox = bbox.split(',')[:-1]
254281
bbox = [np.int(i) for i in bbox]
255282
bboxes[idx] = np.asarray(bbox) / ([w * 1.0, h * 1.0] * (len(bbox)//2))
256-
257-
img = random_scale(img, self.img_size[0])
258-
259-
shrink_map, shrink_mask, threshold_map, threshold_mask = init_map_mask(img, bboxes, self.shrink_ratio, self.thresh_min, self.thresh_max)
260-
261-
imgs = [img, shrink_map, shrink_mask, threshold_map, threshold_mask]
262-
263-
imgs = random_horizontal_flip(imgs)
264-
imgs = random_rotate(imgs)
265-
imgs = random_crop(imgs, self.img_size)
266-
267-
img, shrink_map, shrink_mask, threshold_map, threshold_mask = imgs[0], imgs[1], imgs[2], imgs[3], imgs[4]
268-
269-
img = Image.fromarray(img)
270-
img = img.convert('RGB')
271-
img = transforms.ColorJitter(brightness = 32.0 / 255, saturation = 0.5)(img)
272-
273-
img = transforms.ToTensor()(img)
274-
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
275-
276-
shrink_map = torch.from_numpy(shrink_map).float()
277-
shrink_mask = torch.from_numpy(shrink_mask).float()
278-
threshold_map = torch.from_numpy(threshold_map).float()
279-
threshold_mask = torch.from_numpy(threshold_mask).float()
280-
281-
return img, [shrink_map, shrink_mask, threshold_map, threshold_mask]
283+
return augmentation(img, bboxes, self.shrink_ratio, self.thresh_min, self.thresh_max, self.img_size)
282284

283285
class DBTrainDataset(DatasetBase):
284-
def __init__(self, deepvac_config, sample_path, label_path, is_transform, img_size):
286+
def __init__(self, deepvac_config, sample_path, label_path, img_size):
285287
super(DBTrainDataset, self).__init__(deepvac_config)
286288
self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size)
287289
data_dirs = [sample_path]
288290
gt_dirs = [label_path]
289291
self.img_paths = []
290292
self.gt_paths = []
291-
292293
for data_dir, gt_dir in zip(data_dirs, gt_dirs):
293294
img_names = os.listdir(data_dir)
294-
295295
img_paths = []
296296
gt_paths = []
297297
for idx, img_name in enumerate(img_names):
298298
img_path = data_dir + img_name
299299
img_paths.append(img_path)
300-
301300
gt_name = 'gt_' + img_name[:-4] + '.txt'
302301
gt_path = gt_dir + gt_name
303302
gt_paths.append(gt_path)
304-
305303
self.img_paths.extend(img_paths)
306304
self.gt_paths.extend(gt_paths)
307305

@@ -317,35 +315,9 @@ def __len__(self):
317315
def __getitem__(self, index):
318316
img_path = self.img_paths[index]
319317
gt_path = self.gt_paths[index]
320-
321318
img = get_img(img_path)
322319
bboxes, tags = get_bboxes(img, gt_path)
323-
324-
img = random_scale(img, self.img_size[0])
325-
326-
shrink_map, shrink_mask, threshold_map, threshold_mask = init_map_mask(img, bboxes, self.shrink_ratio, self.thresh_min, self.thresh_max, tags)
327-
328-
imgs = [img, shrink_map, shrink_mask, threshold_map, threshold_mask]
329-
330-
imgs = random_horizontal_flip(imgs)
331-
imgs = random_rotate(imgs)
332-
imgs = random_crop(imgs, self.img_size)
333-
334-
img, shrink_map, shrink_mask, threshold_map, threshold_mask = imgs[0], imgs[1], imgs[2], imgs[3], imgs[4]
335-
336-
img = Image.fromarray(img)
337-
img = img.convert('RGB')
338-
img = transforms.ColorJitter(brightness = 32.0 / 255, saturation = 0.5)(img)
339-
340-
img = transforms.ToTensor()(img)
341-
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
342-
343-
shrink_map = torch.from_numpy(shrink_map).float()
344-
shrink_mask = torch.from_numpy(shrink_mask).float()
345-
threshold_map = torch.from_numpy(threshold_map).float()
346-
threshold_mask = torch.from_numpy(threshold_mask).float()
347-
348-
return img, [shrink_map, shrink_mask, threshold_map, threshold_mask]
320+
return augmentation(img, bboxes, self.shrink_ratio, self.thresh_min, self.thresh_max, self.img_size, tags)
349321

350322
class DBTestDataset(OsWalkDataset):
351323
def __init__(self, deepvac_config, sample_path, long_size = 1280):
@@ -364,7 +336,6 @@ def scale(self, img):
364336
def __getitem__(self, idx):
365337
img = super(DBTestDataset, self).__getitem__(idx)
366338
org_img = img.copy()
367-
368339
img = img[:, :, [2, 1, 0]]
369340
scaled_img = self.scale(img)
370341
scaled_img = Image.fromarray(scaled_img)

0 commit comments

Comments
 (0)