PyTorch プロファイリング (第2部): nn.Linear から融合 MLP へ
本記事は PyTorch プロファイリングシリーズの第2部であり、nn.Linear レイヤーの内部機構(転置操作、バイアス融合エピローグ技術、torch.compile の影響)を掘り下げます。その後、GeGLU 活性化関数を含む多層パーセプトロン (MLP) のパフォーマンス特性を解析し、GPU カーネルのスケジューリングと実行を示します。
PyTorch プロファイリング (第2部): nn.Linear から融合 MLP へ
シリーズの第1部では、torch.add(torch.matmul(x, w), b) を使用して PyTorch プロファイラのトレースの読み方を学び、CPU ディスパッチチェーン、起動オーバーヘッド、計算律速とオーバーヘッド律速の違い、torch.compile の内部機構について議論しました。
第2部では、手書きの行列乗算と加算のペアを nn.Linear(bias=True) に置き換えます。これはすべてのディープラーニングモデルが使用する基本ブロックです。次に、3つのそのような層を(この例では)積み重ね、間に活性化関数を挟んで、多層パーセプトロン(MLP)ブロックを形成します。
このブログ記事のスクリプトは 02_linear.py、03_simple_mlp.py、03_kernels_mlp.py にあります。別のタブで開き、コードを追いながら読むとよいでしょう。NVIDIA A100-SXM4-80GB GPU を使用してスクリプトを実行します。
行列乗算-加算から Linear へ
nn.Linear は、同じ行列乗算と加算を内部で実行するモジュールラッパーです。唯一の違いは、重みとバイアスをパラメータとして保持し、PyTorch ユーザーが慣れ親しんだ forward メソッドを提供することです。
linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)操作は次のように書けます:
y = x @ w.T + b02_linear.py を実行し、プロファイルを確認します。
転置操作の役割
トレースを拡大すると、aten::addmm の前に aten::t(転置)操作があることがわかります。nn.Linear は重みパラメータを転置してから入力と乗算します。重要なのは、aten::t は実際にデータをコピーまたは再構成せず、CPU 上のテンソルメタデータ(形状とストライド)のみを書き換え、GPU カーネルを起動しないことです。
なぜ別々の mul と add カーネルがないのか?
線形層のディスパッチチェーンには aten::add がありません。これは、バイアス加算が エピローグ と呼ばれる機構によって行列乗算カーネルに融合されているためです。エピローグは、GEMM(汎用行列乗算)カーネルが結果を HBM(高帯域幅メモリ)に書き戻す直前に実行する小さな計算です。バイアスの追加、活性化関数の適用、定数によるスケーリングはすべて典型的なエピローグです。エピローグの目的は、HBM への2回目のロードや書き込みを避けることです。メモリトラフィックは操作を高コストにするからです。
nn.Linear は torch.nn.functional.linear を呼び出し、さらに aten::linear を呼び出します。aten::linear はバイアスが渡されたことを検出し、matmul と add を個別に実行する代わりに aten::addmm(bias, x, weight) をディスパッチします。addmm は次の計算を行います:
out = x @ weight.T + biasGPU 上で実行される cuBLAS GEMM カーネルにはバイアス加算バリアントが組み込まれており、aten::addmm はそれを選択します。バイアス加算は独立したカーネルとして現れません。なぜなら、それは matmul カーネルの書き戻しの一部だからです。これがまさにエピローグです。
--compile は単一の Linear を高速化できるか?
単一の nn.Linear の forward について、eager モードとコンパイルモードのプロファイルトレースを比較すると、次のことがわかります:GPU 上では同じ cuBLAS GEMM カーネル、CPU 上では同じ aten::addmm 操作、ただしコンパイルモードでは CPU 行にいくつかの追加行があります。単一のバイアス付き GEMM に対して、torch.compile にはほとんど最適化の余地がありません。コンパイルが融合を行うには、複数の操作が必要です。
転置はどこへ行った?カーネルレイアウトと事前操作
eager モードの CPU ディスパッチチェーンでは、aten::linear 内部は aten::t に続いて aten::addmm です。コンパイルモードでは、aten::addmm が直接呼び出され、転置はありません。テンソルはメモリ内に連続した1次元配列として格納され、形状とストライドはメタデータです。aten::t はストライドを交換することでビューを生成し、データをコピーしません。コンパイル時、Inductor はビューチェーンをトレースし、結果のストライドを事前計算して、そのストライドをハードコードした直接の aten::addmm 呼び出しを生成します。これにより CPU オーバーヘッドが排除されます。GPU は同じ計算を実行し、カーネル名は同じです:cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8。tn はレイアウト記述子で、t は転置、n は非転置を意味します。
3つの Linear を積み重ねる:MLP
ここでは、GeGLU 活性化バリアントを使用する多層パーセプトロン(MLP)をプロファイルします。この MLP には3つの線形層(gate_proj、up_proj、down_proj)があり、forward 関数には GeLU 活性化と要素ごとの乗算が含まれます。
03_simple_mlp.py を実行し、トレースを確認します。予想通り、3つの aten::linear ディスパッチと2つのポイントワイズカーネル(GeLU と乗算)が確認できます。各 forward パスで GPU はちょうど5つのカーネルを実行します。3つの GEMM カーネルは起動前に追加の cudaOccupancyMaxActiveBlocksPerMultiprocessor 呼び出し(オキュパンシークエリ)を実行し、ポイントワイズ演算は直接起動されます。
この記事を通じて、読者は PyTorch プロファイリングの理解を深め、今後の最適化の基盤を築くことができるでしょう。