AI News HubLIVE
站内改写6 分で読了

Parallax: ソフトマックスを維持し、学習された共分散補正ブランチを追加するパラメータ化された局所線形注意機構

Parallaxは、ソフトマックス注意を維持し、学習された共分散補正ブランチを追加する新しい注意機構で、局所線形注意(LLA)のクエリごとの共役勾配解法を置き換えます。FlashAttentionのキー・バリューストリームを再利用することで演算強度を2倍にし、0.6Bおよび1.7BスケールのLLM事前学習でより低いパープレキシティを達成します。ただし、その利得はMuonオプティマイザーに強く依存し、AdamWでは大幅に減少します。

ソースMarkTechPost著者: Asif Razzaq

Transformerの注意機構は2017年以来ほとんど変わっていません。ほとんどの効率化研究はソフトマックス注意を完全に置き換えようとしてきました。新しい論文は別のアプローチを取っています。ソフトマックス注意を維持し、補正ブランチを追加するのです。

ノースウェスタン大学、Tilde Research、ワシントン大学の研究チームは、「Parallax」と呼ばれるパラメータ化された局所線形注意を導入しました。これはLLMの事前学習にスケールし、Muonと共同設計されています。

Parallaxは計算を削減することで効率を追求するのではありません。意図的に計算を追加し、その計算を現代のGPUでより安価に実行できるようにします。

Parallaxとは

Parallaxは局所線形注意(LLA)に基づいています。LLAはテスト時回帰フレームワークに由来します。このフレームワークは注意をキー・バリューペア上の回帰解法と見なします。

この見方では、キーは訓練データ点、値はラベル、クエリはテスト点です。ソフトマックス注意はNadaraya-Watsonと呼ばれるノンパラメトリック推定量で、各クエリに対して局所定数関数を適合させます。

LLAはその局所定数推定を局所線形推定にアップグレードします。研究チームはこれが厳密に小さい積分平均二乗誤差をもたらすことを証明しています。利点は連想記憶におけるより良いバイアス・バリアンストレードオフです。

しかし、LLAには大規模化における問題があります。その正確な前方伝播は、各クエリに対して線形システムを解く必要があり、並列共役勾配(CG)解法を使用します。CG解法は3つの問題を引き起こします:集中的なI/O、難しい正則化と表現力のトレードオフ、低精度非互換性です。

Parallaxは解法を除去します。代わりに、追加の射影行列を学習します。研究チームはこれをρi = WR xiと書き表します。ここでWRはレイヤー入力から直接KV共分散をプローブする学習可能な行列です。

したがって、Parallaxは局所線形原理を維持しつつ、クエリごとの解法を学習されたクエリのような射影器に置き換えます。これにより、よりシンプルで効率的、かつ実装が容易になります。

メカニズムの動作

ParallaxはLLAをソフトマックス注意に加法的補正を加えたものとして再定式化します。出力はソフトマックス注意出力から射影共分散項を引いたものに等しくなります。研究論文の表記では、その項はKV共分散に学習されたプローブρiを掛けたものです。

研究チームはまた、LLAの一部である境界増幅因子を削除し、ゼロに設定しています。これは安定性のために必要です。プローブがパラメータ化されると、元の幾何学的解釈は崩れます。因子を残すと、スケーリングが発散したり符号が反転したりする可能性があります。

Parallaxは注意機構のファミリーの中に位置づけられます。研究チームは論文内で、帯域幅、プローブ構造、アフィン構造の3つの軸に沿ってこれを整理しています。一方の極端では、プローブノルムがゼロになるとParallaxは正確にソフトマックス注意に退化します。

WR = 0に設定すると、Parallax層はソフトマックス注意とまったく同じ動作をします。したがって、事前学習済みTransformerチェックポイントはWRを追加して微調整することで変換できます。

ハードウェアの議論

ParallaxはFlashAttentionのストリーミング構造を継承しています。同じキー・バリューストリームを再利用する1つの共分散ブランチを追加します。

研究チームは前方伝播を2つの並列スコアリングブランチに拡張します。両方のブランチはオンライン最大値、リスケーリング係数、KおよびVタイルを共有します。したがって、Parallaxはイテレーションごとに追加のI/Oを必要としません。

重要な特性はより高い演算強度(AI)です。AIは浮動小数点演算と高帯域幅メモリトラフィックの比率です。KV作業が支配的な領域では、Parallaxは演算強度を約2倍にします。同じメモリストリームを再利用しながら計算を追加します。

これにより、注意はより計算バウンドな領域に移行します。これはまさに、現代のハードウェアでカーネル最適化が役立つ領域です。

研究チームはNVIDIA Hopper GPU上でCuTeDSLを使用してデコードカーネルをプロトタイプ化しました。Hopperのテンソルコア行列乗算命令は、少なくとも64行のタイルで動作します。デコードステップは1つのクエリ行のみを供給します。したがって、QKおよびRK積は、標準注意がすでに発行する命令内で共同計算できます。

彼らはH200 GPU上でBF16精度でFlashAttention 2および3に対してプロファイリングを行いました。バッチサイズ1から2,048、コンテキスト長128から32,768を網羅しました。プロトタイプカーネルはすべての構成でFlashAttentionと同等またはそれを上回る性能を示しました。下図は、計算一致設定で1.54倍、I/O一致設定で1.14倍の高速化を示しています。

