别再死记ResNet结构了!用PyTorch手搓一个ResNet-18,带你彻底搞懂残差连接
用PyTorch手搓ResNet-18从代码实现透视残差连接的本质残差网络ResNet自2015年问世以来一直是计算机视觉领域的基石模型。但很多开发者对它的理解停留在跳跃连接这个表面概念上真正动手实现时才发现诸多细节问题为什么有的残差块用1x1卷积维度不匹配时如何处理Basic Block和Bottleneck Block究竟有什么区别今天我们就用PyTorch从零构建一个ResNet-18在代码层面彻底搞懂这些核心问题。1. 残差网络的设计哲学深度神经网络在图像识别任务中表现出色但当网络深度超过20层后准确率不升反降。这种现象并非过拟合导致而是源于梯度消失——深层网络在反向传播时梯度信号经过多层传递后逐渐衰减直至消失。ResNet的创新之处在于提出了残差学习框架让网络能够学习输入与输出之间的残差即变化部分而非直接学习完整的映射。残差块的核心公式简单优雅output F(x) x其中F(x)是需要学习的残差映射x是恒等映射。当网络已经达到最优状态时理论上可以让F(x)趋近于0此时网络就退化为恒等映射避免了性能退化。在PyTorch中实现这个思想时需要考虑几个关键点当F(x)和x的维度不一致时需要用1x1卷积调整通道数残差块内部通常采用卷积-BN-ReLU的标准组合最终输出前需要再次经过ReLU激活2. 构建Basic BlockResNet-18的核心组件ResNet-18使用的是Basic Block结构每个残差块包含两个3x3卷积层。我们先实现这个基础构件import torch import torch.nn as nn class BasicBlock(nn.Module): expansion 1 # 通道数扩展系数 def __init__(self, in_channels, out_channels, stride1): super().__init__() # 第一个卷积层 self.conv1 nn.Conv2d( in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn1 nn.BatchNorm2d(out_channels) # 第二个卷积层 self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) # 跳跃连接处理 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels * self.expansion: self.shortcut nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): residual x out self.conv1(x) out self.bn1(out) out nn.ReLU()(out) out self.conv2(out) out self.bn2(out) # 处理维度匹配 residual self.shortcut(residual) out residual out nn.ReLU()(out) return out这个实现中有几个值得注意的技术细节维度匹配处理当输入输出维度不一致时通常发生在每个stage的第一个block使用1x1卷积调整通道数和空间尺寸批归一化每个卷积层后都接BatchNorm这是现代CNN的标准配置残差相加在相加前不进行激活这是原始论文的设计提示Basic Block中的expansion参数是为了保持与Bottleneck Block的接口一致在Basic Block中其值为13. 组装完整的ResNet-18架构现在我们可以用Basic Block搭建完整的ResNet-18了。ResNet的网络结构遵循一个通用范式初始卷积层较大的卷积核和下采样4个stage的残差块堆叠全局平均池化和全连接层class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes1000): super().__init__() self.in_channels 64 # 初始卷积层 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 四个stage的残差块 self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2) # 分类头 self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides [stride] [1] * (num_blocks - 1) layers [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): x self.conv1(x) x self.bn1(x) x nn.ReLU()(x) x self.maxpool(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x创建ResNet-18实例的代码如下def resnet18(): return ResNet(BasicBlock, [2, 2, 2, 2])这里[2,2,2,2]表示四个stage各自包含2个Basic Block总计2*48个残差块加上初始卷积层和最后的全连接层正好是18层每个Basic Block包含2个卷积层。4. 残差网络的训练技巧与可视化实现网络结构只是第一步要让ResNet真正发挥作用还需要注意训练过程中的几个关键点4.1 初始化策略残差网络对参数初始化比较敏感。推荐使用以下初始化方法def initialize_weights(model): for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)4.2 学习率调度使用带热重启的余弦退火学习率CosineAnnealingWarmRestarts通常能取得不错的效果optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay1e-4) scheduler torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_010)4.3 梯度流动可视化为了直观理解残差连接如何缓解梯度消失我们可以可视化不同层的梯度范数def plot_gradient_flow(model): gradients [] for name, param in model.named_parameters(): if param.grad is not None and weight in name: gradients.append(param.grad.norm().item()) plt.figure(figsize(10, 5)) plt.plot(gradients, alpha0.3, colorb) plt.hlines(0, 0, len(gradients)1, linewidth1, colork) plt.title(Gradient flow) plt.xlabel(Layers) plt.ylabel(Average gradient norm) plt.yscale(log)与普通CNN相比ResNet的梯度分布更加均匀深层仍然能接收到较强的梯度信号。5. ResNet变体与实战选择虽然我们实现了ResNet-18但ResNet家族还有多个重要变体模型层数残差块类型参数量(M)ImageNet Top-1 AccResNet-1818Basic Block11.769.8%ResNet-3434Basic Block21.873.3%ResNet-5050Bottleneck25.676.2%ResNet-101101Bottleneck44.577.4%ResNet-152152Bottleneck60.278.0%对于不同应用场景选择建议如下轻量级应用ResNet-18/34适合移动端或实时系统平衡型应用ResNet-50在精度和计算量间取得良好平衡高性能应用ResNet-101/152追求最高准确率Bottleneck Block的实现与Basic Block类似只是在两个3x3卷积之间增加了1x1卷积用于降维和升维class Bottleneck(nn.Module): expansion 4 # 最终输出通道数是中间通道数的4倍 def __init__(self, in_channels, out_channels, stride1): super().__init__() # 1x1卷积降维 self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) # 3x3卷积 self.conv2 nn.Conv2d( out_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse ) self.bn2 nn.BatchNorm2d(out_channels) # 1x1卷积升维 self.conv3 nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size1, biasFalse ) self.bn3 nn.BatchNorm2d(out_channels * self.expansion) # 跳跃连接 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels * self.expansion: self.shortcut nn.Sequential( nn.Conv2d( in_channels, out_channels * self.expansion, kernel_size1, stridestride, biasFalse ), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): residual x out self.conv1(x) out self.bn1(out) out nn.ReLU()(out) out self.conv2(out) out self.bn2(out) out nn.ReLU()(out) out self.conv3(out) out self.bn3(out) residual self.shortcut(residual) out residual out nn.ReLU()(out) return out在实际项目中我通常先尝试ResNet-50作为基线模型它提供了较好的精度与计算效率平衡。当需要更高精度时会考虑使用ResNet-101但要注意这会使训练时间显著增加。