> [!abstract] 概要
> 大規模モデルが多様なドメインで優れた性能を発揮できることは広く認められているが、こうした能力は少数の先進ユーザーと業界リーダーに限定されており、より広いコミュニティにとって暗黙的な技術的参入障壁となっている。本論文では、大規模モデル訓練の産業グレードソリューションとして PyTorch Fully Sharded Data Parallel (FSDP) を紹介する。FSDP はテンソル実装、ディスパッチャシステム、CUDA メモリキャッシュアロケータなど PyTorch の主要コンポーネントと緊密に共設計され、非侵入的なユーザー体験と高い訓練効率を提供する。さらに FSDP は多様なハードウェア構成にわたるリソース利用の最適化技法を一連のネイティブ機能として組み込んでいる。実験結果から、FSDP は DistributedDataParallel と同等の性能を達成しつつ、TFLOPS の観点で near-linear スケーラビリティを保ちながら大幅に大きなモデルをサポートできることが示された。
## 論文情報
| 項目 | 内容 |
|------|------|
| タイトル | PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel |
| 著者 | Yanli Zhao, Andrew Gu, Rohan Varma, Liang Luo, Chien-Chin Huang, Min Xu, Less Wright, Hamid Shojanazeri, Myle Ott, Sam Shleifer, Alban Desmaison, Can Balioglu, Pritam Damania, Bernard Nguyen, Geeta Chauhan, Yuchen Hao, Ajit Mathews, Shen Li |
| 所属 | Meta AI |
| 会議 | PVLDB Vol. 16, No. 12 (2023), pp. 3848–3860 |
| DOI | 10.14778/3611540.3611569 |
| URL | https://www.vldb.org/pvldb/vol16/p3848-huang.pdf |
| 実装 | https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py |
## 概要
PyTorch Fully Sharded Data Parallel (FSDP) は、ZeRO パラメータシャーディングの考え方を PyTorch のコアコンポーネントと緊密に共設計した産業グレードの大規模訓練ソリューションである。モデルを FSDP ユニットに分解し、各ユニット内のパラメータをフラット化・シャードすることで、単一 GPU に収まらないモデルの訓練を可能にする。PyTorch 2.0 のベータ機能として搭載されており、言語モデルから推薦システムモデルまで幅広く実戦検証されている。
## 問題設定
大規模モデルの訓練には主に 3 種類のアプローチが存在する。
1. **モデル複製 (Model Replication)**: DDP はすべてのデバイスにモデルのコピーを保持し、後退パスで AllReduce を実行して勾配同期する。しかし DDP はモデル全体を 1 GPU に収める必要があり、10 億パラメータ超では 40GB GPU でメモリ不足が生じる。
2. **モデル分割 (Model Partitioning)**: パイプライン並列や Tensor RPC はモデルをステージ/デバイスに分散できるが、モデルアーキテクチャへの依存やコード変更を要し、汎用性に欠ける。
3. **モデルシャーディング (Model Sharding)**: パラメータをシャードして各ランクが 1/W を保持する。計算には 2 方法あり、(a) シャードのまま計算してアクティベーションを通信(クリティカルパスに通信が入る)、(b) オンデマンドでパラメータをオールギャザーして完全マテリアライズ後に計算する。FSDP は(b)を採用する。
FSDP が対処する主要な設計課題は 4 つである。
- **ユーザー体験**: モデルが 1 GPU に収まらない場合の初期化問題を解決する必要がある。
- **ハードウェア異質性**: ノード内高帯域/ノード間低帯域の階層構造に適応する必要がある。
- **リソース利用**: 通信に起因する計算バブルを排除する必要がある。
- **メモリ計画**: CUDA キャッシュアロケータの断片化を抑制する必要がある。
## 提案手法
### モデル初期化
FSDP は **遅延初期化 (deferred initialization)** を導入する。ユーザーはモデルパラメータテンソルを「仮想」デバイス上に確保し、すべての初期化演算を記録する。モデルを GPU に移動する際に記録済み演算を自動リプレイすることで、GPU メモリを一切消費せずに大規模モデルのインスタンス生成が可能になる。その後、1 ユニットずつ GPU に移動してシャードし、次のユニットに進む。
サブモジュール初期化がクロスモジュール依存をもつ場合など、遅延初期化が使えない場面では GPU 上での完全モデル初期化と CPU 上でのストリーミング初期化の 2 代替手段が提供される。
### シャーディング戦略
シャーディング係数 F を導入し、3 種類の戦略を統一フレームワークで表現する。
- **フルシャーディング (F = W)**: 全ランク数でシャード。最小メモリフットプリントだが通信量が最大(DDP 比 1.5× )。
- **フル複製 (F = 1)**: DDP と同等の AllReduce ベースのデータ並列。
- **ハイブリッドシャーディング (1 < F < W)**: シャードグループとレプリケーショングループを組み合わせ。Fat-Tree ネットワークのノード内/ノード間帯域差を活用してクロスホストトラフィックを削減できる。
### FlatParameter によるコレクティブ効率化
AllGather の効率化には 2 つの条件がある。(1) NCCL は均等な入力テンソルサイズを要求する、(2) 入力サイズが大きいほど起動オーバーヘッド削減と帯域幅利用率向上につながる(33M 要素未満で急激に効率低下)。
この 2 条件を満たすため、FSDP は 1 ユニット内の全パラメータを **FlatParameter** という 1 次元テンソルに連結(シャーディング係数で割り切れるようにパディング)し、均等にシャードする。FlatParameter の勾配は同じレイアウトを継承し、AllGather/ReduceScatter をコピーなしで呼び出せる。ピーク時のパラメータメモリは O(Σψ_i/F + max ψ_i) となる。
### 通信最適化
4 つの主要な通信最適化がある。
**1. 計算通信オーバーラップ (Overlapping Communication and Computation)**
DDP の場合は後退 AllReduce が先行計算に続くため async-collective-and-wait() で重複できる。しかし FSDP の前向き AllGather は先行計算の後に発行されるため同じ手法が使えない。FSDP は別の CUDA ストリームで AllGather を発行し、デフォルトストリームの前計算への偽依存を回避することで重複を実現する。
**2. 後退プリフェッチ (Backward Prefetching)**
後退パスでは ReduceScatter と次の AllGather が同じ NCCL ストリームで逐次実行されるとクリティカルパスに露出する。FSDP は前向きパスのモジュール実行順序を記録し、後退パスの実行順序の代理として使う。現在の ReduceScatter より前に次の AllGather を発行することで 2 連続の通信コールを隠蔽する。GPT-175B で約 18% のスループット向上。
**3. 前向きプリフェッチ (Forward Prefetching)**
CPU 実行が遅いワークロードでは CPU スレッドが NCCL ストリームを満たすのに間に合わないことがある。静的グラフのモデルでは前の反復から実行順序を仮定し、現在の FSDP ユニットの前向き計算前に次の AllGather を発行できる。
**4. 勾配累積 (Gradient Accumulation)**
通信あり/なしの 2 変形を提供する。通信なしの変形は増加したメモリと引き換えに通信を減らし、エンドツーエンドのスループットを向上できる。
### メモリ管理
FSDP は **レートリミッター** で最大 2 つの処理中 AllGather を許可し、CPU スレッドの先行実行による CUDA キャッシュアロケータの過剰メモリ確保と cudaMalloc リトライを抑制する。これは AllGather 先行ストリーム(producer)とデフォルト計算ストリーム(consumer)の間でブロックが再利用できないことに起因する問題に対処する。
### ネイティブ混合精度
FlatParameter 単位で前向き前(pre-forward)にのみキャストを実施することで、演算子レベルのキャスト(torch.amp.autocast)より少ないキャスト回数を実現する。シャードされた FlatParameter をメモリに保持し、非シャード FlatParameter のみを動的確保するため、混合精度でのピークメモリは増加ではなく O(K_full * Σψ_i/F + K_low * max ψ_i) と低精度のキャストコストに抑えられる。勾配もシャードされるため、FSDP は専用のシャード済み勾配スケーラーを提供する。
### 実装詳細
API は 2 系統: `FullyShardedDataParallel` モデルラッパー(サブモジュールを FSDP ユニットで置換)と `fully_shard` モジュールアノテーター(モデル構造と完全修飾名を保持しフックで FSDP ロジックを注入)。後退パスへの統合には autograd エンジンの 3 種類のフック(Tensor の `register_hook`、`queue_callback`、AccumulateGrad) を活用し、非侵入的で正確な通信タイミングを実現する。
## 新規性
- ZeRO(DeepSpeed)のアイデアを PyTorch コアコンポーネント(テンソル、ディスパッチャ、CUDA アロケータ)と共設計し、フレームワーク内部依存を排除して安定性と汎用性を実現した。
- FlatParameter による複数パラメータの連結シャーディングで、均等サイズ・大きなコレクティブ発行という 2 条件を同時に満たす通信効率設計。
- AllGather を別 CUDA ストリームで発行することで FSDP の前向き AllGather と先行計算のオーバーラップを実現した(DDP とは逆方向の依存関係への対処)。
- レートリミッターで CUDA キャッシュアロケータのクロスストリームブロック再利用不可問題を制御する。
- シャーディング係数 F によりフル複製(DDP と等価)からフルシャード(ZeRO-3 に相当)まで統一フレームワークで表現し、ハイブリッドシャーディングでデータセンタートポロジへの適応を実現。
## 実験設定
- ハードウェア: 最大 512 枚の 80GB A100 GPU、2Tb/s RoCE ネットワーク
- モデル: T5-611M / T5-2B / T5-11B、minGPT-175B、DHEN 推薦モデル(768B スパース + 550M 密)
- 比較対象: DDP (DistributedDataParallel)
- 測定指標: TFLOPS/GPU、バッチあたりレイテンシ、ピークメモリ割当/アクティブ/予約量
- 設定: アクティベーションチェックポイント + BF16 混合精度 + Adam オプティマイザ(本番ワークロードを模倣)
## 実験結果
**モデルスケール (T5)**
- T5-611M および 2.28B では FSDP と DDP の性能は同等(TFLOPS 差は 1% 未満)。
- DDP は 2.28B 超でメモリ不足。FSDP は T5-11B をフルシャーディング + BF16 で問題なく処理。
- T5-11B: 8 GPU で約 154 TFLOPS/GPU、512 GPU でも約 148 TFLOPS/GPU(7% の微回帰)。
**後退プリフェッチ (GPT-175B)**
- 128–512 A100 の全規模で約 18% のスループット向上(175 TFLOPS/GPU → 約 175.8 TFLOPS/GPU 台)。
**レートリミッター**
- T5-11B (4 マシン): 制限なし 21.81 s/バッチ → 制限あり 15.33 s/バッチ(約 30% 短縮)。
- T5-11B (2 マシン): 制限なし 18.61 s/バッチ → 制限あり 14.81 s/バッチ(約 20% 短縮)。
- ただし効果はモデルに依存し、RegNet 9B では改善なし、DeepViT 8B では 5% 悪化。
**大規模モデル**
- GPT-175B (512 A100): バッチサイズ 1/2 でそれぞれ 173/186 TFLOPS/GPU、A100 の理論ピーク (312 TFLOPS BF16) の 55%/60% に相当。128–512 GPU で線形スケーラビリティ。
- DHEN 推薦モデル (512 A100): フルシャーディング(RAF)はメモリ最小だが QPS も最小、ハイブリッドシャーディング(NRAF)は逆のトレードオフ。
## 考察
- **FlatParameter の粒度は精度とスループットのトレードオフ**: FlatParameter を細かく構築するとピークメモリは下がるが、コレクティブ数が増えスループットが低下する。ユーザーはラッピング方針で制御可能。
- **ハイブリッドシャーディングのトポロジ適応**: Fat-Tree の過剰購読環境でシャーディング係数を NIC 数に合わせることで、クロスホストトラフィックをフル複製の 2M(W-1)/W から 2M(W-1)/(GW) に削減できる。
- **レートリミッターの適用条件**: cudaMalloc リトライが発生している場合にのみ有効。`torch.cuda.memory_stats()` の `num_alloc_retries` で判断する。
- **他の並列化との相互運用**: パイプライン並列とは各ステージを FSDP でラップ可能。マイクロバッチ毎の AllGather を避けるには NRAF 設定が必要。テンソル並列とは `parallelize_module` (DTensor) との 2D メッシュ構成が可能で、TP をノード内、FSDP をノード間に配置する。
- **既知の制限**: (a) シャード境界がパラメータ境界と一致しないためオプティマイザによっては数学的同値性が崩れる。(b) 共有パラメータは最下位共通祖先ユニットに属させる必要があり、そうしないと未シャードのまま長い区間が保持されうる。
## 強み / 弱点・課題
**強み**
- PyTorch コアとの共設計により、フレームワーク内部の変更に強く、DDP ライクな API でユーザー体験を維持する。
- 単一フレームワークでフル複製〜フルシャード〜ハイブリッドをシャーディング係数 F の調整だけでカバー。
- 通信効率(FlatParameter)、オーバーラップ(別 CUDA ストリーム + プリフェッチ)、メモリ制御(レートリミッター)の 3 軸を統合的に最適化。
- near-linear スケーラビリティを最大 512 A100 で実証。
**弱点・課題**
- レートリミッターの有効性はモデル依存で、通信主体のワークロードでは逆効果になる可能性がある。
- T5-11B で GPUs 数増加に伴う 7% の MFU 低下は通信の計算への露出を示す。完全なオーバーラップは大規模では達成困難。
- オプティマイザが個別パラメータのアンシャード値(ベクトルノルム等)や構造(近似二次オプティマイザ)に依存する場合、数学的同値性が成立しない。
- 共有パラメータの取り扱いは手動のユニット境界調整が必要でエラーが発生しやすい。