Mahjax:一個用於JAX中強化學習的高性能GPU加速麻將模擬器
Mahjax是一個在JAX中實現的完全向量化立直麻將環境,可利用GPU進行大規模並行化,吞吐量達到在8塊NVIDIA A100 GPU上每秒200萬步(無紅寶牌規則)和100萬步(有紅寶牌規則)。該環境支持從零開始(tabula rasa)的強化學習訓練,並附有高質量可視化工具,實驗驗證了訓練智能體可以有效提升排名。
文章情報
要點
- Mahjax是基於JAX的完全向量化立直麻將模擬器,支持GPU並行化。
- 在8塊NVIDIA A100 GPU上,每秒可處理多達200萬步(無紅寶牌規則)。
- 無需人類數據,支持從零開始的強化學習訓練。
- 包含可視化工具,便於調試與智能體交互。
為甚麼重要
這條新聞值得關注,因為Mahjax是基於JAX的完全向量化立直麻將模擬器,支持GPU並行化。
技術影響
可能影響 Agent 架構、工具調用、工作流自動化和產品集成。
近日,由Soichiro Nishimori等六位研究人員共同開發的Mahjax模擬器正式發佈。該模擬器基於JAX框架,針對立直麻將這一複雜的不完全信息博弈進行了深度優化。立直麻將因其多人蔘與、信息不完全、隨機性高以及狀態空間極大等特點,成為強化學習領域極具挑戰性的研究對象,其複雜性堪比現實世界中的許多決策問題。傳統方法通常依賴人類對局日誌進行監督學習來預訓練策略,但Mahjax支持從頭開始的“白板學習”(tabula rasa),類似於AlphaZero系列算法,具有更強的通用性和泛化能力。
Mahjax的核心創新在於其完全向量化的設計,這使得大規模並行仿真成為可能。通過在JAX中實現整個環境,研究人員能夠充分利用GPU進行加速。在八塊NVIDIA A100 GPU的測試平台上,Mahjax在無紅寶牌規則下達到了每秒200萬步的驚人吞吐量,即使在引入紅寶牌規則後,吞吐量也保持在每秒100萬步。這一性能為大規模強化學習實驗提供了堅實基礎,使得在合理時間內進行大量對局訓練成為現實。
除了強大的計算性能,團隊還開發了一款高質量的可視化工具,便於研究人員對訓練過程進行調試,並與訓練後的智能體進行直觀交互。這大大降低了研究門檻,使得非專業用户也能輕鬆上手。為了驗證環境的實用性,研究人員在該環境中訓練了強化學習智能體,並將其與基線策略進行對比。實驗結果表明,經過訓練的智能體能夠顯著提升其在比賽中的排名,證明了Mahjax作為強化學習研究平台的有效性和價值。
Mahjax的出現,不僅為麻將AI研究提供了新的工具,更為在更廣泛的複雜博弈中應用從零開始的強化學習開闢了道路。相關論文已提交至arXiv,編號為2605.20577,感興趣的讀者可以獲取更多技術細節。