@@ -235,6 +235,33 @@ def init_map_mask(img, bboxes, shrink_ratio, thresh_min, thresh_max, tags=None):
235
235
threshold_map = threshold_map * (thresh_max - thresh_min ) + thresh_min
236
236
return shrink_map , shrink_mask , threshold_map , threshold_mask
237
237
238
+ def augmentation (img , bboxes , shrink_ratio , thresh_min , thresh_max , img_size , tags = None ):
239
+ img = random_scale (img , img_size [0 ])
240
+
241
+ shrink_map , shrink_mask , threshold_map , threshold_mask = init_map_mask (img , bboxes , shrink_ratio , thresh_min , thresh_max , tags )
242
+
243
+ imgs = [img , shrink_map , shrink_mask , threshold_map , threshold_mask ]
244
+
245
+ imgs = random_horizontal_flip (imgs )
246
+ imgs = random_rotate (imgs )
247
+ imgs = random_crop (imgs , img_size )
248
+
249
+ img , shrink_map , shrink_mask , threshold_map , threshold_mask = imgs [0 ], imgs [1 ], imgs [2 ], imgs [3 ], imgs [4 ]
250
+
251
+ img = Image .fromarray (img )
252
+ img = img .convert ('RGB' )
253
+ img = transforms .ColorJitter (brightness = 32.0 / 255 , saturation = 0.5 )(img )
254
+
255
+ img = transforms .ToTensor ()(img )
256
+ img = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])(img )
257
+
258
+ shrink_map = torch .from_numpy (shrink_map ).float ()
259
+ shrink_mask = torch .from_numpy (shrink_mask ).float ()
260
+ threshold_map = torch .from_numpy (threshold_map ).float ()
261
+ threshold_mask = torch .from_numpy (threshold_mask ).float ()
262
+
263
+ return img , [shrink_map , shrink_mask , threshold_map , threshold_mask ]
264
+
238
265
class DBTrainCocoDataset (CocoCVContoursDataset ):
239
266
def __init__ (self , deepvac_config , sample_path_prefix , target_path , img_size ):
240
267
super (DBTrainCocoDataset , self ).__init__ (deepvac_config , sample_path_prefix , target_path )
@@ -253,55 +280,26 @@ def __getitem__(self, index):
253
280
bbox = bbox .split (',' )[:- 1 ]
254
281
bbox = [np .int (i ) for i in bbox ]
255
282
bboxes [idx ] = np .asarray (bbox ) / ([w * 1.0 , h * 1.0 ] * (len (bbox )// 2 ))
256
-
257
- img = random_scale (img , self .img_size [0 ])
258
-
259
- shrink_map , shrink_mask , threshold_map , threshold_mask = init_map_mask (img , bboxes , self .shrink_ratio , self .thresh_min , self .thresh_max )
260
-
261
- imgs = [img , shrink_map , shrink_mask , threshold_map , threshold_mask ]
262
-
263
- imgs = random_horizontal_flip (imgs )
264
- imgs = random_rotate (imgs )
265
- imgs = random_crop (imgs , self .img_size )
266
-
267
- img , shrink_map , shrink_mask , threshold_map , threshold_mask = imgs [0 ], imgs [1 ], imgs [2 ], imgs [3 ], imgs [4 ]
268
-
269
- img = Image .fromarray (img )
270
- img = img .convert ('RGB' )
271
- img = transforms .ColorJitter (brightness = 32.0 / 255 , saturation = 0.5 )(img )
272
-
273
- img = transforms .ToTensor ()(img )
274
- img = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])(img )
275
-
276
- shrink_map = torch .from_numpy (shrink_map ).float ()
277
- shrink_mask = torch .from_numpy (shrink_mask ).float ()
278
- threshold_map = torch .from_numpy (threshold_map ).float ()
279
- threshold_mask = torch .from_numpy (threshold_mask ).float ()
280
-
281
- return img , [shrink_map , shrink_mask , threshold_map , threshold_mask ]
283
+ return augmentation (img , bboxes , self .shrink_ratio , self .thresh_min , self .thresh_max , self .img_size )
282
284
283
285
class DBTrainDataset (DatasetBase ):
284
- def __init__ (self , deepvac_config , sample_path , label_path , is_transform , img_size ):
286
+ def __init__ (self , deepvac_config , sample_path , label_path , img_size ):
285
287
super (DBTrainDataset , self ).__init__ (deepvac_config )
286
288
self .img_size = img_size if (img_size is None or isinstance (img_size , tuple )) else (img_size , img_size )
287
289
data_dirs = [sample_path ]
288
290
gt_dirs = [label_path ]
289
291
self .img_paths = []
290
292
self .gt_paths = []
291
-
292
293
for data_dir , gt_dir in zip (data_dirs , gt_dirs ):
293
294
img_names = os .listdir (data_dir )
294
-
295
295
img_paths = []
296
296
gt_paths = []
297
297
for idx , img_name in enumerate (img_names ):
298
298
img_path = data_dir + img_name
299
299
img_paths .append (img_path )
300
-
301
300
gt_name = 'gt_' + img_name [:- 4 ] + '.txt'
302
301
gt_path = gt_dir + gt_name
303
302
gt_paths .append (gt_path )
304
-
305
303
self .img_paths .extend (img_paths )
306
304
self .gt_paths .extend (gt_paths )
307
305
@@ -317,35 +315,9 @@ def __len__(self):
317
315
def __getitem__ (self , index ):
318
316
img_path = self .img_paths [index ]
319
317
gt_path = self .gt_paths [index ]
320
-
321
318
img = get_img (img_path )
322
319
bboxes , tags = get_bboxes (img , gt_path )
323
-
324
- img = random_scale (img , self .img_size [0 ])
325
-
326
- shrink_map , shrink_mask , threshold_map , threshold_mask = init_map_mask (img , bboxes , self .shrink_ratio , self .thresh_min , self .thresh_max , tags )
327
-
328
- imgs = [img , shrink_map , shrink_mask , threshold_map , threshold_mask ]
329
-
330
- imgs = random_horizontal_flip (imgs )
331
- imgs = random_rotate (imgs )
332
- imgs = random_crop (imgs , self .img_size )
333
-
334
- img , shrink_map , shrink_mask , threshold_map , threshold_mask = imgs [0 ], imgs [1 ], imgs [2 ], imgs [3 ], imgs [4 ]
335
-
336
- img = Image .fromarray (img )
337
- img = img .convert ('RGB' )
338
- img = transforms .ColorJitter (brightness = 32.0 / 255 , saturation = 0.5 )(img )
339
-
340
- img = transforms .ToTensor ()(img )
341
- img = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])(img )
342
-
343
- shrink_map = torch .from_numpy (shrink_map ).float ()
344
- shrink_mask = torch .from_numpy (shrink_mask ).float ()
345
- threshold_map = torch .from_numpy (threshold_map ).float ()
346
- threshold_mask = torch .from_numpy (threshold_mask ).float ()
347
-
348
- return img , [shrink_map , shrink_mask , threshold_map , threshold_mask ]
320
+ return augmentation (img , bboxes , self .shrink_ratio , self .thresh_min , self .thresh_max , self .img_size , tags )
349
321
350
322
class DBTestDataset (OsWalkDataset ):
351
323
def __init__ (self , deepvac_config , sample_path , long_size = 1280 ):
@@ -364,7 +336,6 @@ def scale(self, img):
364
336
def __getitem__ (self , idx ):
365
337
img = super (DBTestDataset , self ).__getitem__ (idx )
366
338
org_img = img .copy ()
367
-
368
339
img = img [:, :, [2 , 1 , 0 ]]
369
340
scaled_img = self .scale (img )
370
341
scaled_img = Image .fromarray (scaled_img )
0 commit comments