MoonMath AI、AMD MI300X向けHIPアテンションカーネルをオープンソース公開 – あらゆる形状と丸めモードでAITER v3を上回る
MoonMath AIチームは、AMD MI300X GPU向けのbf16フォワードアテンションカーネルをリリースしました。HIPで記述され、MITライセンスでオープンソース化されています。単一命令アセンブリラッパーと8ウェーブパイプラインなどの革新的技術により、テストしたすべての形状と丸めモードでAMDの最適化カーネルAITER v3を凌駕し、幾何平均で1.08倍から1.18倍の高速化を達成。主要な高速化はメモリ配置(KをLDS、VをL1、Qとアキュムレータをレジスタに配置)によるものです。また、実際のSGLang PRに統合され、Wan2.1ビデオ拡散モデルのエンドツーエンド性能を品質低下なしで1.23倍向上させました。
MoonMath AIチームは、AMD MI300X GPU向けのbf16フォワードアテンションカーネルをリリースしました。このカーネルはHIPで記述され(手書きアセンブリではなく)、MITライセンスでオープンソース公開されています。MoonMath.aiチームによると、テストしたすべての形状においてAMDの最適化カーネルAITER v3を上回る性能を発揮します。ベアメタルアクセスはAMDクラウドプロバイダーHotAisleから提供されました。
アテンションは、各トランスフォーマー内の融合softmax(QKᵀ/√d)·V操作です。MI300XはAMDのCDNA3データセンターGPUで、ISAターゲットはgfx942です。このカーネルはこのハードウェアでのみ動作します。
カーネルとは、GPUの多数のコア上で直接実行され、特定の計算(ここではアテンション演算)を可能な限り高速に実行する小さなプログラムです。このカーネルはMI300X上でのみbf16フォワードアテンションを計算し、BSHDまたはBHSDレイアウトをサポートし、転置は不要で、ヘッド次元は128に固定されています。任意のシーケンス長(クロスアテンションを含む)に対応します。
現在の制限として、因果マスク、GQA、可変長バッチ処理はサポートされていません。出力はbf16で、gfx942ハードウェア専用です。数値的には厳密に制御されており、3つの丸めモードすべてがAITERのモード別丸め規則に一致し、各有限出力はAITERとの差が1 bf16 ULP以内、NaNおよびInf処理はビット単位で同一、結果は決定論的です。
核となる手法:単一命令アセンブリラッパー この手法は一般的なジレンマを回避します。コンパイラ組み込み関数はコードを整理しますが、コンパイラがオペランドを並べ替えたり名前を変更したりする可能性があります。一方、生のインラインアセンブリは制御を提供しますが、手動でレジスタとアドレスを管理する必要があります。MoonMathは単一命令をdevice forceinline関数でラップし、拡張asm制約でオペランドを記述します。チームがオペコードを選択し、コンパイラがレジスタ割り当てとデータフロー追跡を行います。
例えば、asm mfma関数では、"+v"(c)制約によりアキュムレータの入力と出力が同じVGPRに結び付けられ、コピー命令が発行されません。これにより、カーネルは通常のHIPに近くなりながら、機械を1命令ずつ制御できます。
アーキテクチャ:8ウェーブ、2グループ、2バリア CDNA3コンピュートユニットは4つのSIMDユニットを持ちます。教科書的には4ウェーブですが、MoonMathはブロックあたり8ウェーブを2グループ(各4ウェーブ)で実行します。2つのグループは同じQ*K、softmax、O += P*Vシーケンスを実行しますが、位相がずれています。一方のグループが行列コアを占有している間、もう一方はsoftmaxを実行しロードを発行し、その後交換することで行列コアがアイドルになることを防ぎます。反復あたり2つのs_barrierがあります。1つは位相の引き継ぎ時、もう1つは反復境界にあります。残りの同期はカウンタ単位の待機で処理されます。これはFlashAttention-3の行列乗算とsoftmaxの交互実行に似ていますが、プロデューサー/コンシューマーウェーブ分割は採用していません。CDNA3ではすべてのメモリ移動が非同期であるため、専用のプロデューサーウェーブは不要です。
データの配置と16×16×16の選択 高速化の大部分はメモリ配置に由来します。KはHBMからLDSにストリーミングされ、ダブルバッファリングされ、全8ウェーブで共有されます。VはL1に保持され、PV行列乗算のたびに読み取られます。Qとアキュムレータはレジスタに配置されます。チームは32×32×8ではなく16×16×16 MFMAを選択しました。両方の形状でスループットは同一ですが、より小さなタイルではレーンあたり4つのfp32要素(16ではなく)を累積するため、アキュムレータの負荷が低くなり、より深いプリフェッチと3つ目のQタイルの余地が生まれます。
その後、2つの改良により差がさらに広がりました。3つ目のQタイル(3Q)により、ロードされたKおよびVタイルあたりのデータ再利用性が向上しました。Flash-DecodingスタイルのテールKV分割により、MI300Xの304 CUにわたる端数のラウンドが救われました。これらの改良は連鎖的に効果を発揮します。VをL1に移動することでLDSが解放され、そこに3つ目のQタイルが配置されました。
ベンチマーク テストはMI300X上でbf16、ヘッド次元128で実施されました。各形状は3つの丸めモード(RTNE、RTNA、RTZ)で測定されました。結果はMoonMathカーネルの優位性を示しています。AITERとの幾何平均は1.18倍(RTNE)、1.15倍(RTNA)、1.08倍(RTZ)でした。Modular MAXとの比較では幾何平均1.44倍から1.49倍、形状別では最大1.59倍の高速化を達成しました。
実際のユースケース このカーネルはpipでインストールでき、シンプルなAPIを公開します。呼び出し元のストリーム上で起動するため、大規模なパイプライン内でオーバーラップします。具体的なユースケースとしてビデオ拡散があります。チームはLiteAttentionサポートを追加し、SGLang拡散にPRを送信しました。Wan2.1-T2V-1.3B-Diffusersでは、アテンションをAITERからliteattention_rocmに切り替えたところ、MI300X上でエンドツーエンドの生成性能が1.23倍向上し、画質の低下は見られませんでした。BSHDレイアウトは拡散テンソルに直接適合し、クロスアテンションは任意のKV長とパディングなしで機能します。
主要なポイント
- MI300X向けbf16フォワードアテンションカーネル、HIP記述、MITライセンス。
- すべての形状と丸めモードでAITER v3を上回る(幾何平均1.18倍/1.15倍/1.08倍)。
- 単一命令アセンブリラッパーによりオペコード制御を維持しつつコンパイラがレジスタ割り当て。
- メモリ配置が主要な高速化要因:KをLDS、VをL1、Qをレジスタに配置。
- 実際のSGLang PRによりWan2.1ビデオ拡散が1.23倍高速化、品質低下なし。
詳細な技術情報は原文をご参照ください。Twitterでフォロー、150k+ ML SubRedditへの参加、ニュースレターの購読もお忘れなく。