现在我来做学习ViT (Vision Transformer)的学习笔记~
原论文:2010.11929 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
我看的解析教学:https://www.bilibili.com/video/BV15P4y137jb
我看的代码实现:https://www.bilibili.com/video/BV1K94y177Ka
1. 概况
提到CV,人们往往想起CNN。AlexNet的提出,引导人们大量在CV领域使用CNN。然而,ViT的结论是:transformer可以在CV任务上做的很好。ViT打破了CV与NLP的专业壁垒,深刻影响了多模态领域。什么是多模态?“模态”是指信息的不同形式或感官通道,例如图像(视觉模态)、图像(视觉模态)、音频(听觉模态)等等,“多模态”就是将多种模态的信息融合在一起进行理解、生成或推理。ViT 把图像表示做成了类似语言的 token 序列,这为“多模态模型”的融合提供了处理方式。
2. 论文摘要
从论文的题目 An Image is Worth 16x16 Words 就能看出来,将图片看成很多很多patch(方格),每一个patch的固定大小是 16 × 16。
前人只将自注意力有限地应用在CNN上,而论文的作者证明transformer也可以做的很好,同时需要更少的训练资源。(当然,是相对的:2500天 TPUv3 )
为什么在作者之前的人不把transformer用到CV里?因为图像的像素点转成的序列过长,为图像边长的平方。很多人努力去把序列变短,例如,先卷积对图像进行下采样,得到较小的特征图,元素数量就大大减少,然后再使用 Transformer;还有人利用“局部自注意力”(Local Attention ),比如 Window Attention,一个很好的实现是Swin Transformer,人们想到,复杂度高的本质是整张图太大,那就不用整张图了,可以把整个图像分成很多小窗口,通过控制小window窗口大小来控制序列长度,只在局部窗口中做自注意力;此外有“轴自注意力”,人们想到,能不能将图片这种2D的矩阵作为两个1D的向量处理 ?N = H × W,在H上做一次,在W上做一次,降低了时间复杂度。
我个人对第一种尝试(卷积 + Attention)的理解是:将卷积的结果作用自注意力;
我个人对第二种尝试(Local Attention)的理解是:将自注意力视为filter,得到卷积的结果;
我个人对第三种尝试(轴 Attention)的理解是:将二维图像的注意力计算拆解为两个一维方向的注意力操作,分别在“行方向”和“列方向”上进行。 具体流程是,先固定列方向,即每一列作为一个序列,进行 列维度上的自注意力 (每一列内部的所有像素之间交互)。类比 NLP,就是每一列像一个句子。再固定行方向 进行 行维度上的自注意力 。经过行和列注意力更新后的图像就可以传到后续的层了,比如FC和softmax等分类头。
ViT论文作者认为,这些自注意力操作都进行了特殊的变形,没有办法在现在的硬件上加速,很难训练。现在作者要介绍Vision Transformer了。Vision Transformer把图片打成很多个16 × 16的patch。将一个一个patch展平flatten后(从 16×16×3 → 768),再通过一个共享的FC层进行映射,得到固定维度的 embedding 向量,传给transformer。可以把patch当成是NLP里面的单词tokens。注意要用有监督的方式训练 。如果不进行大规模监督预训练,Transformer 在中小数据集上训练效果会很差。这个我在之前的笔记里也总结过。
这么“简单”的思路,之前没有人做? 2020年的一篇paper用了2×2的patch,只在CIFAR-10上做了测试,没有在大规模的数据集上进行预训练,因此没有得到一个更标准的transformer,也就没有得到比CNN更好的结果。ViT论文作者在大实验室的加持下,证明了在大规模数据的预训练下,ViT完全可以取代CNN的地位。
这再一次强调了transformer需要大规模的数据。在CV领域,作者认为transformer缺少CNN有的归纳偏置 。归纳偏置指的是先验知识或提前做好的假设。CNN常常有两个归纳偏置,locality假设图片上相邻的区域有相邻的特征(因为filter在图片上滑动),translation equivariance(平移同变性)认为平移与卷积的顺序不影响结果,即
$$
f(g(x)) = g(f(x))
$$
直观的解释是,不管图片里的同样一个物体平移到哪里,只要是同样的卷积核,输出永远是一样的。有了这些先验信息,CNN只需要相对少的数据。
此外,我对归纳偏置 进行了进一步的思考。事实上,是CNN的结构决定了这些性质。人们先想出这些性质,再去附加设计到CNN的结构上,不需要为天然的CNN刻意添加这些性质。一个卷积核在图像上只看同一个区域,并且在图形上滑动,那么这个模型就天然拥有局部性(Locality)和平移同变性(Translation equivariance) ,因为它只看小区域,并且滑动时输出与位置无关。
而 Transformer 结构是注意力,默认每个位置和其他所有位置都可能有关 ,在预训练之前,结构里没有任何关于“局部性”或“图像位置”的线索,必须靠大量数据去自己学出这些规律,同时可能会学到更多人们没有设计给CNN的规律。
作者在更大的数据集上做了测试,证明了transformer在大数据预训练下效果拔群。
3. 先看看论文的结论
论文的结论是,他们在CV任务上对一个标标准准的transformer进行了非常直接 的应用(即:除了运用patch切割,以及运用位置编码,可以看作用了一些轻微的图像归纳偏置以外,没有其他改动)。
这个论文只用ViT做了分类任务,结果已经战胜了CNN。他们还留下了分割、检测等主流CV任务留给后续研究。还有,自监督的训练方式优化也需要进一步研究。
4. 论文的主体:设计与方法
这里前面的内容与引言相同。现在我来学习新的内容。先来看看模型的总览图:
我们从图的左下角开始。首先,给定一张图,把它达成九宫格小patch,然后组成一个序列。展平的patch经过线性投射层(就是一个FC层,用大写E表示)得到一个特征。这就是patch embedding。之后,也要像原来的transformer一样加上位置编码进行position embedding。现在整体的token有原本的图像信息和位置信息了。现在传给transformer encoder,但是用哪一个作为输出呢?借鉴BERT,加入特殊字符(图中为星号),位置信息为零。作者相信,此特殊字符可以在与序列中其他input的交互中学到有用的信息,并用来做分类任务。再接上MLP分类头即可完成分类任务。
作者在右边放上了标准的transformer encoder结构。可以注意到,这与Attention Is All You Need 中的encoder相比,normalization被放到了多头注意力和多层感知机的前面 。我在之前的笔记里也提到过,对于这个encoder的设计有很多种。
MLP head起到了原来decoder的作用,但是不是复杂的生成任务,而是分类任务,所以较为简单。
flatten的过程我询问了ChatGPT,它向我详细讲解:
🧩 先理解数据的组织方式(H × W × C)
假设这个 patch 是 RGB 图像,所以它是:
16 行(高度) × 16 列(宽度) × 3 个通道(R、G、B)
你可以把它理解为:
- 每个像素有3个数(R,G,B)。
- 一共有 16×16 = 256 个像素。
- 所以,每个 patch 其实包含了 256 个 RGB 像素点 × 每个点3个通道 = 768 个数值 。
🔁 展平的方式(flattening)
最常见的展平方法是:
- 先按行 遍历图片,从上到下;
- 每行从左到右 ;
- 对每个像素,把它的 R、G、B 三个值依次放入一个数组 中。
举个小例子(3×1×3 的 patch):
假设我们有一个 tiny patch:
1
2
3 [[[R11, G11, B11]],
[[R21, G21, B21]],
[[R31, G31, B31]]]这个 patch 的大小是 3×1×3,相当于竖着的三格,每格一个 RGB 值。
展平后我们就得到一个:
1 [R11, G11, B11, R21, G21, B21, R31, G31, B31]这就是展平成一维向量的过程。
🔁 对于 16×16×3 的 patch:
你可以类比上面的例子,只不过像素更多,它的展平就是把:
1
2
3 [ [ [R,G,B], [R,G,B], ..., [R,G,B] ],
...
[ [R,G,B], [R,G,B], ..., [R,G,B] ] ]总共 256 个
[R,G,B]
的小三元组,一个接一个拼接起来 ,变成一个[R,G,B,R,G,B,...]
长度为 768 的向量。
所以,E 的维度是768 × 768,前面的768是16 × 16 × 3,是一个patch的维度;后一个768称为D ,可以把D 改得更大。
patch组成的序列称为X ,X的大小就是196(patch的个数)× 768(patch的维度),那么X乘E 就是196 × 768的矩阵与768 × 768的矩阵相乘,最后得到的矩阵式196 × 768,这就是FC层E的输出。将这个输出理解为196个token,每个token向量的维度是768。 再加上前面的cls token特殊字符(cls token是BERT的Class token的简写 ,我第一反应是命令行里的清屏),它的维度也是768,那么最终进入transformer encoder的序列长度为196 + 1 = 197,维度为768。
这里有一个重点。 先复习一下,位置编码position embedding是怎么做的?是一个有多个向量的矩阵,每个向量代表某个位置的信息,维度为768,将对应某个位置的向量与原始patch经过FC层的输出相加 (要注意是相加)即可。
不过,与我之前的NLP的Transformer笔记矛盾的是 ,老师在视频里提到,这个矩阵是可以被学习 的。NLP的原始Transformer确实用的是正弦/余弦函数生成的位置编码,不可学习 ,但具有良好的泛化能力。在 ViT 里,实验却证明:可学习的位置编码效果更好 ,而且实现也简单,所以 ViT 默认就用了它。
但是,这是为什么?我询问ChatGPT,得到了这样的解释:
✅ NLP 中为什么用不可学习的位置编码?
原始 Transformer 的设计者考虑的是:
自然语言句子的长度变化非常大(可能是5词,也可能是500词),所以用一种**数学函数(sin/cos)**来生成“可泛化”的位置编码。
正弦函数有这些特点:
- 可以推广到任意长度;
- 不需要额外的参数;
- 每个位置编码都是独一无二的;
- 不会因训练数据太短而“记不住”长距离位置信息。
这对 NLP 是有意义的。
✅ ViT 中为什么用可学习的位置编码?
图像就不一样了!
- 一张图片切成 patch 后,patch 的数量是固定的(比如 14×14 = 196)。
- 所以模型不用泛化到更长的序列,更关注每个 patch 之间的相对空间结构。
这时,用一个 可以训练的 position embedding,让模型自己学“第17个 patch 在图中是哪里”,更符合图像处理的需求。
📌 一个小类比:
NLP 的 sin/cos 编码就像你拿尺子给每个单词标记“第几个词”。
ViT 的 learnable position 就像你让模型自己“看图说话”,记住哪个 patch 对应的是之前拆分的“九宫格”的哪一个格子。
5. 细节注意
作者进行了一些消融实验。什么是消融实验?ChatGPT告诉我:
消融实验 是一种常见的深度学习实验方法,用来分析某个组件是否真的重要。
具体做法是:
- 把模型里的某个模块拿掉(或替换/简化)
- 再看模型性能变化多少
所以你可以把它理解为:
“我把某个零件拆掉,看模型还能不能跑得好。”
①cls token
之前的CV并不是用cls token的。以res50为例,最后一个stage出来一个14 × 14的feature map,对它进行GAP(GAP就是 global average pulling,是把图像的每个通道变为一个数值的平均池化操作。可以看作卷积时候,以图片大小为kernal size。因此,最后的输出会是一个长度为通道数量的一维向量) 。
对于transformer的encoder,输入多少个向量就输出多少个向量,那为什么不对输出结果进行GAP、而是用cls token呢?ViT作者的实验表明两种都可以,但作者希望与NLP的transformer尽可能保持一致。数据如下:
注意到不同方法的learning rate设置的不同。视频里老师讲到这里时,强调炼丹技术必须要过硬 。
cls里的参数怎么来的?这里有一个总结:
属性 | 是否 |
---|---|
是人为设计固定的? | ❌ 不是 |
是模型参数吗? | ✅ 是 |
是随机初始化的吗? | ✅ 是 |
会参与训练吗? | ✅ 会 |
能自动学会总结图像? | ✅ 能 |
② Positional Embedding
论文中一直用的是1D positional embedding。其实,作者尝试了好几种不同的 positional embedding 方案 来做实验,看看哪种对模型效果最好。
1. No positional embedding(不加位置编码)
不加任何位置信息,模型就只看到一堆 patch 向量 —— 把它当成“顺序无关的一堆 patch”。
结果最差。
2. 1D positional embedding(默认用法,这篇论文从头到尾在用的方法)
把 patch 看成是一个序列,用一维的位置编码表示每个 patch 的“顺序”。
这种就是 ViT 原始论文中默认采用的方式,比如:
- 把 14×14 的 patch 排成 196 长度的序列
- 每个 patch 加一个 learnable 的向量,表示“你是第几个 patch”
3. 2D positional embedding(二维位置编码)
分别为横轴和纵轴学习一组 embedding,然后拼接 concat成最终的位置向量。
比如:
- patch(3, 5) 表示第3行,第5列的 patch
- 第3行用 Y-embedding 查一个向量
- 第5列用 X-embedding 查一个向量
- 然后拼起来:
[X_embed_5 ; Y_embed_3] → [384 + 384 = 768]
这种方式直接把图像看成二维网格。
4. Relative positional embedding(相对位置编码)
不是说“你在第几行第几列”,而是“你和我之间的距离是多少”。
原理是:
- 对每一对 patch(Query 和 Key)计算它们的“相对距离”
- 比如 patch A 和 patch B 相隔 3 个位置,就查一个 offset embedding,比如:
Embedding(offset = +3)
- 然后把这个 offset embedding 用作 注意力打分的额外 bias
这种方法来源于 NLP 中的“相对位置编码”思想。
结果是:
较为奇怪的是1-D与2-D相差不大。作者的解释是,ViT不是在像素块上做的,而是在patch上做的。排列组合这些小方块的位置信息1D与2D是没有太大的区别。
③公式描述此模型
④归纳偏置与hybrid architecture
ViT比CNN少很多归纳偏置。先验知识贯穿CNN始终。既然ViT全局能力强,而CNN更加data efficient,能否将它们杂交?我们现在不把图片打成patch了,让CNN去处理图片,以Res50为例,最后得到的14 × 14的特征图也是196个元素,也可以去与E 操作。
⑤微调(fine-tuning)
对于一个预训练好的transformer encoder,是不适合为更大尺寸的图片分类的。如果patch size不变,那就会增长序列长度。理论上(显存足够)的transformer 可以处理任意长度序列,但是预训练的位置编码可能就没用了,位置信息变得不正确了。ViT的作者发现,进行2D的插值可以解决这个问题。
我询问ChatGPT如何进行2D插值,它回答:
2D 插值就是:
把原来 $14 \times 14$ 的 patch 位置编码,看成是一张 14x14 的特征图,然后把它 “拉伸”成 24x24 的新特征图 。然后你就可以从这个“放大图”中取出新的 576 个 patch 的位置编码啦!
🔍 实现上会怎么做?
- 你先把预训练好的位置编码 $14 \times 14 \times D$ reshape 成一个“图片”。
- 用图像插值方法(比如
torch.nn.functional.interpolate
)拉伸成 $24 \times 24 \times D$。- reshape 回 $576 \times D$,再加上 cls token 的编码(这个一般不会插值,用原来的)。
💡 总结一句话:
2D 插值 就是把你原来 $14\times14$ 的位置编码当成一张“图片”,拉伸到 $24\times24$,让 ViT 能处理更大图时还保留“位置信息”。
我追问 torch.nn.functional.interpolate
的具体做法,得知,例如F.interpolate(..., mode='bilinear')
就是把每个新位置的值当成周围 4 个像素值的加权平均,用几何距离当权重。
注意这样的操作只是临时的解决方案,如果序列变化过大,插值操作无法得到较好的结果。
插值来适应序列尺寸改变、patch的抽取这两个方法,是(位置编码时利用1D positional embedding的)Vision Transformer唯一的利用2D信息归纳偏置的地方。
⑥自监督
ViT作者做了个小小的实验来训练自监督的ViT。一年后,MAE证明自监督的训练ViT效果很好。老师在视频的最后讲解了这个内容。
作者借鉴BERT,使用masked patch prediction ,意思是有一张图片已经分好patch了,随机将这些patch抹掉,通过模型尝试把patch重建出来。这是对CV与NLP大一统的尝试。
后来,有一些利用contrastive learning(对比学习)来训练ViT的论文。我之后再来学习它们。
⑦模型细节
patch embedding所用的E 与CNN浅层的filter很像,都类似gabor filter。
position embedding学到了很有用的位置信息。
上面这张图是position embedding之间的“相似度矩阵”,具体用的是 余弦相似度(cosine similarity) :
$$
\text{cosine_similarity}(\mathbf{p}_i, \mathbf{p}_j) = \frac{\mathbf{p}_i \cdot \mathbf{p}_j}{|\mathbf{p}_i||\mathbf{p}_j|}
$$
$$
其中 \mathbf{p}_i 是第 i 个 patch 的位置嵌入向量。
$$
图中一共有 7 × 7 = 49 个小图 ,每个小图本身是一个 heatmap。每个 heatmap 表示此位置(row, column)的 patch 的 position embedding 向量 和 所有位置的 position embedding 向量 的 余弦相似度 。
这说明在1D的position embedding下,模型依然学到了二维的关系。这解释了为什么之前1D position embedding与2D position embedding区别不大。
⑧attention机制
其中,五颜六色的点是16个注意力头。纵坐标是平均注意力距离,是这么定义的:
$$
\text{MeanAttentionDistance} = \frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{N} a_{ij} \cdot \text{dist}(p_i, p_j)
$$
$$
其中: N 是 token 的总数(即图像被划分为的 patch 数量);
$$
$$
a_{ij} 是 query token i 对于 key token j 的注意力权重;
$$
$$
p_i = (x_i, y_i) 是 token i 在图像中的二维坐标;
$$
$$
\text{dist}(p_i, p_j) 表示 token i 与 token j 之间的像素距离,常用欧几里得距离定义为:
$$
$$
\text{dist}(p_i, p_j) = \sqrt{(x_i - x_j)^2 + (y_i - y_j)^2}
$$
这张图表明,attention机制在网络的浅层就能注意到全局的信息了,而不是像CNN那样,浅层网络只能看到附近的像素。网络的后半部分平均注意力距离已经非常远了,说明它学到的是语义性的特征。
此外,作者又做了一个很好的可视化:
这是用网络最后一层 输出的output token(对应CLS的输出)对输入图像中各位置的注意力。
6. 有趣的讨论
有学者将ViT中的self attention换成MLP,还是可以工作的很好。
有学者认为transformer真正能work的原因,是它的架构,而不是特殊的算子。因此他把self attention换成了一个甚至不能学习的池化操作的模型,这个模型也能在CV领域取得一定的效果。
这样的结构架构被称为metaformer。
ChatGPT这样和我说:
我来先从PoolFormer 告诉你——这正是 PoolFormer 惊悚之处,它把一个你以为“没用”的非学习操作,直接塞进 Transformer 架构,结果还真的 “能打” !
不是 attention 本身 让 Transformer 强,而是这种 “先混合局部/全局信息 + 再MLP处理 + 加残差 + 正则化” 的结构太强大。
也就是说:
- Self-Attention 是一种 Token Mixing;
- 但 平均池化、卷积、MLP 其实都能“混信息”;
- 只要结构(MetaFormer 框架)不变,就算用一个“傻瓜 Token Mixer”,也能学出不错的特征!
这就像是你突然发现:
原来神秘的魔法并不是魔杖的功劳,而是咒语结构太牛逼了。
💡 一句话解释:
MetaFormer 是对 Transformer 架构的一个“抽象框架总结”,
强调:Transformer 的成功关键不一定是 self-attention,而是它的整体结构!
🔍 它到底是啥?
MetaFormer 并不是一个具体的模型,而是一个“架构模式(框架) ”,把 Transformer 的基本结构抽象成三个模块:
1
2
3
4
5
6
7
8
9
10
11
12 ↓ 输入特征
┌─────────────┐
│ Token Mixer │ ← 这部分原来是 Self-Attention
└─────────────┘
↓
┌─────────────┐
│MLP(前馈网络) │
└─────────────┘
↓
┌─────────────┐
│LayerNorm/Add│
└─────────────┘✅ 只要符合这种结构的模型,都可以被称为 MetaFormer 。
🔧 为什么提出 MetaFormer?
之前大家都以为 Transformer 成功是因为它用了 self-attention 。
但一些工作(比如 PoolFormer)发现:
- 就算把 self-attention 换成简单的东西(比如卷积、池化),只要结构保持不变,性能依然很强!
于是 MetaFormer 就诞生了:
🚨 Self-attention ≠ 唯一核心;
✅ 是这种 “token mixer + FFN” 的结构带来了强大表达能力!
🧪 举例说明:
模型名 Token Mixer 用了啥 属于 MetaFormer 吗? ViT Self-Attention ✅ 是 Mixer MLP ✅ 是 ConvNeXt 卷积 ✅ 是 PoolFormer Pooling ✅ 是 ResNet 卷积 + 特殊结构 ❌ 不算严格意义的 MetaFormer
🧠 总结记忆:
- MetaFormer = 模型结构模板 ,不是具体实现;
- 关键点:Token Mixer + MLP + LayerNorm/Add ;
- Self-Attention 只是 Token Mixer 的一种,不是唯一;
- 提出它是为了说明:Transformer 的成功 结构比机制更重要 !
这些研究很值得思考~ 以上就是我对论文AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE 的学习笔记。😊