Google AIがTabFMを発表:ゼロショット分類と回帰のためのハイブリッド注意機構テーブル基盤モデル
Google Researchは、表形式データ向けの基盤モデルTabFMを発表しました。インコンテキスト学習によりゼロショットで分類と回帰を実行し、データセットごとのトレーニング、ハイパーパラメータ調整、特徴量エンジニアリングは不要です。予測は単一のフォワードパスで行われます。
Google Researchは、表形式データ専用に構築された基盤モデルTabFMを発表しました。表データは企業のデータインフラの中核をなし、顧客離脱予測や金融詐欺検出などのタスクに使用されます。これまで、XGBoostなどの勾配ブースティング木がこの分野で主流でしたが、これらの手法はデータセットごとにハイパーパラメータ調整や特徴量エンジニアリングが必要で時間がかかっていました。TabFMはこのボトルネックを解消することを目指しています。
TabFMはゼロショット学習の概念を表データに適用します。インコンテキスト学習(ICL)技術を用いて、新しいデータセットに対して重みを更新することなく、単一のフォワードパスで予測を生成します。そのアーキテクチャはTabPFNとTabICLの両方を融合したもので、行と列の注意機構を交互に適用して特徴間の相互作用を捉え、行圧縮技術で計算コストを削減します。
大規模な基盤モデルを学習するために、Googleの研究チームは構造的因果モデル(SCM)から動的に生成された数億の合成データセットを使用しました。これらのデータセットは多様なデータ分布と複雑な特徴関係をカバーし、モデルが実世界のデータにうまく一般化できるようにしています。
評価では、TabFMは38の分類データセットと13の回帰データセット(サンプルサイズ700〜150,000)を含むTabArenaベンチマークでテストされました。プレーンなTabFMとTabFM-Ensembleの2つの構成が評価され、いずれも十分にチューニングされたXGBoostなどの産業用教師ありアルゴリズムを上回るパフォーマンスを示しました。TabFM-Ensembleはクロス特徴量とSVD特徴量を追加し、非負最小二乗法で最適な重みを計算することで性能をさらに向上させています。
TabFMは現在、Hugging FaceとGitHubで公開されています。インストールはリポジトリをクローンしてローカルにインストールします。基本的なインストールではCPUのみのJAXを使用し、CUDAエクストラを指定することでGPU実行が可能です。以下は使用例です:プリトレーニング済みモデルをロードし、TabFMClassifierを作成し、年齢、職業、収入などの特徴量を持つデータフレームを準備し、fitメソッドとpredictメソッドを呼び出します。fitメソッドは訓練データをエンコードするだけで、モデルの重みを訓練しません。分類に加えて、回帰タスクにも対応しており、例えば住宅価格の予測が可能です。
Google BigQueryでは、まもなくAI.PREDICT SQLコマンドを通じてTabFMにアクセスできるようになる予定です。