NASNet

今天我们介绍的是NASNet。Neural Architecture Search(NAS)(神经网络架构搜索)指的是通过强化学习方法,寻找最优的网络结构。由于每一个架构组合都需要经过训练才能获知其有效性,获取最优架构的计算量代价是非常大的。在CIFAR-10上学习一个网络就需要500台GPU运行28天才能找到合适的结构。

介绍

​ NASNet的主要贡献是将NAS应用到大数据集上。这个策略是先在小数据集(cifar-10)上学一个网络单元,然后在大数据集(ImageNet)上堆叠这种单元的形式来完成模型迁移。

image-20210226000526611

图1.NAS算法流程。使用一个由循环神经网络RNN构成的控制器,以概率p随机采样一个网络结构A,之后在数据集上对架构A进行训练,获得其在验证集上的正确率。然后使用准确率R更新控制器的参数,循环执行。

​ 我们不妨稍微展开讲讲NAS。NAS最重要的部分在于图1所示的控制器,其控制了网络架构的搜索过程。控制器的目标在于更快找到更优的网络架构。控制器本身也是一个神经网络,使用强化学习的方式进行训练,其选择的架构A的验证集准确率R作为这个网络的奖励。image-20210226001709317

图2.控制器采样流程。控制器每一步采样生成一个网络结构。每次选择两个已有的隐层作为输入,这样允许产生像ResNet一样的跳接结构。右侧是这种操作产生的一个例子。作者使用的模块数量B=5.

image-20210226091754540

图3.隐层操作的搜索空间。也就是说,对隐层向量的操作从以上几种中选取。这种方式能够减少搜索空间。例如卷积操作有4个超参数(卷积核数量、长宽、步长),而这四个超参数的搜索空间又很大,所以如果网络需要枚举不同的超参数来生成结构,计算量开销会非常的大。

​ 此外,还采用了Scheduled Drop Path技术。Drop Path也就是以概率P随机丢弃掉部分分支以避免过拟合,Scheduled意指这个概率P会随着训练时间的增加线性增加,因为训练时间越长,模型越容易过拟合。

image-20210226093322196
image-20210226093425682

图4.上部显示的是搜索得到的两个基本单元;下部表明,这两种单元如何堆叠构成整体的网络。

代码

来自https://github.com/weiaicunzai/pytorch-cifar100

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import torch
import torch.nn as nn

class SeperableConv2d(nn.Module):

def __init__(self, input_channels, output_channels, kernel_size, **kwargs):

super().__init__()
self.depthwise = nn.Conv2d(
input_channels,
input_channels,
kernel_size,
groups=input_channels,
**kwargs
)

self.pointwise = nn.Conv2d(
input_channels,
output_channels,
1
)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)

return x

class SeperableBranch(nn.Module):

def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
"""Adds 2 blocks of [relu-separable conv-batchnorm]."""
super().__init__()
self.block1 = nn.Sequential(
nn.ReLU(),
SeperableConv2d(input_channels, output_channels, kernel_size, **kwargs),
nn.BatchNorm2d(output_channels)
)

self.block2 = nn.Sequential(
nn.ReLU(),
SeperableConv2d(output_channels, output_channels, kernel_size, stride=1, padding=int(kernel_size / 2)),
nn.BatchNorm2d(output_channels)
)

def forward(self, x):
x = self.block1(x)
x = self.block2(x)

return x

class Fit(nn.Module):
"""Make the cell outputs compatible
Args:
prev_filters: filter number of tensor prev, needs to be modified
filters: filter number of normal cell branch output filters
"""

def __init__(self, prev_filters, filters):
super().__init__()
self.relu = nn.ReLU()

self.p1 = nn.Sequential(
nn.AvgPool2d(1, stride=2),
nn.Conv2d(prev_filters, int(filters / 2), 1)
)

#make sure there is no information loss
self.p2 = nn.Sequential(
nn.ConstantPad2d((0, 1, 0, 1), 0),
nn.ConstantPad2d((-1, 0, -1, 0), 0), #cropping
nn.AvgPool2d(1, stride=2),
nn.Conv2d(prev_filters, int(filters / 2), 1)
)

self.bn = nn.BatchNorm2d(filters)

self.dim_reduce = nn.Sequential(
nn.ReLU(),
nn.Conv2d(prev_filters, filters, 1),
nn.BatchNorm2d(filters)
)

self.filters = filters

def forward(self, inputs):
x, prev = inputs
if prev is None:
return x

#image size does not match
elif x.size(2) != prev.size(2):
prev = self.relu(prev)
p1 = self.p1(prev)
p2 = self.p2(prev)
prev = torch.cat([p1, p2], 1)
prev = self.bn(prev)

elif prev.size(1) != self.filters:
prev = self.dim_reduce(prev)

return prev


class NormalCell(nn.Module):

def __init__(self, x_in, prev_in, output_channels):
super().__init__()

self.dem_reduce = nn.Sequential(
nn.ReLU(),
nn.Conv2d(x_in, output_channels, 1, bias=False),
nn.BatchNorm2d(output_channels)
)

self.block1_left = SeperableBranch(
output_channels,
output_channels,
kernel_size=3,
padding=1,
bias=False
)
self.block1_right = nn.Sequential()

