-
Notifications
You must be signed in to change notification settings - Fork 88
/
transforms.py
391 lines (348 loc) · 15.2 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
#-------------------------------------------------------------------------------
# Author: Lukasz Janyst <[email protected]>
# Date: 18.09.2017
#-------------------------------------------------------------------------------
# This file is part of SSD-TensorFlow.
#
# SSD-TensorFlow is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# SSD-TensorFlow is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SSD-Tensorflow. If not, see <http://www.gnu.org/licenses/>.
#-------------------------------------------------------------------------------
import cv2
import random
import numpy as np
from ssdutils import get_anchors_for_preset, get_preset_by_name, anchors2array
from ssdutils import box2array, compute_overlap, compute_location, anchors2array
from utils import Size, Sample, Point, Box, abs2prop, prop2abs
from math import sqrt
#-------------------------------------------------------------------------------
class Transform:
def __init__(self, **kwargs):
for arg, val in kwargs.items():
setattr(self, arg, val)
self.initialized = False
#-------------------------------------------------------------------------------
class ImageLoaderTransform(Transform):
"""
Load and image from the file specified in the Sample object
"""
def __call__(self, data, label, gt):
return cv2.imread(gt.filename), label, gt
#-------------------------------------------------------------------------------
def process_overlap(overlap, box, anchor, matches, num_classes, vec):
if overlap.idx in matches and matches[overlap.idx] >= overlap.score:
return
matches[overlap.idx] = overlap.score
vec[overlap.idx, 0:num_classes+1] = 0
vec[overlap.idx, box.labelid] = 1
vec[overlap.idx, num_classes+1:] = compute_location(box, anchor)
#-------------------------------------------------------------------------------
class LabelCreatorTransform(Transform):
"""
Create a label vector out of a ground trut sample
Parameters: preset, num_classes
"""
#---------------------------------------------------------------------------
def initialize(self):
self.anchors = get_anchors_for_preset(self.preset)
self.vheight = len(self.anchors)
self.vwidth = self.num_classes+5 # background class + location offsets
self.img_size = Size(1000, 1000)
self.anchors_arr = anchors2array(self.anchors, self.img_size)
self.initialized = True
#---------------------------------------------------------------------------
def __call__(self, data, label, gt):
#-----------------------------------------------------------------------
# Initialize the data vector and other variables
#-----------------------------------------------------------------------
if not self.initialized:
self.initialize()
vec = np.zeros((self.vheight, self.vwidth), dtype=np.float32)
#-----------------------------------------------------------------------
# For every box compute the best match and all the matches above 0.5
# Jaccard overlap
#-----------------------------------------------------------------------
overlaps = {}
for box in gt.boxes:
box_arr = box2array(box, self.img_size)
overlaps[box] = compute_overlap(box_arr, self.anchors_arr, 0.5)
#-----------------------------------------------------------------------
# Set up the training vector resolving conflicts in favor of a better
# match
#-----------------------------------------------------------------------
vec[:, self.num_classes] = 1 # background class
vec[:, self.num_classes+1] = 0 # x offset
vec[:, self.num_classes+2] = 0 # y offset
vec[:, self.num_classes+3] = 0 # log width scale
vec[:, self.num_classes+4] = 0 # log height scale
matches = {}
for box in gt.boxes:
for overlap in overlaps[box].good:
anchor = self.anchors[overlap.idx]
process_overlap(overlap, box, anchor, matches, self.num_classes, vec)
matches = {}
for box in gt.boxes:
overlap = overlaps[box].best
if not overlap:
continue
anchor = self.anchors[overlap.idx]
process_overlap(overlap, box, anchor, matches, self.num_classes, vec)
return data, vec, gt
#-------------------------------------------------------------------------------
class ResizeTransform(Transform):
"""
Resize an image
Parameters: width, height, algorithms
"""
def __call__(self, data, label, gt):
alg = random.choice(self.algorithms)
resized = cv2.resize(data, (self.width, self.height), interpolation=alg)
return resized, label, gt
#-------------------------------------------------------------------------------
class RandomTransform(Transform):
"""
Call another transform with a given probability
Parameters: prob, transform
"""
def __call__(self, data, label, gt):
p = random.uniform(0, 1)
if p < self.prob:
return self.transform(data, label, gt)
return data, label, gt
#-------------------------------------------------------------------------------
class ComposeTransform(Transform):
"""
Call a bunch of transforms serially
Parameters: transforms
"""
def __call__(self, data, label, gt):
args = (data, label, gt)
for t in self.transforms:
args = t(*args)
return args
#-------------------------------------------------------------------------------
class TransformPickerTransform(Transform):
"""
Call a randomly chosen transform from the list
Parameters: transforms
"""
def __call__(self, data, label, gt):
pick = random.randint(0, len(self.transforms)-1)
return self.transforms[pick](data, label, gt)
#-------------------------------------------------------------------------------
class BrightnessTransform(Transform):
"""
Transform brightness
Parameters: delta
"""
def __call__(self, data, label, gt):
data = data.astype(np.float32)
delta = random.randint(-self.delta, self.delta)
data += delta
data[data>255] = 255
data[data<0] = 0
data = data.astype(np.uint8)
return data, label, gt
#-------------------------------------------------------------------------------
class ContrastTransform(Transform):
"""
Transform contrast
Parameters: lower, upper
"""
def __call__(self, data, label, gt):
data = data.astype(np.float32)
delta = random.uniform(self.lower, self.upper)
data *= delta
data[data>255] = 255
data[data<0] = 0
data = data.astype(np.uint8)
return data, label, gt
#-------------------------------------------------------------------------------
class HueTransform(Transform):
"""
Transform hue
Parameters: delta
"""
def __call__(self, data, label, gt):
data = cv2.cvtColor(data, cv2.COLOR_BGR2HSV)
data = data.astype(np.float32)
delta = random.randint(-self.delta, self.delta)
data[0] += delta
data[0][data[0]>180] -= 180
data[0][data[0]<0] +=180
data = data.astype(np.uint8)
data = cv2.cvtColor(data, cv2.COLOR_HSV2BGR)
return data, label, gt
#-------------------------------------------------------------------------------
class SaturationTransform(Transform):
"""
Transform hue
Parameters: lower, upper
"""
def __call__(self, data, label, gt):
data = cv2.cvtColor(data, cv2.COLOR_BGR2HSV)
data = data.astype(np.float32)
delta = random.uniform(self.lower, self.upper)
data[1] *= delta
data[1][data[1]>255] = 255
data[1][data[1]<0] = 0
data = data.astype(np.uint8)
data = cv2.cvtColor(data, cv2.COLOR_HSV2BGR)
return data, label, gt
#-------------------------------------------------------------------------------
class ReorderChannelsTransform(Transform):
"""
Reorder Image Channels
"""
def __call__(self, data, label, gt):
channels = [0, 1, 2]
random.shuffle(channels)
return data[:, :,channels], label, gt
#-------------------------------------------------------------------------------
def transform_box(box, orig_size, new_size, h_off, w_off):
#---------------------------------------------------------------------------
# Compute the new coordinates of the box
#---------------------------------------------------------------------------
xmin, xmax, ymin, ymax = prop2abs(box.center, box.size, orig_size)
xmin += w_off
xmax += w_off
ymin += h_off
ymax += h_off
#---------------------------------------------------------------------------
# Check if the center falls within the image
#---------------------------------------------------------------------------
width = xmax - xmin
height = ymax - ymin
new_cx = xmin + int(width/2)
new_cy = ymin + int(height/2)
if new_cx < 0 or new_cx >= new_size.w:
return None
if new_cy < 0 or new_cy >= new_size.h:
return None
center, size = abs2prop(xmin, xmax, ymin, ymax, new_size)
return Box(box.label, box.labelid, center, size)
#-------------------------------------------------------------------------------
def transform_gt(gt, new_size, h_off, w_off):
boxes = []
for box in gt.boxes:
box = transform_box(box, gt.imgsize, new_size, h_off, w_off)
if box is None:
continue
boxes.append(box)
return Sample(gt.filename, boxes, new_size)
#-------------------------------------------------------------------------------
class ExpandTransform(Transform):
"""
Expand the image and fill the empty space with the mean value
Parameters: max_ratio, mean_value
"""
def __call__(self, data, label, gt):
#-----------------------------------------------------------------------
# Calculate sizes and offsets
#-----------------------------------------------------------------------
ratio = random.uniform(1, self.max_ratio)
orig_size = gt.imgsize
new_size = Size(int(orig_size.w*ratio), int(orig_size.h*ratio))
h_off = random.randint(0, new_size.h-orig_size.h)
w_off = random.randint(0, new_size.w-orig_size.w)
#-----------------------------------------------------------------------
# Create the new image and place the input image in it
#-----------------------------------------------------------------------
img = np.zeros((new_size.h, new_size.w, 3))
img[:, :] = np.array(self.mean_value)
img[h_off:h_off+orig_size.h, w_off:w_off+orig_size.w, :] = data
#-----------------------------------------------------------------------
# Transform the ground truth
#-----------------------------------------------------------------------
gt = transform_gt(gt, new_size, h_off, w_off)
return img, label, gt
#-------------------------------------------------------------------------------
class SamplerTransform(Transform):
"""
Sample a fraction of the image according to given parameters
Params: min_scale, max_scale, min_aspect_ratio, max_aspect_ratio,
min_jaccard_overlap
"""
def __call__(self, data, label, gt):
#-----------------------------------------------------------------------
# Check whether to sample or not
#-----------------------------------------------------------------------
if not self.sample:
return data, label, gt
#-----------------------------------------------------------------------
# Retry sampling a couple of times
#-----------------------------------------------------------------------
source_boxes = anchors2array(gt.boxes, gt.imgsize)
box = None
box_arr = None
for _ in range(self.max_trials):
#-------------------------------------------------------------------
# Sample a bounding box
#-------------------------------------------------------------------
scale = random.uniform(self.min_scale, self.max_scale)
aspect_ratio = random.uniform(self.min_aspect_ratio,
self.max_aspect_ratio)
# make sure width and height will not be larger than 1
aspect_ratio = max(aspect_ratio, scale**2)
aspect_ratio = min(aspect_ratio, 1/(scale**2))
width = scale*sqrt(aspect_ratio)
height = scale/sqrt(aspect_ratio)
cx = 0.5*width + random.uniform(0, 1-width)
cy = 0.5*height + random.uniform(0, 1-height)
center = Point(cx, cy)
size = Size(width, height)
#-------------------------------------------------------------------
# Check if the box satisfies the jaccard overlap constraint
#-------------------------------------------------------------------
box_arr = np.array(prop2abs(center, size, gt.imgsize))
overlap = compute_overlap(box_arr, source_boxes, 0)
if overlap.best and overlap.best.score >= self.min_jaccard_overlap:
box = Box(None, None, center, size)
break
if box is None:
return None
#-----------------------------------------------------------------------
# Crop the box and adjust the ground truth
#-----------------------------------------------------------------------
new_size = Size(box_arr[1]-box_arr[0], box_arr[3]-box_arr[2])
w_off = -box_arr[0]
h_off = -box_arr[2]
data = data[box_arr[2]:box_arr[3], box_arr[0]:box_arr[1]]
gt = transform_gt(gt, new_size, h_off, w_off)
return data, label, gt
#-------------------------------------------------------------------------------
class SamplePickerTransform(Transform):
"""
Run a bunch of sample transforms and return one of the produced samples
Parameters: samplers
"""
def __call__(self, data, label, gt):
samples = []
for sampler in self.samplers:
sample = sampler(data, label, gt)
if sample is not None:
samples.append(sample)
return random.choice(samples)
#-------------------------------------------------------------------------------
class HorizontalFlipTransform(Transform):
"""
Horizontally flip the image
"""
def __call__(self, data, label, gt):
data = cv2.flip(data, 1)
boxes = []
for box in gt.boxes:
center = Point(1-box.center.x, box.center.y)
box = Box(box.label, box.labelid, center, box.size)
boxes.append(box)
gt = Sample(gt.filename, boxes, gt.imgsize)
return data, label, gt