https://arxiv.org/pdf/2605.29157

実験結果

研究チームは合成タスクおよび0.6Bおよび1.7BスケールのLLM事前学習でParallaxを検証しました。モデルはtorchtitanリポジトリのQwen-3アーキテクチャを使用しました。Ultra-FineWebデータセットでコンテキスト長4096で訓練しました。ベースラインにはソフトマックス注意(Transformer)、Mamba、Gated DeltaNet、MesaNet、Kimi DeltaAttentionが含まれます。

MAD-Benchmarkでは、Parallaxは平均0.716で最高の全体的な精度を達成しました。In-Context-RecallやSelective-Copyingなどの想起指向タスクで一貫して改善しました。圧縮および記憶タスクでも競争力を維持しました。

言語モデリングでは、Muonを使用したParallaxが両方のスケールで最良のパープレキシティを達成しました。また、最高の平均下流精度を記録しました。1.7Bでは、Parallaxは62.45の平均スコアを記録し、Transformerの61.43を上回りました。

2つの対照実験で、利得の源泉をテストしました。パラメータ一致のTransformerはギャップのごく一部しか埋めませんでした。計算一致のParallaxは依然として両方のベースラインを上回りました。論文は、これは余分なパラメータや計算ではなく、メカニズム自体によるものだと主張しています。

オプティマイザーのひねり

中心的な発見は、オプティマイザーとアーキテクチャの相互作用です。ParallaxはMuonのもとで大きな優位性を示します。AdamWのもとでは、その優位性は著しく縮小するか、消滅します。

Muonは隠れ層の行列パラメータ向けの最近のオプティマイザーです。モーメントバッファの極因子を使用するため、更新の条件数は正確に1になります。先行研究では、これによりより良い条件付けの重み行列が生成されることが示されています。

研究チームはギャップを補正ブランチに帰着させます。彼らは補正出力比(COR)を定義します。Muonのもとでは、最深層でCORが8を超えます。AdamWのもとでは、CORは4未満にとどまります。

WR射影は不均衡に影響を受けます。AdamWのもとではその安定ランクが崩壊しますが、Muonのもとでは高いままです。ゲーティング実験がパターンを確認します。AdamWのもとでは、モデルは補正ブランチを使用する代わりに抑制することを学習します。

研究チームは、これを注意機構における強いアーキテクチャ・オプティマイザー共同設計の最初の実証的デモンストレーションと呼んでいます。彼らはMuonとWSDが最適なレシピであるとは主張していません。付録のアブレーションは、減衰フェーズ中に優位性が縮小することを示しています。

スコアの違い

Parallaxはまた、ソフトマックス注意とは異なるスコア分布を生成します。そのトークンごとの重みは負の値を取り得、大きさが1を超えることもあります。標準的なソフトマックス重みはこれができません。

研究チームは3つの効果を報告しています。Parallaxは無関係なトークンから値成分を積極的に減算できます。最初のトークン上の注意シンクを大幅に減少させます。そのベースとなるソフトマックスエントロピーはより高く、より拡散した注意重みを与えます。

長所と短所、未解決の問題

長所:

  • ソフトマックス注意をそのまま維持するため、事前学習済みTransformerはWRを追加して微調整することで変換可能。
  • FlashAttentionのキー・バリューストリームを再利用することで、イテレーションごとに追加のI/Oを必要としない。
  • 演算強度を2倍にし、プロトタイプデコードカーネルはFlashAttention 2/3と同等またはそれ以上。
  • パラメータ一致および計算一致の制御下で一貫したパープレキシティと下流利得を示す。

短所と未解決の問題:

  • 利得はMuonに大きく依存し、AdamWでは優位性がほぼ消失。
  • オプティマイザー依存の正確な原因は未解決の問題。
  • 結果は1.7Bスケールまでで、MoE、より長いコンテキスト、またはより大規模な実行はなし。
  • 優位性はWSD減衰フェーズ中に侵食され、重み減衰アニーリングによって部分的にしか修正されない。

主要なポイント

  • Parallaxはソフトマックス注意を維持し、学習された共分散補正ブランチを追加し、LLAのクエリごとの共役勾配解法を置き換える。
  • 同じKVストリームを再利用しながら演算強度を2倍にし、デコードカーネルはFlashAttention 2/3と同等またはそれを上回る。
  • 0.6Bおよび1.7Bスケールで、パラメータ一致および計算一致の制御下でも一貫したパープレキシティと下流利得を示す。
  • 利得はMuonに強く依存し、AdamWではその優位性が著しく縮小または消失する。
  • WR = 0に設定するとソフトマックス注意が正確に復元されるため、事前学習済みTransformerはWRを追加して微調整することで変換可能。
  • 論文とリポジトリをチェックしてください。また、Twitterでフォローし、150k+ ML SubRedditに参加し、ニュースレターを購読することをお忘れなく。Telegramにも参加できます。
  • GitHubリポジトリ、Hugging Faceページ、製品リリース、ウェビナーなどのプロモーションでパートナーになりたい場合は、お問い合わせください。