Google AI 推出 TabFM:用於零樣本分類和迴歸的混合注意力表格基礎模型
Google Research 發佈了 TabFM,一種專門為表格數據設計的基礎模型。它通過上下文學習實現零樣本分類和迴歸,無需針對每個數據集進行訓練、超參數調優或特徵工程。該模型結合了 TabPFN 的行/列注意力機制和 TabICL 的上下文學習方法,並在數百萬個合成數據集上訓練。在 TabArena 基準測試中,TabFM 的表現優於經過充分調優的 XGBoost 等傳統方法。
Google Research 正式發佈了 TabFM,這是一個為表格數據量身打造的基礎模型。表格數據是企業數據基礎設施的核心,廣泛應用於客户流失預測、金融欺詐檢測等場景。傳統上,基於樹的方法(如 XGBoost、AdaBoost 和隨機森林)在這一領域佔據主導地位,但它們需要大量的超參數調優和特徵工程,耗時耗力。TabFM 的目標就是打破這一瓶頸。
TabFM 將零樣本學習的理念引入表格數據。它採用上下文學習(In-Context Learning, ICL)技術,可以在一個前向傳播中完成對新數據集的預測,無需更新模型權重或進行任何額外訓練。其架構融合了 TabPFN 和 TabICL 兩種方法:通過交替的行和列注意力機制捕捉特徵之間的交互關係,並通過行壓縮技術降低計算成本。
為了訓練這樣一個大規模模型,Google 的研究團隊使用了數億個由結構因果模型(SCMs)動態生成的合成數據集。這些數據集涵蓋了廣泛的數據分佈和複雜特徵關係,使模型能夠很好地泛化到真實世界的數據上。
在評估方面,TabFM 在 TabArena 基準上進行了測試,該基準包含 38 個分類數據集和 13 個迴歸數據集,樣本量從 700 到 150,000 不等。兩個配置版本——普通 TabFM 和 TabFM-Ensemble——均表現出色,甚至超過了經過充分調優的 XGBoost 等工業級監督算法。TabFM-Ensemble 通過添加交叉特徵和 SVD 特徵,並使用非負最小二乘法求解最優權重,進一步提升了性能。
TabFM 目前已開源,可在 Hugging Face 和 GitHub 上獲取。安裝過程簡單,需克隆倉庫並使用 CPU 或 GPU 版本的 JAX。以下是一個簡單的使用示例:加載預訓練模型後,創建分類器,準備包含年齡、職業和收入等特徵的數據集,然後調用 fit 和 predict 方法。注意,fit 方法僅對訓練數據進行編碼,並不訓練模型權重。除了分類,TabFM 還支持迴歸任務,例如房價預測。
Google 還計劃通過 BigQuery 的 AI.PREDICT SQL 命令提供 TabFM 的訪問接口,這將是該模型在企業級應用中的重要一步。