别再死记硬背了用PyTorch代码实战搞懂SGD、Adam优化器的核心区别深度学习面试中优化器相关的问题几乎从不缺席。但大多数教程和面经都停留在理论描述让人看完依旧云里雾里。今天我们换个方式——直接动手写代码在PyTorch中实现SGD、Momentum、Adam等优化器并通过可视化对比它们的实际表现。你会发现那些抽象的概念如噪声、偏置修正、参数平稳性在代码和图表面前变得无比清晰。1. 环境准备与数据生成工欲善其事必先利其器。我们先搭建实验环境import torch import torch.nn as nn import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm # 设置随机种子保证可复现性 torch.manual_seed(42) np.random.seed(42)为了聚焦优化器本身的表现我们构造一个简单的线性回归任务# 生成模拟数据 X torch.randn(1000, 1) * 5 # 输入特征 true_w torch.tensor([[2.0]]) # 真实权重 true_b torch.tensor([1.0]) # 真实偏置 y X true_w true_b torch.randn(X.shape) * 2 # 添加噪声 # 可视化数据 plt.scatter(X.numpy(), y.numpy(), alpha0.5) plt.xlabel(X) plt.ylabel(y) plt.title(Generated Data for Optimization Comparison) plt.show()2. 从零实现基础优化器2.1 Vanilla SGD实现标准SGD的更新规则最简单θ θ - η * ∇J(θ)PyTorch实现如下class VanillaSGD: def __init__(self, params, lr0.01): self.params list(params) self.lr lr def step(self): for param in self.params: if param.grad is not None: param.data - self.lr * param.grad.data def zero_grad(self): for param in self.params: if param.grad is not None: param.grad.detach_() param.grad.zero_()测试效果# 初始化模型和优化器 model nn.Linear(1, 1) optimizer VanillaSGD(model.parameters(), lr0.01) # 训练循环 losses [] for epoch in range(100): optimizer.zero_grad() outputs model(X) loss nn.MSELoss()(outputs, y) loss.backward() optimizer.step() losses.append(loss.item()) # 绘制损失曲线 plt.plot(losses) plt.title(Vanilla SGD Training Loss) plt.xlabel(Epoch) plt.ylabel(MSE Loss) plt.show()2.2 带Momentum的SGDMomentum通过引入速度变量累积历史梯度信息v γ * v η * ∇J(θ) θ θ - v实现代码class SGDMomentum: def __init__(self, params, lr0.01, momentum0.9): self.params list(params) self.lr lr self.momentum momentum self.velocities [torch.zeros_like(p) for p in self.params] def step(self): for i, param in enumerate(self.params): if param.grad is not None: self.velocities[i] self.momentum * self.velocities[i] self.lr * param.grad.data param.data - self.velocities[i] def zero_grad(self): for param in self.params: if param.grad is not None: param.grad.detach_() param.grad.zero_()关键参数对比参数Vanilla SGDSGD with Momentum学习率(lr)0.010.01momentum无0.9收敛速度慢较快震荡幅度大较小3. Adam优化器深度解析Adam结合了Momentum和RMSProp的优点其核心公式m_t β1 * m_{t-1} (1-β1) * g_t # 一阶矩估计 v_t β2 * v_{t-1} (1-β2) * g_t² # 二阶矩估计 m̂_t m_t / (1-β1^t) # 偏置修正 v̂_t v_t / (1-β2^t) θ_t θ_{t-1} - η * m̂_t / (√v̂_t ε)完整实现class Adam: def __init__(self, params, lr0.001, betas(0.9, 0.999), eps1e-8): self.params list(params) self.lr lr self.betas betas self.eps eps self.m [torch.zeros_like(p) for p in self.params] self.v [torch.zeros_like(p) for p in self.params] self.t 0 def step(self): self.t 1 for i, param in enumerate(self.params): if param.grad is not None: self.m[i] self.betas[0] * self.m[i] (1 - self.betas[0]) * param.grad.data self.v[i] self.betas[1] * self.v[i] (1 - self.betas[1]) * param.grad.data**2 # 偏置修正 m_hat self.m[i] / (1 - self.betas[0]**self.t) v_hat self.v[i] / (1 - self.betas[1]**self.t) param.data - self.lr * m_hat / (torch.sqrt(v_hat) self.eps) def zero_grad(self): for param in self.params: if param.grad is not None: param.grad.detach_() param.grad.zero_()4. 优化器对比实验现在我们同时运行三种优化器对比它们的表现def train_with_optimizer(optimizer_class, **kwargs): model nn.Linear(1, 1) optimizer optimizer_class(model.parameters(), **kwargs) losses [] param_trajectory [] for epoch in range(100): optimizer.zero_grad() outputs model(X) loss nn.MSELoss()(outputs, y) loss.backward() optimizer.step() losses.append(loss.item()) param_trajectory.append([p.detach().clone() for p in model.parameters()]) return losses, param_trajectory # 训练并记录结果 sgd_losses, sgd_traj train_with_optimizer(VanillaSGD, lr0.01) momentum_losses, momentum_traj train_with_optimizer(SGDMomentum, lr0.01, momentum0.9) adam_losses, adam_traj train_with_optimizer(Adam, lr0.01)可视化结果plt.figure(figsize(12, 5)) # 损失曲线对比 plt.subplot(1, 2, 1) plt.plot(sgd_losses, labelVanilla SGD) plt.plot(momentum_losses, labelSGD with Momentum) plt.plot(adam_losses, labelAdam) plt.xlabel(Epoch) plt.ylabel(Loss) plt.title(Loss Comparison) plt.legend() # 参数更新轨迹 plt.subplot(1, 2, 2) sgd_w [p[0][0].item() for p in sgd_traj] momentum_w [p[0][0].item() for p in momentum_traj] adam_w [p[0][0].item() for p in adam_traj] plt.plot(sgd_w, labelVanilla SGD) plt.plot(momentum_w, labelSGD with Momentum) plt.plot(adam_w, labelAdam) plt.axhline(ytrue_w.item(), colorr, linestyle--, labelTrue Value) plt.xlabel(Epoch) plt.ylabel(Weight Value) plt.title(Parameter Update Trajectory) plt.legend() plt.tight_layout() plt.show()关键观察点收敛速度Adam通常最快Momentum次之Vanilla SGD最慢参数平稳性Adam的参数更新最平稳SGD波动最大噪声处理Momentum和Adam都能有效抑制梯度噪声的影响5. 优化器选择实战建议根据实验结果和实际项目经验总结以下选择指南小规模数据/简单模型Vanilla SGD足够调参简单中等规模数据SGD with Momentum (β0.9)大规模数据/复杂模型Adam是安全选择特别注意事项Adam的默认参数(β10.9, β20.999)对大多数任务表现良好学习率设置Adam通常比SGD小一个数量级(如SGD用0.1Adam用0.001)对于需要极高精度的任务(如GAN训练)SGDMomentum可能更稳定优化器超参数典型设置优化器学习率范围关键参数Vanilla SGD0.1-0.01-SGDMomentum0.1-0.01momentum0.9Adam0.001-0.0001β10.9, β20.999在真实项目中我通常会先用Adam快速获得一个baseline然后再尝试SGD类优化器进行精细调优。对于视觉类任务Adam的表现通常更稳定而在NLP领域部分研究表明SGDMomentum可能取得更好的最终效果。