SENet

今天我们来介绍SENet。SENet以极大的优势获得了最后一届ImageNet 2017竞赛 Image Classification任务的冠军。

介绍

​ 卷积核作为卷积神经网络的核心,通常被看做是在局部感受野上,将空间上(spatial)的信息和特征维度上(channel-wise)的信息进行聚合的信息聚合体。卷积神经网络由一系列卷积层、非线性层和下采样层构成,这样它们能够从全局感受野上去捕获图像的特征来进行图像的描述。

img

图1.卷积操作

​ 有不少工作被提出以提升网络的性能,像是ResNet更改了网络结构,增加了跳接;Inception结构通过不同大小的卷积核融入了多尺度信息,聚合多种不同感受野上的特征来获得性能增益。而作者的想法是,从特征通道之间的关系来提升性能。

​ 作者提出了Squeeze-and-Excitation Networks(简称SENet)。这个结构中有两个比较重要的结构:Squeeze和Extraction。作者希望通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征。

img

图2.SE模块。给定一个输入x,其特征通道数为c_1,通过一系列卷积等一般变换后得到一个特征通道数为c_2的特征。

​ 首先是Squeeze操作,顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用的。具体实现是global average pooling

image-20210205132908988

其中uc是卷积操作之后的feature。

​ 其次是Excitation操作,它是一个类似于循环神经网络中门的机制。通过参数来为每个特征通道生成权重,其中参数 被学习用来显式地建模特征通道间的相关性。

image-20210205133155454

σ表示sigmoid激活。

​ 最后是一个Reweight的操作,将Excitation的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上。

image-20210205133220145
img

图3.Inception模型和ResNet模型加入SE模块后的结果

代码

代码来自

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""senet in pytorch



[1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu

Squeeze-and-Excitation Networks
https://arxiv.org/abs/1709.01507
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicResidualSEBlock(nn.Module):

expansion = 1

def __init__(self, in_channels, out_channels, stride, r=16):
super().__init__()

self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),

nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1),
nn.BatchNorm2d(out_channels * self.expansion),
nn.ReLU(inplace=True)
)

self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride),
nn.BatchNorm2d(out_channels * self.expansion)
)

self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r),
nn.ReLU(inplace=True),
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion),
nn.Sigmoid()
)

def forward(self, x):
shortcut = self.shortcut(x)
residual = self.residual(x)

squeeze = self.squeeze(residual)
squeeze = squeeze.view(squeeze.size(0), -1)
excitation = self.excitation(squeeze)
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1)

x = residual * excitation.expand_as(residual) + shortcut

return F.relu(x)

class BottleneckResidualSEBlock(nn.Module):

expansion = 4

def __init__(self, in_channels, out_channels, stride, r=16):
super().__init__()

self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),

nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),

nn.Conv2d(out_channels, out_channels * self.expansion, 1),
nn.BatchNorm2d(out_channels * self.expansion),
nn.ReLU(inplace=True)
)

self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r),
nn.ReLU(inplace=True),
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion),
nn.Sigmoid()
)

self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride),
nn.BatchNorm2d(out_channels * self.expansion)
)

def forward(self, x):

shortcut = self.shortcut(x)

residual = self.residual(x)
squeeze = self.squeeze(residual)
squeeze = squeeze.view(squeeze.size(0), -1)
excitation = self.excitation(squeeze)
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1)

x = residual * excitation.expand_as(residual) + shortcut

return F.relu(x)

class SEResNet(nn.Module):

def __init__(self, block, block_num, class_num=100):
super().__init__()

self.in_channels = 64

self.pre = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)

self.stage1 = self._make_stage(block, block_num[0], 64, 1)
self.stage2 = self._make_stage(block, block_num[1], 128, 2)
self.stage3 = self._make_stage(block, block_num[2], 256, 2)
self.stage4 = self._make_stage(block, block_num[3], 512, 2)

self.linear = nn.Linear(self.in_channels, class_num)

def forward(self, x):
x = self.pre(x)

x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)

x = F.adaptive_avg_pool2d(x, 1)
x = x.view(x.size(0), -1)

x = self.linear(x)

return x


def _make_stage(self, block, num, out_channels, stride):

layers = []
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion

while num - 1:
layers.append(block(self.in_channels, out_channels, 1))
num -= 1

return nn.Sequential(*layers)

def seresnet18():
return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2])

def seresnet34():
return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3])

def seresnet50():
return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3])

def seresnet101():
return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3])

def seresnet152():
return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3])

其他参考:Momenta(作者团队)的官方专栏https://zhuanlan.zhihu.com/p/32733549