0%

ViT论文复现笔记

首先,对ViT的结构做一个回顾

对于图片的读入,与CNN的卷积操作不同(不过,我们在下面的代码实现中,还是在用卷积对patch进行划分),ViT对图片进行embedding处理。

在原始论文里,Vision Transformer将图片打成很多个16 × 16的patch。将一个一个patch展平(flatten)后(从 16×16×3 → 768),再通过一个它们共享的FC层进行映射,得到固定维度的 embedding 向量,传给transformer。

可以把patch当成是NLP里面的单词tokens。

原论文中的结构如下:

1

我们从图的左下角开始。先给定一张图,把它分成“九宫格”一样的小patch。我们使用卷积核对它划分,具体是把输入变成 (B, C, H’, W’) 的形状,其中,

B: batch size
C: 通道数
H': patch 的高度个数
W': patch 的宽度个数

然后组合concatenate变成右边一个序列。之后,每个patch会被展平。

我们来具体看一下展平。在等会的代码实现中,我们使用的是

1
nn.Flatten(start_dim=2, end_dim=3)

展平的patch经过线性投射层(就是一个FC层,用大写E 表示)得到一个特征。这个过程就是patch embedding。

之后,也要像原来的transformer一样,加上位置编码来进行position embedding。现在整体的token就有原本的图像信息和位置信息了。

现在传给transformer encoder,但是用哪一个作为输出呢?借鉴BERT,加入特殊字符(图中为星号),位置信息为零。作者相信,此特殊字符可以在与序列中其他input的交互中学到有用的信息,并用来做分类任务。事实证明这是非常有用的,我在这篇笔记的最后会利用此特殊字符(CLS)来进行attention机制可视化的实现。

最后,再接上MLP分类头即可完成分类任务。

作者在右边放上了标准的transformer encoder结构。可以注意到,这与Attention Is All You Need 中的encoder相比,normalization被放到了多头注意力和多层感知机的前面 。我在之前的笔记里也提到过,对于这个encoder的设计有很多种。

MLP head起到了原来decoder的作用,因为不是复杂的生成任务,而是分类任务,所以它较为简单。

但需要注意的是,这里位置编码position embedding是怎么做的?答案是:是一个有多个向量的矩阵,每个向量代表某个位置的信息,维度为768,将对应某个位置的向量与原始patch经过FC层的输出相加 (要注意是相加)即可。

不过,与我之前的NLP的Transformer笔记矛盾的是 ,这个矩阵是可以被学习 的。NLP的原始Transformer确实用的是正弦/余弦函数生成的位置编码,不可学习 ,但具有良好的泛化能力,因为机器通过学习会关注每个 patch 之间的相对空间结构。在 ViT 里,实验却证明:可学习的位置编码效果更好 ,甚至可以从一维学习到二维的图像位置特征(就是说,机器可以通过学习,理解某个位置编号对应图像的某个二维位置,即这个patch为几行几列等等),而且实现也简单,所以 ViT 默认就用了它。

我复习了ViT的结构之后,就开始尝试复现它的代码。 过程如下。


论文复现与代码实现

