AI News HubLIVE
站内改写

Nous Research、Lighthouse Attentionを提案:トレーニング専用の選択ベース階層型注意力機構により、長コンテキストで1.4~1.7倍の事前学習高速化を実現

Nous Researchは、Lighthouse Attentionを公開しました。これはトレーニング専用の選択ベース階層型注意力機構で、事前学習中に標準のスケーリングドット積注意をラップし、その後削除されます。キーと値のみをプールする従来の方法とは異なり、LighthouseはQ、K、Vを対称的にマルチ解像度ピラミッドでプールし、注意呼び出しをO(N·S·d)からO(S²·d)に削減し、小さな密な部分列で標準FlashAttentionを実行します。530MパラメータのLlama-3スタイルモデルで98Kコンテキストでテストされ、cuDNN SDPAベースラインに対して1.40~1.69倍のエンドツーエンドの壁時計速度向上を達成し、最終トレーニング損失は同等かそれ以下です。

記事インテリジェンス

エンジニア上級

要点

  • Lighthouse Attentionはトレーニング専用の階層型注意機構で、マルチ解像度ピラミッドでQ、K、Vを対称的にプールし、計算をO(N·S·d)からO(S²·d)に削減します。
  • 530MパラメータのLlama-3モデルで98Kコンテキストにおいて、cuDNN SDPAベースラインに対して1.4~1.7倍の高速化を達成し、最終損失は同等かそれ以下です。
  • 2段階トレーニングアプローチ(Lighthouse後、密な注意で回復)により、モデルは密な注意機能を維持し、損失は約1,000ステップで回復し、より低い最終損失に達します。
  • コンテキスト並列性により、カスタムカーネルなしで32台のBlackwell GPU上で100万トークンのトレーニングにスケーリング可能です。

重要な理由

このニュースが重要なのは、Lighthouse Attentionはトレーニング専用の階層型注意機構で、マルチ解像度ピラミッドでQ、K、Vを対称的にプールし、計算をO(N·S·d)からO(S²·d)に削減しますためです。

技術的影響

モデル選定、推論コスト、プロダクト能力、評価基準に影響する可能性があります。

大規模言語モデルを長いシーケンスで訓練する際、注意力の計算コストは周知の問題です。すべてのTransformerの中核であるスケーリングドット積注意(SDPA)は、計算とメモリの両方でシーケンス長Nに対して2次スケーリングΘ(N²)を持ちます。FlashAttentionは、IO対応タイル化によって高帯域幅メモリに完全なN×N注意行列を具体化するのを回避し、メモリフットプリントを大幅に削減しましたが、根本的なΘ(N²)計算スケーリングは変わりません。Nous Researchの研究者は、事前学習段階でこのボトルネックに対処する新しい手法「Lighthouse Attention(灯台注意)」を導入し、cuDNN対応SDPAベースラインに対して1.40倍から1.69倍のエンドツーエンド壁時計高速化を達成し、最終訓練損失は同等かそれ以下です。

既存のスパース注意法の中心的な問題

Lighthouseがなぜそのように機能するかを理解するには、既存のスパース注意法の限界を知る必要があります。ほとんどの先行研究(NSA、HISA、DSA、MoBAなど)は、同じ2つの設計選択を行っています。第一に、キーと値のみをプールし、クエリは全解像度のままにします(非対称圧縮)。第二に、それらの選択ロジックはカスタム注意カーネル内に存在するため、現代のGPUテンソルコア向けに最適化された密注意カーネルを再利用できません。さらに、トレーニング時スパース法は推論時法にはない特別な問題に直面します。推論時スパース法は、その密なバックボーンに対してのみ評価され、せいぜいバックボーンと同等の性能です。一方、トレーニング時スパース法はより厳しいテストに直面します。トレーニングが終了した後、得られた重みが依然として有能な密注意モデルを生成するかどうかです。Lighthouseはこの問題を中心的な正しさの基準としています。

Lighthouseはこれらの設計選択に対して異なるアプローチを取ります。クエリ、キー、値を多層ピラミッド全体で対称的にプールし、選択を完全に注意カーネルの外部に配置します。選択後、システムは選択されたエントリを連続した密な部分列に集め、それに対して標準のFlashAttentionを実行します。これは密なベースラインと同じカーネルです。

4段階パイプラインの仕組み

Lighthouse注意層は、スケーリングドット積注意をラップしますが、変更はしません。パイプラインは4つの段階からなります。

第1段階では、平均プーリングがQ、K、VからLレベルのピラミッドを構築します。プーリング係数pで、ピラミッドの第ℓ層はN/p^ℓ個のトークンを持ち、各トークンはp^ℓ個の基本位置を要約します。重要なのは、同じプーリングが3つの射影すべてに適用され、すべてのレベルで一貫した(Q^(ℓ), K^(ℓ), V^(ℓ))トリプレットを生成することです。ピラミッド構築の総コストはΘ(N)時間とメモリです。

