AI News HubLIVE
站内改写2 分鐘閱讀

PyTorch 效能分析(第2部分):從 nn.Linear 到融合 MLP

本文是 PyTorch 效能分析系列的第二部分,深入探討了 nn.Linear 層的內部機制,包括轉置操作、融合偏置的 epilogue 技術,以及 torch.compile 對單個線性層的影響。隨後,文章剖析了一個包含 GeGLU 啟用的多層感知機(MLP)的效能特徵,展示了 GPU 核心的排程和執行過程。

PyTorch 效能分析(第2部分):從 nn.Linear 到融合 MLP

在系列的第一部分中,我們使用 torch.add(torch.matmul(x, w), b) 學習瞭如何讀取 PyTorch 效能分析器的跟蹤資訊,並討論了 CPU 排程鏈、啟動開銷、計算受限與開銷受限的差異,以及 torch.compile 的一些內部機制。

在本部分,我們向上邁進一步,用 nn.Linear(bias=True) 替換手動編寫的矩陣乘法與加法對。這是深度學習模型的基本構建塊。接著,我們堆疊三個這樣的層(具體示例),並在其間加入啟用函式,形成一個多層感知機(MLP)塊。

用於本文的指令碼位於:02_linear.py03_simple_mlp.py03_kernels_mlp.py。建議在新標籤頁中開啟它們,並邊閱讀程式碼邊跟蹤。我們使用 NVIDIA A100-SXM4-80GB GPU 執行指令碼。

從矩陣乘法-加法到線性層

nn.Linear 是一個模組封裝器,內部執行相同的矩陣乘法和加法。唯一的區別在於它擁有自己的權重和偏置作為引數,並暴露了使用者熟悉的 forward 方法。

linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)

操作可寫為:

y = x @ w.T + b

執行 02_linear.py 並檢查效能分析跟蹤。

轉置操作的作用

放大跟蹤圖,我們發現在 aten::addmm 之前有一個 aten::t(轉置)操作。nn.Linear 對權重引數進行轉置,然後與輸入相乘。重要的是,aten::t 並不實際複製或重排資料,它僅重寫 CPU 上的張量後設資料(形狀和步長),不啟動 GPU 核心。

為什麼沒有單獨的乘法和加法核心?

線性層的排程鏈中沒有 aten::add,因為偏置加法已透過所謂的 epilogue 融入了矩陣乘法核心。Epilogue 是 GEMM(通用矩陣乘法)核心在最後執行的一個小型計算,在將結果寫回 HBM(高頻寬記憶體)之前完成。新增偏置、應用啟用函式或縮放常數都是典型的 epilogue。其目的是避免二次載入或寫入 HBM,因為記憶體流量會使操作變得昂貴。

nn.Linear 呼叫 torch.nn.functional.linear,後者進一步呼叫 aten::linearaten::linear 注意到傳入了偏置,於是排程 aten::addmm(bias, x, weight),而不是分別執行 matmul 和 add。addmm 計算:

out = x @ weight.T + bias

GPU 上執行的 cuBLAS GEMM 核心內建了偏置加法變體,aten::addmm 即選擇該變體。偏置加法不會作為獨立核心出現,因為它已是 matmul 核心回寫的一部分——這正是 epilogue 的含義。

--compile 能最佳化單個線性層嗎?

比較 eager 模式和編譯模式對單個 nn.Linear 前向傳播的效能分析跟蹤,你會發現:GPU 上執行相同的 cuBLAS GEMM 核心,CPU 上執行相同的 aten::addmm 操作,只是編譯模式下 CPU 行多了幾行。對於單個帶偏置的 GEMM,torch.compile 幾乎沒有最佳化空間。編譯需要多個操作才可能進行融合。

轉置去哪裡了?核心佈局與預操作

在 eager 模式的 CPU 排程鏈中,aten::linear 內部是 aten::t 後跟 aten::addmm。而編譯模式直接呼叫 aten::addmm,沒有轉置操作。張量儲存為連續的一維陣列,形狀和步長是後設資料。aten::t 透過交換步長來建立檢視,不復制資料。編譯時,Inductor 跟蹤檢視鏈,預先計算步長,並直接發射硬編碼了步長的 aten::addmm 呼叫,從而消除了 CPU 開銷。GPU 仍然執行相同的數學運算,核心名稱相同:cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8,其中的 tn 表示佈局描述符,t 表示轉置,n 表示非轉置。

堆疊三個線性層:MLP

現在,我們分析一個多層感知機,使用 GeGLU 啟用變體。該 MLP 包含三個線性層:gate_projup_projdown_proj,前向函式還包括 GeLU 啟用和逐元素乘法。

執行 03_simple_mlp.py 並檢視跟蹤。預期可以看到三個 aten::linear 排程和兩個逐點核心(GeLU 和乘法)。實際跟蹤確認了這一點:每個前向傳播 GPU 執行恰好 5 個核心。三個 GEMM 核心在啟動前會執行額外的 cudaOccupancyMaxActiveBlocksPerMultiprocessor 呼叫(occupancy 查詢),而逐點操作直接啟動。

透過本文,讀者應能更深入地理解 PyTorch 效能分析,為後續最佳化打下基礎。