先来回顾一下,对于数据的处理、模型的建立,需要经过哪些过程?

  • 构建模型:

    • 首先,构建transformer block,class TransformerBlock(nn.Module):
      • 初始化TransformerBlock类:def __init__(self, dim, heads, mlp_dim, dropout=0.1): ,其中dim是输入的维度,heads是注意力头的数量,mlp_dim是MLP的维度,dropout是Dropout的dropout率 ,batch_first=True表示输入张量的第一个维度是batch_size。
      • 定义TransformerBlock类的forward方法,用于前向传播:def forward(self, x): ,x代表输入的特征张量。
    • 然后,构建ViT模型:class ViT(nn.Module):
      • ViT类的初始化方法,用于初始化ViT类:def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1): ,image_size是输入的图像尺寸,patch_size是patch的大小,num_classes是分类的数量,dim是嵌入的维度,depth是Transformer块的数量,heads是注意力头的数量,mlp_dim是MLP的维度,dropout是Dropout的dropout率。
      • 定义ViT的forward方法,用于前向传播:def forward(self, x):x 是输入图像张量,通常形状是[batch_size, channels, height, width]
  • 实例化model:model = ViT(image_size=32, patch_size=4, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=2048) ,针对CIFAR-10数据集调整参数:图像尺寸32x32,patch大小4x4,类别数10,嵌入维度512,6个TransformerBlock,8个注意力头,MLP维度2048

  • 定义损失函数和优化器:criterion = nn.CrossEntropyLoss() 交叉熵损失
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) AdamW优化器,带权重衰减
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) 余弦退火学习率调度器,控制学习率的

  • 数据增强和预处理:

    • 官方的统计值是:
      cifar10_mean = (0.4914, 0.4822, 0.4465)
      cifar10_std = (0.2470, 0.2435, 0.2616)
    • 设置train_transformtest_transform ,用来等会改变
    • 加载 CIFAR-10 数据集trainset = torchvision.datasets.CIFAR10(……) trainloader = torch.utils.data.DataLoader(……) testset = torchvision.datasets.CIFAR10(……) testloader = torch.utils.data.DataLoader(……)
  • 设置设备:device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device) 将模型移动到GPU或CPU。

  • 训练模型:for epoch in range(10):…… ,边训练边评估,model.eval()with torch.no_grad():……

以上就是程序的整体流程,现在来做具体实现:

一、导入库

1
2
3
4
5
import torch
import torch.nn as nn # 导入torch.nn模块,用于定义神经网络层
import torch.optim as optim # 导入torch.optim模块,用于定义优化器
import torchvision # 导入torchvision模块,用于加载和处理数据集
import torchvision.transforms as transforms # 导入torchvision.transforms模块,用于定义数据增强和预处理

二、构建模型

① 构建transformer block

1
class TransformerBlock(nn.Module):

先初始化TransformerBlock类,

1
def __init__(self, dim, heads, mlp_dim, dropout=0.1): # dim是输入的维度,heads是注意力头的数量,mlp_dim是MLP的维度,dropout是Dropout的dropout率

调用父类的构造函数,继承nn.Module

1
super().__init__()

初始化 TransformerBlock 的各个子模块,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 第一个层归一化
self.norm1 = nn.LayerNorm(dim)
# 多头注意力机制,batch_first=True表示输入张量的第一个维度是batch_size
self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
# 第二个层归一化
self.norm2 = nn.LayerNorm(dim)
# 多层感知机,用于特征转换
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim), # 第一个全连接层,扩大维度
nn.GELU(), # GELU激活函数
nn.Dropout(dropout), # Dropout防止过拟合
nn.Linear(mlp_dim, dim), # 第二个全连接层,恢复原始维度
nn.Dropout(dropout) # Dropout防止过拟合
)

定义TransformerBlock类的forward方法,用于向前传播,

1
2
3
4
5
6
7
8
9
def forward(self, x): #x代表输入的特征张量
# 残差连接 + 多头注意力
x_norm = self.norm1(x) # x_norm是归一化后的特征张量
attn_output, _ = self.attn(x_norm, x_norm, x_norm) # 多头注意力机制,返回的attn_output是(batch_size, seq_len, dim),有两个值,第一个是注意力输出,第二个是注意力权重
x = x + attn_output # 残差连接

# 残差连接 + MLP
x = x + self.mlp(self.norm2(x))
return x

② 构建Vision Transformer模型

1
class ViT(nn.Module):

初始化ViT类:

1
def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1): #这一段是ViT类的初始化方法,用于初始化ViT类,image_size是输入的图像尺寸,patch_size是patch的大小,num_classes是分类的数量,dim是嵌入的维度,depth是Transformer块的数量,heads是注意力头的数量,mlp_dim是MLP的维度,dropout是Dropout的dropout率

继承父类,

1
super(ViT, self).__init__()

把参数保存为实例属性,

1
2
3
4
# 保存基本参数
self.image_size = image_size
self.patch_size = patch_size
self.num_classes = num_classes

断言异常情况,

