AI News HubLIVE
站内改写2 分鐘閱讀

MiniMax稀疏注意力(MSA):基於109B引數MoE模型訓練的雙分支塊稀疏注意力機制

MiniMax提出MSA,一種基於分組查詢注意力(GQA)的稀疏注意力方法。透過輕量級索引分支選擇每個查詢的top-k鍵值塊,主分支僅關注這些塊,在1M上下文下將每token注意力計算量降低28.4倍,同時在下游基準測試中匹配GQA效能。

來源MarkTechPost作者: Asif Razzaq

MiniMax近日釋出了MSA(MiniMax稀疏注意力),這是一種直接在分組查詢注意力(GQA)上構建的稀疏注意力方法。其目標是解決長上下文場景下softmax注意力的二次成本瓶頸。研究團隊在具有原生多模態資料的109B引數混合專家模型上進行了測試,同時開源了推理核心併發布了生產模型MiniMax-M3。

MSA將注意力分解為兩個階段:索引分支和主分支。索引分支決定每個查詢應讀取哪些鍵值塊,主分支僅對這些塊執行精確的softmax注意力。選擇以塊為單位(預設塊大小128個token),每個查詢和GQA組保留k=16個塊,固定每個查詢的預算為2048個鍵值token。與密集GQA注意力的O(N)複雜度相比,MSA的複雜度為O(kBk),隨上下文長度增長,計算差距不斷擴大。

索引分支僅向標準GQA層新增兩個投影矩陣。它定義一個索引查詢頭和一個共享的索引鍵頭,對可見鍵token評分,然後透過最大池化聚合到塊級別,使用Top-k選擇得分最高的塊。本地塊始終包含在內,防止遺漏近鄰區域。主分支則收集所選塊中的因果可見token,應用縮放點積softmax注意力。

由於Top-k選擇不可微分,MSA透過KL對齊損失訓練索引投影,使索引分支分佈匹配主分支注意力模式。三種機制穩定稀疏訓練:梯度分離(阻止梯度傳播到骨幹網路)、索引器預熱(前40B token使用全注意力)和強制本地塊(保留一個槽位給相鄰上下文)。研究團隊透過消融實驗最佳化了最終方案。

MSA支援兩種訓練路徑:MSA-PT(從頭訓練,含40B token預熱)和MSA-CPT(從2.6T token的密集GQA檢查點轉換,繼續訓練400B token)。演算法配合兩個核心設計實現實際加速:無指數Top-k選擇和KV外部稀疏注意力與查詢收集。開源核心fmha_sm100針對NVIDIA SM100 GPU,支援BF16、FP8等精度。

與現有稀疏方法相比,MSA的獨特之處在於每組獨立進行Top-k塊選擇,保持KV讀取連續性的同時允許各組獨立檢索。實驗表明,在3T token預算下,稀疏模型與全注意力基線競爭力相當。例如,MMLU得分分別為67.0(全注意力)、67.2(MSA-PT)和66.8(MSA-CPT)。

MSA適用於長上下文場景:長週期智慧體、倉庫級程式碼推理、永續性記憶和長影片理解。推理核心可透過Hugging Face庫使用,但要求SM100 GPU和CUDA工具包。

優勢包括:在1M上下文下每token注意力計算降低28.4倍,實際加速14.2倍(預填充)和7.6倍(解碼),設計簡潔,支援多種訓練模式,核心開源。劣勢:核心僅限NVIDIA SM100,部分任務存在效能差距,加速假設特定配置,KL損失增加訓練複雜度,結果基於自有評估。