面向機器學習系統的現代GPU程式設計
本文介紹了一本關於現代GPU程式設計的綜合性指南,面向機器學習系統,涵蓋GPU硬體架構、程式設計模型以及使用TIRx DSL逐步最佳化關鍵核心(如GEMM和FlashAttention)的方法。內容基於卡內基梅隆大學的課程,以Blackwell GPU為目標。
現代GPU程式設計對於機器學習系統至關重要。在這些系統中,效能往往取決於少數GPU核心的質量。注意力核心、LLM預填充和解碼核心、低精度塊縮放GEMM、融合MoE層以及其他大型融合核心直接影響訓練和推理的端到端速度。然而,要使這些核心快速執行,需要的不僅僅是一系列最佳化技巧。現代GPU不再是同一舊設計的簡單變體。最近的架構引入了更豐富的記憶體空間、新的訪問模式和日益專門的執行單元。為了良好地程式設計這些GPU,我們既需要清晰的硬體心智模型,也需要理解高效能核心是如何構建的實用知識。本書旨在培養這兩方面的能力。
本書遵循簡單的遞進順序:首先理解GPU硬體,然後學習我們將使用的程式設計模型,最後逐步構建最先進的核心。我們的主要目標是Blackwell代,主要執行示例是快速矩陣乘法(GEMM)和FlashAttention。在此過程中,我們還將研究GPU最佳化的核心要素:資料佈局、非同步資料移動和非同步協調。
本書內容源於卡內基梅隆大學的機器學習系統課程系列。為了便於研究和執行,本書使用TIRx Python DSL逐步構建真實的GPU核心示例。TIRx貼近硬體,使我們能夠在透過可執行程式碼學習的同時,推理低層控制。
第一部分:理解GPU:介紹GPU的整體組織、編寫快速核心的通用方法以及關鍵概念,如資料佈局、非同步記憶體操作和協調。建立本書其餘部分依賴的硬體直覺。
第二部分:TIRx概述:介紹TIRx的關鍵元素,作為全書程式碼示例的基礎。
第三部分:GEMM:從分塊到SOTA:透過TMA流水線、持久排程、warp specialized和2-CTA叢集,完整指導最佳化分塊GEMM。
第四部分:Flash Attention 4:使用第三部分的技術構建完整的注意力核心:兩個MMA,中間有softmax、線上softmax重新縮放、因果掩碼和GQA。
附錄:TIRx語言參考和編譯器內部。
本書詳細介紹了最佳化路徑,包括9個步驟,從順序單塊GEMM到多消費者warp specialized,以及Flash Attention 4的具體實現,如Tile-Primitive圖、warp角色和作用域、兩MMA階段、TMEM佈局和重用、屏障連線、流水線結構、重新縮放和寫回、因果掩碼、GQA支援、Tile排程等。