1
2
# 确保图像尺寸能被patch大小整除
assert image_size % patch_size == 0, "图像尺寸必须能被patch大小整除"

构建其他模块,保存为实例属性,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 计算patch数量
self.num_patches = (image_size // patch_size) ** 2

# 将图像转换为patch嵌入

self.to_patch_embedding = nn.Sequential(
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), # 使用卷积进行patch划分
nn.Flatten(start_dim=2, end_dim=3), # 展平patch
)

# 分类token,用于最终的分类
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 可学习的位置编码(符合原ViT)
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))

# Dropout层,防止过拟合
self.dropout = nn.Dropout(dropout)

# 堆叠多个Transformer块
self.transformer = nn.Sequential(*[
TransformerBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)
])

# 最终的分类头
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim), # 层归一化
nn.Linear(dim, num_classes) # 全连接层输出类别
)

定义前向传播方法,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def forward(self, x): # 定义ViT的forward方法,用于前向传播
# 将输入图像转换为patch嵌入
x = self.to_patch_embedding(x)
x = x.transpose(1, 2)

# 添加分类token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

# 添加位置编码
x = x + self.pos_embedding
# 应用dropout
x = self.dropout(x)

# 通过Transformer块
x = self.transformer(x)
# 使用分类token进行最终分类
x = self.mlp_head(x[:, 0])
return x

三、实例化model

1
2
3
4
5
6
7
8
9
10
# 创建ViT模型实例
# 针对CIFAR-10数据集调整参数:
# - 图像尺寸32x32
# - patch大小4x4
# - 类别数10
# - 嵌入维度512
# - 6个Transformer块
# - 8个注意力头
# - MLP维度2048
model = ViT(image_size=32, patch_size=4, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=2048)

四、定义损失函数和优化器

1
2
3
4
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) # AdamW优化器,带权重衰减
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # 余弦退火学习率调度器

五、数据增强和预处理:

① 保存官方的统计值

1
2
3
4
# 数据增强和预处理
# 官方统计值
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2470, 0.2435, 0.2616)

② 设置train_transformtest_transform ,待会传递给dataset。

1
2
3
4
5
6
7
8
9
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
transforms.ToTensor(), # 转为张量
transforms.Normalize( # 标准化
cifar10_mean,
cifar10_std
)
])

但是,要注意测试集的预处理不需要水平翻转、裁剪、填充等操作。

1
2
3
4
5
6
7
8
# 测试集的预处理(只做 ToTensor + Normalize)
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
cifar10_mean,
cifar10_std
)
])

③ 加载 CIFAR-10 数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 3. 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(
root='D:/CIFAR10',
train=True,
download=True,
transform=train_transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
root='D:/CIFAR10',
train=False,
download=True,
transform=test_transform # 注意这里是 test_transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=128, shuffle=False, num_workers=2
)

六、设置设备

1
2
3
# 设置设备(GPU或CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # 将模型移动到GPU或CPU

七、训练并评估模型

先开始循环,

1
2
3
4
5
# 训练循环
for epoch in range(10): # 训练10轮
model.train() # 设置为训练模式
running_loss = 0.0
for i, data in enumerate(trainloader, 0):

获取训练集数据,转移到设备,

1
2
inputs, labels = data # 获取训练集数据
inputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPU或CPU

前向传播,

