AI News HubLIVE
站内改写4 分鐘閱讀

Parallax:一種引數化區域性線性注意力機制,保留Softmax並新增學習協方差修正分支

Parallax是一種新型注意力機制,保留Softmax注意力並新增一個學習的協方差修正分支,替代了區域性線性注意力(LLA)中的逐查詢求解器。它透過重用FlashAttention的鍵值流,將算術強度提高一倍,並在0.6B和1.7B規模的LLM預訓練中實現了更低的困惑度。然而,其優勢高度依賴於Muon最佳化器,在AdamW下增益顯著縮小。

來源MarkTechPost作者: Asif Razzaq

Transformer的注意力機制自2017年以來幾乎沒有變化。大多數效率工作試圖直接替換Softmax注意力。一篇新論文采取了不同的路線:保留Softmax注意力,並增加一個修正分支。

來自西北大學、Tilde Research和華盛頓大學的研究團隊提出了一種引數化的區域性線性注意力機制,稱為“Parallax”,它能夠擴充套件到LLM預訓練,並與Muon最佳化器協同設計。

Parallax並非透過削減計算來追求效率。它有意增加計算量,然後使這些計算在現代GPU上執行得更便宜。

什麼是Parallax?

Parallax建立在區域性線性注意力(LLA)的基礎上。LLA源於測試時迴歸框架,該框架將注意力視為鍵值對上的迴歸求解器。在此視角下,鍵是訓練資料點,值是標籤,查詢是測試點。Softmax注意力是一種稱為Nadaraya-Watson的非引數估計器,它為每個查詢擬合一個區域性常數函式。

LLA將該區域性常數估計升級為區域性線性估計。研究團隊證明這產生了嚴格更小的積分均方誤差。其優勢在於聯想記憶的更好偏差-方差權衡。

但LLA在大規模應用中存在問題。其精確前向傳播需要為每個查詢求解一個線性系統,使用並行共軛梯度(CG)求解器。CG求解器帶來三個問題:密集的I/O、困難的正則化-表達權衡以及低精度不相容。

Parallax移除了求解器。取而代之,它學習一個額外的投影矩陣。研究團隊將其寫為ρi = WR xi,其中WR是一個可學習矩陣,直接從層輸入探測KV協方差。

因此,Parallax保留了區域性線性原理,只是用學習的、類似查詢的投影器替代了逐查詢求解。這使其更簡單、更高效,且更易於實現。

機制如何運作?

Parallax將LLA重新表述為Softmax注意力加上一個加性修正。輸出等於Softmax注意力輸出減去一個投影協方差項。在研究論文的符號中,該項是KV協方差乘以學習的探針ρi。

研究團隊還丟棄了LLA中的一個元件,稱為邊界放大因子,並將其設為零。這對於穩定性是必要的。一旦探針變為引數化,原有的幾何解釋就失效了。保留該因子可能導致縮放發散或符號翻轉。

Parallax位於一個注意力機制家族中。研究團隊在論文中透過三個軸對其進行組織:頻寬、探針構造和仿射結構。在一個極端,當探針範數為零時,Parallax精確退化為Softmax注意力。

設定WR = 0可以使Parallax層的行為與Softmax注意力完全相同。因此,預訓練的Transformer檢查點可以透過新增WR並微調進行轉換。

硬體論證

Parallax繼承了FlashAttention的流式結構。它新增了一個協方差分支,重用相同的鍵值流。

研究團隊將前向傳播擴充套件為兩個並行的評分分支。兩個分支共享線上最大值、縮放因子以及K和V的tile。因此,Parallax每次迭代不需要額外的I/O。

關鍵屬性是更高的算術強度(AI)。AI是浮點運算與高頻寬記憶體流量的比率。在KV工作佔主導地位的場景中,Parallax大約將算術強度提高一倍。它在重用相同記憶體流的同時增加了計算。

這將注意力推向更接近計算受限的狀態。這正是核心最佳化在現代硬體上有幫助的狀態。

研究團隊在NVIDIA Hopper GPU上使用CuTeDSL原型化了一個解碼核心。Hopper的張量核心矩陣乘法指令操作至少64行的tile。一個解碼步驟僅提供一行查詢。因此,QK和RK乘積可以聯合計算,在標準注意力已經發出的指令之內。

