OPD Reverse KL
一、 OPD在线策略蒸馏On-Policy Distillation, 简称OPD。在语言模型LLM的知识蒸馏中通常有两个模型学生模型Student策略为参数为我们要优化的对象。教师模型Teacher策略为参数固定提供监督信号。逆向 KL 散度Reverse KL, RKL的公式为。与前向 KLForward KL,相比逆向 KL 具有 模式寻求Mode-seeking的特性。它会让学生模型倾向于只在教师模型概率较高的地方生成文本从而减少模型产生幻觉Hallucination或语无伦次的概率使生成的文本更加确定和精准。二、 公式 1Full-vocab Reverse KL 损失函数在生成第 t 个 token 时的损失函数符号解释* V整个词表Vocabulary。* x输入提示词Prompt。*在 t 时刻之前已经生成的历史 token 序列。*教师模型可能额外享有的输入信息例如更丰富的上下文、思维链提示或参考答案。*学生模型在当前上下文下预测词表中每个词 v 的概率分布。*教师模型预测的概率分布。物理意义这个公式计算的是在当前步骤 t学生模型分布与教师模型分布在全词表Full-vocab上的逆向 KL 散度。因为求期望的权重项是学生模型的概率所以它是一种 在线/在策略On-policy 的评估方式——它关注的是“站在学生模型自己的视角下其当前输出与教师的偏差”。三、 公式 2梯度的推导与化简对上述损失函数关于学生模型参数求导。为了书写简便我们将简记为将简记为。目标是计算。1. 梯度推导过程损失函数为由于只有学生模型包含参数教师模型与无关。利用导数的乘积法则对求导我们分别处理这两项第二项因为对其求导得将这一结果代回第二项中与外面的相乘消去分母得到第一项使用对数导数技巧Log-derivative trick常用于强化学习即。将这两项重新整合利用对数导数技巧把第二项也写成含有的形式为了与图片中的形式完全一致我们将括号里的项取负号倒过来代入后即得到公式四、 恒等式的消去作用里面的 1 在全词表下会被这条恒等式抵消。为什么该恒等式成立因为概率分布在全词表上的和恒等于 1即对其两边求导再利用对数导数技巧即可得到消除后的简化梯度这意味着公式中括号里的常数 -1展开后与外面的负号结合变成 1在对整个词表求和时其贡献为 0。因此实际计算时梯度可以简化为五、 物理意义与直观理解如果我们把简化后的梯度写成策略梯度Policy Gradient中常见的形式考虑最小化损失函数参数更新方向为负梯度这相当于一种自带基线Baseline的策略梯度算法1. 动作空间在当前步骤学生模型在全词表 V 上进行探索。2. 权重项Reward对于词表中的每一个词 v其受到的奖励/惩罚因子为当时说明学生模型低估了该词的概率。此时则。梯度更新会提高该词的生成概率。* 当时说明学生模型过度自信高估了该词。此时。梯度更新会压低该词的生成概率。3. 全词表覆盖因为是 Full-vocab算法不仅对采样到的单个词进行更新而是同时对词表中的所有词进行推拉Push-Pull。这使得训练过程比单样本采样的策略梯度更加平滑和稳定。