1
2
3
4
# 前向传播
optimizer.zero_grad() # 梯度清零
outputs = model(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失

反向传播,

1
2
# 反向传播
loss.backward()

下一步,要进行梯度裁剪,防止梯度爆炸。具体解释:对 model 中的所有参数的梯度进行裁剪(clip) 。如果梯度的 L2 范数总值超过了 1.0 ,就把它们按比例缩小,使总范数不超过 1.0。这样做可以防止梯度过大(爆炸)导致模型训练不稳定 。补充:对于一个向量
$$
\mathbf{x} = [x_1, x_2, …, x_n]
$$
它的 L2 范数 是:
$$
||\mathbf{x}|_2 = \sqrt{x_1^2 + x_2^2 + \cdots + x_n^2}
$$
再补充一个问题:为什么会发生梯度爆炸?以此ViT模型为例,

  • 网络太深:ViT 通常有 很多层 Transformer Block ,层数为depth。每一层都要参与反向传播,链式法则的乘法就非常多。如果每一层大于1的偏导数较多,多层累计起来,最终就可能指数级爆炸。

  • mlp 全连接层的激活函数 GELU 会放大梯度。

ViT 中哪些策略是为了防止爆炸的?

机制 在 ViT 中的作用
LayerNorm 保持输入输出分布稳定
Dropout 防止过拟合、限制激活放大
梯度裁剪(clip_grad) 显式限制最大梯度
warm-up 学习率 避免刚开始训练时“梯度突然暴涨”
初始化方式 保证参数初始值不会导致爆炸传播

现在,来实现梯度裁剪,

1
2
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

之后,更新参数,

1
optimizer.step() # 更新参数

先累加损失,再每100个batch打印一次结果,

1
2
3
4
5
# 打印训练信息
running_loss += loss.item()
if i % 100 == 99: # 每100个batch打印一次
print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}")
running_loss = 0.0

里层循环结束,现在在外层循环里更新学习率。这个scheduler是学习率调度器。在深度学习中,通常会随着训练的进行逐步 调整学习率 ,这种做法叫做 学习率调度 (Learning Rate Scheduling)。目的是:

  • 避免训练初期学习率过大 ,导致模型不稳定。

  • 防止训练后期学习率过小 ,让模型能更好地收敛。

调度器会按照规则动态更新学习率。放在外层循环里,根据 epoch 的进展来更新学习率。虽然scheduleroptimizer 是两个不同的实例,但是在定义时 optimizer 被传递给了schedulerscheduler.step() 让调度器通过访问 optimizer 的参数组(optimizer.param_groups)来更新学习率,就是

1
2
# 更新学习率
scheduler.step()

现在不退出外层循环,而是在这个epoch的基础下,在测试集上评估模型,先做好基础设置,

1
2
3
4
5
# 在测试集上评估模型
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad(): # 不计算梯度

获取测试集数据,评估,

1
2
3
4
5
6
7
for data in testloader: # 获取测试集数据
inputs, labels = data # 获取测试集数据
inputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPU或CPU
outputs = model(inputs) # 前向传播
_, predicted = outputs.max(1) # 获取预测结果
total += labels.size(0) # 计算总样本数
correct += predicted.eq(labels).sum().item() # 计算正确样本数

打印基于这个epoch的结果,

1
print(f"Epoch {epoch + 1} - Accuracy: {100 * correct / total:.2f}%")

八、保存模型权重

1
2
torch.save(model.state_dict(), 'vit_cifar10.pth')
print("Model saved to vit_cifar10.pth")

基于运行代码的实验结果与总结

现在我开始运行程序,并记录一下实验现象。

下载cifar中…

2

犯了低级错误,num_workers > 0 ,使用 多进程 ,但是忘了写 _main_

3

赶快补上,

1
2
3
4
5
6
7
if __name__ == "__main__":
model = ViT(image_size=32, patch_size=4, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=2048)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) # AdamW优化器,带权重衰减
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # 余弦退火学习率调度器
…………

开始运行,

4

