-
Notifications
You must be signed in to change notification settings - Fork 0
/
ROIProcessor.lua
174 lines (141 loc) · 5.85 KB
/
ROIProcessor.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
--[[
Process ROI samples.
]]
local utils = require 'fastrcnn.utils'
local boxoverlap = utils.box.boxoverlap
if not fastrcnn then fastrcnn = {} end
---------------------------------------------------------------------------------------------------
local ROIProcessor = torch.class('fastrcnn.ROIProcessor')
function ROIProcessor:__init(dataLoadFn, proposals, opt)
assert(dataLoadFn, 'Invalid input: dataLoadFn')
assert(proposals, 'Invalid input: proposals')
assert(opt, 'Invalid input: opt')
self.dataLoadFn = dataLoadFn
self.roidb = proposals
self.classes = dataLoadFn.classLabel
self.nFiles = dataLoadFn.nfiles
self.augment_offset = opt.frcnn_roi_augment_offset
end
------------------------------------------------------------------------------------------------------------
function ROIProcessor:getROIBoxes(idx)
return self.roidb[idx][{{},{1,4}}]:float()
end
------------------------------------------------------------------------------------------------------------
function ROIProcessor:getGTBoxes(idx)
return self.dataLoadFn.getGTBoxes(idx)
end
------------------------------------------------------------------------------------------------------------
function ROIProcessor:getFilename(idx)
return self.dataLoadFn.getFilename(idx)
end
------------------------------------------------------------------------------------------------------------
function ROIProcessor:augmentRoiProposals(boxes)
-- augment number of region proposals by using coordinate offset
local new_boxes = {}
if self.augment_offset > 0 then
--print('Augmenting the number of roi proposal regions by jittering the available rois coordinates...')
--local tic = torch.tic()
local offset = torch.range(0.1, self.augment_offset, 0.1):totable()
offset = offset[#offset]
local roi_data = boxes:clone()
for ix=-offset, offset, 0.1 do
for iy=-offset, offset, 0.1 do
if not (ix == iy and math.abs(ix) <= 0.0001) then
local roi_data_offset = boxes:clone()
local offx = (roi_data_offset:select(2,3) - roi_data_offset:select(2,1)):mul(ix)
local offy = (roi_data_offset:select(2,4) - roi_data_offset:select(2,2)):mul(iy)
roi_data_offset:select(2,1):add(offx)
roi_data_offset:select(2,2):add(offy)
roi_data_offset:select(2,3):add(offx)
roi_data_offset:select(2,4):add(offy)
roi_data = roi_data:cat(roi_data_offset, 1)
end
end
end
--roi_data[roi_data:lt(1)] = 1
new_boxes = roi_data
--print('Done. Elapsed time: ' .. torch.toc(tic))
else
new_boxes = boxes
end
return new_boxes
end
------------------------------------------------------------------------------------------------------------
function ROIProcessor:getProposals(idx)
-- fetch object boxes, classes
local gt_boxes, gt_classes = self.dataLoadFn.getGTBoxes(idx)
-- check if there are any roi boxes for the current image
if gt_boxes == nil then
return nil
end
-- fetch roi proposal boxes
local boxes = self:getROIBoxes(idx)
local all_boxes
if boxes:numel() > 0 and gt_boxes:numel() > 0 then
all_boxes = torch.cat(gt_boxes,boxes,1)
elseif boxes:numel() == 0 then
all_boxes = gt_boxes
else
all_boxes = boxes
end
local num_boxes = boxes:numel() > 0 and boxes:size(1) or 0
local num_gt_boxes = #gt_classes
-- data recipient
local rec = {}
if num_gt_boxes > 0 and num_boxes > 0 then
rec.gt = torch.cat(torch.ByteTensor(num_gt_boxes):fill(1), torch.ByteTensor(num_boxes):fill(0))
elseif num_boxes > 0 then
rec.gt = torch.ByteTensor(num_boxes):fill(0)
elseif num_gt_boxes > 0 then
rec.gt = torch.ByteTensor(num_gt_boxes):fill(1)
else
rec.gt = torch.ByteTensor(0)
end
-- augment the number of roi proposals
all_boxes = self:augmentRoiProposals(all_boxes)
-- box overlap
rec.overlap_class = torch.FloatTensor(all_boxes:size(1), #self.classes):fill(0)
rec.overlap = torch.FloatTensor(all_boxes:size(1), num_gt_boxes):fill(0)
for idx=1,num_gt_boxes do
local o = boxoverlap(all_boxes, gt_boxes[idx])
local tmp = rec.overlap_class[{{}, gt_classes[idx]}] -- pointer copy
tmp[tmp:lt(o)] = o[tmp:lt(o)]
rec.overlap[{{}, idx}] = boxoverlap(all_boxes, gt_boxes[idx])
end
-- correspondence
if num_gt_boxes > 0 then
rec.overlap, rec.correspondance = rec.overlap:max(2)
rec.overlap = torch.squeeze(rec.overlap,2)
rec.correspondance = torch.squeeze(rec.correspondance,2)
rec.correspondance[rec.overlap:eq(0)] = 0
else
--rec.overlap = torch.FloatTensor(num_boxes+num_gt_boxes):fill(0)
--rec.correspondance = torch.LongTensor(num_boxes+num_gt_boxes):fill(0)
rec.overlap = torch.FloatTensor(all_boxes:size(1)):fill(0)
rec.correspondance = torch.LongTensor(all_boxes:size(1)):fill(0)
end
-- set class label
--rec.label = torch.IntTensor(num_boxes+num_gt_boxes):fill(0)
--for idx=1,(num_boxes+num_gt_boxes) do
rec.label = torch.IntTensor(all_boxes:size(1)):fill(0)
for idx=1, all_boxes:size(1) do
local corr = rec.correspondance[idx]
if corr > 0 then
rec.label[idx] = gt_classes[corr]
end
end
rec.boxes = all_boxes
if num_gt_boxes > 0 and num_boxes > 0 then
rec.class = torch.cat(torch.CharTensor(gt_classes), torch.CharTensor(all_boxes:size(1)-1):fill(0))
elseif num_boxes > 0 then
rec.class = torch.CharTensor(all_boxes:size(1)-1):fill(0)
elseif num_gt_boxes > 0 then
rec.class = torch.CharTensor(gt_classes)
else
rec.class = torch.CharTensor(0)
end
function rec:size()
return all_boxes:size(1)
end
return rec
end