注意力探針
注意力探針是一種用於分類語言模型內部狀態的新方法,通過注意力層聚合隱藏狀態,避免了對多個token進行池化。實驗表明,多頭注意力探針(特別是8頭)在多數數據集上優於均值探針,訓練代碼已開源。
EleutherAI 研究團隊提出了一種名為“注意力探針”(Attention Probes)的新方法,用於分類語言模型的內部狀態。傳統線性探針通常基於每個token或通過池化(如均值池化或取最後一個token)壓縮多個token的潛在向量來進行訓練。注意力探針則引入一個注意力層來收集隱藏狀態,從而避免了池化帶來的信息損失。
該方法的核心偽代碼顯示,注意力探針具有多個頭,每個頭為每個token計算一個注意力對數,並通過softmax得到注意力概率。與標準多頭注意力不同,這裏每個頭只關注一個token(類似於交叉注意力中只有一個查詢token)。此外,還添加了一個可學習的位置偏置(類似ALiBi),使得注意力可以偏向序列中的特定位置。最終,通過值投影和加權求和得到探針輸出。
相關工作方面,McKenzie等人(2025)曾提出類似架構,但僅使用單頭且無位置偏置;Kantamneni等人(2025)最早展示了注意力探針,但作為次要方法。本研究使用的數據集包括MOSAIC(基於Gemma 2B和Gemma 2 2B模型)以及Neurons-In-A-Haystack(NiAH)數據集。訓練時,探針使用AdamW優化器(注意力探針)或LBFGS(均值/末位token探針),並進行超參數搜索。
實驗結果顯示,在MOSAIC數據集上,均值探針優於末位token探針,而8頭注意力探針(AdamW訓練)整體優於均值探針,且始終優於AdamW訓練的均值探針。單頭注意力探針表現中等。在NiAH數據集上,注意力探針未明顯優於末位token探針,但即使單頭注意力探針也比均值探針有所改進。此外,隨着頭數增加,注意力熵也增加,且熵值高度依賴於數據集。
通過分析最大激活示例,研究者發現注意力探針有時會關注與任務相關的詞語(如性別相關詞)。總體而言,注意力探針的性能與均值或末位token探針相當,但在多數情況下更優,尤其是多頭配置。訓練代碼已開源,可通過pip安裝。