loss.backward() 和 梯度累积
loss.backward()在做什么这一行调用的是 PyTorch 的反向传播机制——核心作用是从 loss 出发沿着前向构建的计算图反向走对所有requires_gradTrue的参数计算梯度∂loss / ∂param并累加到param.grad上。一、背景前向时 PyTorch 在做什么当kd_algorithm.training_step(micro_batch)执行学生模型前向、算 logits、算 loss 时PyTorch自动构建了一张计算图computation graph每个张量记录了我是怎么从哪些张量算出来的grad_fn这张图把loss和所有可训练参数θ连接起来例如简化后input → [Linear w₁] → h₁ → [Linear w₂] → logits → [CE loss] → loss ↑ ↑ ↑ w₁.grad_fn w₂.grad_fn loss.grad_fn二、loss.backward()反向走这张图1. 计算梯度链式法则按拓扑逆序遍历计算图对每个节点应用链式法则∂loss/∂logits → ∂loss/∂h₁ → ∂loss/∂w₂ → ∂loss/∂w₁ ...每经过一个算子就调用其backward实现如 matmul 的 backward 是矩阵乘转置等把上游传下来的梯度乘上局部雅可比再传给更上游。2. 把梯度写到param.grad最终对每个叶子参数pp.requires_gradTrue且没有grad_fnp.grad ∂loss/∂p # 累加不是覆盖注意是累加——这就是为什么需要optimizer.zero_grad()在每次 step 后把 grad 清零否则梯度会一直叠加。利用了累加特性来实现梯度累积连续多次backward()不调zero_grad梯度自然累加最后一次累积满后才调optimizer.step()zero_grad()。3. 释放计算图默认情况下反向走完后 PyTorch 会释放保存的中间激活值用于反向计算的 saved tensors节省显存。再次对同一张图 backward 会报错除非传retain_graphTrue。三、在 FSDP 下的特殊行为普通单卡loss.backward()只算梯度。但FSDPFully Sharded Data Parallel它在 backward 阶段会做更多事阶段FSDP 行为前向各 rank 持有 param 的一个shard前向时 all-gather 完整参数 → 计算 → 释放完整参数反向反向再次 all-gather 参数计算梯度 →reduce-scatter把每个参数的梯度 reduce求和/平均并 scatter 回该 param shard 所在的 rank结果每个 rank 只在自己负责的 param shard 上得到对应梯度p.grad梯度也被分片存储所以loss.backward()在 FSDP 上下文里还隐含了跨 rank 通信reduce-scatter 把数据并行的梯度聚合并分片相当于普通 DDP 的all-reduce但通信对象是 sharded 梯度。四、整体作用与定位backward()是梯度计算阶段——它不更新参数只把这一步学生模型应该怎么调整以梯度形式存到p.grad里等优化器 step 时再用。loss.backward() 沿计算图反向应用链式法则对所有可训练参数计算∂loss/∂p并累加到p.grad在 FSDP 下还会自动完成跨 rank 的梯度 reduce-scatter让每个 rank 拿到自己 param shard 对应的梯度。它是算梯度参数更新由后续optimizer.step()完成。梯度累积Gradient Accumulation的作用与差异一、什么是梯度累积核心思想把一个大 batch拆成 N 个小 batchmicro_batch连续做 N 次前向 反向但不更新参数让梯度在p.grad上自然累加累加 N 次后再做一次optimizer.step()。def backward(self, loss, model, optimizer, **kwargs): self.step (self.step 1) % self.accumulated_gradient loss loss / self.accumulated_gradient # ← 先除以 N loss.backward() # ← 梯度累加到 p.grad def optimizer_step(self, optimizer, model, scheduler, **kwargs): if self.step 0: # ← 累满 N 次才真正更新 ... optimizer.step() optimizer.zero_grad() scheduler.step()累积 N 次 backward → 1 次 optimizer.step。二、为什么要累积——主要作用1.在显存受限下实现大 batch 训练最核心动机显存放不下 batch_size64 怎么办拆成 8 次 batch_size8 累积效果约等于batch_size64方案单步显存峰值等效 batch显存友好一次 batch64高容易 OOM64✗累积 8×8低只到 batch864✓LLM 训练里大 batch对收敛稳定性、梯度信噪比、scaling law 都有帮助但显存往往是瓶颈累积是绕过显存限制拿到大有效 batch的标准手段。2.匹配分布式并行的语义每个 prompt batch 生成一批 rollout被切成多个 micro_batch 喂给学生 actor。让 actor 在内部把这 N 个 micro_batch累积成一次更新意味着一次 rollout → 一次梯度更新——保持 on-policy 的语义清晰避免一份 rollout 数据用多次。三、累积 vs 不累积的具体区别数学上梯度等价在常见条件下假设 loss 是样本平均reductionmean不累积batchNL (1/N) Σᵢ Lᵢ ∂L/∂θ (1/N) Σᵢ ∂Lᵢ/∂θ累积 N 次每次 micro_batch1带loss / N归一化每次 backward 累加: (1/N) ∂Lᵢ/∂θ N 次后 p.grad Σᵢ (1/N) ∂Lᵢ/∂θ (1/N) Σᵢ ∂Lᵢ/∂θ ✓例如配置train_batch_size128、micro_train_batch_size4、world_size8num_micro_batches self.args.train.train_batch_size // self.args.train.micro_train_batch_size数据并行下每 GPU 一次 forward 跑 4 条样本累积步数 128 / 4 / 8 4每张卡反向 4 次后再 step 一次等效全局 batch 4 × 4 × 8 128梯度累积 把大 batch 一次更新拆成小 batch 多次累加 最后一次更新用计算时间换显存空间数学上在样本平均损失 无 BN 的前提下与不累积等价。区别主要在显存峰值、吞吐、BN 统计、通信次数和更新频率上LLM 训练里因为用 LayerNorm二者基本等价所以累积是大模型训练框架的默认手段。梯度累积为什么能节省显存显存峰值由单次前向 反向时同时驻留的中间激活值activations决定而梯度累积让每一步只跑一个小 micro_batch激活值只按小 batch 算 → 峰值显存大幅下降。累加的只是参数级别的梯度p.grad它的大小与 batch 无关。一、训练时显存的几大组成一次训练步显存占用大致为组成大小是否随 batch 增长说明模型参数θ否与模型大小有关与 batch 无关梯度p.grad否形状与参数相同与 batch 无关优化器状态Adam: m, v否~2× 参数大小与 batch 无关前向激活值activations是随 batch 线性增长反向计算梯度时必须保留的中间张量临时 buffer / workspace部分通信缓冲、cuDNN workspace 等LLM 训练里最大的显存开销往往是activations——因为每层 transformer 的输入、attention 中间结果Q/K/V、attention scores都要保存以便反向传播。它的大小约为activations ≈ batch_size × seq_len × hidden_size × num_layers × const显存峰值就是发生在反向开始前——所有层的 activations 都还驻留在 GPU 上的那一刻。二、不累积 vs 累积的显存对比不累积一次大 batchNforward (batchN) → 保存 ~N 倍激活值 ─┐ ├─ 显存峰值 ∝ N backward ← 消耗激活算梯度 ─────┘ optimizer.step激活值正比于 NN 大就 OOM。累积micro_batchN/k累积 k 次循环 k 次 forward (batchN/k) → 保存 ~(N/k) 倍激活值 ─┐ ├─ 每轮峰值 ∝ N/k backward ← 消耗激活算梯度 ─────────┘ ✅ 激活值 backward 后立即释放 ✅ p.grad 累加大小不变 optimizer.step累积 k 次后才执行一次每轮反向结束后 PyTorch自动释放该轮的激活值默认retain_graphFalse所以下一轮重新分配的激活值占用同一片显存不会叠加。→峰值显存 ∝ N/k是不累积方案的 1/k。三、为什么累加梯度自己不会占额外显存loss.backward() # 在已有的 p.grad 上做 in-placep.grad是和p形状一致的张量第一次 backward 就会分配好后续 backward 是in-place 累加到这块已有显存上累积 1 次 vs 累积 100 次p.grad占用一模一样所以累积的代价是时间多跑 k 次 forwardbackward收益是激活值显存峰值降到 1/k——而梯度本身不变大。四、一个直观的数字例子假设模型 7B 参数bf16 →参数 14 GB梯度 14 GBAdam 状态28 GBfp32 mv单卡 80GB 显存去掉以上 56 GB剩 ~24 GB 给 activations bufferbatch_size64、seq_len4096 时 activations 估 ~80 GB →OOMbatch_size8、seq_len4096 时 activations 估 ~10 GB → 装得下不累积要 OOM累积 8 次 micro_batch8每轮峰值 ≈ 56 10 66 GB ✓ 不 OOM8 次反向后p.grad仍是 14 GB不变累积满 8 次再optimizer.step()等效 batch64五、与 FSDP / 梯度检查点的关系显存优化技术分三类——梯度累积只是其中之一常常组合使用技术削减的部分代价梯度累积activations按 1/k时间多次 forwardgradient checkpointing重计算activations按 ~1/√L 或更多时间反向时重做部分前向FSDP参数/梯度/优化器状态分片params grads optim states按 1/world_size通信开销六、为什么梯度求和不会爆可能有人会担心累积 k 次梯度p.grad会越来越大不会——p.grad是数值上累加张量大小恒定不变形状和 dtype 不变。变化的只是里面存的数值从 0 累加到最终的累计梯度不消耗额外显存空间。显存峰值主要被 forward 期间保留的 activations 撑大而 activations 与 batch 大小成正比梯度累积通过把大 batch 拆成多个小 micro_batch 串行跑每轮的 activations 在 backward 后立刻释放使峰值降到 1/k累加发生在已分配好的p.grad上in-place不引入额外显存——所以用时间换空间成立。optimizer.step()的作用核心作用用loss.backward()累积在p.grad中的梯度按优化算法如 AdamW、SGD的更新规则实际修改模型参数p.data。这是训练循环里唯一真正改动模型权重的一步。一、做了什么简化的伪代码以 SGD 为例for group in optimizer.param_groups: lr group[lr] for p in group[params]: if p.grad is None: continue p.data - lr * p.grad # 实际更新参数不同优化器的更新规则不同优化器更新公式简化SGDp ← p - lr · gSGDMomentumv ← μv g; p ← p - lr · vAdam / AdamW维护一阶/二阶矩估计 m, vp ← p - lr · m̂ / (√v̂ ε) - lr · wd · p本项目用的是 AdamWfsdp_strategy.py中通过create_optimizer创建所以optimizer.step()在内部读取p.grad更新一阶矩m和二阶矩v保存在optimizer.state[p]做 bias correction计算更新量并写回p.data应用 weight decayAdamW 把 weight decay 单独从梯度中解耦二、在训练循环中的位置def optimizer_step(self, optimizer, model, scheduler, **kwargs): if self.step 0: # 累积满才更新 if self.max_norm 0.0: ...clip_grad_norm_(model.parameters(), self.max_norm) # ① 梯度裁剪 optimizer.step() # ② 真正更新参数 ← 你问的这行 optimizer.zero_grad() # ③ 清零梯度准备下一轮累积 if scheduler: scheduler.step() # ④ 更新学习率完整训练步顺序forward → loss → backward → [grad_norm 监控] → clip_grad_norm_ → optimizer.step() → zero_grad → scheduler.step() ↑ ↑ ↑ 算梯度 限制梯度大小 用梯度更新参数注意三个关键点if self.step 0守卫梯度累积下只有累满accumulated_gradient次 backward 才执行 step——这就是多次累加 一次更新的实现顺序clip 在 step 之前——必须先把梯度限制住再用梯度更新step 后立即zero_grad清空p.grad否则下一轮 backward 还会在旧梯度上累加相当于把已用过的梯度重复用三、与 FSDP 的关系在 FSDP 下每个 rank 只持有参数的一个 shard对应的p.grad也是分片的。optimizer.step()在 FSDP 中每个 rank 只更新自己的 param shard——不需要跨 rank 通信因为梯度已在 backward 阶段通过 reduce-scatter 分到了对应 shard 所在 rank优化器状态Adam 的 m, v也是分片存储的与 param shard 一一对应——这是 FSDP 相对 DDP 节省显存的关键之一DDP 下每个 rank 都要存完整的 m, v所以optimizer.step()在 FSDP 上是个本地操作性能开销与单卡基本一致。四、step()vsbackward()的角色对比操作改变什么输出loss.backward()计算并累加p.grad不改p.dataoptimizer.step()修改p.data真正更新模型更新优化器内部状态m, v 等不改p.gradoptimizer.zero_grad()把p.grad清零准备下一轮累积可以说backward 是算账step 是扣款。前者只是在.grad里写下应该这样调整后者才把这个调整真正落到模型权重上。optimizer.step()是按优化算法这里是 AdamW读取累积在p.grad里的梯度并真正更新模型参数p.data的步骤——这是整个训练循环里唯一改动模型权重的地方前面的 backward 只是算梯度没有 step 模型不会变。