多令牌残差预测
多令牌残差预测(MRP)是一种针对扩散语言模型的轻量级模块,通过预测相邻去噪步骤之间的残差而非完整分布,实现了在静态和动态两种推理场景下的性能提升。静态模式下可实现无质量损失的加速(高达1.56倍),动态模式下可恢复因激进阈值解码而损失的高达+16个百分点的准确率。
多令牌残差预测(Multi-token Residual Prediction, MRP)是一项由Modal Research与纽约大学上海分校HeavyBall Research合作提出的创新技术,旨在优化扩散语言模型(DLM)的推理效率。该研究发表于一篇博客文章,详细介绍了MRP的设计原理、实验结果及其在两种推理场景下的双重应用。
MRP的核心思想源于对多令牌预测(MTP)的改进。在自回归模型中,MTP通过轻量级模块从主干网络的隐藏状态预测多个后续令牌,结合推测解码实现加速。然而,直接将该方法应用于扩散语言模型时遇到了困难:简单的蒸馏头在预测多个步骤后精度急剧下降。例如,在SDAR-4B模型上,直接MTP在第一步达到84.8%的GSM8K准确率,但到第四步时已降至1.9%。
研究团队发现,问题在于完整分布预测的难度。他们转而预测相邻去噪步骤之间的残差,即当前步骤分布与下一步分布之间的差异。由于扩散过程的马尔可夫性质,每一步只改变少量位置,因此残差信号复杂度低,容易学习。基于这一洞察,他们设计了MRP模块:一个轻量级3层Transformer,附加在冻结的DLM主干上,读取主干隐藏状态并预测残差对数几率,然后将其加到主干的原始输出上。训练目标采用残差版本的KL散度,仅在仍掩码的位置上最小化预测与真实分布的差异。
实验表明,MRP在静态和动态两种推理场景中均表现出色。在静态模式下,即每步去掩码固定数量的令牌,MRP可用作推测解码的草稿模型,实现无损加速。在SGLang实现中,SDAR-8B模型在GSM8K上达到90.4%准确率和1.40倍吞吐量提升。若采用直接解码(跳过验证),速度提升可达1.89倍,同时仅带来微小质量损失。在动态模式下,即基于置信度阈值批量化去掩码,MRP可恢复因低阈值而过度揭露令牌导致的准确率损失。例如,在τ=0.5时,SDAR-1.7B模型在GSM8K上的准确率从41.6%提升至59.1%(+17.5点)。
MRP的另一个优势是灵活性:用户可根据应用需求在无损加速与有损高速之间自由选择。这一控制权对于拥有自主推理栈的开发者尤为重要,因为封闭API通常固定了解码策略。研究团队已在SDAR-1.7B、4B和8B模型上进行了广泛验证,覆盖GSM8K、MATH500、HumanEval和MBPP等基准,并提供了开源代码和SGLang实现。