1
1
import torch
2
2
import torch .nn as nn
3
- import torch .nn .functional as F
4
- import numpy as np
5
3
6
4
try :
7
5
from inplace_abn import InPlaceABN
@@ -119,7 +117,6 @@ class ECAM(nn.Module):
119
117
Ensemble Channel Attention Module for UNetPlusPlus.
120
118
Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J].
121
119
IEEE Geoscience and Remote Sensing Letters, 2021.
122
-
123
120
Not completely consistent, to be improved.
124
121
"""
125
122
def __init__ (self , in_channels , out_channels , map_num = 4 ):
@@ -142,218 +139,26 @@ def forward(self, x):
142
139
return out
143
140
144
141
145
- class ModuleHelper :
146
-
147
- @staticmethod
148
- def BNReLU (num_features , bn_type = None , ** kwargs ):
149
- return nn .Sequential (
150
- nn .BatchNorm2d (num_features , ** kwargs ),
151
- nn .ReLU ()
152
- )
153
-
154
- @staticmethod
155
- def BatchNorm2d (* args , ** kwargs ):
156
- return nn .BatchNorm2d
157
-
158
-
159
- class SpatialGather_Module (nn .Module ):
160
- """
161
- Aggregate the context features according to the initial
162
- predicted probability distribution.
163
- Employ the soft-weighted method to aggregate the context.
164
- """
165
- def __init__ (self , cls_num = 0 , scale = 1 ):
166
- super (SpatialGather_Module , self ).__init__ ()
167
- self .cls_num = cls_num
168
- self .scale = scale
169
-
170
- def forward (self , feats , probs ):
171
- batch_size , c , h , w = probs .size (0 ), probs .size (1 ), probs .size (2 ), probs .size (3 )
172
- probs = probs .view (batch_size , c , - 1 )
173
- feats = feats .view (batch_size , feats .size (1 ), - 1 )
174
- feats = feats .permute (0 , 2 , 1 ) # batch x hw x c
175
- probs = F .softmax (self .scale * probs , dim = 2 )# batch x k x hw
176
- ocr_context = torch .matmul (probs , feats )\
177
- .permute (0 , 2 , 1 ).unsqueeze (3 )# batch x k x c
178
- return ocr_context
179
-
180
-
181
- class _ObjectAttentionBlock (nn .Module ):
182
- '''
183
- The basic implementation for object context block
184
- Input:
185
- N X C X H X W
186
- Parameters:
187
- in_channels : the dimension of the input feature map
188
- key_channels : the dimension after the key/query transform
189
- scale : choose the scale to downsample the input feature maps (save memory cost)
190
- bn_type : specify the bn type
191
- Return:
192
- N X C X H X W
193
- '''
194
- def __init__ (self ,
195
- in_channels ,
196
- key_channels ,
197
- scale = 1 ,
198
- bn_type = None ):
199
- super (_ObjectAttentionBlock , self ).__init__ ()
200
- self .scale = scale
201
- self .in_channels = in_channels
202
- self .key_channels = key_channels
203
- self .pool = nn .MaxPool2d (kernel_size = (scale , scale ))
204
- self .f_pixel = nn .Sequential (
205
- nn .Conv2d (in_channels = self .in_channels , out_channels = self .key_channels ,
206
- kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
207
- ModuleHelper .BNReLU (self .key_channels , bn_type = bn_type ),
208
- nn .Conv2d (in_channels = self .key_channels , out_channels = self .key_channels ,
209
- kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
210
- ModuleHelper .BNReLU (self .key_channels , bn_type = bn_type ),
211
- )
212
- self .f_object = nn .Sequential (
213
- nn .Conv2d (in_channels = self .in_channels , out_channels = self .key_channels ,
214
- kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
215
- ModuleHelper .BNReLU (self .key_channels , bn_type = bn_type ),
216
- nn .Conv2d (in_channels = self .key_channels , out_channels = self .key_channels ,
217
- kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
218
- ModuleHelper .BNReLU (self .key_channels , bn_type = bn_type ),
219
- )
220
- self .f_down = nn .Sequential (
221
- nn .Conv2d (in_channels = self .in_channels , out_channels = self .key_channels ,
222
- kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
223
- ModuleHelper .BNReLU (self .key_channels , bn_type = bn_type ),
224
- )
225
- self .f_up = nn .Sequential (
226
- nn .Conv2d (in_channels = self .key_channels , out_channels = self .in_channels ,
227
- kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
228
- ModuleHelper .BNReLU (self .in_channels , bn_type = bn_type ),
229
- )
230
-
231
- def forward (self , x , proxy ):
232
- batch_size , h , w = x .size (0 ), x .size (2 ), x .size (3 )
233
- if self .scale > 1 :
234
- x = self .pool (x )
235
-
236
- query = self .f_pixel (x ).view (batch_size , self .key_channels , - 1 )
237
- query = query .permute (0 , 2 , 1 )
238
- key = self .f_object (proxy ).view (batch_size , self .key_channels , - 1 )
239
- value = self .f_down (proxy ).view (batch_size , self .key_channels , - 1 )
240
- value = value .permute (0 , 2 , 1 )
241
-
242
- sim_map = torch .matmul (query , key )
243
- sim_map = (self .key_channels ** - .5 ) * sim_map
244
- sim_map = F .softmax (sim_map , dim = - 1 )
245
-
246
- # add bg context ...
247
- context = torch .matmul (sim_map , value )
248
- context = context .permute (0 , 2 , 1 ).contiguous ()
249
- context = context .view (batch_size , self .key_channels , * x .size ()[2 :])
250
- context = self .f_up (context )
251
- if self .scale > 1 :
252
- context = F .interpolate (input = context , size = (h , w ), mode = 'bilinear' , align_corners = True )
253
-
254
- return context
255
-
256
-
257
- class ObjectAttentionBlock2D (_ObjectAttentionBlock ):
258
- def __init__ (self ,
259
- in_channels ,
260
- key_channels ,
261
- scale = 1 ,
262
- bn_type = None ):
263
- super (ObjectAttentionBlock2D , self ).__init__ (in_channels ,
264
- key_channels ,
265
- scale ,
266
- bn_type = bn_type )
267
-
268
-
269
- class SpatialOCR_Module (nn .Module ):
270
- """
271
- Implementation of the OCR module:
272
- We aggregate the global object representation to update the representation for each pixel.
273
- """
274
- def __init__ (self ,
275
- in_channels ,
276
- key_channels ,
277
- out_channels ,
278
- scale = 1 ,
279
- dropout = 0.1 ,
280
- bn_type = None ):
281
- super (SpatialOCR_Module , self ).__init__ ()
282
- self .object_context_block = ObjectAttentionBlock2D (in_channels ,
283
- key_channels ,
284
- scale ,
285
- bn_type )
286
- _in_channels = 2 * in_channels
287
-
288
- self .conv_bn_dropout = nn .Sequential (
289
- nn .Conv2d (_in_channels , out_channels , kernel_size = 1 , padding = 0 , bias = False ),
290
- ModuleHelper .BNReLU (out_channels , bn_type = bn_type ),
291
- nn .Dropout2d (dropout )
292
- )
293
-
294
- def forward (self , feats , proxy_feats ):
295
- context = self .object_context_block (feats , proxy_feats )
296
-
297
- output = self .conv_bn_dropout (torch .cat ([context , feats ], 1 ))
298
-
299
- return output
300
-
301
-
302
- class OCR (nn .Module ):
142
+ class SEModule (nn .Module ):
303
143
"""
304
- Segmentation Transformer: Object-Contextual Representations for Semantic Segmentation
305
- https://arxiv.org/pdf/1909.11065.pdf
144
+ Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]
145
+ //Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141.
306
146
"""
307
- def __init__ (self , in_channels , num_classes , ocr_mid_channels = 512 , ocr_key_channels = 256 ):
308
-
309
- super ().__init__ ()
310
- pre_stage_channels = in_channels
311
- last_inp_channels = np .int (np .sum (pre_stage_channels ))
312
-
313
- self .conv3x3_ocr = nn .Sequential (
314
- nn .Conv2d (last_inp_channels , ocr_mid_channels ,
315
- kernel_size = 3 , stride = 1 , padding = 1 ),
316
- nn .BatchNorm2d (ocr_mid_channels ),
317
- nn .ReLU (inplace = True ),
318
- )
319
- self .ocr_gather_head = SpatialGather_Module (num_classes )
320
-
321
- self .ocr_distri_head = SpatialOCR_Module (in_channels = ocr_mid_channels ,
322
- key_channels = ocr_key_channels ,
323
- out_channels = ocr_mid_channels ,
324
- scale = 1 ,
325
- dropout = 0.05 ,
326
- )
327
- self .cls_head = nn .Conv2d (
328
- ocr_mid_channels , num_classes , kernel_size = 1 , stride = 1 , padding = 0 , bias = True )
329
-
330
- self .aux_head = nn .Sequential (
331
- nn .Conv2d (last_inp_channels , last_inp_channels ,
332
- kernel_size = 1 , stride = 1 , padding = 0 ),
333
- nn .BatchNorm2d (last_inp_channels ),
147
+ def __init__ (self , in_channels , reduction = 16 ):
148
+ super (SEModule , self ).__init__ ()
149
+ self .avg_pool = nn .AdaptiveAvgPool2d (1 )
150
+ self .fc = nn .Sequential (
151
+ nn .Linear (in_channels , in_channels // reduction , bias = False ),
334
152
nn .ReLU (inplace = True ),
335
- nn .Conv2d ( last_inp_channels , num_classes ,
336
- kernel_size = 1 , stride = 1 , padding = 0 , bias = True )
153
+ nn .Linear ( in_channels // reduction , in_channels , bias = False ) ,
154
+ nn . Sigmoid ( )
337
155
)
338
156
339
157
def forward (self , x ):
340
-
341
- out_aux_seg = []
342
-
343
- # ocr
344
- out_aux = self .aux_head (x )
345
- # compute contrast feature
346
- feats = self .conv3x3_ocr (x )
347
-
348
- context = self .ocr_gather_head (feats , out_aux )
349
- feats = self .ocr_distri_head (feats , context )
350
-
351
- out = self .cls_head (feats )
352
-
353
- out_aux_seg .append (out_aux )
354
- out_aux_seg .append (out )
355
-
356
- return out_aux_seg
158
+ b , c , _ , _ = x .size ()
159
+ y = self .avg_pool (x ).view (b , c )
160
+ y = self .fc (y ).view (b , c , 1 , 1 )
161
+ return x * y .expand_as (x )
357
162
358
163
359
164
class ArgMax (nn .Module ):
@@ -412,6 +217,8 @@ def __init__(self, name, **params):
412
217
self .attention = CBAMSpatial (** params )
413
218
elif name == 'cbam' :
414
219
self .attention = CBAM (** params )
220
+ elif name == 'se' :
221
+ self .attention = SEModule (** params )
415
222
else :
416
223
raise ValueError ("Attention {} is not implemented" .format (name ))
417
224
0 commit comments