AI News HubLIVE
站内改写2 分钟阅读

MiniMax稀疏注意力(MSA):基于109B参数MoE模型训练的双分支块稀疏注意力机制

MiniMax提出MSA,一种基于分组查询注意力(GQA)的稀疏注意力方法。通过轻量级索引分支选择每个查询的top-k键值块,主分支仅关注这些块,在1M上下文下将每token注意力计算量降低28.4倍,同时在下游基准测试中匹配GQA性能。

来源MarkTechPost作者: Asif Razzaq

MiniMax近日发布了MSA(MiniMax稀疏注意力),这是一种直接在分组查询注意力(GQA)上构建的稀疏注意力方法。其目标是解决长上下文场景下softmax注意力的二次成本瓶颈。研究团队在具有原生多模态数据的109B参数混合专家模型上进行了测试,同时开源了推理内核并发布了生产模型MiniMax-M3。

MSA将注意力分解为两个阶段:索引分支和主分支。索引分支决定每个查询应读取哪些键值块,主分支仅对这些块执行精确的softmax注意力。选择以块为单位(默认块大小128个token),每个查询和GQA组保留k=16个块,固定每个查询的预算为2048个键值token。与密集GQA注意力的O(N)复杂度相比,MSA的复杂度为O(kBk),随上下文长度增长,计算差距不断扩大。

索引分支仅向标准GQA层添加两个投影矩阵。它定义一个索引查询头和一个共享的索引键头,对可见键token评分,然后通过最大池化聚合到块级别,使用Top-k选择得分最高的块。本地块始终包含在内,防止遗漏近邻区域。主分支则收集所选块中的因果可见token,应用缩放点积softmax注意力。

由于Top-k选择不可微分,MSA通过KL对齐损失训练索引投影,使索引分支分布匹配主分支注意力模式。三种机制稳定稀疏训练:梯度分离(阻止梯度传播到骨干网络)、索引器预热(前40B token使用全注意力)和强制本地块(保留一个槽位给相邻上下文)。研究团队通过消融实验优化了最终方案。

MSA支持两种训练路径:MSA-PT(从头训练,含40B token预热)和MSA-CPT(从2.6T token的密集GQA检查点转换,继续训练400B token)。算法配合两个内核设计实现实际加速:无指数Top-k选择和KV外部稀疏注意力与查询收集。开源内核fmha_sm100针对NVIDIA SM100 GPU,支持BF16、FP8等精度。

与现有稀疏方法相比,MSA的独特之处在于每组独立进行Top-k块选择,保持KV读取连续性的同时允许各组独立检索。实验表明,在3T token预算下,稀疏模型与全注意力基线竞争力相当。例如,MMLU得分分别为67.0(全注意力)、67.2(MSA-PT)和66.8(MSA-CPT)。

MSA适用于长上下文场景:长周期智能体、仓库级代码推理、持久性记忆和长视频理解。推理内核可通过Hugging Face库使用,但要求SM100 GPU和CUDA工具包。

优势包括:在1M上下文下每token注意力计算降低28.4倍,实际加速14.2倍(预填充)和7.6倍(解码),设计简洁,支持多种训练模式,内核开源。劣势:内核仅限NVIDIA SM100,部分任务存在性能差距,加速假设特定配置,KL损失增加训练复杂度,结果基于自有评估。