训练和评估模块都是正常的,我开始观察记录每一个epoch的accuracy,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Files already downloaded and verified
Files already downloaded and verified
using cuda...
[Epoch 1, Batch 100] Loss: 2.265
[Epoch 1, Batch 200] Loss: 1.999
[Epoch 1, Batch 300] Loss: 1.950
Epoch 1 - Accuracy: 31.45%
[Epoch 2, Batch 100] Loss: 1.940
[Epoch 2, Batch 200] Loss: 1.942
[Epoch 2, Batch 300] Loss: 1.922
Epoch 2 - Accuracy: 30.55%
[Epoch 3, Batch 100] Loss: 1.924
[Epoch 3, Batch 200] Loss: 1.934
[Epoch 3, Batch 300] Loss: 1.934
Epoch 3 - Accuracy: 32.89%
[Epoch 4, Batch 100] Loss: 1.877
[Epoch 4, Batch 200] Loss: 1.874
[Epoch 4, Batch 300] Loss: 1.907
Epoch 4 - Accuracy: 33.57%
[Epoch 5, Batch 100] Loss: 1.883
[Epoch 5, Batch 200] Loss: 1.890
[Epoch 5, Batch 300] Loss: 1.870
Epoch 5 - Accuracy: 34.53%
[Epoch 6, Batch 100] Loss: 1.860
[Epoch 6, Batch 200] Loss: 1.845
[Epoch 6, Batch 300] Loss: 1.838
Epoch 6 - Accuracy: 34.17%
[Epoch 7, Batch 100] Loss: 1.823
[Epoch 7, Batch 200] Loss: 1.822
[Epoch 7, Batch 300] Loss: 1.825
Epoch 7 - Accuracy: 36.81%
[Epoch 8, Batch 100] Loss: 1.794
[Epoch 8, Batch 200] Loss: 1.785
[Epoch 8, Batch 300] Loss: 1.796
Epoch 8 - Accuracy: 37.51%
[Epoch 9, Batch 100] Loss: 1.780
[Epoch 9, Batch 200] Loss: 1.773
[Epoch 9, Batch 300] Loss: 1.765
Epoch 9 - Accuracy: 38.64%
[Epoch 10, Batch 100] Loss: 1.759
[Epoch 10, Batch 200] Loss: 1.763
[Epoch 10, Batch 300] Loss: 1.768
Epoch 10 - Accuracy: 39.05%
Model saved to vit_cifar10.pth

进程已结束,退出代码为 0

Accuracy最后只达到39.05%,似乎非常不顺利。我们来分析一下原因:

ViT模型本身特点

  • ViT 是设计给大数据集的(如 ImageNet),它没有 CNN 的局部感受野和归纳偏置。
  • CIFAR-10 图像太小(32x32),切 patch 后的信息非常少,不容易学有效特征。

训练轮数太少

  • 模型还在下降 loss,accuracy 也在涨,明显是没收敛就被停了。

学习率可能过高或过低

  • transformer 非常敏感学习率。

我决定这样改正:

修改lr与weight_decay

1
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)

训练100轮,那么在100轮中lr逐渐退火到最低值,

1
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

同时改一下for循环,

1
for epoch in range(10):  # 训练10轮

最后,加上可视化,

1
2
3
4
import matplotlib.pyplot as plt

train_losses = []
test_accuracies = []

在每个epoch结束后,

1
2
3
4
train_losses.append(running_loss / num_train_batches)
correct = 0; total = 0
acc = 100 * correct / total
test_accuracies.append(acc)

训练结束后,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 训练结束后,画图
plt.figure(figsize=(10,4))

# Loss 曲线
plt.subplot(1,2,1)
plt.plot(range(1, len(train_losses)+1), train_losses)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Training Loss')

# Accuracy 曲线
plt.subplot(1,2,2)
plt.plot(range(1, len(test_accuracies)+1), test_accuracies)
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.title('Test Accuracy')

plt.tight_layout()
plt.show()

关注loss与accuracy的可视化后,我还想关注这个:对于某一个特定的图片,transformer是如何分配注意力权重的。具体来说,每个 TransformerBlock 内部的 nn.MultiheadAttention 会计算 Query 、Key 和 Value
$$
Q = XW_Q,\quad K = XW_K,\quad V = XW_V
$$
这里 X 就是所有 token 的矩阵,大小
$$
(\text{batch},,\text{seq_len},,\text{dim})
$$
注意力权重矩阵 A 就是:
$$
A = \mathrm{softmax}\Bigl(\frac{QK^T}{\sqrt{d}}\Bigr),
$$
其中 sqrt d 是缩放系数,这样A的形状是
$$
(\text{batch},,\text{heads},,\text{seq_len},,\text{seq_len})
$$
表示每个 head、每个 query token 对每个 key token 的注意力分数。可视化热力图后,我们可以更直观地展示 CLS token 对各 patch 的权重分布。为什么只看CLS与其他patch?因为ViT 最终用 CLS token 的输出去做分类。最后点明细节:展示的是CLS token 作为 Query 与各 Patch 作为 Key 之间的相似度权重。