第2段階では、パラメータフリーのスコアラーが各ピラミッドエントリに、ヘッドごとのℓ₂ノルムを使用して2つのスカラースコア(クエリスコア(∥Q^(ℓ)_i∥₂)とキースコア(∥K^(ℓ)_i∥₂))を割り当てます。粗いレベルは最大プーリングを介して細かいレベルからスコアを継承するため、粗いスパンはその最も強いトークンの重要性を捉えます。次に、融合されたチャンク化ビトニックtop-Kカーネルが、すべてのピラミッドレベルにわたって共同でk個のエントリを選択します。注目すべき設計詳細の1つ:最も粗いピラミッドレベルは常に完全に保持されます。これはコストが低く、すべての基本位置に少なくとも1つの貢献者を保証します。残りの選択予算はより細かいレベルに割り当てられます。さらに、チャンク化ビトニック設計は、厳密なグローバルtop-Kではなく、層別top-Kを生成します。スコアストリームは固定サイズのチャンクに分割され、各チャンクはレジスタ内のtop-mバッファを維持するため、グローバルに最もスコアの高いk個のエントリが1つのチャンクに集中した場合、一部は他のチャンクからの低スコアエントリに置き換えられます。その結果、シーケンス全体でよりバランスの取れた注意カバレッジが得られ、選択が狭い範囲に崩壊するのを防ぎます。

top-Kステップは離散的で微分不可能です。ストレートスルー推定器もGumbel softmaxもありません。選択インデックスは勾配を運びません。勾配は収集されたQ、K、Vエントリを介してWQ、WK、WVにのみ流れるため、射影は選択に優れたスコアではなく、選択されたときに有用な値を生成することを学習します。

第3段階では、選択されたエントリが長さS = N/p^(L−1) + (L−1)·p·kの連続した部分列に集められ、標準のFlashAttentionに渡されます。N = 1,000,000、L = 4、p = 4、k = 4,096の場合、S ≈ 65,000であり、Nよりはるかに小さいです。収集プロセスは、組み立てられた部分列に「穴」や空きスペースがないことを保証します。これは、Lighthouseもクエリを圧縮するため重要です。シーケンスのギャップは、それらの欠落トークンが逆伝播中に勾配パスを持たず、トレーニングの不安定性を引き起こす可能性があるためです。クエリを全解像度のままにする非対称手法はこの問題に直面しませんが、Lighthouseの対称設計では、収集された部分列が完全に密である必要があります。

第4段階では、各出力エントリは、それが表すp^ℓ個の基本位置に、決定論的な整数アトミック散乱カーネルを介して散乱され、因果性を維持するためにp^ℓ − 1のシフトが適用されます。位置ごとのファンインはkに関係なくLに制限されます。

対称プーリングが計算を変える理由

クエリをキーと値とともにプーリングすることで、トレーニング時の注意呼び出しの計算がO(N S d)からO(S² d)に変わります。長いコンテキストではS ≪ Nであるため、これがレイテンシの利点を生み出します。単一のNVIDIA B200で512Kコンテキスト(bfloat16、B=1、H=8、ヘッド次元128、L=3、p=4、スパース比≈1:64)でベンチマークしたところ、LighthouseはcuDNN対応SDPAに対して、フォワードパスで21倍、フォワード+バックワードパスで17.3倍高速でした。

漸近的観点から、L = log_p(N/k)と設定すると、収集された部分列サイズS = Θ(k log N)となり、密なFlashAttention呼び出しのコストはΘ(k² log² N d)(固定kでNに対して多対数)になります。線形コストの段階(ピラミッド構築、スコアリング、散乱)と組み合わせると、固定kでの層ごとの総計算量はΘ(T d)であり、線形注意や状態空間モデルと同じ漸近クラスでありながら、選択された部分列上でソフトマックス注意の想起特性を保持します。

推論は異なる制約があります。自己回帰デコードは一度に1つのクエリを提示するため、すべてのクエリが1つのフォワードパスで同時に発生するという仮定に違反します。Lighthouseはトレーニング専用の手法であり、対称プーリング設計は推論時に直接使用できません。

2段階トレーニングレシピと回復可能性

実験設定では、530MパラメータのLlama-3スタイルデコーダ(d_model=1024、30層、8ヘッド、ヘッド次元128、FFN幅1536、バイトレベルトークナイザ)を使用し、C4データセットで98,304トークンコンテキストでトレーニングしました。オプティマイザはAdamW(学習率2×10⁻³、β1=0.9、β2=0.95、重み減衰0.1、線形ウォームアップ2kステップ、勾配ノルムクリップ1、bfloat16、FSDP)です。実務者にとって重要な実装詳細:30層のうち、層{0,1,28,29}は全体を通して密なSDPAを保持し、他の26層のみがLighthouseを使用します。それら26のLighthouse層内の内部注意呼び出しは、密なベースラインと同じcuDNN対応SDPAカーネルを使用します。

