Parallax:一种参数化局部线性注意力机制,保留Softmax并添加学习协方差修正分支
Parallax是一种新型注意力机制,保留Softmax注意力并添加一个学习的协方差修正分支,替代了局部线性注意力(LLA)中的逐查询求解器。它通过重用FlashAttention的键值流,将算术强度提高一倍,并在0.6B和1.7B规模的LLM预训练中实现了更低的困惑度。然而,其优势高度依赖于Muon优化器,在AdamW下增益显著缩小。
Transformer的注意力机制自2017年以来几乎没有变化。大多数效率工作试图直接替换Softmax注意力。一篇新论文采取了不同的路线:保留Softmax注意力,并增加一个修正分支。
来自西北大学、Tilde Research和华盛顿大学的研究团队提出了一种参数化的局部线性注意力机制,称为“Parallax”,它能够扩展到LLM预训练,并与Muon优化器协同设计。
Parallax并非通过削减计算来追求效率。它有意增加计算量,然后使这些计算在现代GPU上运行得更便宜。
什么是Parallax?
Parallax建立在局部线性注意力(LLA)的基础上。LLA源于测试时回归框架,该框架将注意力视为键值对上的回归求解器。在此视角下,键是训练数据点,值是标签,查询是测试点。Softmax注意力是一种称为Nadaraya-Watson的非参数估计器,它为每个查询拟合一个局部常数函数。
LLA将该局部常数估计升级为局部线性估计。研究团队证明这产生了严格更小的积分均方误差。其优势在于联想记忆的更好偏差-方差权衡。
但LLA在大规模应用中存在问题。其精确前向传播需要为每个查询求解一个线性系统,使用并行共轭梯度(CG)求解器。CG求解器带来三个问题:密集的I/O、困难的正则化-表达权衡以及低精度不兼容。
Parallax移除了求解器。取而代之,它学习一个额外的投影矩阵。研究团队将其写为ρi = WR xi,其中WR是一个可学习矩阵,直接从层输入探测KV协方差。
因此,Parallax保留了局部线性原理,只是用学习的、类似查询的投影器替代了逐查询求解。这使其更简单、更高效,且更易于实现。
机制如何运作?
Parallax将LLA重新表述为Softmax注意力加上一个加性修正。输出等于Softmax注意力输出减去一个投影协方差项。在研究论文的符号中,该项是KV协方差乘以学习的探针ρi。
研究团队还丢弃了LLA中的一个组件,称为边界放大因子,并将其设为零。这对于稳定性是必要的。一旦探针变为参数化,原有的几何解释就失效了。保留该因子可能导致缩放发散或符号翻转。
Parallax位于一个注意力机制家族中。研究团队在论文中通过三个轴对其进行组织:带宽、探针构造和仿射结构。在一个极端,当探针范数为零时,Parallax精确退化为Softmax注意力。
设置WR = 0可以使Parallax层的行为与Softmax注意力完全相同。因此,预训练的Transformer检查点可以通过添加WR并微调进行转换。
硬件论证
Parallax继承了FlashAttention的流式结构。它添加了一个协方差分支,重用相同的键值流。
研究团队将前向传播扩展为两个并行的评分分支。两个分支共享在线最大值、缩放因子以及K和V的tile。因此,Parallax每次迭代不需要额外的I/O。
关键属性是更高的算术强度(AI)。AI是浮点运算与高带宽内存流量的比率。在KV工作占主导地位的场景中,Parallax大约将算术强度提高一倍。它在重用相同内存流的同时增加了计算。
这将注意力推向更接近计算受限的状态。这正是内核优化在现代硬件上有帮助的状态。
研究团队在NVIDIA Hopper GPU上使用CuTeDSL原型化了一个解码内核。Hopper的张量核心矩阵乘法指令操作至少64行的tile。一个解码步骤仅提供一行查询。因此,QK和RK乘积可以联合计算,在标准注意力已经发出的指令之内。
他们在H200 GPU上以BF16精度对FlashAttention 2和3进行了性能分析。他们扫描了从1到2,048的批次大小和从128到32,768的上下文长度。原型内核在所有配置下与FlashAttention性能相当或更优。下图标注了在计算匹配设置下1.54倍的加速和在I/O匹配设置下1.14倍的加速。
https://arxiv.org/pdf/2605.29157
实验展示
研究团队在合成任务和0.6B及1.7B规模的LLM预训练上验证了Parallax。模型使用了torchtitan仓库中的Qwen-3架构。他们在Ultra-FineWeb数据集上训练,上下文长度为4096。基线包括Softmax注意力(Transformer)、Mamba、Gated DeltaNet、MesaNet和Kimi DeltaAttention。
在MAD-Benchmark上,Parallax取得了最高整体准确率,平均0.716。它持续改善了如上下文召回和选择性复制等召回导向的任务。在压缩和记忆任务上保持竞争力。
在语言建模上,使用Muon的Parallax在两个规模上均取得了最佳困惑度。它还取得了最高的平均下游准确率。在1.7B规模,Parallax得分为62.45,而Transformer为61.43。
两个对照实验测试了增益的来源。参数匹配的Transformer仅缩小了一小部分差距。计算匹配的Parallax仍然击败了两个基线。论文认为这表明增益来自机制本身,而非额外参数或计算。
优化器转折
一个核心发现是优化器与架构的交互。Parallax在Muon下显示出巨大优势。在AdamW下,优势显著缩小甚至消失。
Muon是一种用于隐藏层矩阵参数的最新优化器。它使用动量缓冲区的极坐标因子,因此更新的条件数恰好为1。先前的研究表明这产生了条件更好的权重矩阵。
研究团队将差距追溯到修正分支。他们定义了修正输出比(COR)。在Muon下,最深层的COR超过8。在AdamW下,COR保持在4以下。
WR投影受到不成比例的影响。在AdamW下,其稳定秩崩溃,而在Muon下保持高秩。一个门控实验确认了该模式。在AdamW下,模型学会抑制修正分支而非使用它。
研究团队称这是注意力机制中架构-优化器协同设计的第一个实证演示。他们未声称Muon与WSD是最优方案。附录中的消融实验显示,在衰减阶段优势缩小。
分数差异
Parallax还产生与Softmax注意力不同的分数分布。其每token权重可以取负值并超过1的幅度。标准Softmax权重无法做到这一点。
研究团队报告了三个效应。Parallax可以从不相关token中主动减去值分量。它显著减少了第一个token上的注意力汇聚。其基础Softmax熵保持较高,从而产生更分散的注意力权重。
优势与劣势及开放问题
优势:
- 保持Softmax注意力完整,因此预训练Transformer可通过添加WR并微调转换。
- 通过重用FlashAttention的键值流,每次迭代不增加额外I/O。
- 算术强度翻倍,原型解码内核在性能上与FlashAttention 2/3相当或更优。
- 在参数匹配和计算匹配控制下显示一致的困惑度和下游增益。
劣势与开放问题:
- 增益高度依赖Muon;在AdamW下优势基本消失。
- 优化器依赖的确切原因仍是一个开放问题。
- 结果止于1.7B规模,未涉及MoE、更长上下文或更大规模运行。
- 优势在WSD衰减阶段减弱,仅通过权重衰减退火部分修复。
关键要点
- Parallax保留Softmax注意力并添加一个学习的协方差修正分支,替代了LLA的逐查询共轭梯度求解器。
- 它在重用相同KV流的同时将算术强度提高一倍,解码内核与FlashAttention 2/3性能相当或更优。
- 在0.6B和1.7B规模上,在参数匹配和计算匹配控制下均显示一致的困惑度和下游增益。
- 增益高度依赖Muon;在AdamW下优势显著缩小或消失。
- 设置WR = 0可精确恢复Softmax注意力,因此预训练Transformer可通过添加WR并微调转换。
- 查看论文和代码库。另外,欢迎在Twitter上关注我们,别忘了加入我们的150k+ ML SubReddit和订阅我们的新闻通讯。等等,你在Telegram上吗?现在你也可以在Telegram上加入我们。
- 需要与我们合作推广你的GitHub仓库、Hugging Face页面、产品发布或网络研讨会等?联系我们。