MLP-Mixer:An all-MLP Architecture for Vision

MLP-Mixer:An all-MLP Architecture for Vision

本文是谷歌研究院大脑研究团队提出的新网络架构。MLP-Mixer抛弃了视觉领域常用的卷积网络,仅仅用MLP(Multilayer perceptron,多层感知机)来构建整个网络。团队指出,这种新的架构无需卷积、注意力机制,MLP-Mixer仅需MLP即可达到与CNN、Transformer相媲美的性能。

摘要

​ 在计算机视觉领域,卷积神经网络已经成为几乎是必要的模型;近来基于注意力的网络,比如Vision Transformer也变得很流行。在本文中,作者指出,卷积和注意力机制虽然都能提供良好的性能,但是并不是达成良好性能的必要路径。作者提出了只使用多层感知机的网络:MLP-Mixer。这个网络有两种类型的网络层:一种独立作用于每个图像块,也就是在图像块内整合每个像素的信息;另一种作用于不同图像块之间,亦即整合不同图像块之间的空间信息。利用大规模数据集或是先进的正则化手段,MLP-Mixer能够在图像分类基准线上取得有竞争力的结构。虽然没有取得最先进的结果,但是作者希望这个发现能够激发对CNN和Transformers之外的研究。

介绍

​ 在计算机视觉发展史上,更大规模的数据集和更强大的算力经常会导致研究范式的转变,卷积神经网络和Vision Transformer(ViT)就是很好的例子。ViT延续了移除人工设计的特征和归纳偏置的趋势,能够更加自主地从原始数据学习。

​ 作者提出了MLP-Mixer,这种网络结构不包含卷积或是自注意力机制。MLP-Mixer完全基于两种多层感知机:一种独立作用于每个图像块;另一种作用于不同图像块之间。在更细节的部分,Mixer只依赖于基础的矩阵乘法,数据排布变换和非线性映射。

image-20210506194516827

图1.MLP-Mixer总览

​ 如图1所示,MLP-Mixer接受的输入是被线性分割的图片块(image patches,也称为tokens),形式是“patches×channels”(前者是图像块数,后者是通道数)。Mixer使用两种网络层:(1)channel-mixing MLPs,这种网络层单独处理每个token,即采用每一行作为输入;

(2)token-mixing MLPs,允许不同空间位置(不同token)之间的信息交换,以每一列作为输入。

​ 这种架构也可以视为一种特殊的卷积神经网络,其使用1×1卷积进行channal mixing,用单通道全感受野的depth-wise卷积和参数共享进行token mixing。

Mixer 架构

​ 目前深度视觉框架包含以两种方式混合特征的网络层:(i)在给定的空域位置;(ii)不同的空域位置。在卷积神经网络中,(ii)是用N×N的卷积和池化实现的,更深的层具有更大的感受野。同时,1×1卷积也扮演(i)的角色。在Vision Transformers和其他的自注意力架构中,自注意力层能够同时进行(i)和(ii)。Mixer背后的想法是,将(i)channel-mixing和(ii)token-mixing两者分开,不过两者都是用MLPs进行实现的。

​ 图1给出了MLP-Mixer的示意图,Mixer以S个不重叠的图像块作为输入,每一个图像块都被投影到期望的隐层维度C。这会得到一张二维实值输入表$X \in \mathbb{R}^{S\times C}$。如果原始图像的大小是(H,W),每个图像块大小是(P,P),则总的块数是$S = HW/P^2$.每一个图像块都使用同样的投影矩阵进行线性投影。Mixer包含多个等尺寸的层,每一层都由两种MLP模块组成:(1)token-mixing模块在X的每一列上进行操作;(2)channal-mixing模块:在X的每一行上进行操作。每个MLP块包括两个全连接层和非线性层,模块可被描述如下:

image-20210506204214822

其中$\sigma$是逐元非线性激活函数(GELU)。$D_s和D_C$分别表示token-mixing和channal-mixing的隐层宽度,这两者都是可调的超参数。注意到$D_S$的选择独立于图像块的数量,因此网络的计算复杂度与输入块的数量成线性关系,而不是ViT的二次关系。

