AI News HubLIVE
站内改写

Nous Research提出灯塔注意力机制:一种仅用于训练的基于选择的分层注意力,在长上下文预训练中实现1.4–1.7倍加速

Nous Research发布了灯塔注意力机制,这是一种仅用于训练的基于选择的分层注意力机制,在预训练期间包裹标准缩放点积注意力,之后移除。与之前仅池化键和值的方法不同,灯塔注意力对称地池化查询、键和值,形成多分辨率金字塔,将注意力调用从O(N·S·d)降低到O(S²·d),并对小型密集子序列运行标准FlashAttention。在530M参数的Llama-3风格模型上,以98K上下文测试,相较于cuDNN SDPA基线实现了1.40–1.69倍端到端加速,且最终训练损失持平或更低。

文章情报

工程师进阶

要点

  • 灯塔注意力是一种训练时使用的分层注意力机制,通过对称池化查询、键和值构建多分辨率金字塔,大幅减少计算量。
  • 在530M参数的Llama-3模型上,98K上下文长度下实现1.4-1.7倍加速,且最终损失低于或等于密集基线。
  • 采用两阶段训练:第一阶段使用灯塔注意力,第二阶段切换到标准密集注意力以恢复模型能力,实验显示损失迅速恢复并降低。
  • 该方法可扩展至100万token训练,通过上下文并行性在32个Blackwell GPU上实现无特殊内核的扩展。

为什么重要

这条新闻值得关注,因为灯塔注意力是一种训练时使用的分层注意力机制,通过对称池化查询、键和值构建多分辨率金字塔,大幅减少计算量。

技术影响

可能影响模型选型、推理成本、产品能力和评测基准。

训练大型语言模型处理长序列时,注意力机制的计算成本是一个众所周知的难题。每个Transformer核心的缩放点积注意力(SDPA)在计算和内存上均呈二次方规模Θ(N²),其中N是序列长度。FlashAttention通过IO感知的分块技术解决了内存瓶颈,避免了在高速内存中实例化完整的N×N注意力矩阵,但底层的Θ(N²)计算规模并未改变。Nous Research的研究人员提出了一种名为“灯塔注意力”(Lighthouse Attention)的新方法,专门在预训练阶段解决这一瓶颈,实现了相对于cuDNN支持的SDPA基线1.40倍到1.69倍的端到端加速,同时最终训练损失持平或更低。

现有稀疏注意力方法的核心问题

理解灯塔注意力为何有效,需要先了解现有稀疏注意力方法的局限性。大多数先前工作,如NSA、HISA、DSA、MoBA,做出了两个相同的设计选择。首先,它们仅池化键和值侧,而保留查询的全分辨率(非对称压缩)。其次,它们的选择逻辑位于自定义注意力内核内部,这意味着团队无法重复利用现代GPU张量核心优化的密集注意力内核。此外,训练时稀疏方法面临一个推理时方法没有的特殊问题:推理时稀疏方法仅针对其密集骨干网络评估,最多与骨干网络性能相当;而训练时稀疏方法面临更严峻的考验——训练结束后,得到的权重是否仍能产生具备能力的密集注意力模型?灯塔注意力将此问题作为其核心正确性标准。

灯塔注意力在这两个设计选择上采取了不同路径。它跨多层金字塔对称地池化查询、键和值,并将选择过程完全置于注意力内核之外。选择之后,系统将选中的条目收集到一个连续的密集子序列中,并对其运行标准的FlashAttention——与密集基线使用的内核相同。

四阶段流水线的工作原理

灯塔注意力层包裹但不修改缩放点积注意力。该流水线包含四个阶段。

第一阶段,平均池化从Q、K、V构建L级金字塔。池化因子为p时,金字塔的第ℓ层包含N/p^ℓ个令牌,每个令牌汇总p^ℓ个基础位置。关键在于,相同的池化应用于所有三个投影,从而在每个层级生成一致的(Q^(ℓ), K^(ℓ), V^(ℓ))三元组。金字塔构建的总成本为Θ(N)时间和内存。

第二阶段,一个无参数评分器使用每个头的ℓ₂范数为每个金字塔条目分配两个标量分数:一个作为查询分数(∥Q^(ℓ)_i∥₂),一个作为键分数(∥K^(ℓ)_i∥₂)。较粗的层级通过最大池化继承较细层级的分数,因此粗粒度跨度能够捕获其最强令牌的重要性。然后,一个融合的分块双调top-K内核联合跨所有金字塔层级选择k个条目。一个值得注意的设计细节:最粗的金字塔层级始终完整保留——它成本低廉且保证每个基础位置至少有一个贡献者;剩余的选择预算分配给较细层级。此外,分块双调设计产生的是分层top-K而非严格的全局top-K:分数流被划分为固定大小的块,每个块维护一个寄存器内的top-m缓冲区,因此如果全局得分最高的k个条目聚集在一个块中,其中一些将被来自其他块得分较低的条目替换。结果是在整个序列上实现更平衡的注意力覆盖,避免了选择坍缩到狭窄范围内。

top-K步骤是离散且不可微的——没有直通估计器,也没有Gumbel softmax。选择索引不携带梯度。梯度仅流经收集到的Q、K、V条目进入WQ、WK、WV,因此投影层学会产生对选择有用的值,而非擅长选择的分数。

第三阶段,选中的条目被收集到一个长度为S = N/p^(L−1) + (L−1)·p·k的连续子序列中,并传递给标准FlashAttention。当N = 1,000,000,L = 4,p = 4,k = 4,096时,S ≈ 65,000,远小于N。收集过程的一个关键属性是保证组装的子序列中没有“空洞”或空位。这一点很重要,因为灯塔注意力也压缩了查询:序列中的间隙意味着这些缺失的令牌在反向传播中没有梯度路径,可能导致训练不稳定。保持查询全分辨率的非对称方法不会面临此问题,但灯塔注意力的对称设计要求收集的子序列保持完全密集。

