AI News HubLIVE
站内改写2 分钟阅读

PyTorch 性能分析(第2部分):从 nn.Linear 到融合 MLP

本文是 PyTorch 性能分析系列的第二部分,深入探讨了 nn.Linear 层的内部机制,包括转置操作、融合偏置的 epilogue 技术,以及 torch.compile 对单个线性层的影响。随后,文章剖析了一个包含 GeGLU 激活的多层感知机(MLP)的性能特征,展示了 GPU 内核的调度和执行过程。

PyTorch 性能分析(第2部分):从 nn.Linear 到融合 MLP

在系列的第一部分中,我们使用 torch.add(torch.matmul(x, w), b) 学习了如何读取 PyTorch 性能分析器的跟踪信息,并讨论了 CPU 调度链、启动开销、计算受限与开销受限的差异,以及 torch.compile 的一些内部机制。

在本部分,我们向上迈进一步,用 nn.Linear(bias=True) 替换手动编写的矩阵乘法与加法对。这是深度学习模型的基本构建块。接着,我们堆叠三个这样的层(具体示例),并在其间加入激活函数,形成一个多层感知机(MLP)块。

用于本文的脚本位于:02_linear.py03_simple_mlp.py03_kernels_mlp.py。建议在新标签页中打开它们,并边阅读代码边跟踪。我们使用 NVIDIA A100-SXM4-80GB GPU 运行脚本。

从矩阵乘法-加法到线性层

nn.Linear 是一个模块封装器,内部执行相同的矩阵乘法和加法。唯一的区别在于它拥有自己的权重和偏置作为参数,并暴露了用户熟悉的 forward 方法。

linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)

操作可写为:

y = x @ w.T + b

运行 02_linear.py 并检查性能分析跟踪。

转置操作的作用

放大跟踪图,我们发现在 aten::addmm 之前有一个 aten::t(转置)操作。nn.Linear 对权重参数进行转置,然后与输入相乘。重要的是,aten::t 并不实际复制或重排数据,它仅重写 CPU 上的张量元数据(形状和步长),不启动 GPU 内核。

为什么没有单独的乘法和加法内核?

线性层的调度链中没有 aten::add,因为偏置加法已通过所谓的 epilogue 融入了矩阵乘法内核。Epilogue 是 GEMM(通用矩阵乘法)内核在最后执行的一个小型计算,在将结果写回 HBM(高带宽内存)之前完成。添加偏置、应用激活函数或缩放常数都是典型的 epilogue。其目的是避免二次加载或写入 HBM,因为内存流量会使操作变得昂贵。

nn.Linear 调用 torch.nn.functional.linear,后者进一步调用 aten::linearaten::linear 注意到传入了偏置,于是调度 aten::addmm(bias, x, weight),而不是分别执行 matmul 和 add。addmm 计算:

out = x @ weight.T + bias

GPU 上运行的 cuBLAS GEMM 内核内置了偏置加法变体,aten::addmm 即选择该变体。偏置加法不会作为独立内核出现,因为它已是 matmul 内核回写的一部分——这正是 epilogue 的含义。

--compile 能优化单个线性层吗?

比较 eager 模式和编译模式对单个 nn.Linear 前向传播的性能分析跟踪,你会发现:GPU 上运行相同的 cuBLAS GEMM 内核,CPU 上运行相同的 aten::addmm 操作,只是编译模式下 CPU 行多了几行。对于单个带偏置的 GEMM,torch.compile 几乎没有优化空间。编译需要多个操作才可能进行融合。

转置去哪里了?内核布局与预操作

在 eager 模式的 CPU 调度链中,aten::linear 内部是 aten::t 后跟 aten::addmm。而编译模式直接调用 aten::addmm,没有转置操作。张量存储为连续的一维数组,形状和步长是元数据。aten::t 通过交换步长来创建视图,不复制数据。编译时,Inductor 跟踪视图链,预先计算步长,并直接发射硬编码了步长的 aten::addmm 调用,从而消除了 CPU 开销。GPU 仍然执行相同的数学运算,内核名称相同:cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8,其中的 tn 表示布局描述符,t 表示转置,n 表示非转置。

堆叠三个线性层:MLP

现在,我们分析一个多层感知机,使用 GeGLU 激活变体。该 MLP 包含三个线性层:gate_projup_projdown_proj,前向函数还包括 GeLU 激活和逐元素乘法。

执行 03_simple_mlp.py 并查看跟踪。预期可以看到三个 aten::linear 调度和两个逐点内核(GeLU 和乘法)。实际跟踪确认了这一点:每个前向传播 GPU 运行恰好 5 个内核。三个 GEMM 内核在启动前会执行额外的 cudaOccupancyMaxActiveBlocksPerMultiprocessor 调用(occupancy 查询),而逐点操作直接启动。

通过本文,读者应能更深入地理解 PyTorch 性能分析,为后续优化打下基础。