📅  最后修改于: 2023-12-03 15:26:53.126000             🧑  作者: Mango
残差网络(ResNet)是2015年由Kaiming He等人提出的深度卷积神经网络结构,在ImageNet Large-Scale Visual Recognition Challenge比赛中获得了冠军。该网络通过引入“残差块”(Residual block)的结构,使得网络可以在更深的层数下学习到更加复杂的特征,从而提高了网络的表现。
残差块是残差网络的核心,其结构如下图所示:
其中,$\mathcal{F}(x)$表示一个普通的卷积层+非线性激活函数的操作,$H(x)$表示一个跳连接(shortcut connection),用来跳过一层卷积操作,直接将输入$x$接到输出$y$之上。
跳连接的目的是解决深度神经网络的梯度消失问题。在传统神经网络中,每一层都会对输入进行一次非线性的变换,这样随着网络的加深,梯度也随之指数级地减小,导致网络很难优化。而跳连接能让信息在不同的层之间流动,使得梯度能够更好地传递,从而加速网络的训练。
残差网络是由多个残差块组成的网络结构,如下图所示:
可以看到,残差网络将网络分为多个stage,每个stage内部使用相同的残差块。在stage之间还会添加池化层或降采样层,来进一步减小特征图的尺寸。
以下是一个简单的残差块的实现代码:
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
identity = self.shortcut(identity)
out += identity
out = self.relu(out)
return out
要使用该模块构建残差网络,只需要堆叠多个残差块即可,例如:
import torch.nn as nn
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.stage1 = nn.Sequential(
ResidualBlock(64, 64),
ResidualBlock(64, 64)
)
self.stage2 = nn.Sequential(
ResidualBlock(64, 128, stride=2),
ResidualBlock(128, 128)
)
self.stage3 = nn.Sequential(
ResidualBlock(128, 256, stride=2),
ResidualBlock(256, 256)
)
self.stage4 = nn.Sequential(
ResidualBlock(256, 512, stride=2),
ResidualBlock(512, 512)
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, 1000)
def forward(self, x):
x = self.stem(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
残差网络是一种非常有效的深度神经网络结构,其跳连接的设计能够解决深度网络的梯度消失问题。由于其出色的表现,现在已经成为很多图像任务的标准模型之一。