OlmPool:小型架构选择如何叠加破坏长上下文扩展
OlmPool是一个包含26个模型的受控套件,展示了即使训练数据和扩展方法保持不变,小型架构选择也可能叠加起来使长上下文扩展变得更加困难。
大多数语言模型都是在短文本序列上进行训练的,然后通过额外的长文档训练(称为上下文扩展)来处理更长的输入。由于Llama 3的预训练数据是专有的,很难判断其扩展的易用性来自架构还是训练数据。研究者假设相同的扩展方法可以迁移到其他架构,但OlmPool的研究表明,事实往往并非如此。
OlmPool是艾伦人工智能研究所(Ai2)开发的受控模型套件,包含26个7B模型,所有模型在相同数据上预训练1400亿个token,然后使用相同的长上下文数据混合和过程扩展到64K上下文。唯一变化的是架构。研究聚焦于四个影响注意力机制的架构选择:QK归一化、分组查询注意力(GQA)、滑动窗口注意力和预训练上下文长度。
研究发现,这些选择单独作用时影响温和:QK归一化影响最大,移除它并切换归一化顺序可在HELMET基准上带来6分提升;GQA和较短的预训练上下文长度各自造成较小下降;滑动窗口注意力单独影响约1分。但当它们组合时,效果远超各部分之和——例如,在已有GQA的模型上添加滑动窗口注意力平均导致9分下降。最差的配置组合了两种或更多限制注意力灵活性的选择。
更关键的是,标准训练信号几乎无法预测长上下文性能。训练损失、验证困惑度和16个短上下文基准都无法识别哪些模型在32K或64K上下文中表现更好。即使在同一基准的8K分割上,分数也无法预测扩展后两位数的波动。扩展后表现看似相同的模型,在HELMET上可能相差26分以上。
研究还发现,Llama 3的架构组合表现强劲,但并非在所有情况下最优——其他几种模型明显超越它。这表明Llama 3的长上下文成功主要源于架构,验证于Llama的扩展方法可能需要针对其他模型族进行调整。此外,架构导致的差距不会随着更多数据而消失:即使经过500亿token的上下文扩展(占训练总量的26%),最差架构仍无法达到Llama架构在10亿token后的性能。
注意力模式分析揭示了原因:没有QK归一化的模型会产生更强的注意力汇聚——输入早期位置持续获得大量关注,即使它们与当前预测无关。更强的汇聚与更好的长上下文表现相关,表明它是模型在没有QK归一化时支持长距离检索的默认策略。在针在该测试中,有QK归一化的模型对目标信息的关注度较低,与其较弱的长上下文表现一致。
OlmPool套件已完整发布,包含所有26个模型在预训练和上下文扩展各阶段的38个检查点。这项研究强调,架构选择的组合可能产生远低于预期的长上下文性能,且这种结果无法从标准训练信号中预知。研究者希望该资源能帮助开发更好的上下文扩展方法,并研究早期预训练中的其他现象。