Parallax:一種參數化局部線性注意力機制,保留Softmax並添加學習協方差修正分支
Parallax是一種新型注意力機制,保留Softmax注意力並添加一個學習的協方差修正分支,替代了局部線性注意力(LLA)中的逐查詢求解器。它通過重用FlashAttention的鍵值流,將算術強度提高一倍,並在0.6B和1.7B規模的LLM預訓練中實現了更低的困惑度。然而,其優勢高度依賴於Muon優化器,在AdamW下增益顯著縮小。
Transformer的注意力機制自2017年以來幾乎沒有變化。大多數效率工作試圖直接替換Softmax注意力。一篇新論文采取了不同的路線:保留Softmax注意力,並增加一個修正分支。
來自西北大學、Tilde Research和華盛頓大學的研究團隊提出了一種參數化的局部線性注意力機制,稱為“Parallax”,它能夠擴展到LLM預訓練,並與Muon優化器協同設計。
Parallax並非通過削減計算來追求效率。它有意增加計算量,然後使這些計算在現代GPU上運行得更便宜。
什麼是Parallax?
Parallax建立在局部線性注意力(LLA)的基礎上。LLA源於測試時迴歸框架,該框架將注意力視為鍵值對上的迴歸求解器。在此視角下,鍵是訓練數據點,值是標籤,查詢是測試點。Softmax注意力是一種稱為Nadaraya-Watson的非參數估計器,它為每個查詢擬合一個局部常數函數。
LLA將該局部常數估計升級為局部線性估計。研究團隊證明這產生了嚴格更小的積分均方誤差。其優勢在於聯想記憶的更好偏差-方差權衡。
但LLA在大規模應用中存在問題。其精確前向傳播需要為每個查詢求解一個線性系統,使用並行共軛梯度(CG)求解器。CG求解器帶來三個問題:密集的I/O、困難的正則化-表達權衡以及低精度不兼容。
Parallax移除了求解器。取而代之,它學習一個額外的投影矩陣。研究團隊將其寫為ρi = WR xi,其中WR是一個可學習矩陣,直接從層輸入探測KV協方差。
因此,Parallax保留了局部線性原理,只是用學習的、類似查詢的投影器替代了逐查詢求解。這使其更簡單、更高效,且更易於實現。
機制如何運作?
Parallax將LLA重新表述為Softmax注意力加上一個加性修正。輸出等於Softmax注意力輸出減去一個投影協方差項。在研究論文的符號中,該項是KV協方差乘以學習的探針ρi。
研究團隊還丟棄了LLA中的一個組件,稱為邊界放大因子,並將其設為零。這對於穩定性是必要的。一旦探針變為參數化,原有的幾何解釋就失效了。保留該因子可能導致縮放發散或符號翻轉。
Parallax位於一個注意力機制家族中。研究團隊在論文中通過三個軸對其進行組織:帶寬、探針構造和仿射結構。在一個極端,當探針範數為零時,Parallax精確退化為Softmax注意力。
設置WR = 0可以使Parallax層的行為與Softmax注意力完全相同。因此,預訓練的Transformer檢查點可以通過添加WR並微調進行轉換。
硬件論證
Parallax繼承了FlashAttention的流式結構。它添加了一個協方差分支,重用相同的鍵值流。
研究團隊將前向傳播擴展為兩個並行的評分分支。兩個分支共享在線最大值、縮放因子以及K和V的tile。因此,Parallax每次迭代不需要額外的I/O。
關鍵屬性是更高的算術強度(AI)。AI是浮點運算與高帶寬內存流量的比率。在KV工作佔主導地位的場景中,Parallax大約將算術強度提高一倍。它在重用相同內存流的同時增加了計算。
這將注意力推向更接近計算受限的狀態。這正是內核優化在現代硬件上有幫助的狀態。
研究團隊在NVIDIA Hopper GPU上使用CuTeDSL原型化了一個解碼內核。Hopper的張量核心矩陣乘法指令操作至少64行的tile。一個解碼步驟僅提供一行查詢。因此,QK和RK乘積可以聯合計算,在標準注意力已經發出的指令之內。
他們在H200 GPU上以BF16精度對FlashAttention 2和3進行了性能分析。他們掃描了從1到2,048的批次大小和從128到32,768的上下文長度。原型內核在所有配置下與FlashAttention性能相當或更優。下圖標註了在計算匹配設置下1.54倍的加速和在I/O匹配設置下1.14倍的加速。
https://arxiv.org/pdf/2605.29157
實驗展示
研究團隊在合成任務和0.6B及1.7B規模的LLM預訓練上驗證了Parallax。模型使用了torchtitan倉庫中的Qwen-3架構。他們在Ultra-FineWeb數據集上訓練,上下文長度為4096。基線包括Softmax注意力(Transformer)、Mamba、Gated DeltaNet、MesaNet和Kimi DeltaAttention。
在MAD-Benchmark上,Parallax取得了最高整體準確率,平均0.716。它持續改善了如上下文召回和選擇性複製等召回導向的任務。在壓縮和記憶任務上保持競爭力。
在語言建模上,使用Muon的Parallax在兩個規模上均取得了最佳困惑度。它還取得了最高的平均下游準確率。在1.7B規模,Parallax得分為62.45,而Transformer為61.43。
兩個對照實驗測試了增益的來源。參數匹配的Transformer僅縮小了一小部分差距。計算匹配的Parallax仍然擊敗了兩個基線。論文認為這表明增益來自機制本身,而非額外參數或計算。
優化器轉折
一個核心發現是優化器與架構的交互。Parallax在Muon下顯示出巨大優勢。在AdamW下,優勢顯著縮小甚至消失。
Muon是一種用於隱藏層矩陣參數的最新優化器。它使用動量緩衝區的極座標因子,因此更新的條件數恰好為1。先前的研究表明這產生了條件更好的權重矩陣。
研究團隊將差距追溯到修正分支。他們定義了修正輸出比(COR)。在Muon下,最深層的COR超過8。在AdamW下,COR保持在4以下。
WR投影受到不成比例的影響。在AdamW下,其穩定秩崩潰,而在Muon下保持高秩。一個門控實驗確認了該模式。在AdamW下,模型學會抑制修正分支而非使用它。
研究團隊稱這是注意力機制中架構-優化器協同設計的第一個實證演示。他們未聲稱Muon與WSD是最優方案。附錄中的消融實驗顯示,在衰減階段優勢縮小。
分數差異
Parallax還產生與Softmax注意力不同的分數分佈。其每token權重可以取負值並超過1的幅度。標準Softmax權重無法做到這一點。
研究團隊報告了三個效應。Parallax可以從不相關token中主動減去值分量。它顯著減少了第一個token上的注意力匯聚。其基礎Softmax熵保持較高,從而產生更分散的注意力權重。
優勢與劣勢及開放問題
優勢:
- 保持Softmax注意力完整,因此預訓練Transformer可通過添加WR並微調轉換。
- 通過重用FlashAttention的鍵值流,每次迭代不增加額外I/O。
- 算術強度翻倍,原型解碼內核在性能上與FlashAttention 2/3相當或更優。
- 在參數匹配和計算匹配控制下顯示一致的困惑度和下游增益。
劣勢與開放問題:
- 增益高度依賴Muon;在AdamW下優勢基本消失。
- 優化器依賴的確切原因仍是一個開放問題。
- 結果止於1.7B規模,未涉及MoE、更長上下文或更大規模運行。
- 優勢在WSD衰減階段減弱,僅通過權重衰減退火部分修復。
關鍵要點
- Parallax保留Softmax注意力並添加一個學習的協方差修正分支,替代了LLA的逐查詢共軛梯度求解器。
- 它在重用相同KV流的同時將算術強度提高一倍,解碼內核與FlashAttention 2/3性能相當或更優。
- 在0.6B和1.7B規模上,在參數匹配和計算匹配控制下均顯示一致的困惑度和下游增益。
- 增益高度依賴Muon;在AdamW下優勢顯著縮小或消失。
- 設置WR = 0可精確恢復Softmax注意力,因此預訓練Transformer可通過添加WR並微調轉換。
- 查看論文和代碼庫。另外,歡迎在Twitter上關注我們,別忘了加入我們的150k+ ML SubReddit和訂閲我們的新聞通訊。等等,你在Telegram上嗎?現在你也可以在Telegram上加入我們。
- 需要與我們合作推廣你的GitHub倉庫、Hugging Face頁面、產品發佈或網絡研討會等?聯繫我們。