如何使用xFormers構建記憶體高效的Transformer:打包序列、GQA、ALiBi、SwiGLU和因果注意力
本教程詳細介紹瞭如何使用xFormers工具包在GPU上構建快速、記憶體高效的Transformer模型。我們從驗證記憶體高效注意力與標準注意力的對比開始,然後比較不同序列長度下的速度和記憶體消耗。接著探討了因果掩碼、變長序列打包、分組查詢注意力(GQA)和自定義ALiBi位置偏置。最後,我們將這些技術組合成一個可訓練的GPT風格模型,包含SwiGLU前饋層和自動混合精度訓練。
在本教程中,我們使用xFormers工具包在GPU上構建快速、記憶體高效的Transformer模型。首先,我們安裝xFormers並驗證GPU可用性,然後定義輔助函式測量CUDA執行時間和峰值記憶體消耗。接著,我們比較xFormers的記憶體高效注意力與標準注意力實現,發現兩者在數值上高度一致(僅存在FP16舍入誤差),但xFormers從不儲存完整的MxM分數矩陣。
在基準測試中,我們對序列長度從512到4096進行前向和後向傳播測試。結果顯示,標準注意力的記憶體消耗隨序列長度平方增長(每次長度加倍,記憶體增加約4倍),而xFormers幾乎呈線性增長,且速度保持穩定。我們還使用隱式下三角掩碼實現因果注意力,無需分配布林張量,並透過與參考實現比較驗證了正確性。
對於可變長度序列,我們使用BlockDiagonalMask將不同長度的序列打包在一起,防止跨序列的注意力,同時避免填充開銷。我們恢復了各個分段的輸出,並執行了打包的因果注意力,這正是vLLM等推理引擎批次處理不同長度請求的方式。然後,我們演示了分組查詢注意力(GQA),其中多個查詢頭共享較少的鍵值頭,從而減少KV快取大小——這是Llama/Mistral類模型在推理時使用的技術。
我們還構建了自定義的ALiBi張量,對每個注意力頭施加不同的線性位置懲罰,並與因果掩碼結合。xFormers可以直接接受這種自定義偏置張量。最後,我們構建了一個緊湊的GPT風格Transformer,包含xFormers因果注意力、殘差連線、層歸一化和SwiGLU前饋層。模型透過自動混合精度(AMP)在合成資料上進行400步訓練,展示了從零開始訓練一個完整因果Transformer的流程。該示例可直接替換為實際資料和分詞器進行擴充套件。
本教程涵蓋了xFormers的核心功能(第1-3節)和高階技巧(第4-6節),為構建記憶體高效的Transformer模型提供了實用指南。