通俗来说,就是模型更关注图中哪一个小块。

这个注意力可视化程序我会单独写一个来实现。

好,现在修改好代码后,开始重新训练模型。

现在结束了,数据与可视化结果如下:

Figure1

最后几个epoch的数据是

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
[Epoch 96, Batch 100] Loss: 0.008
[Epoch 96, Batch 200] Loss: 0.008
[Epoch 96, Batch 300] Loss: 0.007
Epoch 96 - Accuracy: 81.97%
[Epoch 97, Batch 100] Loss: 0.009
[Epoch 97, Batch 200] Loss: 0.009
[Epoch 97, Batch 300] Loss: 0.010
Epoch 97 - Accuracy: 81.96%
[Epoch 98, Batch 100] Loss: 0.008
[Epoch 98, Batch 200] Loss: 0.010
[Epoch 98, Batch 300] Loss: 0.007
Epoch 98 - Accuracy: 82.01%
[Epoch 99, Batch 100] Loss: 0.008
[Epoch 99, Batch 200] Loss: 0.008
[Epoch 99, Batch 300] Loss: 0.010
Epoch 99 - Accuracy: 81.98%
[Epoch 100, Batch 100] Loss: 0.009
[Epoch 100, Batch 200] Loss: 0.009
[Epoch 100, Batch 300] Loss: 0.007
Epoch 100 - Accuracy: 81.97%

可以看到模型在90多的时候就逐步收敛了。现在模型的权重已经被我保存到了vit_cifar10.pth里面,让我再写一个程序来进行注意力可视化,只要把vit_cifar10.pth传递给它即可,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib

# 1. 指定一个支持中文的字体(Windows 下通常有 SimHei、Microsoft YaHei)
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 用黑体
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['axes.unicode_minus'] = False

…………(transformer block与之前一样)

class ViT(nn.Module):
…………(ViT类前面和之前一样,但下面要添加)
def forward_with_attn(self, x):
x = self.to_patch_embedding(x)
x = x.transpose(1, 2)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embedding
x = self.dropout(x)

attn_maps = []
for block in self.transformer:
x, attn = block(x, return_attn=True)
attn_maps.append(attn) # [B, num_heads, num_patches+1]
return x, attn_maps

…………(模型的实例化与之前一样)

# CIFAR-10 类别名
cifar10_classes = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]


# 可视化函数
def visualize_attention(model, image_path, device):
model.eval()
image = Image.open(image_path).convert("RGB")
transform = test_transform
input_tensor = transform(image).unsqueeze(0).to(device)

# 获取预测分类
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = outputs.argmax(dim=1).item()
predicted_label = cifar10_classes[predicted_class]

# 获取 attention
_, attn_maps = model.forward_with_attn(input_tensor)

# 根据 attn_maps 的维度来取 cls→patch 的注意力
last_attn = attn_maps[-1] # 可能是 4D,也可能是 3D

attn = last_attn # 直接用
cls_attn = attn[:, 0, 1:]


# 多头平均
avg_attn = cls_attn.mean(dim=0)

grid_size = int(model.num_patches ** 0.5)
attn_map = avg_attn.reshape(grid_size, grid_size).cpu().detach().numpy()

# 显示图像
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(image)
axs[0].set_title(f"此图片被模型识别为: {predicted_label}")
axs[0].axis("off")

axs[1].imshow(image)
axs[1].imshow(attn_map, cmap='jet', alpha=0.5,
extent=(0, image.size[0], image.size[1], 0))
axs[1].set_title("Attention Map展示")
axs[1].axis("off")

plt.suptitle(image_path)
plt.tight_layout()
plt.show()


# 加载已训练模型
model.load_state_dict(torch.load("vit_cifar10.pth"))
model.forward_with_attn = ViT.forward_with_attn.__get__(model)
model.to(device)

