MiniMax稀疏注意力(MSA):基於109B參數MoE模型訓練的雙分支塊稀疏注意力機制
MiniMax提出MSA,一種基於分組查詢注意力(GQA)的稀疏注意力方法。通過輕量級索引分支選擇每個查詢的top-k鍵值塊,主分支僅關注這些塊,在1M上下文下將每token注意力計算量降低28.4倍,同時在下游基準測試中匹配GQA性能。
MiniMax近日發佈了MSA(MiniMax稀疏注意力),這是一種直接在分組查詢注意力(GQA)上構建的稀疏注意力方法。其目標是解決長上下文場景下softmax注意力的二次成本瓶頸。研究團隊在具有原生多模態數據的109B參數混合專家模型上進行了測試,同時開源了推理內核併發布了生產模型MiniMax-M3。
MSA將注意力分解為兩個階段:索引分支和主分支。索引分支決定每個查詢應讀取哪些鍵值塊,主分支僅對這些塊執行精確的softmax注意力。選擇以塊為單位(默認塊大小128個token),每個查詢和GQA組保留k=16個塊,固定每個查詢的預算為2048個鍵值token。與密集GQA注意力的O(N)複雜度相比,MSA的複雜度為O(kBk),隨上下文長度增長,計算差距不斷擴大。
索引分支僅向標準GQA層添加兩個投影矩陣。它定義一個索引查詢頭和一個共享的索引鍵頭,對可見鍵token評分,然後通過最大池化聚合到塊級別,使用Top-k選擇得分最高的塊。本地塊始終包含在內,防止遺漏近鄰區域。主分支則收集所選塊中的因果可見token,應用縮放點積softmax注意力。
由於Top-k選擇不可微分,MSA通過KL對齊損失訓練索引投影,使索引分支分佈匹配主分支注意力模式。三種機制穩定稀疏訓練:梯度分離(阻止梯度傳播到骨幹網絡)、索引器預熱(前40B token使用全注意力)和強制本地塊(保留一個槽位給相鄰上下文)。研究團隊通過消融實驗優化了最終方案。
MSA支持兩種訓練路徑:MSA-PT(從頭訓練,含40B token預熱)和MSA-CPT(從2.6T token的密集GQA檢查點轉換,繼續訓練400B token)。算法配合兩個內核設計實現實際加速:無指數Top-k選擇和KV外部稀疏注意力與查詢收集。開源內核fmha_sm100針對NVIDIA SM100 GPU,支持BF16、FP8等精度。
與現有稀疏方法相比,MSA的獨特之處在於每組獨立進行Top-k塊選擇,保持KV讀取連續性的同時允許各組獨立檢索。實驗表明,在3T token預算下,稀疏模型與全注意力基線競爭力相當。例如,MMLU得分分別為67.0(全注意力)、67.2(MSA-PT)和66.8(MSA-CPT)。
MSA適用於長上下文場景:長週期智能體、倉庫級代碼推理、持久性記憶和長視頻理解。推理內核可通過Hugging Face庫使用,但要求SM100 GPU和CUDA工具包。
優勢包括:在1M上下文下每token注意力計算降低28.4倍,實際加速14.2倍(預填充)和7.6倍(解碼),設計簡潔,支持多種訓練模式,內核開源。劣勢:內核僅限NVIDIA SM100,部分任務存在性能差距,加速假設特定配置,KL損失增加訓練複雜度,結果基於自有評估。