トレーニングアプローチは2段階です。段階1では、ステップ予算の大部分でLighthouse選択を有効にしてトレーニングします。段階2では、密なSDPA(同じオプティマイザ状態、同じデータローダ)で段階1のチェックポイントから再開し、短いテールを実行します。段階1でモデルの密な注意能力が空洞化されていた場合、段階2の回復は失敗するはずです。

しかし、失敗しませんでした。合計16,000ステップ(約50.3Bトークン)の予算で、3つの分割ポイント(10k+6k、11k+5k、12k+4k)をテストし、スクラッチからの密なSDPAベースラインと比較しました。各再開ポイントで、モデルがトレーニングされていない注意力を初めて実行するため、トレーニング損失が一時的に1.12~1.57 nats上昇し、その後約1,000~1,500ステップのSDPA内で回復し、密なベースラインを下回りました。16,000ステップまでに、3つの再開されたLighthouse実行はすべて0.6980~0.7102の最終損失に達し、密なベースラインの0.7237に対して、同じトークン予算で壁時計時間は22.5~27.0時間であり、スクラッチからの密なSDPAの37.9時間と比較して短縮されました。

アブレーションとスループット

完全なアブレーショングリッドは、スコアラー型、プーリング係数p、ピラミッドレベル数L、top-K予算kをカバーしています。主な発見:射影ノルムスコアラーは、希釈ソフトマックス注意スコアラーとどちらの方向でも約0.01以内(一様な勝者なし)ですが、B200時間で約9%安価です。これはピラミッド上の注意パスをスキップするためです。一致する予算では、浅いピラミッド(L=3)が深いもの(L=4、L=5)よりも一貫して優れています。テスト範囲内では、小さいk値が回復後の損失を低くします。グリッド全体で最低損失の構成はL=3、p=2、k=1536で希釈スコアラーを使用した場合で、最終損失0.6825に達しました。この直感に反する結果は、研究チームがこのトークン予算規模での階層的選択の正則化効果に起因するとしています。

段階1のスループットはアブレーショングリッド全体で84,000~126,000トークン/秒/GPUであり、密なSDPAの約46,000に対してです。射影ノルムスコアラーはL=3、p=4、k=1536で、希釈注意パスを完全にスキップすることにより、範囲の上限である126,000トークン/秒/GPUに達しました。

長コンテキスト検索

損失ベースの回復可能性結果を補完するため、研究チームは簡略化された「干し草の山の中の針」評価を実行しました。単一のパスキー桁をランダムな英数字フィラーに隠し、深さ0~100%、コンテキスト長4K~96Kトークンで、検索は10桁トークンに対する1トークンのargmaxとしてスコアリングされます(ランダム確率:10%)。4つのLighthouse構成(k ∈ {1536, 2048}、スコアラー ∈ {希釈, ノルム}、L=3、p=4)を、スクラッチからの密なSDPAベースラインに対してテストしました。4つのLighthouse実行のうち3つは、密なベースラインの平均検索率0.72に一致またはそれを上回りました。k=2048希釈で0.76、k=1536希釈で0.73、k=2048ノルムでベースラインの0.72に一致しました。k=1536ノルムのみが0.65に低下しました。グリッド全体でパターンが現れました:より大きなkが検索性能の主要な軸であり、ノルムスコアラーは同じkでトレーニング損失よりも検索をより損なう。実用的な意味は、最適な構成は下流タスクが損失駆動型か検索駆動型かによって異なることです。

コンテキスト並列性スケーリング

約100Kトークンを超えるシーケンスの場合、Lighthouseはコンテキスト並列性の下で実行されます。ピラミッドプーリング、スコアリング、top-Kは各ランクでローカルに実行され、ランク間通信は不要です。これは、最も粗いプールウィンドウ(例:64トークン)がシャードサイズよりも数桁小さいためです。収集された部分列は密であるため、スパース認識集合演算なしで標準のリング注意に参加できます。これはスパースインデックスベースの手法では、スパースレイアウトに特化したエンジニアリングなしにはできません。コンテキスト並列性は、リングローテーションによりランクあたり約10%のスループットオーバーヘッドを導入しますが、Lighthouse対SDPAの高速化比は維持されます。この手法は、32台のBlackwell GPU(4ノード、CP度8)にわたる100万トークンのトレーニングに、内部注意カーネルを変更することなくスケーリング可能です。