Google AI 推出 TabFM:用于零样本分类和回归的混合注意力表格基础模型
Google Research 发布了 TabFM,一种专门为表格数据设计的基础模型。它通过上下文学习实现零样本分类和回归,无需针对每个数据集进行训练、超参数调优或特征工程。该模型结合了 TabPFN 的行/列注意力机制和 TabICL 的上下文学习方法,并在数百万个合成数据集上训练。在 TabArena 基准测试中,TabFM 的表现优于经过充分调优的 XGBoost 等传统方法。
Google Research 正式发布了 TabFM,这是一个为表格数据量身打造的基础模型。表格数据是企业数据基础设施的核心,广泛应用于客户流失预测、金融欺诈检测等场景。传统上,基于树的方法(如 XGBoost、AdaBoost 和随机森林)在这一领域占据主导地位,但它们需要大量的超参数调优和特征工程,耗时耗力。TabFM 的目标就是打破这一瓶颈。
TabFM 将零样本学习的理念引入表格数据。它采用上下文学习(In-Context Learning, ICL)技术,可以在一个前向传播中完成对新数据集的预测,无需更新模型权重或进行任何额外训练。其架构融合了 TabPFN 和 TabICL 两种方法:通过交替的行和列注意力机制捕捉特征之间的交互关系,并通过行压缩技术降低计算成本。
为了训练这样一个大规模模型,Google 的研究团队使用了数亿个由结构因果模型(SCMs)动态生成的合成数据集。这些数据集涵盖了广泛的数据分布和复杂特征关系,使模型能够很好地泛化到真实世界的数据上。
在评估方面,TabFM 在 TabArena 基准上进行了测试,该基准包含 38 个分类数据集和 13 个回归数据集,样本量从 700 到 150,000 不等。两个配置版本——普通 TabFM 和 TabFM-Ensemble——均表现出色,甚至超过了经过充分调优的 XGBoost 等工业级监督算法。TabFM-Ensemble 通过添加交叉特征和 SVD 特征,并使用非负最小二乘法求解最优权重,进一步提升了性能。
TabFM 目前已开源,可在 Hugging Face 和 GitHub 上获取。安装过程简单,需克隆仓库并使用 CPU 或 GPU 版本的 JAX。以下是一个简单的使用示例:加载预训练模型后,创建分类器,准备包含年龄、职业和收入等特征的数据集,然后调用 fit 和 predict 方法。注意,fit 方法仅对训练数据进行编码,并不训练模型权重。除了分类,TabFM 还支持回归任务,例如房价预测。
Google 还计划通过 BigQuery 的 AI.PREDICT SQL 命令提供 TabFM 的访问接口,这将是该模型在企业级应用中的重要一步。