从零手写 FlashAttention(PyTorch实现 + 原理推导)
本文基于一个最小 PyTorch 示例手写实现 FlashAttention的核心计算流程并详细解释其数值稳定性和分块计算原理。1. 标准 Attention 回顾标准 Attention 的计算公式Attention(Q,K,V)softmax(QKT)V Attention(Q,K,V) softmax(QK^T)VAttention(Q,K,V)softmax(QKT)Vimporttorch querytorch.randn(1,12,10)keytorch.randn(1,12,10)valuetorch.randn(1,12,10)logitstorch.einsum(bqd,bkd-bqk,query,key)probstorch.nn.functional.softmax(logits,dim-1)softmax_outputtorch.einsum(bqk,bkd-bqd,probs,value)2. FlashAttention 核心思想FlashAttention 的核心目标避免显式存储整个 attention matrixQK^T关键手段分块计算block-wise在线 Softmaxonline softmax3. 数值稳定 Softmaxsoftmax(xj)exj−m∑kexk−m,mmax(x) softmax(x_j) \frac{e^{x_j - m}}{\sum_k e^{x_k - m}}, \quad m max(x)softmax(xj)∑kexk−mexj−m,mmax(x)4. 核心递推mimax(mi−1,mij) m_i max(m_{i-1}, m_{ij})mimax(mi−1,mij)lili−1emi−1−mi∑exij−mi l_i l_{i-1} e^{m_{i-1} - m_i} \sum e^{x_{ij} - m_i}lili−1emi−1−mi∑exij−mioioi−1emi−1−mi∑(exij−miVj) o_i o_{i-1} e^{m_{i-1} - m_i} \sum (e^{x_{ij} - m_i} V_j)oioi−1emi−1−mi∑(exij−miVj) 关键细节深入理解很多人在理解这里时容易卡住为什么需要对历史的oi−1o_{i-1}oi−1做rescale我们一步一步拆解1️⃣oi−1o_{i-1}oi−1并不是最终正确的值在第i−1i-1i−1次循环时我们用的是局部最大值mi−1m_{i-1}mi−1所以 softmax 实际是exi−1∑exi−1exi−1−mi−1∑exi−1−mi−1 \frac{e^{x_{i-1}}}{\sum e^{x_{i-1}}} \frac{e^{x_{i-1} - m_{i-1}}}{\sum e^{x_{i-1} - m_{i-1}}}∑exi−1exi−1∑exi−1−mi−1exi−1−mi−1 注意这里的归一化是基于局部 block 的尺度2️⃣ 当进入第iii个 block 时发生了什么我们得到了新的最大值mimax(mi−1,mij) m_i max(m_{i-1}, m_{ij})mimax(mi−1,mij) 这个mim_imi更接近全局最大值3️⃣ 问题的本质此时出现一个不一致项目 使用的 maxoi−1o_{i-1}oi−1mi−1m_{i-1}mi−1当前 blockmim_imi 如果直接相加会导致不同尺度的指数项被混合数值错误4️⃣ 解决方法统一尺度rescale我们需要把旧的oi−1o_{i-1}oi−1从ex−mi−1 e^{x - m_{i-1}}ex−mi−1转换到ex−mi e^{x - m_i}ex−mi变换方式ex−mi−1ex−mi⋅emi−mi−1 e^{x - m_{i-1}} e^{x - m_i} \cdot e^{m_i - m_{i-1}}ex−mi−1ex−mi⋅emi−mi−1 因此oi−1→oi−1⋅emi−1−mi o_{i-1} \rightarrow o_{i-1} \cdot e^{m_{i-1} - m_i}oi−1→oi−1⋅emi−1−mi5️⃣ 对应代码o_i o_i_1 * torch.exp(m_i_1 - m_i)[…, None] torch.einsum(‘bqk,bkd-bqd’, exp_term, v_i)含义是第一项旧结果 rescale 到新尺度第二项当前 block 的贡献6️⃣ 一个直观理解可以把整个过程理解为我们在不断修正历史让所有累积值都统一到当前最稳定的坐标系最大值下随着循环进行mim_imi会逐步逼近全局最大值所有历史贡献都会被重新缩放到这个统一尺度7️⃣ 最终结果当所有 block 处理完mim_imi 全局最大值oi/lio_i / l_ioi/li 完整 softmax 结果5. PyTorch实现flash_softmax_outputs[]q_chunks4q_chunk_sizequery.shape[1]//q_chunks k_chunks3k_chunk_sizekey.shape[1]//k_chunksforiinrange(q_chunks):q_iquery[:,i*q_chunk_size:(i1)*q_chunk_size]m_i_1torch.full((q_i.shape[0],q_i.shape[1]),-float(inf))l_i_1torch.zeros_like(m_i_1)o_i_1torch.zeros((q_i.shape[0],q_i.shape[1],value.shape[-1]))forjinrange(k_chunks):k_ikey[:,j*k_chunk_size:(j1)*k_chunk_size]# (B, K_block, D)v_ivalue[:,j*k_chunk_size:(j1)*k_chunk_size]# (B, K_block, Dv)logits_itorch.einsum(nqd,nkd-nqk,q_i,k_i)# (B, Q_block, K_block)# ---- 更新 m ----m_ijtorch.max(logits_i,dim-1)[0]# (B, Q_block)m_itorch.maximum(m_i_1,m_ij)# 计算Softmax分子e^(x_i - m_i)exp_termtorch.exp(logits_i-m_i[...,None])# (B, Q_block, K_block)# 更新Softmax分母# rescale * 旧的softmax分母 新的softmax分母l_il_i_1*torch.exp(m_i_1-m_i)exp_term.sum(dim-1)# ---- 更新 O关键----# rescale * 旧的logit * v 新的logit * vo_io_i_1*torch.exp(m_i_1-m_i)[...,None]torch.einsum(nqk,nkd-nqd,exp_term,v_i)# ---- 状态更新 ----m_i_1m_i l_i_1l_i o_i_1o_i# ---- 最后除以Softmax分母----outputo_i/l_i[...,None]flash_softmax_outputs.append(output)flash_softmax_outputstorch.cat(flash_softmax_outputs,dim1)6. 正确性验证torch.allclose(softmax_output,flash_softmax_outputs)7. 总结FlashAttention 本质分块计算在线 softmax动态重标定rescale复杂度从 O(N^2) 降到 O(N)