他們在H200 GPU上以BF16精度對FlashAttention 2和3進行了效能分析。他們掃描了從1到2,048的批次大小和從128到32,768的上下文長度。原型核心在所有配置下與FlashAttention效能相當或更優。下圖示註了在計算匹配設定下1.54倍的加速和在I/O匹配設定下1.14倍的加速。

https://arxiv.org/pdf/2605.29157

實驗展示

研究團隊在合成任務和0.6B及1.7B規模的LLM預訓練上驗證了Parallax。模型使用了torchtitan倉庫中的Qwen-3架構。他們在Ultra-FineWeb資料集上訓練,上下文長度為4096。基線包括Softmax注意力(Transformer)、Mamba、Gated DeltaNet、MesaNet和Kimi DeltaAttention。

在MAD-Benchmark上,Parallax取得了最高整體準確率,平均0.716。它持續改善了如上下文召回和選擇性複製等召回導向的任務。在壓縮和記憶任務上保持競爭力。

在語言建模上,使用Muon的Parallax在兩個規模上均取得了最佳困惑度。它還取得了最高的平均下游準確率。在1.7B規模,Parallax得分為62.45,而Transformer為61.43。

兩個對照實驗測試了增益的來源。引數匹配的Transformer僅縮小了一小部分差距。計算匹配的Parallax仍然擊敗了兩個基線。論文認為這表明增益來自機制本身,而非額外引數或計算。

最佳化器轉折

一個核心發現是最佳化器與架構的互動。Parallax在Muon下顯示出巨大優勢。在AdamW下,優勢顯著縮小甚至消失。

Muon是一種用於隱藏層矩陣引數的最新最佳化器。它使用動量緩衝區的極座標因子,因此更新的條件數恰好為1。先前的研究表明這產生了條件更好的權重矩陣。

研究團隊將差距追溯到修正分支。他們定義了修正輸出比(COR)。在Muon下,最深層的COR超過8。在AdamW下,COR保持在4以下。

WR投影受到不成比例的影響。在AdamW下,其穩定秩崩潰,而在Muon下保持高秩。一個門控實驗確認了該模式。在AdamW下,模型學會抑制修正分支而非使用它。

研究團隊稱這是注意力機制中架構-最佳化器協同設計的第一個實證演示。他們未聲稱Muon與WSD是最優方案。附錄中的消融實驗顯示,在衰減階段優勢縮小。

分數差異

Parallax還產生與Softmax注意力不同的分數分佈。其每token權重可以取負值並超過1的幅度。標準Softmax權重無法做到這一點。

研究團隊報告了三個效應。Parallax可以從不相關token中主動減去值分量。它顯著減少了第一個token上的注意力匯聚。其基礎Softmax熵保持較高,從而產生更分散的注意力權重。

優勢與劣勢及開放問題

優勢:

  • 保持Softmax注意力完整,因此預訓練Transformer可透過新增WR並微調轉換。
  • 透過重用FlashAttention的鍵值流,每次迭代不增加額外I/O。
  • 算術強度翻倍,原型解碼核心在效能上與FlashAttention 2/3相當或更優。
  • 在引數匹配和計算匹配控制下顯示一致的困惑度和下游增益。

劣勢與開放問題:

  • 增益高度依賴Muon;在AdamW下優勢基本消失。
  • 最佳化器依賴的確切原因仍是一個開放問題。
  • 結果止於1.7B規模,未涉及MoE、更長上下文或更大規模執行。
  • 優勢在WSD衰減階段減弱,僅透過權重衰減退火部分修復。

關鍵要點

  • Parallax保留Softmax注意力並新增一個學習的協方差修正分支,替代了LLA的逐查詢共軛梯度求解器。
  • 它在重用相同KV流的同時將算術強度提高一倍,解碼核心與FlashAttention 2/3效能相當或更優。
  • 在0.6B和1.7B規模上,在引數匹配和計算匹配控制下均顯示一致的困惑度和下游增益。
  • 增益高度依賴Muon;在AdamW下優勢顯著縮小或消失。
  • 設定WR = 0可精確恢復Softmax注意力,因此預訓練Transformer可透過新增WR並微調轉換。
  • 檢視論文和程式碼庫。另外,歡迎在Twitter上關注我們,別忘了加入我們的150k+ ML SubReddit和訂閱我們的新聞通訊。等等,你在Telegram上嗎?現在你也可以在Telegram上加入我們。
  • 需要與我們合作推廣你的GitHub倉庫、Hugging Face頁面、產品釋出或網路研討會等?聯絡我們。