​ 正如上述所提到的,相同的channel(token)-mixing MLP作用于X的每一行(列)。对MLP的参数进行绑定就是一种很自然的选择,它可以提供类似卷积特征的位置不变形。然而,在CNN中进行跨通道参数绑定并不常见。比如,CNN的分离卷积对每个通道采用不同的卷积核,这与本文MLP中的处理机制(所有通道采用相同的核)不相同。这种参数绑定可以避免架构随隐层维度C、序列长度S提升而增长过快,进而导致了显著的显存节省。令人惊讶的是:这种参数绑定机制并不会影响性能。

​ Mixer中的每一层(除了初始块投影层)采用相同尺寸的输入,这种“各向同性”设计类似于Transformer和RNN中定宽;这与CNN中金字塔结构(越深的层具有更低的分辨率、更多的通道数)不同。

​ 除了MLP外,Mixer还采用其他标准架构成分:跳接、LayerNorm。此外,不同于ViT,Mixer并没有采用position embedding,这是因为token-mixingMLP对于输入的顺序极为敏感。最后,Mixer采用了标注分类头,即全局均值池化+线性分类器。文末附录中贴出了利用JAX/Flax的实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import einops
2 import flax . linen as nn
3 import jax . numpy as jnp
4
5 class MlpBlock ( nn . Module ) :
6 mlp_dim : int
7 @nn . compact
8 def __call__ ( self , x ) :
9 y = nn . Dense ( self . mlp_dim ) (x )
10 y = nn . gelu ( y )
11 return nn . Dense ( x . shape [ -1]) ( y )
12
13 class MixerBlock ( nn . Module ) :
14 tokens_mlp_dim : int
15 channels_mlp_dim : int
16 @nn . compact
17 def __call__ ( self , x ) :
18 y = nn . LayerNorm () ( x )
19 y = jnp . swapaxes (y , 1 , 2)
20 y = MlpBlock ( self . tokens_mlp_dim , name =’ token_mixing ’) ( y )
21 y = jnp . swapaxes (y , 1 , 2)
22 x = x +y
23 y = nn . LayerNorm () ( x )
24 return x + MlpBlock ( self . channels_mlp_dim , name =’ channel_mixing ’) ( y )
25
26 class MlpMixer ( nn . Module ) :
27 num_classes : int
28 num_blocks : int
29 patch_size : int
30 hidden_dim : int
31 tokens_mlp_dim : int
32 channels_mlp_dim : int
33 @nn . compact
34 def __call__ ( self , x ) :
35 s = self . patch_size
36 x = nn . Conv ( self . hidden_dim , (s , s ) , strides =( s , s ) , name =’stem ’) ( x )
37 x = einops . rearrange (x , ’n h w c -> n (h w) c’)
38 for _ in range ( self . num_blocks ) :
39 x = MixerBlock ( self . tokens_mlp_dim , self . channels_mlp_dim ) ( x )
40 x = nn . LayerNorm ( name =’ pre_head_layer_norm ’) ( x )
41 x = jnp . mean (x , axis =1)
42 return nn . Dense ( self . num_classes , name =’head ’,
43 kernel_init = nn . initializers . zeros ) ( x)

实验结果

基于中等与大尺度数据的预训练,我们在不同下游分类任务上对所提MLP-Mixer的性能进行了评估。我们主要对以下三个问题比较感兴趣并进行重点分析。

  • 在下游任务上的精度
  • 总计预训练计算量,这对于在上游数据上从头开始训练模型非常重要;
  • 推理耗时,这对于实际应用非常重要。

image-20210506205556204

表1.在不同的大规模数据集上训练的结果

​ 在ImageNet-21k+额外正则技术预训练后,Mixer在ImageNet数据集取得非常强的性能:84.15%top1,比其他模型稍弱。此外,在推理速度和参数量总数上,Mixer相比其他模型(如ViT)有着较为明显的优势。

​ 以上。