Skip to content

Commit de96478

Browse files
committed
add se
1 parent 9286d9f commit de96478

File tree

3 files changed

+17
-226
lines changed

3 files changed

+17
-226
lines changed

change_detection_pytorch/base/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@
99
from .heads import (
1010
SegmentationHead,
1111
ClassificationHead,
12-
SegmentationOCRHead,
1312
)

change_detection_pytorch/base/heads.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch.nn as nn
2-
from .modules import Flatten, Activation, OCR
2+
from .modules import Flatten, Activation
33

44

55
class SegmentationHead(nn.Sequential):
@@ -11,21 +11,6 @@ def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, up
1111
super().__init__(conv2d, upsampling, activation)
1212

1313

14-
class SegmentationOCRHead(nn.Module):
15-
16-
def __init__(self, in_channels, out_channels, activation=None, upsampling=1, align_corners=True):
17-
super().__init__()
18-
self.ocr_head = OCR(in_channels, out_channels)
19-
self.upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=align_corners) if upsampling > 1 else nn.Identity()
20-
self.activation = Activation(activation)
21-
22-
def forward(self, x):
23-
coarse_pre, pre = self.ocr_head(x)
24-
coarse_pre = self.activation(self.upsampling(coarse_pre))
25-
pre = self.activation(self.upsampling(pre))
26-
return [coarse_pre, pre]
27-
28-
2914
class ClassificationHead(nn.Sequential):
3015

3116
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):

change_detection_pytorch/base/modules.py

Lines changed: 16 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
import torch.nn.functional as F
4-
import numpy as np
53

