AI News HubLIVE
站内改写

Sakana AI 提出 DiffusionBlocks:一種將殘差網路轉換為可獨立訓練去噪模組的塊狀訓練框架

來自Sakana AI和東京大學的研究人員提出了DiffusionBlocks,這是一種塊狀訓練框架,可將Transformer網路劃分為獨立訓練的塊,從而將訓練記憶體減少B倍(B為塊數),同時在不同架構上保持效能。該方法透過將殘差連線解釋為擴散模型中的尤拉步驟,利用分數匹配目標實現塊級獨立訓練。

文章情報

工程師進階

要點

  • DiffusionBlocks透過將網路劃分為B個獨立訓練的塊,將訓練記憶體減少B倍,適用於多種架構。
  • 核心創新在於將殘差連線視為反向擴散過程的尤拉離散化步驟,從而為每個塊提供原則性的區域性訓練目標。
  • 實驗表明,在ViT、DiT、自迴歸和迴圈深度Transformer上,DiffusionBlocks實現了與端到端訓練相當的效能,記憶體和計算效率顯著提升。
  • 對於擴散模型,推理時每個去噪步驟僅啟用一個塊,進一步降低了計算成本。

為什麼重要

這條新聞值得關注,因為DiffusionBlocks透過將網路劃分為B個獨立訓練的塊,將訓練記憶體減少B倍,適用於多種架構。

技術影響

可能影響模型選型、推理成本、產品能力和評測基準。

Sakana AI 與東京大學的研究人員聯合提出了一種名為 DiffusionBlocks 的新型訓練框架,旨在解決深度神經網路訓練中的記憶體瓶頸問題。該框架將 Transformer 網路劃分為多個獨立的塊(block),每個塊均可獨立訓練,從而將訓練記憶體消耗降低至原來的 1/B(B 為塊數),同時保持與端到端訓練相當的效能。

在傳統的端到端反向傳播訓練中,需要儲存每一層的中間啟用值,記憶體消耗隨網路深度線性增長。儘管啟用檢查點(activation checkpointing)可以減少啟用記憶體,但無法減少引數、梯度和最佳化器狀態所佔用的記憶體。例如,使用 Adam 最佳化器時,每一層需要儲存引數、梯度以及動量與方差兩個狀態,總記憶體約為引數量的 4 倍。塊狀訓練透過將網路分為 B 塊並獨立訓練,將記憶體需求降低至約 1/B。然而,如何為每個塊設計合理的區域性目標一直是個挑戰。此前的方法如 Hinton 的正向-正向演算法(Forward-Forward)和貪婪逐層訓練依賴於特定的區域性目標,效能往往不如端到端訓練,且主要侷限於分類任務。

DiffusionBlocks 的核心見解在於利用殘差網路與擴散模型之間的深層聯絡。殘差網路的更新形式為 z_ℓ = z_{ℓ-1} + f_θℓ(z_{ℓ-1}),這相當於常微分方程的尤拉離散化。研究者證明,這一更新形式恰好對應於基於分數的擴散模型中的機率流 ODE(在方差爆炸(VE)框架下)。透過反向擴散過程的尤拉離散化,可以得到與殘差連線結構完全匹配的更新規則。因此,一疊殘差塊可以被視為離散化的去噪步驟。在基於分數的擴散模型中,分數匹配目標可以在每個噪聲水平上獨立最佳化,這意味著每個塊可以使用自己的區域性目標獨立訓練,無需塊間通訊。

將標準殘差網路轉換為 DiffusionBlocks 需要三步:首先,將 L 層網路劃分為 B 個連續塊;其次,定義噪聲分佈並分配噪聲區間給每個塊,推薦使用對數正態分佈;最後,透過向每個塊輸入新增目標的噪聲版本和條件歸一化(AdaLN)來實現噪聲條件化。訓練時每次迭代只取樣一個塊,其他塊不進行計算,記憶體消耗僅為 L/B 層。

在劃分策略上,DiffusionBlocks 採用等機率劃分而非均勻劃分。均勻劃分忽略了不同噪聲水平的去噪難度差異,而等機率劃分使每個塊處理相同機率質量下的區間,從而在 CIFAR-10 上實現了更低的 FID(38.03 vs 43.53)。

實驗評估覆蓋了五種架構和多個資料集。在 CIFAR-100 上使用 ViT,DiffusionBlocks 達到了 59.30% 的準確率,而端到端基線為 60.25%;在 CIFAR-10 上使用 DiT-S/2,FID 為 37.20 vs 39.83;在 ImageNet 256×256 上使用 DiT-L/2,FID 為 10.63 vs 12.09;在 text8 上使用掩碼擴散模型(MDM),BPC 為 1.45 vs 1.56;在 LM1B 上使用自迴歸 Transformer,MAUVE 為 0.71 vs 0.50;在 OpenWebText 上,MAUVE 為 0.82 vs 0.85。此外,對於迴圈深度模型 Huginn,DiffusionBlocks 以約 10 倍的計算縮減實現了 MAUVE 0.70 vs 0.49。這些結果表明,DiffusionBlocks 在記憶體大幅減少的同時,效能與端到端訓練相當,甚至在某些情況下更好。

與同期工作 NoProp 相比,DiffusionBlocks 是唯一結合連續時間公式和塊狀訓練的方法,在 CIFAR-100 上準確率達到 46.88%,僅比端到端反向傳播低不到 1 個百分點。

DiffusionBlocks 的優勢包括:基於分數匹配的原則性理論基礎,無需任務特定修改即可適用於多種架構,B 倍的記憶體減少,以及擴散模型推理時 B 倍的計算減少。等機率劃分顯著優於均勻劃分,且對於迴圈深度模型,它用單次前向傳播代替了多次 BPTT。此外,塊可以在 GPU 上並行訓練,通訊開銷為零。

侷限性包括:需要輸入和輸出維度匹配,目前無法應用於 U-Net 類架構;僅在從頭訓練的模型上驗證過;缺乏選擇最優塊數的原則性方法;噪聲條件化帶來輕微的時間開銷;在 OpenWebText 上某些指標略差。

總體而言,DiffusionBlocks 為訓練深層 Transformer 網路提供了一種記憶體高效且理論上合理的替代方案,尤其適用於擴散模型和迴圈深度模型。