self.block2_left = SeperableBranch(
output_channels,
output_channels,
kernel_size=3,
padding=1,
bias=False
)
self.block2_right = SeperableBranch(
output_channels,
output_channels,
kernel_size=5,
padding=2,
bias=False
)

self.block3_left = nn.AvgPool2d(3, stride=1, padding=1)
self.block3_right = nn.Sequential()

self.block4_left = nn.AvgPool2d(3, stride=1, padding=1)
self.block4_right = nn.AvgPool2d(3, stride=1, padding=1)

self.block5_left = SeperableBranch(
output_channels,
output_channels,
kernel_size=5,
padding=2,
bias=False
)
self.block5_right = SeperableBranch(
output_channels,
output_channels,
kernel_size=3,
padding=1,
bias=False
)

self.fit = Fit(prev_in, output_channels)

def forward(self, x):
x, prev = x

#return transformed x as new x, and original x as prev
#only prev tensor needs to be modified
prev = self.fit((x, prev))

h = self.dem_reduce(x)

x1 = self.block1_left(h) + self.block1_right(h)
x2 = self.block2_left(prev) + self.block2_right(h)
x3 = self.block3_left(h) + self.block3_right(h)
x4 = self.block4_left(prev) + self.block4_right(prev)
x5 = self.block5_left(prev) + self.block5_right(prev)

return torch.cat([prev, x1, x2, x3, x4, x5], 1), x

class ReductionCell(nn.Module):

def __init__(self, x_in, prev_in, output_channels):
super().__init__()

self.dim_reduce = nn.Sequential(
nn.ReLU(),
nn.Conv2d(x_in, output_channels, 1),
nn.BatchNorm2d(output_channels)
)

#block1
self.layer1block1_left = SeperableBranch(output_channels, output_channels, 7, stride=2, padding=3)
self.layer1block1_right = SeperableBranch(output_channels, output_channels, 5, stride=2, padding=2)

#block2
self.layer1block2_left = nn.MaxPool2d(3, stride=2, padding=1)
self.layer1block2_right = SeperableBranch(output_channels, output_channels, 7, stride=2, padding=3)

#block3
self.layer1block3_left = nn.AvgPool2d(3, 2, 1)
self.layer1block3_right = SeperableBranch(output_channels, output_channels, 5, stride=2, padding=2)

#block5
self.layer2block1_left = nn.MaxPool2d(3, 2, 1)
self.layer2block1_right = SeperableBranch(output_channels, output_channels, 3, stride=1, padding=1)

#block4
self.layer2block2_left = nn.AvgPool2d(3, 1, 1)
self.layer2block2_right = nn.Sequential()

self.fit = Fit(prev_in, output_channels)

def forward(self, x):
x, prev = x
prev = self.fit((x, prev))

h = self.dim_reduce(x)

layer1block1 = self.layer1block1_left(prev) + self.layer1block1_right(h)
layer1block2 = self.layer1block2_left(h) + self.layer1block2_right(prev)
layer1block3 = self.layer1block3_left(h) + self.layer1block3_right(prev)
layer2block1 = self.layer2block1_left(h) + self.layer2block1_right(layer1block1)
layer2block2 = self.layer2block2_left(layer1block1) + self.layer2block2_right(layer1block2)

return torch.cat([
layer1block2, #https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739
layer1block3,
layer2block1,
layer2block2
], 1), x


class NasNetA(nn.Module):

def __init__(self, repeat_cell_num, reduction_num, filters, stemfilter, class_num=100):
super().__init__()

self.stem = nn.Sequential(
nn.Conv2d(3, stemfilter, 3, padding=1, bias=False),
nn.BatchNorm2d(stemfilter)
)

self.prev_filters = stemfilter
self.x_filters = stemfilter
self.filters = filters

self.cell_layers = self._make_layers(repeat_cell_num, reduction_num)

self.relu = nn.ReLU()
self.avg = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(self.filters * 6, class_num)


def _make_normal(self, block, repeat, output):
"""make normal cell
Args:
block: cell type
repeat: number of repeated normal cell
output: output filters for each branch in normal cell
Returns:
stacked normal cells
"""

layers = []
for r in range(repeat):
layers.append(block(self.x_filters, self.prev_filters, output))
self.prev_filters = self.x_filters
self.x_filters = output * 6 #concatenate 6 branches

return layers

def _make_reduction(self, block, output):
"""make normal cell
Args:
block: cell type
output: output filters for each branch in reduction cell
Returns:
reduction cell
"""

reduction = block(self.x_filters, self.prev_filters, output)
self.prev_filters = self.x_filters
self.x_filters = output * 4 #stack for 4 branches

return reduction

def _make_layers(self, repeat_cell_num, reduction_num):

layers = []
for i in range(reduction_num):

layers.extend(self._make_normal(NormalCell, repeat_cell_num, self.filters))
self.filters *= 2
layers.append(self._make_reduction(ReductionCell, self.filters))

layers.extend(self._make_normal(NormalCell, repeat_cell_num, self.filters))

return nn.Sequential(*layers)


def forward(self, x):

x = self.stem(x)
prev = None
x, prev = self.cell_layers((x, prev))
x = self.relu(x)
x = self.avg(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return x


def nasnet():

#stem filters must be 44, it's a pytorch workaround, cant change to other number
return NasNetA(4, 2, 44, 44)