# 可视化指定图片 attention
image_paths = [
"custom_images/cat1107.png",
"custom_images/cat1129.png",
"custom_images/cat1205.png",
"custom_images/cat1206.png",
"custom_images/cat1212.png",
"custom_images/cat1633.png",
"custom_images/airplane105.png",
"custom_images/airplane165.png",
"custom_images/airplane168.png",
"custom_images/airplane170.png",
"custom_images/airplane352.png",
"custom_images/airplane370.png"
]

for path in image_paths:
visualize_attention(model, path, device)

运行程序,就能调用刚刚训练出来的模型权重,进行图片识别,

F1

F2

F3

F4

F5

F6

F7

F8

F9

F10

可以看出来,模型注意到的地方(偏红色标记)与物体的特征高度相关,而模型“忽视”的地方(偏蓝色标记)常常是图片的背景与无关噪声。这说明ViT通过注意力机制真正学习到了物体的特征。


附录Ⅰ:源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import torch
import torch.nn as nn # 导入torch.nn模块,用于定义神经网络层
import torch.optim as optim # 导入torch.optim模块,用于定义优化器
import torchvision # 导入torchvision模块,用于加载和处理数据集
import torchvision.transforms as transforms # 导入torchvision.transforms模块,用于定义数据增强和预处理

import matplotlib.pyplot as plt
import matplotlib

from torch.utils.data import random_split



# 指定一个支持中文的字体(Windows 下通常有 SimHei、Microsoft YaHei)
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 用黑体
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['axes.unicode_minus'] = False

# 定义Transformer块,这是ViT的核心组件
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.1): # dim是输入的维度,heads是注意力头的数量,mlp_dim是MLP的维度,dropout是Dropout的dropout率,这一段是在初始化TransformerBlock类
super().__init__() # 继承nn.Module类
# 第一个层归一化
self.norm1 = nn.LayerNorm(dim)
# 多头注意力机制,batch_first=True表示输入张量的第一个维度是batch_size
self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
# 第二个层归一化
self.norm2 = nn.LayerNorm(dim)
# 多层感知机,用于特征转换
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim), # 第一个全连接层,扩大维度
nn.GELU(), # GELU激活函数
nn.Dropout(dropout), # Dropout防止过拟合
nn.Linear(mlp_dim, dim), # 第二个全连接层,恢复原始维度
nn.Dropout(dropout) # Dropout防止过拟合
)

def forward(self, x, return_attn=False): #这一段是TransformerBlock类的forward方法,用于前向传播,x代表输入的特征张量
# 残差连接 + 多头注意力
x_norm = self.norm1(x) # x_norm是归一化后的特征张量
attn_output, attn_weights = self.attn(x_norm, x_norm, x_norm)# 多头注意力机制,返回的attn_output是(batch_size, seq_len, dim),有两个值,第一个是输出,第二个是权重
x = x + attn_output # 残差连接
# 残差连接 + MLP
x = x + self.mlp(self.norm2(x))
if return_attn:
return x, attn_weights
return x

# 定义Vision Transformer模型
class ViT(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1): #这一段是ViT类的初始化方法,用于初始化ViT类,image_size是输入的图像尺寸,patch_size是patch的大小,num_classes是分类的数量,dim是嵌入的维度,depth是Transformer块的数量,heads是注意力头的数量,mlp_dim是MLP的维度,dropout是Dropout的dropout率
super(ViT, self).__init__()
# 保存基本参数
self.image_size = image_size
self.patch_size = patch_size
self.num_classes = num_classes

# 确保图像尺寸能被patch大小整除
assert image_size % patch_size == 0, "图像尺寸必须能被patch大小整除"
# 计算patch数量
self.num_patches = (image_size // patch_size) ** 2

# 将图像转换为patch嵌入

self.to_patch_embedding = nn.Sequential(
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), # 使用卷积进行patch划分
nn.Flatten(start_dim=2, end_dim=3), # 展平patch
)

# 分类token,用于最终的分类
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 可学习的位置编码(符合原ViT)
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))

# Dropout层,防止过拟合
self.dropout = nn.Dropout(dropout)