64
try:
75
from inplace_abn import InPlaceABN
@@ -119,7 +117,6 @@ class ECAM(nn.Module):
119117
Ensemble Channel Attention Module for UNetPlusPlus.
120118
Fang S, Li K, Shao J, et al. SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images[J].
121119
IEEE Geoscience and Remote Sensing Letters, 2021.
122-
123120
Not completely consistent, to be improved.
124121
"""
125122
def __init__(self, in_channels, out_channels, map_num=4):
@@ -142,218 +139,26 @@ def forward(self, x):
142139
return out
143140

144141

145-
class ModuleHelper:
146-
147-
@staticmethod
148-
def BNReLU(num_features, bn_type=None, **kwargs):
149-
return nn.Sequential(
150-
nn.BatchNorm2d(num_features, **kwargs),
151-
nn.ReLU()
152-
)
153-
154-
@staticmethod
155-
def BatchNorm2d(*args, **kwargs):
156-
return nn.BatchNorm2d
157-
158-
159-
class SpatialGather_Module(nn.Module):
160-
"""
161-
Aggregate the context features according to the initial
162-
predicted probability distribution.
163-
Employ the soft-weighted method to aggregate the context.
164-
"""
165-
def __init__(self, cls_num=0, scale=1):
166-
super(SpatialGather_Module, self).__init__()
167-
self.cls_num = cls_num
168-
self.scale = scale
169-
170-
def forward(self, feats, probs):
171-
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
172-
probs = probs.view(batch_size, c, -1)
173-
feats = feats.view(batch_size, feats.size(1), -1)
174-
feats = feats.permute(0, 2, 1) # batch x hw x c
175-
probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw
176-
ocr_context = torch.matmul(probs, feats)\
177-
.permute(0, 2, 1).unsqueeze(3)# batch x k x c
178-
return ocr_context
179-
180-
181-
class _ObjectAttentionBlock(nn.Module):
182-
'''
183-
The basic implementation for object context block
184-
Input:
185-
N X C X H X W
186-
Parameters:
187-
in_channels : the dimension of the input feature map
188-
key_channels : the dimension after the key/query transform
189-
scale : choose the scale to downsample the input feature maps (save memory cost)
190-
bn_type : specify the bn type
191-
Return:
192-
N X C X H X W
193-
'''
194-
def __init__(self,
195-
in_channels,
196-
key_channels,
197-
scale=1,
198-
bn_type=None):
199-
super(_ObjectAttentionBlock, self).__init__()
200-
self.scale = scale
201-
self.in_channels = in_channels
202-
self.key_channels = key_channels
203-
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
204-
self.f_pixel = nn.Sequential(
205-
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
206-
kernel_size=1, stride=1, padding=0, bias=False),
207-
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
208-
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
209-
kernel_size=1, stride=1, padding=0, bias=False),
210-
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
211-
)
212-
self.f_object = nn.Sequential(
213-
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
214-
kernel_size=1, stride=1, padding=0, bias=False),
215-
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
216-
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
217-
kernel_size=1, stride=1, padding=0, bias=False),
218-
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
219-
)
220-
self.f_down = nn.Sequential(
221-
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
222-
kernel_size=1, stride=1, padding=0, bias=False),
223-
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
224-
)
225-
self.f_up = nn.Sequential(
226-
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
227-
kernel_size=1, stride=1, padding=0, bias=False),
228-
ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type),
229-
)
230-
231-
def forward(self, x, proxy):
232-
batch_size, h, w = x.size(0), x.size(2), x.size(3)
233-
if self.scale > 1:
234-
x = self.pool(x)
235-
236-
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
237-
query = query.permute(0, 2, 1)
238-
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
239-
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
240-
value = value.permute(0, 2, 1)
241-
242-
sim_map = torch.matmul(query, key)
243-
sim_map = (self.key_channels**-.5) * sim_map
244-
sim_map = F.softmax(sim_map, dim=-1)
245-
246-
# add bg context ...
247-
context = torch.matmul(sim_map, value)
248-
context = context.permute(0, 2, 1).contiguous()
249-
context = context.view(batch_size, self.key_channels, *x.size()[2:])
250-
context = self.f_up(context)
251-
if self.scale > 1:
252-
context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True)
253-
254-
return context
255-
256-
257-
class ObjectAttentionBlock2D(_ObjectAttentionBlock):
258-
def __init__(self,
259-
in_channels,
260-
key_channels,
261-
scale=1,
262-
bn_type=None):
263-
super(ObjectAttentionBlock2D, self).__init__(in_channels,
264-
key_channels,
265-
scale,
266-
bn_type=bn_type)
267-
268-
269-
class SpatialOCR_Module(nn.Module):
270-
"""
271-
Implementation of the OCR module:
272-
We aggregate the global object representation to update the representation for each pixel.
273-
"""
274-
def __init__(self,
275-
in_channels,
276-
key_channels,
277-
out_channels,
278-
scale=1,
279-
dropout=0.1,
280-
bn_type=None):
281-
super(SpatialOCR_Module, self).__init__()
282-
self.object_context_block = ObjectAttentionBlock2D(in_channels,
283-
key_channels,
284-
scale,
285-
bn_type)
286-
_in_channels = 2 * in_channels
287-
288-
self.conv_bn_dropout = nn.Sequential(
289-
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
290-
ModuleHelper.BNReLU(out_channels, bn_type=bn_type),
291-
nn.Dropout2d(dropout)
292-
)
293-
294-
def forward(self, feats, proxy_feats):
295-
context = self.object_context_block(feats, proxy_feats)
296-
297-
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
298-
299-
return output
300-
301-
302-
class OCR(nn.Module):
142+
class SEModule(nn.Module):
303143
"""
304-
Segmentation Transformer: Object-Contextual Representations for Semantic Segmentation
305-
https://arxiv.org/pdf/1909.11065.pdf
144+
Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]
145+
//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141.
306146
"""
307-
def __init__(self, in_channels, num_classes, ocr_mid_channels=512, ocr_key_channels=256):
308-
309-
super().__init__()
310-
pre_stage_channels = in_channels
311-
last_inp_channels = np.int(np.sum(pre_stage_channels))
312-
313-
self.conv3x3_ocr = nn.Sequential(
314-
nn.Conv2d(last_inp_channels, ocr_mid_channels,
315-
kernel_size=3, stride=1, padding=1),
316-
nn.BatchNorm2d(ocr_mid_channels),
317-
nn.ReLU(inplace=True),
318-
)
319-
self.ocr_gather_head = SpatialGather_Module(num_classes)
320-
321-
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
322-
key_channels=ocr_key_channels,
323-
out_channels=ocr_mid_channels,
324-
scale=1,
325-
dropout=0.05,
326-
)
327-
self.cls_head = nn.Conv2d(
328-
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
329-
330-
self.aux_head = nn.Sequential(
331-
nn.Conv2d(last_inp_channels, last_inp_channels,
332-
kernel_size=1, stride=1, padding=0),
333-
nn.BatchNorm2d(last_inp_channels),
147+
def __init__(self, in_channels, reduction=16):
148+
super(SEModule, self).__init__()
149+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
150+
self.fc = nn.Sequential(
151+
nn.Linear(in_channels, in_channels // reduction, bias=False),
334152
nn.ReLU(inplace=True),
335-
nn.Conv2d(last_inp_channels, num_classes,
336-
kernel_size=1, stride=1, padding=0, bias=True)
153+
nn.Linear(in_channels // reduction, in_channels, bias=False),
154+
nn.Sigmoid()
337155
)
338156

339157
def forward(self, x):
340-
341-
out_aux_seg = []
342-
343-
# ocr
344-
out_aux = self.aux_head(x)
345-
# compute contrast feature
346-
feats = self.conv3x3_ocr(x)
347-
348-
context = self.ocr_gather_head(feats, out_aux)
349-
feats = self.ocr_distri_head(feats, context)
350-
351-
out = self.cls_head(feats)
352-
353-
out_aux_seg.append(out_aux)
354-
out_aux_seg.append(out)
355-
356-
return out_aux_seg
158+
b, c, _, _ = x.size()
159+
y = self.avg_pool(x).view(b, c)
160+
y = self.fc(y).view(b, c, 1, 1)
161+
return x * y.expand_as(x)
357162

358163

359164
class ArgMax(nn.Module):
@@ -412,6 +217,8 @@ def __init__(self, name, **params):
412217
self.attention = CBAMSpatial(**params)
413218
elif name == 'cbam':
414219
self.attention = CBAM(**params)
220+
elif name == 'se':
221+
self.attention = SEModule(**params)
415222
else:
416223
raise ValueError("Attention {} is not implemented".format(name))
417224

0 commit comments

Comments
 (0)