注意力机制模块:动态稀疏注意力 S²Attention 在 ConvNeXt 中的实现,QKV 计算量减半
一、引言:为什么需要重新思考注意力机制在动手写代码之前,先搞清楚我们为什么要关注动态稀疏注意力。本质上,这源于一个无法忽视的事实:注意力机制中的QKV投影权重,占了大模型总权重的50%以上,在推理阶段QKV的存储量会随上下文长度线性增长,计算量则呈平方级攀升。打个比方,标准的多头自注意力机制就像要求会场里的每个人都必须和其余的每个人单独交流 —— 当人数从几百人增加到几千人时,这种交流的成本是不可持续的。其实,Transformer架构从诞生之日起就背上了一个“原罪”:标准自注意力机制的时间复杂度和空间复杂度都是 O(N²)(这里的 N 是token/特征序列的长度)。对于大语言模型而言,这意味着当你把上下文长度从 4K 扩展到 128K 时,计算开销不是线性增长,而是爆炸式增长。根据业界研究人员的长期观察,在标准Transformer的一轮前向推理中,大约 25% 的时间用于计算 QKV 矩阵,约 8% 用于计算注意力输出矩阵,剩下的约 66% 用于 FFN(前馈网络)。这意味着注意力子模块虽然只占总计算时间的三分之一左右,但却是显存消耗的主要来源 —— 因为中间矩阵需要显式存储完整的注意力权重图。学术界和工业界对这一困局的回应是多方面的。一部分工作聚焦于 KV Cache 的压缩与管理——例如 MQA(Multi-Query Attention)和 GQA(Grouped-Query Attention)通过让多个注意力头共享同一组 K/V 投影来降低显存占用。另一部分工作则在推理侧发力,包括投机解码和内核级优化(典型代表