# AXLearn: Modular, Hardware-Agnostic Large Model Training
> [!info] Talk metadata
> - **会議:** [[MLSys2026]] Day 4 (May 21 / Thu)、Industry Track Oral "LLM Training 4"(14:45 PDT 開始)
> - **登壇者:** Xiyou Zhou(Apple Foundation Model チーム、スライド表紙より)
> - **全著者:** Mark Lee*¹, Chang Lan*¹, Tom Gunter†¹, John Peebles†¹, Hanzhi Zhou†¹, Kelvin Zou†¹, Sneha Bangalore¹, Chung-Cheng Chiu¹, Nan Du¹, Xianzhi Du¹, Philipp Dufter¹, Liang He¹, Ruixuan Hou¹, Haoshuo Huang¹, Dongseong Hwang¹, Xiang Kong¹, Jinhao Lei¹, Tao Lei¹, Meng Li¹, Li Li¹, Jiarui Lu¹, Zhiyun Lu¹, Yiping Ma¹, David Qiu¹, Vivek Rathod¹, Senyu Tong¹, Zhucheng Tu¹, Chong Wang¹, Jianyu Wang¹, Yongqiang Wang¹, Zirui Wang¹, Floris Weers¹, Sam Wiseman¹, Guoli Yin¹, Bowen Zhang¹, Xiyou Zhou¹, Danyang Zhuo§², Cheng Leong¹, Ruoming Pang‡¹(*共同筆頭著者、†コア著者、§客員研究者、‡連絡先著者)
> - **所属:** ¹Apple、²Duke University
> - **システム名:** AXLearn
> - **関連リンク:** https://github.com/apple/axlearn(Apache 2.0 ライセンス)
> [!abstract] 概要(論文 PDF アブストラクト・忠実日本語訳)
> AXLearn は大規模深層学習モデルのスケーラブルかつ高性能な訓練を実現するプロダクションシステムである。既存の最先端深層学習システムと比較して、AXLearn はモジュラリティとハードウェア非依存訓練のサポートに独自の焦点を当てている。AXLearn のソフトウェアコンポーネント間の内部インタフェースは厳密なカプセル化に従っており、異なるコンポーネントを組み立てて迅速なモデル開発と異なるハードウェアインフラ上での実験を容易にする。AXLearn はシステム内のコンポーネントをスケールしても一定の計算量を維持し、最先端の訓練システムにおける線形または二次的な計算量と対比される。これにより、たとえば RoPE(回転位置埋め込み)を AXLearn に統合する際にわずか 10 行のコードで済み、他のシステムでは数百のモジュールにまたがる数百行の変更が必要となる。同時に、AXLearn は最先端の訓練システムと同等の性能を達成する。Apple における AXLearn の開発・運用経験も共有する。
> [!note] 出典に関する注記
> 本ノートは**論文 PDF とスライド PDF** のみに基づく。音声文字起こしは無いため Q&A セクションは設けない。数値はスライドを一次資料、論文を補完資料として記載した。
## 設計哲学: モジュラリティとハードウェア非依存
### 問題意識
Apple は数十億のユーザに AI 機能を提供しており、GPU・TPU・AWS Trainium を含む複数のハードウェアプラットフォーム上で数百のアーキテクチャバリアントを訓練する必要がある(スライド p.2)。既存システムはこの 2 つの要件を同時に満たせない。
- **Megatron-LM** は GPU のみ、**MaxText** は GPU + TPU だが Trainium 非対応(スライド p.3、論文 Table 1)。
- ほとんどのフレームワークはサブタイピング(継承)に依存しており、機能追加が祖先モジュールに波及する。DeepSpeed で FFN を MoE に置換する場合、分離すると 4 行の差分だが、実際の QwenV2 の MoE 対応差分は **200 行以上**、Apple の内部バリアント規模では **4,000 行**に達する(スライド p.4、論文 §2.1)。
### LoC 計算量: システムの進化しやすさを測る新指標
AXLearn は、ある機能を新たに追加するために既存インタフェースに必要な**漸近的なコード行数変化(LoC-Complexity)**を提案する(論文 §2.1)。$N$ はシステム内のモジュール数、$M$ は機能のバリアント数を表す。
| システム | LoC-Complexity(RoPE) | LoC-Complexity(MoE) | RoPE 推定 LoC | MoE 推定 LoC |
|---|---|---|---|---|
| Megatron-LM | $O(NM)$ | $O(N)$ | 400 | 20 |
| DeepSpeed | $O(NM)$ | $O(NM)$ | 320 | 4,000 |
| TorchTitan | $O(NM)$ | $O(NM)$ | 240 | 400 |
| Flax | $O(NM)$ | N/A | 600 | N/A |
| Praxis (Pax) | $O(NM)$ | $O(M)$ | 300 | 5 |
| MaxText | $O(NM)$ | $O(NM)$ | 200 | 300 |
| **AXLearn** | **$O(1)$** | **$O(1)$** | **0** | **0** |
(論文 Table 2・スライド p.5 より。AXLearn は既存インタフェースへの変更行数がゼロ。)
### 厳密なカプセル化
AXLearn の核心設計は**厳密なカプセル化(strict encapsulation)**である(論文 §1、スライド p.6)。
- **全モジュールが差し替え可能:** 入力パイプライン・モデル・オプティマイザ・チェックポインタ・トレーナループのいずれも独立して交換できる。
- **コンフィグは木構造で階層的に合成:** 親は子のハイパーパラメータを検査せず、入出力の形状のみで合意する。`TransformerLayer` のコンフィグは `AttentionLayer.Config` と `FeedForwardLayer.Config` を子として持ち、子の実装を切り替えても親は変更不要(論文 §4.1)。
- **コンフィグモディファイア:** コンフィグ木を走査し、対象のサブツリーだけを親に触れずに書き換える(論文 §4.1)。約 10 行の `replace_config` スニペットで FFN を MoE に置換し、同一コードを **1,000 以上の実験設定**に再利用している(スライド p.7、論文 §4.1)。
## アーキテクチャと実装
### システム全体像
AXLearn は (1) **AXLearn Composer** と (2) **AXLearn Runtime** の 2 主要コンポーネントから成る(論文 Fig.2、スライド p.9)。
**Composer(コンフィグ → JAX プログラム):**
- ユーザスクリプトとレイヤライブラリからコンフィグを生成し、メッシュ形状の選択・シャーディングアノテーション付与・リマテリアライゼーション戦略・AOT コンパイルを適用して JAX プログラムを生成する。
- バックエンドごとのカスタムカーネル選択もこの段階で行われる。
**Runtime(クラスタ上のオーケストレーション):**
- 監視・非同期チェックポイント(S3 / GCS 対応)・障害検知とリカバリを担う。
- Kubernetes 上でスライスレベルのホットスワップを実装。
**基盤:** JAX + XLA + GSPMD 上に構築される。1 つの JAX プログラムが全バックエンドで動作する(スライド p.9)。
### ハードウェア非依存の実現
XLA だけでは不十分であり、AXLearn は追加の工夫を施している(スライド p.10、論文 §4.2)。
- **コンフィグベースの並列化:** FSDP・パイプライン並列・エキスパート並列・シーケンス並列・テンソル並列をコンフィグだけで指定でき、モデルコードの変更は不要(論文 §4.2)。
- **メッシュルール:** アクセラレータ種別からコンフィグモディファイアへのマッピング。TPU v5e なら FSDP をスライス内・データ並列をスライス間に設定し、H100 なら 8-way テンソル並列をノード内・FSDP をノード間に設定する、といったバックエンド固有の最適化を簡潔に表現する(論文 Appendix A)。
- **リマテリアライゼーションタグ:** アテンション QKV プロジェクションや出力などの共通リマテリアライゼーション地点に名前付きタグを付与し、ハードウェアごとに保存・再計算・CPU オフロードを選択できる(論文 §4.2)。
- **カスタムカーネル:** FlashAttention 等をブラックボックスノードとしてドロップイン置換。GPU では cuDNN、AWS Trainium では NKI カーネル、TPU では SplashAttention (Pallas) を透過的にディスパッチする(論文 §4.2)。
- **AWS Trainium2 を大規模に対応した最初の深層学習システム**(スライド p.10、論文 §2.2)。
### InvocationContext: モジュール境界を超えた状態管理
JAX の関数型制約の下で、ニューラルネットワーク訓練に不可欠な状態管理(モデルパラメータ、PRNG、サマリ、出力)を `InvocationContext` という抽象で実現する(論文 §4.3、Fig.3)。親モジュールが子を呼び出すとコンテキストがスタックにプッシュされ、PRNG キーの分割とデータストアの生成が透過的に行われる。子の返却時にサマリと出力が親に集約される。これによりモジュール実装は他モジュールの存在を全く知る必要がない。
### AOT コンパイル
JAX の Ahead-of-Time コンパイルをネイティブサポートし、単一ホスト上で訓練プログラムのメモリ・FLOPS 使用率を実行せずに解析できる(論文 §4.2)。OOM などの分散実行時エラーをローカルで事前検出でき、限られたクラウド TPU 容量のもとでの開発スケーリングに貢献した。
## 評価
### 訓練性能(論文 Table 3・スライド p.11)
同一ハードウェア・バッチサイズ 1,024 で比較。AXLearn は**唯一、4 種類のバックエンド全てで動作する**システムである。
**Llama2-7B:**
| ハードウェア | システム | イテレーション時間 (s) | MFU | スループット (tokens/s) |
|---|---|---|---|---|
| 32 × H100-8 | MaxText | 1.4 | 54.7% | 3.0M |
| 32 × H100-8 | AXLearn | 1.4 | 54.2% | 3.0M |
| tpu-v5p-512 | MaxText | 2.7 | 61.6% | 1.6M |
| tpu-v5p-512 | AXLearn | 2.5 | 66.2% | 1.7M |
| 64 × Trainium2-16 | AXLearn | 1.2 | 24.2% | 3.5M |
**Llama2-70B:**
| ハードウェア | システム | イテレーション時間 (s) | MFU | スループット (tokens/s) |
|---|---|---|---|---|
| 64 × H100-8 | MaxText | 9.4 | 39.1% | 446K |
| 64 × H100-8 | AXLearn | 9.2 | 40.0% | 456K |
| tpu-v5p-1024 | AXLearn | 11.6 | 68.0% | 360K |
| 64 × Trainium2-16 | AXLearn | 11.2 | 25.0% | 374K |
**Qwen-3 30B-A3B (MoE):**
| ハードウェア | システム | イテレーション時間 (s) | MFU | スループット (tokens/s) |
|---|---|---|---|---|
| tpu-v5p-1024 | MaxText | 13.0 | 31.3% | 1.3M |
| tpu-v5p-1024 | AXLearn | 12.9 | 31.6% | 1.3M |
| 64 × B200-8 | Megatron-LM | 4.1 | 20.2% | 4.1M |
| 64 × B200-8 | AXLearn | 4.3 | 19.2% | 3.9M |
TPU では AXLearn が最先端性能を達成し、MaxText をやや上回る。H100 / B200 GPU では Megatron-LM が PyTorch の細粒度スケジューリング能力により若干優位だが、AXLearn は XLA によるハードウェア非依存性をトレードオフとして受け入れている(論文 §7.2)。
### スケーラビリティ(論文 §7.2、Fig.4)
プロダクションモデルでの弱スケーリング実験を実施。70B パラメータモデル(Model A)は 256→4,096 チップで MFU が 63.0%→52.4%、150B パラメータモデル(Model B)は 8,192→32,768 チップで MFU が 40.6%→37.6% と、**ほぼ線形のスケーリング**を達成(スライド p.11 下部注記)。
### 推論性能(論文 Table 4・スライド p.13)
TPU 上で vLLM との比較。vLLM の TPU サポートは実験段階であったため、公正な比較ではないが、モジュラーな訓練フレームワークがプロダクション級の推論性能を追加努力なく達成できることを示す。
| モデル | システム | TTFT | TPOT | スループット |
|---|---|---|---|---|
| Llama2-7B (TPU v5p-8) | vLLM | 538.6 ms | 22.4 ms | 1,117 tok/s |
| Llama2-7B (TPU v5p-8) | AXLearn | 40.1 ms | 9.1 ms | 3,125 tok/s |
| Llama2-70B (TPU v6e-8) | vLLM | 80 s | 189.8 ms | 705 tok/s |
| Llama2-70B (TPU v6e-8) | AXLearn | 150.5 ms | 28.1 ms | 1,139 tok/s |
7B モデルで TTFT 500 倍・TPOT 6 倍の高速化、スループットは 2.8 倍。70B モデルではスループット 1.6 倍(論文 §7.2)。
### 障害復旧(論文 §7.3、Fig.5・スライド p.12)
32,768 TPU 上のプロダクション訓練ジョブで、ハードウェア障害からの復旧を計測。
- **スライスレベルホットスワップ:** 障害検知後 **4 分**で完了(Kubernetes のスペアレプリカを活用)。
- **チェックポイント復元:** ホットスワップ完了後 **9 分**で復元。
- **合計ダウンタイム:** **21 分**(ホットスワップ+チェックポイント復元+最終チェックポイント以降の進捗喪失を含む)。
- チェックポイントは非同期書き込みのため、チェックポイント作成中のスループット低下はない。
## Apple での運用経験
### 開発の歴史(論文 §7.4)
- **2021 年後半:** 少人数のチームで PyTorch ベースの開発を開始。Transformer アーキテクチャの収斂により、モジュラー設計でレイヤを再利用できるという着想が出発点。
- **GSPMD(2021 年発表)**の登場を機に JAX/XLA との深い統合を選択。コンパイラファースト設計により、シャーディングの伝播をグラフレベルで自動化し、ハードウェアプロバイダのコンパイラ最適化がコード変更なしに性能向上をもたらすと判断。
- 当初は Google Cloud TPU のみサポートしていた JAX/XLA が、現在は AWS・GCP で対等な性能を達成し、Trainium2 もネイティブ対応するに至った。
- `InvocationContext` の設計は、命令型プログラミングから関数型プログラミングへの移行に伴う使いやすさの課題(可変状態の禁止・コールスタックを通じたパラメータの受け渡し)を解決するために導入された。
### 運用上の教訓
- **リソース競合の回避:** GCP の限られた TPU 容量のもとで、低リソース使用率や予防可能なエラーによるジョブの浪費が顕著だった。XLA コンパイルはデバイス抽象上で動作するため、AOT コンパイルによりローカルマシンで OOM やサブオプティマルなシャーディングを事前検出できるようになった。
- **ゴールデンコンフィグテスト:** 訓練コンフィグを人間可読形式にシリアライズしコード変更とともにコミットすることで、一貫したレビュー差分の生成・関連コードオーナーのトリガ・実験履歴の追跡可能性を実現。
- **障害対応:** パブリッククラウドの不透明な障害(ハードウェア故障・ICI 障害・サイレントデータ破壊・カーネルパニック・ファイルシステムスロットリングなど)に対し、複数レイヤのレジリエンス機構を構築。Google・Amazon・NVIDIA との協力が不可欠だった。
### 現在の規模(論文 §7.4・スライド p.14)
- 少人数から**数百名のエンジニア**が利用する規模に成長。100 万から 1 兆パラメータ規模のモデルを訓練。
- 同時に **10,000 以上の実験**が、AWS・GCP・Azure・オンプレミスを含む数十の異種クラスタ上で稼働。
- AXLearn で訓練されたモデルの一部は **10 億人以上のユーザ**が利用する製品(Apple Intelligence)に搭載されている。インテリジェントアシスタント・マルチモーダル理解と生成・コードインテリジェンスなどを含む。
## 結論・Takeaway
1. **厳密なカプセル化がモジュラリティを解放する。** コンフィグを木構造で合成し、モジュールをドロップイン置換可能にする設計により、LoC-Complexity を $O(1)$ に抑えた。
2. **LoC-Complexity は通常の LoC カウントが見落とすシステムの進化しやすさを捉える指標である。** 他の全システムは $O(NM)$ であるのに対し、AXLearn のみ $O(1)$。
3. **1 つのコードベースで全バックエンドを対応。** GPU・TPU・B200・Trainium2 を XLA + カスタムカーネル + メッシュルールで統一。同じ Llama2-70B コンフィグが H100・TPU v5p・Trainium2 上で動作し、メッシュ形状の変更だけで済む(スライド p.10)。
4. **コンパイラファースト設計(XLA + GSPMD)への賭けが功を奏した。** 新しいハードウェアが登場してもレイヤのモジュラリティが維持され、コンパイラの進化がコード変更なしに性能を向上させる。
5. **訓練コードからプロダクション級推論を最小限の追加工数で導出できる。** KV キャッシュのカプセル化により、連続バッチング・分離プリフィル/デコード・ページドキャッシュを再実装なしに組み込める。
> オープンソース: https://github.com/apple/axlearn(Apache 2.0、JAX + XLA + GSPMD)