# 堆叠多个Transformer块
self.transformer = nn.Sequential(*[
TransformerBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)
])

# 最终的分类头
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim), # 层归一化
nn.Linear(dim, num_classes) # 全连接层输出类别
)

def forward(self, x): # 定义ViT的forward方法,用于前向传播
# 将输入图像转换为patch嵌入
x = self.to_patch_embedding(x)
x = x.transpose(1, 2)

# 添加分类token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

# 添加位置编码
x = x + self.pos_embedding
# 应用dropout
x = self.dropout(x)

# 通过Transformer块
x = self.transformer(x)
# 使用分类token进行最终分类
x = self.mlp_head(x[:, 0])
return x



if __name__ == "__main__":

train_losses = []
test_accuracies = []

# 创建ViT模型实例
# 针对CIFAR-10数据集调整参数:
# - 图像尺寸32x32
# - patch大小4x4
# - 类别数10
# - 嵌入维度512
# - 6个Transformer块
# - 8个注意力头
# - MLP维度2048
model = ViT(image_size=32, patch_size=4, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=2048)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.03) # AdamW优化器,带权重衰减
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) # 余弦退火学习率调度器

# 数据增强和预处理
# 官方统计值
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2470, 0.2435, 0.2616)

# 2. 预处理
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪并填充
transforms.ToTensor(), # 转为张量
transforms.Normalize( # 标准化
cifar10_mean,
cifar10_std
)
])

test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
cifar10_mean,
cifar10_std
)
])

# 3. 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(
root='D:/CIFAR10',
train=True,
download=True,
transform=train_transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
root='D:/CIFAR10',
train=False,
download=True,
transform=test_transform # 注意这里是 test_transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=128, shuffle=False, num_workers=2
)
# 设置设备(GPU或CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # 将模型移动到GPU或CPU

print(f"using {device}...")

num_train_batches = len(trainloader)

# 训练循环
for epoch in range(100): # 训练100轮
model.train() # 设置为训练模式
running_loss = 0.0
epoch_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data # 获取训练集数据
inputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPU或CPU

# 前向传播
optimizer.zero_grad() # 梯度清零
outputs = model(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失

# 反向传播
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # 更新参数

# 打印训练信息
running_loss += loss.item() # 累加损失
epoch_loss += loss.item()
if i % 100 == 99: # 每100个batch打印一次
print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}")
running_loss = 0.0

# 更新学习率
scheduler.step()

# 记录训练可视化数据
train_losses.append(epoch_loss / num_train_batches)

# 在测试集上评估模型
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad(): # 不计算梯度
for data in testloader: # 获取测试集数据
inputs, labels = data # 获取测试集数据
inputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPU或CPU
outputs = model(inputs) # 前向传播
_, predicted = outputs.max(1) # 获取预测结果
total += labels.size(0) # 计算总样本数
correct += predicted.eq(labels).sum().item() # 计算正确样本数

print(f"Epoch {epoch + 1} - Accuracy: {100 * correct / total:.2f}%")

# 记录测试可视化数据
acc = 100 * correct / total
test_accuracies.append(acc)

# 保存模型权重
torch.save(model.state_dict(), 'vit_cifar10.pth')
print("Model saved to vit_cifar10.pth")

# 训练结束后,画图
plt.figure(figsize=(10,4))

# Loss 曲线
plt.subplot(1,2,1)
plt.plot(range(1, len(train_losses)+1), train_losses)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Training Loss')

# Accuracy 曲线
plt.subplot(1,2,2)
plt.plot(range(1, len(test_accuracies)+1), test_accuracies)
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.title('Test Accuracy')

plt.tight_layout()
plt.show()


附录Ⅱ:参考文献

  1. Dosovitskiy, A., Beyer, L., Kolesnikov, A., et al. “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale.” arXiv preprint arXiv:2010.11929, 2020.
    https://arxiv.org/abs/2010.11929
  2. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. arXiv preprint arXiv:1706.03762.
    https://arxiv.org/abs/1706.03762
  3. PyTorch 官方文档. “torch.nn.MultiheadAttention.” 2025.
    https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html