MLSysのためのモダンGPUプログラミング
本記事では、機械学習システム向けのモダンGPUプログラミングに関する包括的なガイドを紹介します。GPUハードウェアアーキテクチャ、プログラミングモデル、そしてTIRx DSLを使用したGEMMやFlashAttentionなどの主要カーネルの段階的最適化をカバーしています。CMUのコースに基づき、Blackwell GPU世代を対象としています。
機械学習システムにおいて、モダンGPUプログラミングは極めて重要です。これらのシステムでは、パフォーマンスは少数のGPUカーネルの品質に依存することがよくあります。アテンションカーネル、LLMプリフィル・デコードカーネル、低精度ブロックスケーリングGEMM、融合MoE層、その他の大規模融合カーネルはすべて、トレーニングとサービングの両方でエンドツーエンドの速度を直接形作ります。しかし、これらのカーネルを高速にするには、最適化テクニックのリスト以上のものが必要です。モダンGPUはもはや同じ古い設計の単純なバリエーションではありません。最近のアーキテクチャは、よりリッチなメモリ空間、新しいアクセスパターン、ますます特殊化された実行ユニットを導入しています。うまくプログラムするには、ハードウェアの明確なメンタルモデルと、高性能カーネルがどのように構築されるかについての実践的な理解の両方が必要です。この本はその両方を培うことを目的としています。
この本は単純な進行に従います。最初にGPUハードウェアを理解し、次に使用するプログラミングモデルを学び、最後に最先端のカーネルを段階的に構築します。主なターゲットはBlackwell世代であり、主な実行例は高速行列乗算(GEMM)とFlashAttentionです。その過程で、GPU最適化の核心要素であるデータレイアウト、非同期データ移動、非同期調整も学びます。
この教材はカーネギーメロン大学の機械学習システムコースシリーズから生まれました。アイデアを学びやすく実行しやすくするために、この本ではTIRx Python DSLを使用して実際のGPUカーネル例を段階的に構築します。TIRxはハードウェアに密接しており、実行可能なコードを通じて学びながら低レベルの制御について推論できます。
第I部: GPUの理解:GPUの全体的な構成、高速カーネル記述の一般的なレシピ、データレイアウト、非同期メモリ操作、調整などの主要概念を紹介します。この本の残りの部分が依存するハードウェア直感を構築します。
第II部: TIRxの概要:コード例全体の基礎となるTIRxの主要要素を紹介します。
第III部: GEMM: タイル化からSOTAへ:TMAパイプライン、永続スケジューリング、ワープ専門化、2-CTAクラスターを通じて、タイル化GEMMの最適化を完全にガイドします。
第IV部: Flash Attention 4:第III部のテクニックを使用して構築された完全なアテンションカーネル:2つのMMAとその間のソフトマックス、オンラインソフトマックス再スケーリング、因果マスキング、GQA。
付録:TIRx言語リファレンスとコンパイラ内部。
この本では、順次単一タイルGEMMからマルチコンシューマワープ専門化までの9つのステップからなる最適化パス、およびFlash Attention 4の具体的な実装(Tile-Primitiveグラフ、ワープの役割とスコープ、2つのMMAフェーズ、TMEMレイアウトと再利用、バリアの接続、パイプライン構造、再スケーリングとライトバック、因果マスキング、GQAサポート、Tileスケジューリングなど)を詳細に説明しています。