📜  残差网络简介(1)

📅  最后修改于: 2023-12-03 15:26:53.126000             🧑  作者: Mango

残差网络简介

简介

残差网络(ResNet)是2015年由Kaiming He等人提出的深度卷积神经网络结构,在ImageNet Large-Scale Visual Recognition Challenge比赛中获得了冠军。该网络通过引入“残差块”(Residual block)的结构,使得网络可以在更深的层数下学习到更加复杂的特征,从而提高了网络的表现。

残差块

残差块是残差网络的核心,其结构如下图所示:

Residual Block

其中,$\mathcal{F}(x)$表示一个普通的卷积层+非线性激活函数的操作,$H(x)$表示一个跳连接(shortcut connection),用来跳过一层卷积操作,直接将输入$x$接到输出$y$之上。

跳连接的目的是解决深度神经网络的梯度消失问题。在传统神经网络中,每一层都会对输入进行一次非线性的变换,这样随着网络的加深,梯度也随之指数级地减小,导致网络很难优化。而跳连接能让信息在不同的层之间流动,使得梯度能够更好地传递,从而加速网络的训练。

残差网络

残差网络是由多个残差块组成的网络结构,如下图所示:

ResNet

可以看到,残差网络将网络分为多个stage,每个stage内部使用相同的残差块。在stage之间还会添加池化层或降采样层,来进一步减小特征图的尺寸。

实现

以下是一个简单的残差块的实现代码:

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        identity = self.shortcut(identity)

        out += identity
        out = self.relu(out)

        return out

要使用该模块构建残差网络,只需要堆叠多个残差块即可,例如:

import torch.nn as nn

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.stage1 = nn.Sequential(
            ResidualBlock(64, 64),
            ResidualBlock(64, 64)
        )

        self.stage2 = nn.Sequential(
            ResidualBlock(64, 128, stride=2),
            ResidualBlock(128, 128)
        )

        self.stage3 = nn.Sequential(
            ResidualBlock(128, 256, stride=2),
            ResidualBlock(256, 256)
        )

        self.stage4 = nn.Sequential(
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 512)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 1000)

    def forward(self, x):
        x = self.stem(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
总结

残差网络是一种非常有效的深度神经网络结构,其跳连接的设计能够解决深度网络的梯度消失问题。由于其出色的表现,现在已经成为很多图像任务的标准模型之一。