Mahjax:一個用於JAX中強化學習的高效能GPU加速麻將模擬器
Mahjax是一個在JAX中實現的完全向量化立直麻將環境,可利用GPU進行大規模並行化,吞吐量達到在8塊NVIDIA A100 GPU上每秒200萬步(無紅寶牌規則)和100萬步(有紅寶牌規則)。該環境支援從零開始(tabula rasa)的強化學習訓練,並附有高質量視覺化工具,實驗驗證了訓練智慧體可以有效提升排名。
文章情報
要點
- Mahjax是基於JAX的完全向量化立直麻將模擬器,支援GPU並行化。
- 在8塊NVIDIA A100 GPU上,每秒可處理多達200萬步(無紅寶牌規則)。
- 無需人類資料,支援從零開始的強化學習訓練。
- 包含視覺化工具,便於除錯與智慧體互動。
為什麼重要
這條新聞值得關注,因為Mahjax是基於JAX的完全向量化立直麻將模擬器,支援GPU並行化。
技術影響
可能影響 Agent 架構、工具呼叫、工作流自動化和產品整合。
近日,由Soichiro Nishimori等六位研究人員共同開發的Mahjax模擬器正式釋出。該模擬器基於JAX框架,針對立直麻將這一複雜的不完全資訊博弈進行了深度最佳化。立直麻將因其多人參與、資訊不完全、隨機性高以及狀態空間極大等特點,成為強化學習領域極具挑戰性的研究物件,其複雜性堪比現實世界中的許多決策問題。傳統方法通常依賴人類對局日誌進行監督學習來預訓練策略,但Mahjax支援從頭開始的“白板學習”(tabula rasa),類似於AlphaZero系列演算法,具有更強的通用性和泛化能力。
Mahjax的核心創新在於其完全向量化的設計,這使得大規模並行模擬成為可能。透過在JAX中實現整個環境,研究人員能夠充分利用GPU進行加速。在八塊NVIDIA A100 GPU的測試平臺上,Mahjax在無紅寶牌規則下達到了每秒200萬步的驚人吞吐量,即使在引入紅寶牌規則後,吞吐量也保持在每秒100萬步。這一效能為大規模強化學習實驗提供了堅實基礎,使得在合理時間內進行大量對局訓練成為現實。
除了強大的計算效能,團隊還開發了一款高質量的視覺化工具,便於研究人員對訓練過程進行除錯,並與訓練後的智慧體進行直觀互動。這大大降低了研究門檻,使得非專業使用者也能輕鬆上手。為了驗證環境的實用性,研究人員在該環境中訓練了強化學習智慧體,並將其與基線策略進行對比。實驗結果表明,經過訓練的智慧體能夠顯著提升其在比賽中的排名,證明了Mahjax作為強化學習研究平臺的有效性和價值。
Mahjax的出現,不僅為麻將AI研究提供了新的工具,更為在更廣泛的複雜博弈中應用從零開始的強化學習開闢了道路。相關論文已提交至arXiv,編號為2605.20577,感興趣的讀者可以獲取更多技術細節。