shmem:昇腾NPU多卡共享内存的底层机
前言做多卡推理时模型太大一张卡放不下要把模型参数分到4张卡上。推理时每张卡都要访问其他卡的参数如果每次都走hccl AllGather延迟受不了——4卡AllGather 1MB数据要1.2ms一次推理可能要做10次AllGather光通信就12ms比计算还慢。shmem让每张卡能直接读其他卡的HBM像读本地内存一样。本地HBM访问200nsshmem远端HBM访问2.3μs走HCCShccl AllGather 1.2ms。差了500倍。这篇文章拆解shmem的设计理念、三层架构和性能特征。设计理念让多卡像单卡一样编程shmem的核心抽象是全局地址空间Global Address Space。单卡编程时你分配一块HBM拿到一个指针直接读写。多卡编程时shmem给每张NPU的HBM分配一个全局地址任何NPU都能用这个全局地址直接读写远端NPU的HBM。对编程者来说多卡编程变成了单卡编程——你不需要管数据在哪张卡上用shmem_get/shmem_put读写就行shmem内部帮你做地址翻译和数据搬运。这跟CUDA的思路不同。CUDA多卡编程要显式调用cudaMemcpyPeer你得知道源卡和目标卡的ID。shmem把数据在卡几这个细节隐藏了你只需要知道全局地址。三层架构拆解shmem分三层接口层、调度层、传输层。接口层OpenSHMEM API兼容shmem的API兼容OpenSHMEM标准学过MPI/OpenSHMEM的人上手很快#includeshmem.h// 1. 初始化shmem_init();// 2. 分配全局内存对称分配每张NPU都分配相同大小double*local_buf(double*)shmem_malloc(1024*sizeof(double));// 3. 写远端内存putinttarget_pe1;// 目标NPU的编号shmem_double_put(local_buf,local_buf,1024,target_pe);// 把本地的local_buf写到PE 1的local_buf// 4. 读远端内存getshmem_double_get(local_buf,local_buf,1024,target_pe);// 从PE 1的local_buf读到本地的local_buf// 5. 同步shmem_barrier_all();// 6. 释放shmem_free(local_buf);shmem_finalize();Python接口更简洁importshmemimporttorch# 1. 初始化shmem.init()# 2. 分配全局内存对称分配bufshmem.empty(1024,dtypetorch.float16,devicenpu:0)# 3. 写远端shmem.put(buf,target_pe1)# 把buf写到PE 1# 4. 读远端shmem.get(buf,target_pe2)# 从PE 2读buf# 5. 同步shmem.barrier()调度层地址翻译和路由选择调度层做两件事地址翻译和路由选择。地址翻译全局地址 → 本地HBM / 远端HBM全局地址空间 PE 0: [0x0000_0000_0000 - 0x0000_00FF_FFFF] ← 本地HBM PE 1: [0x0000_0100_0000 - 0x0000_01FF_FFFF] ← 远端HBM PE 2: [0x0000_0200_0000 - 0x0000_02FF_FFFF] ← 远端HBM PE 3: [0x0000_0300_0000 - 0x0000_03FF_FFFF] ← 远端HBM当PE 0访问地址0x0000_0100_0000时调度层识别出这是PE 1的HBM选择合适的路由把请求发出去。路由选择根据拓扑选择最优路径源 → 目标路由带宽延迟同一HCCS域HCCS直连200 GB/s2.3 μs不同HCCS域同节点PCIe Switch64 GB/s5.1 μs不同节点RDMARoCEv225 GB/s8.7 μsshmem自动选择最优路由编程者不需要关心拓扑。传输层HCCS / PCIe / RDMA传输层是实际搬数据的部分三种传输方式各有适用场景HCCS昇腾高速互连同一节点内的NPU互连带宽最高200GB/s延迟最低2.3μs。这是shmem性能最好的场景。PCIe Switch同一节点内跨HCCS域的NPU通信带宽中等64GB/s延迟中等5.1μs。RDMARoCEv2跨节点的NPU通信带宽最低25GB/s受网卡限制延迟最高8.7μs。shmem的RDMA传输复用了hixl的单边通信能力。实战用shmem做多卡模型参数共享4卡推理场景每张卡放1/4的模型参数推理时用shmem_get读其他卡的参数。importshmemimporttorch shmem.init()peshmem.my_pe()# 当前NPU编号0/1/2/3n_peshmem.n_pe()# 总NPU数量4# 1. 每张卡加载1/4的模型参数model_shardload_model_shard(shard_idpe)# 加载本卡负责的参数分片# 2. 注册为shmem全局内存param_bufshmem.register(model_shard.parameters())# 3. 推理时读取其他卡的参数defget_remote_param(target_pe,param_name):从目标NPU读取参数# 先拿到目标PE上param_buf的地址信息remote_addrshmem.query_addr(target_pe,param_name)# 用shmem_get读取local_copytorch.empty_like(remote_addr.shape,dtyperemote_addr.dtype,devicenpu:0)shmem.get(local_copy,remote_addr,target_pe)returnlocal_copy# 4. 推理definference(input_ids):# 每层的参数可能分布在不同卡上hiddenembedding(input_ids)# embedding在本卡forlayer_idinrange(32):target_pelayer_id%n_pe# 参数分片策略layer i在PE (i%4)上wqget_remote_param(target_pe,flayer{layer_id}.wq)wkget_remote_param(target_pe,flayer{layer_id}.wk)wvget_remote_param(target_pe,flayer{layer_id}.wv)hiddenattention(hidden,wq,wk,wv)returnhidden shmem.finalize()性能数据访问延迟对比访问方式延迟vs 本地HBM本地HBM200 ns1xshmem远端HBMHCCS2.3 μs11.5xshmem远端HBMPCIe5.1 μs25.5xshmem远端HBMRDMA8.7 μs43.5xhccl AllGather4卡1MB1.2 ms6000xshmem远端访问比本地HBM慢10-40倍但比hccl AllGather快500倍以上。带宽对比访问方式带宽本地HBM1.2 TB/sshmem远端HBMHCCS200 GB/s16.7%本地shmem远端HBMPCIe64 GB/s5.3%本地shmem远端HBMRDMA25 GB/s2.1%本地远端带宽只有本地的2-17%所以shmem不适合频繁读写远端大块数据的场景适合偶尔读一次、读完本地缓存的模式。踩坑实录坑1远端带宽比本地低很多问题把所有模型参数都放在PE 0其他3张卡通过shmem_get读参数。推理性能只有本地模式的1/4。原因HCCS带宽200GB/s只有本地HBM1.2TB/s的17%4张卡同时读PE 0HCCS带宽被打满。解决方案参数均匀分布到4张卡上每张卡只读1/4的远端参数HCCS带宽够用# ❌ 错误做法所有参数放PE 0# PE 1/2/3 每次推理都要shmem_get全部参数 → HCCS带宽打满# ✅ 正确做法参数均匀分布# PE i 存第i组参数推理时每张卡只shmem_get 1/4远端参数坑2缓存一致性开销问题PE 0更新了参数PE 1读到的还是旧值。原因shmem远端读取走HCCS/RDMA数据可能在本地L2 Cache中有缓存。PE 0更新参数后PE 1的缓存没有失效。解决方案写完参数后做一次shmem_barrier()确保所有PE看到最新数据# PE 0: 更新参数model_shard.update_parameters(new_params)shmem_barrier()# 所有PE同步缓存失效# PE 1: 读取参数保证拿到最新值paramget_remote_param(0,layer0.wq)坑3shmem_malloc是对称分配问题4张卡调用shmem_malloc(1024)每张卡都分配1024字节总共4KB。但你想让PE 0分配4KB、其他PE不分配。原因shmem_malloc是对称分配Symmetric Allocation——所有PE必须分配相同大小的内存地址在所有PE上也是对称的。这是OpenSHMEM标准的要求。解决方案如果需要非对称分配用shmem.register()注册已有的torch tensor# 对称分配所有PE分配相同大小bufshmem.empty(1024,dtypetorch.float16,devicenpu:0)# 非对称分配注册已有tensorlocal_tensortorch.randn(2048,dtypetorch.float16,devicenpu:0)bufshmem.register(local_tensor)# 只注册本PE的tensorshmem在CANN架构中的位置shmem位于CANN五层架构的第4层昇腾计算执行层跟hccl和hixl并列第4层昇腾计算执行层 ├─ Runtime ├─ HCCL集合通信 ├─ hixl单边通信 ├─ shmem共享内存← 你在这里 └─ DVPP / AIPPshmem底层复用了hixl的RDMA传输能力上层被hccl的AllGather/ReduceScatter调用。结尾shmem让多卡编程更简单——像单卡一样用全局地址读写远端HBM不用手动调hccl通信原语。但简单不代表无脑用远端访问的延迟10-40x本地和带宽2-17%本地是硬约束。合理的使用模式是参数均匀分布、偶尔远端读取、读完本地缓存。如果你在做多卡推理且模型太大单卡放不下shmem是最干净的解决方案——不用写复杂的通信逻辑一个shmem_get搞定。但如果你只是做数据并行训练hccl就够了不需要引入shmem的复杂性。https://atomgit.com/cann/shmem