AI News HubLIVE
站内改写

PyTorch 性能分析(第一部分):torch.profiler 入门指南

本文是 PyTorch 性能分析系列的第一篇,从最简单的矩阵乘加操作开始,引导读者学习如何使用 torch.profiler 进行性能分析,包括设置分析器、解读分析表和追踪数据,以及理解 CPU 和 GPU 活动之间的时间关系。文章还讨论了预热和优化开销等问题。

文章情报

投资人进阶

要点

  • torch.profiler 可以生成性能分析表和时间线追踪,帮助识别热点和瓶颈。
  • 小矩阵乘法容易导致开销受限,增大矩阵规模可转为计算受限。
  • 预热步骤对于消除启动开销至关重要,确保分析结果准确。
  • CPU 和 GPU 活动之间存在时间偏移,反映了内核启动和同步的延迟。

为什么重要

这条新闻值得关注,因为torch.profiler 可以生成性能分析表和时间线追踪,帮助识别热点和瓶颈。

技术影响

可能影响模型选型、推理成本、产品能力和评测基准。

在优化 PyTorch 模型时,性能分析是不可或缺的一步。无论你是想提高大语言模型(LLM)的每秒 token 数,缩短推理时间,还是仅仅想了解训练循环为何比预期慢,最终都需要通过性能分析来定位问题。然而,性能分析有着较高的学习门槛:追踪数据是密集的彩色矩形,事件名称令人望而生畏,大多数教程假设读者已经具备读取能力。因此,即使我们知道应该进行性能分析,打开追踪文件也常常被推迟或交给他人处理。本文旨在降低这一门槛。

本系列文章将从 PyTorch 性能分析的基础开始,逐步培养读者读取分析追踪数据并用其指导优化的能力。本部分(第一部分)从最简单的操作——矩阵乘法加偏置——开始,教你如何理解分析器返回的信息。后续部分将扩展到 nn.Linear 和小型 MLP,并最终应用于大语言模型中的 Transformer。

准备工作:核心概念

在开始之前,需要了解两个定义:

  • GPU 内核(kernel):是在 GPU 上并行运行的程序。
  • CPU 调度并启动这些内核

当使用 PyTorch 操作时,它会被自动转换为一个或多个在 GPU 上执行的内核。

创建性能分析代码

我们从最简单的操作开始:一个矩阵乘法加一个偏置加法。代码如下:

def fn(x, w, b):
    return torch.add(torch.matmul(x, w), b)

使用 torch.profiler 进行性能分析的步骤如下:

  1. 准备好要分析的代码(如上)。
  2. torch.profiler.record_function 标注算法(可选但推荐)。
  3. 将代码包装在 torch.profiler.profile 上下文管理器中。
  4. 导出分析结果。

分析器产生两个不同的产物:

  • 分析表:提供事件的统计摘要,回答“什么占用了最多时间”。
  • 追踪数据:提供时间维度的执行视图,回答“操作何时发生以及为何发生”。

解读分析表

运行脚本时,会生成包含分析表的 .txt 文件和包含追踪数据的 .json 文件。分析表的第一列是触发的事件,其他列是 CPU、GPU 等设备上的时间。通过观察事件消耗的时间,可以直观地了解哪些是热点。注意“自 CPU/CUDA 时间”与“总 CPU/CUDA 时间”的区别:前者仅包括事件本身的时间,后者包括其所有子事件。

例如,对于 64×64 的矩阵,自 CUDA 时间仅为 23.104 微秒,而自 CPU 时间为 2.314 毫秒,说明 GPU 大部分时间空闲,这是典型开销受限的情况。通过将矩阵大小增加到 4096×4096,自 CUDA 时间变为 4.495 毫秒,CPU 时间为 4.908 毫秒,GPU 时间显著增加,说明从开销受限转向了计算受限

解读追踪数据

追踪数据可以用 Perfetto UI 打开。CPU 和 GPU 有各自的泳道,条形宽度表示事件持续时间,垂直嵌套表示调用层次。在 64×64 的追踪中,可以观察到以下现象:

  • ProfileStep#2 耗时明显更长:这是因为没有预热 GPU,导致启动开销被记录。预热步骤(先运行几次代码再进行分析)可以消除这种一次性开销。
  • CPU 和 GPU 泳道之间存在约 2.5 毫秒的偏移:这是 CPU 启动内核后,GPU 实际开始执行之前的延迟,可能涉及内存传输、内核排队等。

总结与展望

本部分介绍了 torch.profiler 的基本用法以及如何解读分析表和追踪数据。通过调整矩阵大小和预热,可以观察到从开销受限到计算受限的转变。在后续部分中,我们将扩展这些概念到更复杂的网络,并利用分析结果指导优化。

通过本部分的学习,你应该能够:

  • 设置 torch.profiler 并理解其输出。
  • 读取分析表和时间线追踪(CPU 泳道、GPU 泳道以及两者之间的间隙)。
  • 跟踪从 Python 调用到 CUDA 内核的完整事件链。
  • 理解 torch.compile 带来的变化。