PyTorch 性能分析(第2部分):从 nn.Linear 到融合 MLP
本文是 PyTorch 性能分析系列的第二部分,深入探讨了 nn.Linear 层的内部机制,包括转置操作、融合偏置的 epilogue 技术,以及 torch.compile 对单个线性层的影响。随后,文章剖析了一个包含 GeGLU 激活的多层感知机(MLP)的性能特征,展示了 GPU 内核的调度和执行过程。
PyTorch 性能分析(第2部分):从 nn.Linear 到融合 MLP
在系列的第一部分中,我们使用 torch.add(torch.matmul(x, w), b) 学习了如何读取 PyTorch 性能分析器的跟踪信息,并讨论了 CPU 调度链、启动开销、计算受限与开销受限的差异,以及 torch.compile 的一些内部机制。
在本部分,我们向上迈进一步,用 nn.Linear(bias=True) 替换手动编写的矩阵乘法与加法对。这是深度学习模型的基本构建块。接着,我们堆叠三个这样的层(具体示例),并在其间加入激活函数,形成一个多层感知机(MLP)块。
用于本文的脚本位于:02_linear.py、03_simple_mlp.py 和 03_kernels_mlp.py。建议在新标签页中打开它们,并边阅读代码边跟踪。我们使用 NVIDIA A100-SXM4-80GB GPU 运行脚本。
从矩阵乘法-加法到线性层
nn.Linear 是一个模块封装器,内部执行相同的矩阵乘法和加法。唯一的区别在于它拥有自己的权重和偏置作为参数,并暴露了用户熟悉的 forward 方法。
linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)操作可写为:
y = x @ w.T + b运行 02_linear.py 并检查性能分析跟踪。
转置操作的作用
放大跟踪图,我们发现在 aten::addmm 之前有一个 aten::t(转置)操作。nn.Linear 对权重参数进行转置,然后与输入相乘。重要的是,aten::t 并不实际复制或重排数据,它仅重写 CPU 上的张量元数据(形状和步长),不启动 GPU 内核。
为什么没有单独的乘法和加法内核?
线性层的调度链中没有 aten::add,因为偏置加法已通过所谓的 epilogue 融入了矩阵乘法内核。Epilogue 是 GEMM(通用矩阵乘法)内核在最后执行的一个小型计算,在将结果写回 HBM(高带宽内存)之前完成。添加偏置、应用激活函数或缩放常数都是典型的 epilogue。其目的是避免二次加载或写入 HBM,因为内存流量会使操作变得昂贵。
nn.Linear 调用 torch.nn.functional.linear,后者进一步调用 aten::linear。aten::linear 注意到传入了偏置,于是调度 aten::addmm(bias, x, weight),而不是分别执行 matmul 和 add。addmm 计算:
out = x @ weight.T + biasGPU 上运行的 cuBLAS GEMM 内核内置了偏置加法变体,aten::addmm 即选择该变体。偏置加法不会作为独立内核出现,因为它已是 matmul 内核回写的一部分——这正是 epilogue 的含义。
--compile 能优化单个线性层吗?
比较 eager 模式和编译模式对单个 nn.Linear 前向传播的性能分析跟踪,你会发现:GPU 上运行相同的 cuBLAS GEMM 内核,CPU 上运行相同的 aten::addmm 操作,只是编译模式下 CPU 行多了几行。对于单个带偏置的 GEMM,torch.compile 几乎没有优化空间。编译需要多个操作才可能进行融合。
转置去哪里了?内核布局与预操作
在 eager 模式的 CPU 调度链中,aten::linear 内部是 aten::t 后跟 aten::addmm。而编译模式直接调用 aten::addmm,没有转置操作。张量存储为连续的一维数组,形状和步长是元数据。aten::t 通过交换步长来创建视图,不复制数据。编译时,Inductor 跟踪视图链,预先计算步长,并直接发射硬编码了步长的 aten::addmm 调用,从而消除了 CPU 开销。GPU 仍然执行相同的数学运算,内核名称相同:cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8,其中的 tn 表示布局描述符,t 表示转置,n 表示非转置。
堆叠三个线性层:MLP
现在,我们分析一个多层感知机,使用 GeGLU 激活变体。该 MLP 包含三个线性层:gate_proj、up_proj 和 down_proj,前向函数还包括 GeLU 激活和逐元素乘法。
执行 03_simple_mlp.py 并查看跟踪。预期可以看到三个 aten::linear 调度和两个逐点内核(GeLU 和乘法)。实际跟踪确认了这一点:每个前向传播 GPU 运行恰好 5 个内核。三个 GEMM 内核在启动前会执行额外的 cudaOccupancyMaxActiveBlocksPerMultiprocessor 调用(occupancy 查询),而逐点操作直接启动。
通过本文,读者应能更深入地理解 PyTorch 性能分析,为后续优化打下基础。