第四阶段,每个输出条目通过确定性整数原子散射内核散射回其代表的p^ℓ个基础位置,并通过移动p^ℓ − 1来保持因果性。每个位置的实际扇入被限制为L,与k无关。

对称池化如何改变计算

将查询与键和值一起池化,将训练时的注意力调用计算从O(N S d)改变为O(S² d)。由于在长上下文中S ≪ N,这产生了延迟优势。在单个NVIDIA B200上以512K上下文(bfloat16,B=1,H=8,头维度128,L=3,p=4,稀疏比≈1:64)进行基准测试,灯塔注意力相对于cuDNN支持的SDPA,前向传播快21倍,前后向组合传播快17.3倍。

从渐近角度来看,设置L = log_p(N/k)使得收集的子序列大小S = Θ(k log N),因此密集FlashAttention调用的成本为Θ(k² log² N d) —— 在固定k下关于N是多项对数复杂度。结合线性成本阶段(金字塔构建、评分、散射回),在固定k下每层总计算量为Θ(T d) —— 与线性注意力和状态空间模型相同的渐近类别,同时在选中的子序列上保留了softmax注意力的召回特性。

推理则是不同的约束。自回归解码一次呈现一个查询,这违反了所有查询在一次前向传播中共同出现的假设。灯塔注意力是仅训练方法,对称池化设计无法直接在推理中使用。

两阶段训练配方与可恢复性

实验设置使用了530M参数的Llama-3风格解码器(d_model=1024,30层,8头,头维度128,FFN宽度1536,字节级分词器),在C4数据集上以98,304令牌上下文训练,使用AdamW优化器(学习率2×10⁻³,β1=0.9,β2=0.95,权重衰减0.1,线性预热2k步,梯度范数裁剪1,bfloat16,FSDP)。一个对实践者重要的实现细节:30层中,层{0,1,28,29}全程保留密集SDPA——只有其余26层使用灯塔注意力。这些灯塔注意力层内部的注意力调用使用与密集基线相同的cuDNN支持的SDPA内核。

训练方法为两阶段。阶段1在大部分步骤预算中启用灯塔注意力选择。阶段2在密集SDPA下(相同优化器状态、相同数据加载器)从阶段1检查点恢复训练,进行短时间收尾。如果阶段1破坏了模型的密集注意力能力,阶段2的恢复将会失败。

但实验没有失败。在总预算16,000步(约50.3B令牌)下,测试了三个分界点(10k+6k,11k+5k,12k+4k),并与从头训练的密集SDPA基线进行比较。在每个恢复点,训练损失暂时上升1.12–1.57 nats,因为模型首次运行其未训练过的注意力,然后在大约1,000–1,500步SDPA内恢复并降至密集基线之下。到第16,000步,所有三个恢复的灯塔注意力运行最终损失在0.6980–0.7102之间,而密集基线为0.7237,同时在相同令牌预算下,壁钟时间花费22.5至27.0小时,而从头训练密集SDPA需要37.9小时。

消融实验与吞吐量

完整消融网格涵盖评分器类型、池化因子p、金字塔层数L和top-K预算k。关键发现:投影范数评分器在任一方向上与稀释softmax注意力评分器相差约0.01(无统一胜者),但在B200小时上大约便宜9%,因为它跳过了金字塔上的注意力传递。在匹配预算下,较浅的金字塔(L=3)始终优于较深的(L=4、L=5)。在测试范围内,较小的k值在恢复后产生较低的损失——网格中最低损失配置为L=3,p=2,k=1536,使用稀释评分器,最终损失达到0.6825——这一反直觉的结果被研究团队归因于在此令牌预算规模下分层选择起到了正则化作用。

阶段1吞吐量在消融网格中范围为84,000至126,000令牌/秒/GPU,而密集SDPA约为46,000。投影范数评分器在L=3,p=4,k=1536时达到范围上限126,000令牌/秒/GPU,因为它完全跳过了稀释注意力传递。

长上下文检索

为补充基于损失的可恢复性结果,研究团队运行了简化版“大海捞针”评估:在0-100%深度、4K到96K令牌上下文长度下,将单个口令数字隐藏于随机字母数字填充中,检索评分作为十个数字令牌的单令牌argmax(随机概率:10%)。测试了四种灯塔注意力配置(k ∈ {1536, 2048},评分器 ∈ {稀释, 范数},L=3,p=4),并与从头训练的密集SDPA基线比较。四个灯塔注意力运行中有三个匹配或超过密集基线的平均检索率0.72:k=2048稀释达到0.76,k=1536稀释达到0.73,k=2048范数匹配基线0.72。只有k=1536范数下降至0.65。网格中呈现一个模式:较大的k是检索性能的主要轴,范数评分器在相同k下对检索的损害大于对训练损失的损害。实际含义是,最优配置取决于下游任务是损失驱动型还是检索驱动型。

上下文并行性扩展

对于超过约100K令牌的序列,灯塔注意力在上下文并行性下运行。金字塔池化、评分和top-K在每个rank上本地分片运行,无需跨rank通信,因为最粗的池化窗口(例如64个令牌)远小于分片大小。收集的子序列是密集的,因此参与标准的环形注意力,无需稀疏感知的集合操作——这是稀疏索引方法无法做到的,除非专门针对稀疏布局进行工程处理。上下文并行性因环形旋转引入约10%的每rank吞吐量开销,但灯塔注意力与SDPA的加速比得以保持。该方法可扩展至跨32个Blackwell GPU(4节点,CP度为8)的100万令牌训练,无需更改内部注意力内核。