AI News HubLIVE
站内改写2 分鐘閱讀

PyTorch 性能分析(第2部分):從 nn.Linear 到融合 MLP

本文是 PyTorch 性能分析系列的第二部分,深入探討了 nn.Linear 層的內部機制,包括轉置操作、融合偏置的 epilogue 技術,以及 torch.compile 對單個線性層的影響。隨後,文章剖析了一個包含 GeGLU 激活的多層感知機(MLP)的性能特徵,展示了 GPU 內核的調度和執行過程。

PyTorch 性能分析(第2部分):從 nn.Linear 到融合 MLP

在系列的第一部分中,我們使用 torch.add(torch.matmul(x, w), b) 學習瞭如何讀取 PyTorch 性能分析器的跟蹤信息,並討論了 CPU 調度鏈、啓動開銷、計算受限與開銷受限的差異,以及 torch.compile 的一些內部機制。

在本部分,我們向上邁進一步,用 nn.Linear(bias=True) 替換手動編寫的矩陣乘法與加法對。這是深度學習模型的基本構建塊。接着,我們堆疊三個這樣的層(具體示例),並在其間加入激活函數,形成一個多層感知機(MLP)塊。

用於本文的腳本位於:02_linear.py03_simple_mlp.py03_kernels_mlp.py。建議在新標籤頁中打開它們,並邊閲讀代碼邊跟蹤。我們使用 NVIDIA A100-SXM4-80GB GPU 運行腳本。

從矩陣乘法-加法到線性層

nn.Linear 是一個模塊封裝器,內部執行相同的矩陣乘法和加法。唯一的區別在於它擁有自己的權重和偏置作為參數,並暴露了用户熟悉的 forward 方法。

linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)

操作可寫為:

y = x @ w.T + b

運行 02_linear.py 並檢查性能分析跟蹤。

轉置操作的作用

放大跟蹤圖,我們發現在 aten::addmm 之前有一個 aten::t(轉置)操作。nn.Linear 對權重參數進行轉置,然後與輸入相乘。重要的是,aten::t 並不實際複製或重排數據,它僅重寫 CPU 上的張量元數據(形狀和步長),不啓動 GPU 內核。

為什麼沒有單獨的乘法和加法內核?

線性層的調度鏈中沒有 aten::add,因為偏置加法已通過所謂的 epilogue 融入了矩陣乘法內核。Epilogue 是 GEMM(通用矩陣乘法)內核在最後執行的一個小型計算,在將結果寫回 HBM(高帶寬內存)之前完成。添加偏置、應用激活函數或縮放常數都是典型的 epilogue。其目的是避免二次加載或寫入 HBM,因為內存流量會使操作變得昂貴。

nn.Linear 調用 torch.nn.functional.linear,後者進一步調用 aten::linearaten::linear 注意到傳入了偏置,於是調度 aten::addmm(bias, x, weight),而不是分別執行 matmul 和 add。addmm 計算:

out = x @ weight.T + bias

GPU 上運行的 cuBLAS GEMM 內核內置了偏置加法變體,aten::addmm 即選擇該變體。偏置加法不會作為獨立內核出現,因為它已是 matmul 內核回寫的一部分——這正是 epilogue 的含義。

--compile 能優化單個線性層嗎?

比較 eager 模式和編譯模式對單個 nn.Linear 前向傳播的性能分析跟蹤,你會發現:GPU 上運行相同的 cuBLAS GEMM 內核,CPU 上運行相同的 aten::addmm 操作,只是編譯模式下 CPU 行多了幾行。對於單個帶偏置的 GEMM,torch.compile 幾乎沒有優化空間。編譯需要多個操作才可能進行融合。

轉置去哪裏了?內核佈局與預操作

在 eager 模式的 CPU 調度鏈中,aten::linear 內部是 aten::t 後跟 aten::addmm。而編譯模式直接調用 aten::addmm,沒有轉置操作。張量存儲為連續的一維數組,形狀和步長是元數據。aten::t 通過交換步長來創建視圖,不復制數據。編譯時,Inductor 跟蹤視圖鏈,預先計算步長,並直接發射硬編碼了步長的 aten::addmm 調用,從而消除了 CPU 開銷。GPU 仍然執行相同的數學運算,內核名稱相同:cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8,其中的 tn 表示佈局描述符,t 表示轉置,n 表示非轉置。

堆疊三個線性層:MLP

現在,我們分析一個多層感知機,使用 GeGLU 激活變體。該 MLP 包含三個線性層:gate_projup_projdown_proj,前向函數還包括 GeLU 激活和逐元素乘法。

執行 03_simple_mlp.py 並查看跟蹤。預期可以看到三個 aten::linear 調度和兩個逐點內核(GeLU 和乘法)。實際跟蹤確認了這一點:每個前向傳播 GPU 運行恰好 5 個內核。三個 GEMM 內核在啓動前會執行額外的 cudaOccupancyMaxActiveBlocksPerMultiprocessor 調用(occupancy 查詢),而逐點操作直接啓動。

通過本文,讀者應能更深入地理解 PyTorch 性能分析,為後續優化打下基礎。