CATLASS MLA【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlassCATLASS MLA是基于CATLASS Gemm API实现的亲和昇腾AtlasA2硬件的Flash-MLA算子算子的结构可以分为以下几部分Tiling计算Kernel实现具体有两种实现通用的mla_kernel.cpp以及特化的mla_kernel_tp1_spec.cppKernel中依赖适合Flash-MLA运算的Block组件使用的Block组件依赖模板库提供的Tile组件。TilingTiling计算的逻辑位于mla.cpp文件中在调用算子前需要准备好tiling计算所需的各项参数赋值给MLAInfo结构体并调用GetMLATilingParam函数。mla.cpp中提供了一个示例// 准备Tiling计算所需的中间结构体以及Host侧空间 MLATiling::MLAInfo mlaInfo; ... MLATiling::GetMLATilingParam(mlaInfo, blockDim, (uint32_t *)tilingHost);GetMLATilingParam函数中调用了两个函数GetMLATilingCommon与GetMLATilingSpec分别对应了通用场景下和特化场景下的分核逻辑Kernel本算子提供了两种Kernel实现通用的mla_kernel.cpp在qHeadNum为16/32/64场景分别对应模型侧TP8/4/2场景性能更优。特化的mla_kernel_tp1_spec.cpp在qHeadNum为128场景对应模型侧TP1场景性能更优。mla_kernel.cpp具有以下特性采用FlashAttention的四阶段计算流程对于输入的Q, QRope, K, KRope进行切块后运算。对输入序列长度kvSeqlen按照blockSize为单位进行切块每次Attention运算的基块为一个block使能提前下发一个基块的QK Mmad与softmax让不同基块的CUBE与VECTOR阶段互相掩盖。在同一基块的QK与PV的矩阵乘之间由于K与V共用同一段数据使能K常驻在L1 buffer上减少搬入带宽占用。mla_kernel_tp1_spec.cpp具有以下特性采用FlashAttention的四阶段计算流程对于输入的Q, QRope, K, KRope进行切块后运算。对输入序列长度kvSeqlen按照blockSize为单位进行切块每次Attention运算的基块为四个block使能提前下发一个基块的QK Mmad与softmax让不同基块的CUBE与VECTOR阶段互相掩盖。由于基块大小的放大该Kernel的PV Mmad阶段的搬出数据量降低减少了搬出带宽占用相应的由于硬件buffer大小限制取消了K的常驻。在本算子中使用了Block和Tile层级组件来组装Kernel具体步骤为组装attention计算中的两个BlockMmadQK,PV以及三个BlockEpiloguesoftmax, rescaleO, flashDecoding。将Block组合在一起构建成MLAKernel并在Kernel类中完成对各个Block的循环调用。这一过程也体现在Kernel入口的代码中以mla_kernel.cpp为例// GEMM Block模块实现Flash MLA的Q * K^T using DispatchPolicyQK Gemm::MmadAtlasA2MLAQK; using QType Gemm::GemmTypeElementQ, LayoutQ; using KType Gemm::GemmTypeElementK, LayoutK; using SType Gemm::GemmTypeElementS, LayoutS; using BlockMmadQK Gemm::Block::BlockMmadDispatchPolicyQK, L1TileShape, L0TileShape, QType, KType, SType; // Epilogue Block模块, 实现Flash MLA中当前S基块的softmax using PType Gemm::GemmTypeElementP, LayoutP; using MaskType Gemm::GemmTypeElementMask, LayoutMask; using EpilogueMLASoftmax Epilogue::Block::BlockEpilogueEpilogue::EpilogueAtlasA2MLASoftmax, PType, SType, MaskType; // GEMM Block模块实现Flash MLA的P * V using DispatchPolicyPV Gemm::MmadAtlasA2MLAPV; using VType Gemm::GemmTypeElementV, LayoutV; using OTmpType Gemm::GemmTypeElementOTmp, LayoutOTmp; using BlockMmadPV Gemm::Block::BlockMmadDispatchPolicyPV, L1TileShape, L0TileShape, PType, VType, OTmpType; // Epilogue Block模块, 实现Flash MLA中当前O基块的更新 using OType Gemm::GemmTypeElementO, LayoutO; using OUpdateType Gemm::GemmTypeElementUpdate, LayoutUpdate; using EpilogueMLARescaleO Epilogue::Block::BlockEpilogueEpilogue::EpilogueAtlasA2MLARescaleO, OType, OUpdateType, OTmpType; // Epilogue Block模块, 实现Flash MLA中flash decoding using lType Gemm::GemmTypeElementUpdate, LayoutUpdate; constexpr uint32_t ComputeEleNum 6144; using EpilogueMLAFDRescaleO Epilogue::Block::BlockEpilogueEpilogue::EpilogueAtlasA2MLAFDRescaleOComputeEleNum, OType, lType; // Kernel level using MLAKernel MLAKernelBlockMmadQK, BlockMmadPV, EpilogueMLASoftmax, EpilogueMLARescaleO, EpilogueMLAFDRescaleO;Block Mmad算子总共使用了两类Block Mmad组件分别为BlockMmadQK为BlockMmad模板类的偏特化用于处理Flash-MLA中的Q与K的矩阵乘操作头文件block_mmad_mla_qk.hpp中的实现对应通用的mla_kernel.cpp头文件block_mmad_mla_qk_tp1_spec.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。BlockMmadPV为BlockMmad模板类的偏特化用于处理Flash-MLA中的P与V的矩阵乘操作头文件block_mmad_mla_pv.hpp中的实现对应通用的mla_kernel.cpp头文件block_mmad_mla_pv_tp1_spec.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。Block Epilogue算子总共使用了三类Block Epilogue组件分别为EpilogueMLASoftmax为BlockEpilogue模板类的偏特化用于处理Flash-MLA中的online softmax操作头文件block_epilogue_mla_softmax.hpp中的实现对应通用的mla_kernel.cpp头文件block_epilogue_mla_tp1_softmax.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。EpilogueMLARescaleO为BlockEpilogue模板类的偏特化用于处理Flash-MLA中的rescaleO操作头文件block_epilogue_mla_rescale_o.hpp中的实现对应通用的mla_kernel.cpp头文件block_epilogue_mla_tp1_rescale_o.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。EpilogueMLAFDRescaleO为BlockEpilogue模板类的偏特化用于处理Flash-MLA中的flashDecoding操作如有必要头文件block_epilogue_mla_fd_rescale_o.hpp中的实现为mla_kernel.cpp与mla_kernel_tp1_spec.cpp两者共用。Tile Mmad Tile Copy在通用Kernel使用的Block组件中使用了位于tile_mmad.hpp中的tileMmad组件和位于tile_copy.hpp中的tileCopy组件例如using TileMmad TileMmad_; using CopyGmToL1A typename TileCopy_::CopyGmToL1A; using CopyGmToL1B typename TileCopy_::CopyGmToL1B; using CopyL1ToL0A typename TileCopy_::CopyL1ToL0A; using CopyL1ToL0B typename TileCopy_::CopyL1ToL0B; using CopyL0CToGm typename TileCopy_::CopyL0CToGm; using ElementAccumulator typename Gemm::helper::ElementAccumulatorSelectorElementA, ElementB::ElementAccumulator;【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考