Nous Research提出燈塔注意力機制:一種僅用於訓練的基於選擇的分層注意力,在長上下文預訓練中實現1.4–1.7倍加速
Nous Research釋出了燈塔注意力機制,這是一種僅用於訓練的基於選擇的分層注意力機制,在預訓練期間包裹標準縮放點積注意力,之後移除。與之前僅池化鍵和值的方法不同,燈塔注意力對稱地池化查詢、鍵和值,形成多解析度金字塔,將注意力呼叫從O(N·S·d)降低到O(S²·d),並對小型密集子序列執行標準FlashAttention。在530M引數的Llama-3風格模型上,以98K上下文測試,相較於cuDNN SDPA基線實現了1.40–1.69倍端到端加速,且最終訓練損失持平或更低。
文章情報
要點
- 燈塔注意力是一種訓練時使用的分層注意力機制,透過對稱池化查詢、鍵和值構建多解析度金字塔,大幅減少計算量。
- 在530M引數的Llama-3模型上,98K上下文長度下實現1.4-1.7倍加速,且最終損失低於或等於密集基線。
- 採用兩階段訓練:第一階段使用燈塔注意力,第二階段切換到標準密集註意力以恢復模型能力,實驗顯示損失迅速恢復並降低。
- 該方法可擴充套件至100萬token訓練,透過上下文並行性在32個Blackwell GPU上實現無特殊核心的擴充套件。
為什麼重要
這條新聞值得關注,因為燈塔注意力是一種訓練時使用的分層注意力機制,透過對稱池化查詢、鍵和值構建多解析度金字塔,大幅減少計算量。
技術影響
可能影響模型選型、推理成本、產品能力和評測基準。
訓練大型語言模型處理長序列時,注意力機制的計算成本是一個眾所周知的難題。每個Transformer核心的縮放點積注意力(SDPA)在計算和記憶體上均呈二次方規模Θ(N²),其中N是序列長度。FlashAttention透過IO感知的分塊技術解決了記憶體瓶頸,避免了在高速記憶體中例項化完整的N×N注意力矩陣,但底層的Θ(N²)計算規模並未改變。Nous Research的研究人員提出了一種名為“燈塔注意力”(Lighthouse Attention)的新方法,專門在預訓練階段解決這一瓶頸,實現了相對於cuDNN支援的SDPA基線1.40倍到1.69倍的端到端加速,同時最終訓練損失持平或更低。
現有稀疏注意力方法的核心問題
理解燈塔注意力為何有效,需要先了解現有稀疏注意力方法的侷限性。大多數先前工作,如NSA、HISA、DSA、MoBA,做出了兩個相同的設計選擇。首先,它們僅池化鍵和值側,而保留查詢的全解析度(非對稱壓縮)。其次,它們的選擇邏輯位於自定義注意力核心內部,這意味著團隊無法重複利用現代GPU張量核心最佳化的密集註意力核心。此外,訓練時稀疏方法面臨一個推理時方法沒有的特殊問題:推理時稀疏方法僅針對其密集骨幹網路評估,最多與骨幹網路效能相當;而訓練時稀疏方法面臨更嚴峻的考驗——訓練結束後,得到的權重是否仍能產生具備能力的密集註意力模型?燈塔注意力將此問題作為其核心正確性標準。
燈塔注意力在這兩個設計選擇上採取了不同路徑。它跨多層金字塔對稱地池化查詢、鍵和值,並將選擇過程完全置於注意力核心之外。選擇之後,系統將選中的條目收集到一個連續的密集子序列中,並對其執行標準的FlashAttention——與密集基線使用的核心相同。
四階段流水線的工作原理
燈塔注意力層包裹但不修改縮放點積注意力。該流水線包含四個階段。
第一階段,平均池化從Q、K、V構建L級金字塔。池化因子為p時,金字塔的第ℓ層包含N/p^ℓ個令牌,每個令牌彙總p^ℓ個基礎位置。關鍵在於,相同的池化應用於所有三個投影,從而在每個層級生成一致的(Q^(ℓ), K^(ℓ), V^(ℓ))三元組。金字塔構建的總成本為Θ(N)時間和記憶體。
第二階段,一個無引數評分器使用每個頭的ℓ₂範數為每個金字塔條目分配兩個標量分數:一個作為查詢分數(∥Q^(ℓ)_i∥₂),一個作為鍵分數(∥K^(ℓ)_i∥₂)。較粗的層級透過最大池化繼承較細層級的分數,因此粗粒度跨度能夠捕獲其最強令牌的重要性。然後,一個融合的分塊雙調top-K核心聯合跨所有金字塔層級選擇k個條目。一個值得注意的設計細節:最粗的金字塔層級始終完整保留——它成本低廉且保證每個基礎位置至少有一個貢獻者;剩餘的選擇預算分配給較細層級。此外,分塊雙調設計產生的是分層top-K而非嚴格的全域性top-K:分數流被劃分為固定大小的塊,每個塊維護一個暫存器內的top-m緩衝區,因此如果全域性得分最高的k個條目聚集在一個塊中,其中一些將被來自其他塊得分較低的條目替換。結果是在整個序列上實現更平衡的注意力覆蓋,避免了選擇坍縮到狹窄範圍內。
top-K步驟是離散且不可微的——沒有直通估計器,也沒有Gumbel softmax。選擇索引不攜帶梯度。梯度僅流經收集到的Q、K、V條目進入WQ、WK、WV,因此投影層學會產生對選擇有用的值,而非擅長選擇的分數。
第三階段,選中的條目被收集到一個長度為S = N/p^(L−1) + (L−1)·p·k的連續子序列中,並傳遞給標準FlashAttention。當N = 1,000,000,L = 4,p = 4,k = 4,096時,S ≈ 65,000,遠小於N。收集過程的一個關鍵屬性是保證組裝的子序列中沒有“空洞”或空位。這一點很重要,因為燈塔注意力也壓縮了查詢:序列中的間隙意味著這些缺失的令牌在反向傳播中沒有梯度路徑,可能導致訓練不穩定。保持查詢全解析度的非對稱方法不會面臨此問題,但燈塔注意力的對稱設計要求收集的子序列保持完全密集。
第四階段,每個輸出條目透過確定性整數原子散射核心散射回其代表的p^ℓ個基礎位置,並透過移動p^ℓ − 1來保持因果性。每個位置的實際扇入被限制為L,與k無關。
對稱池化如何改變計算
將查詢與鍵和值一起池化,將訓練時的注意力呼叫計算從O(N S d)改變為O(S² d)。由於在長上下文中S ≪ N,這產生了延遲優勢。在單個NVIDIA B200上以512K上下文(bfloat16,B=1,H=8,頭維度128,L=3,p=4,稀疏比≈1:64)進行基準測試,燈塔注意力相對於cuDNN支援的SDPA,前向傳播快21倍,前後向組合傳播快17.3倍。
從漸近角度來看,設定L = log_p(N/k)使得收集的子序列大小S = Θ(k log N),因此密集FlashAttention呼叫的成本為Θ(k² log² N d) —— 在固定k下關於N是多項對數複雜度。結合線性成本階段(金字塔構建、評分、散射回),在固定k下每層總計算量為Θ(T d) —— 與線性注意力和狀態空間模型相同的漸近類別,同時在選中的子序列上保留了softmax注意力的召回特性。
推理則是不同的約束。自迴歸解碼一次呈現一個查詢,這違反了所有查詢在一次前向傳播中共同出現的假設。燈塔注意力是僅訓練方法,對稱池化設計無法直接在推理中使用。
兩階段訓練配方與可恢復性
實驗設定使用了530M引數的Llama-3風格解碼器(d_model=1024,30層,8頭,頭維度128,FFN寬度1536,位元組級分詞器),在C4資料集上以98,304令牌上下文訓練,使用AdamW最佳化器(學習率2×10⁻³,β1=0.9,β2=0.95,權重衰減0.1,線性預熱2k步,梯度範數裁剪1,bfloat16,FSDP)。一個對實踐者重要的實現細節:30層中,層{0,1,28,29}全程保留密集SDPA——只有其餘26層使用燈塔注意力。這些燈塔注意力層內部的注意力呼叫使用與密集基線相同的cuDNN支援的SDPA核心。
訓練方法為兩階段。階段1在大部分步驟預算中啟用燈塔注意力選擇。階段2在密集SDPA下(相同最佳化器狀態、相同資料載入器)從階段1檢查點恢復訓練,進行短時間收尾。如果階段1破壞了模型的密集註意力能力,階段2的恢復將會失敗。
但實驗沒有失敗。在總預算16,000步(約50.3B令牌)下,測試了三個分界點(10k+6k,11k+5k,12k+4k),並與從頭訓練的密集SDPA基線進行比較。在每個恢復點,訓練損失暫時上升1.12–1.57 nats,因為模型首次執行其未訓練過的注意力,然後在大約1,000–1,500步SDPA內恢復並降至密集基線之下。到第16,000步,所有三個恢復的燈塔注意力執行最終損失在0.6980–0.7102之間,而密集基線為0.7237,同時在相同令牌預算下,壁鐘時間花費22.5至27.0小時,而從頭訓練密集SDPA需要37.9小時。
消融實驗與吞吐量
完整消融網格涵蓋評分器型別、池化因子p、金字塔層數L和top-K預算k。關鍵發現:投影範數評分器在任一方向上與稀釋softmax注意力評分器相差約0.01(無統一勝者),但在B200小時上大約便宜9%,因為它跳過了金字塔上的注意力傳遞。在匹配預算下,較淺的金字塔(L=3)始終優於較深的(L=4、L=5)。在測試範圍內,較小的k值在恢復後產生較低的損失——網格中最低損失配置為L=3,p=2,k=1536,使用稀釋評分器,最終損失達到0.6825——這一反直覺的結果被研究團隊歸因於在此令牌預算規模下分層選擇起到了正則化作用。
階段1吞吐量在消融網格中範圍為84,000至126,000令牌/秒/GPU,而密集SDPA約為46,000。投影範數評分器在L=3,p=4,k=1536時達到範圍上限126,000令牌/秒/GPU,因為它完全跳過了稀釋注意力傳遞。
長上下文檢索
為補充基於損失的可恢復性結果,研究團隊執行了簡化版“大海撈針”評估:在0-100%深度、4K到96K令牌上下文長度下,將單個口令數字隱藏於隨機字母數字填充中,檢索評分作為十個數字令牌的單令牌argmax(隨機機率:10%)。測試了四種燈塔注意力配置(k ∈ {1536, 2048},評分器 ∈ {稀釋, 範數},L=3,p=4),並與從頭訓練的密集SDPA基線比較。四個燈塔注意力執行中有三個匹配或超過密集基線的平均檢索率0.72:k=2048稀釋達到0.76,k=1536稀釋達到0.73,k=2048範數匹配基線0.72。只有k=1536範數下降至0.65。網格中呈現一個模式:較大的k是檢索效能的主要軸,範數評分器在相同k下對檢索的損害大於對訓練損失的損害。實際含義是,最優配置取決於下游任務是損失驅動型還是檢索驅動型。
上下文並行性擴充套件
對於超過約100K令牌的序列,燈塔注意力在上下文並行性下執行。金字塔池化、評分和top-K在每個rank上本地分片執行,無需跨rank通訊,因為最粗的池化視窗(例如64個令牌)遠小於分片大小。收集的子序列是密集的,因此參與標準的環形注意力,無需稀疏感知的集合操作——這是稀疏索引方法無法做到的,除非專門針對稀疏佈局進行工程處理。上下文並行性因環形旋轉引入約10%的每rank吞吐量開銷,但燈塔注意力與SDPA的加速比得以保持。該方法可擴充套件至跨32個Blackwell GPU(4節點,CP度為8)的100萬令牌訓練,無需更改內部注意力核心。