# ZeROパラメータシャーディング ## 定義 ZeRO (Zero Redundancy Optimizer) パラメータシャーディングは、データ並列訓練においてモデルパラメータ・勾配・オプティマイザ状態をデータ並列ランク間でシャードし、各デバイスが全体の 1/W しか保持しないようにすることでメモリフットプリントを削減する手法である。ZeRO-3 はパラメータまでシャードする最も積極的な段階であり、PyTorch の Fully Sharded Data Parallel (FSDP) はこれを PyTorch コアコンポーネントと共設計した産業グレード実装として位置づけられる。 ZeRO は [[DeepSpeed]](Microsoft)が提案し、PyTorch の FSDP はその設計思想を PyTorch のテンソル実装・ディスパッチャ・CUDA メモリキャッシュアロケータと緊密に再実装したものである。 ### 基本アルゴリズム 各ランクはシャードされたパラメータのみをメモリに保持する。前向き・後退計算の直前に AllGather でアンシャードパラメータを再構成し、計算後に即座に破棄する。後退計算後に ReduceScatter で勾配をシャード化する。訓練ループ全体を通じてオプティマイザ状態はシャードされたまま維持される。 ピーク時のパラメータメモリ要件は O(Σψ_i/F + max_i ψ_i) で、第 1 項はシャードの常駐コスト、第 2 項は最大 FSDP ユニットのアンシャードコストに対応する。([[@@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]]) ### DDP との通信量比較 フルシャーディング(F=W)では DDP の AllReduce と比較して 1.5× の通信量・体積になる。これは前向きの AllGather + 後退の ReduceScatter = 2× AllReduce 相当の通信で、DDP の 1× AllReduce に対して増加しているためである。この追加コストを隠蔽することが FSDP の通信最適化の核心課題となる。 ## 横断的知見 - **FlatParameter による効率化**: PyTorch FSDP は FSDP ユニット内の全パラメータを 1 次元の FlatParameter に連結・パディングすることで均等サイズの大きな AllGather を 1 回発行する設計を採る。これは NCCL の均等入力要件を満たしながら起動オーバーヘッドを最小化する。33M 要素未満の AllGather では総通信時間が急増するという実測から導かれた設計原則である。([[@@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]]) - **シャーディング係数 F によるメモリ・スループット連続トレードオフ**: F=1 がデータ並列(DDP 相当)、F=W がフルシャード(ZeRO-3 相当)で、1 < F < W のハイブリッドシャーディングが中間のトレードオフを提供する。Fat-Tree ネットワークのノード内/ノード間帯域差を活用するとき、シャーディング係数をノード内 GPU 数に合わせることでクロスホストトラフィックを大幅削減できる。([[@@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]]) - **DeepSpeed ZeRO-3 との設計上の差異**: ZeRO/DeepSpeed はパラメータをパー・パラメータのシャーディングと Broadcast/Gather 系コレクティブで実装するのに対し、FSDP は FlatParameter への連結と AllGather/ReduceScatter を採用する。前者はデバイス間でワークロードが不均等になりうるのに対し、後者は均等に分割される。また FSDP はフレームワーク内部を書き換えるのではなく公開 API を通じて設計されているため、フレームワーク変更への耐性がある。([[@@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]]) - **MiCS との勾配通信戦略の差**: MiCS は global AllReduce 後にシャーディングを行うため各ランクが全勾配を保持するのに対し、FSDP の ReduceScatter はシャード単位の勾配のみを保持し、1 レイヤー分のシャード勾配への低メモリ要件を実現する。([[@@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]]) - **データ並列の複製冗長が復元資源になる**: [[FlashRecovery]] は FSDP などの ZeRO 系でデータ並列度 N の複製を障害復元の冗長コピーとして活用し、チェックポイントを不要化する。ZeRO シャーディングのメモリ削減効果と耐障害性の向上が同一のデータ並列構造で両立する。([[@2025__arXiv__FlashRecovery - Fast and Low-Cost Recovery from Failures for Large-Scale Training of LLMs]]) - **DeepSpeed ZeRO Stage 2 の実装上の注意点**: [[PMBS 2025]]([[@2025__PMBS__Pretraining LLMs at Scale - Tuning Strategies and Performance Portability]])は DeepSpeed ZeRO Stage 2 の既定実装が論文上の説明と異なり reduce-scatter ではなく all-reduce で動作することを実測で指摘する。`reduce_scatter=true` に加えて `use_multi_rank_bucket_allreduce=false` の設定が必要。実装と理論記述のずれが実測チューニングで発見される代表例。 ## 未解決の問い - FlatParameter のユニット境界設定は手動または auto_wrap_policy で制御されるが、静的構造と実行順序が乖離するケースでは最適境界の発見が難しい。動的実行順序を反映した自動 FlatParameter 構築の実現可能性。 - シャーディング係数 F の最適値はモデル・ハードウェア・通信トポロジの組み合わせによる。F の自動探索(Auto Parallelism との統合)はどこまで成熟しているか。 - オプティマイザのシャード境界非整合問題(数学的同値性の崩壊)。二次オプティマイザや norm 依存オプティマイザとの共設計は未解決の研究課題とされている。 - パイプライン並列と組み合わせたとき、マイクロバッチごとの AllGather を NRAF で回避できるが、これは 1 パイプラインステージ全体をアンシャードのまま保持するメモリコストを要求する。PP+FSDP の最適構成。 ## 関連 - ソース: [[wiki/sources/@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]] - 概念: [[並列化戦略]] / [[LLM分散学習]] / [[集合通信]] / [[チェックポイント]] - エンティティ: [[DeepSpeed]] / [[Yanli Zhao]] / [[Meta AI]] - 関連 MOC: [[分散深層学習 - MOC]] ## 出典 - [[@@2023__VLDB__PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel]] — PyTorch FSDP の設計・実装・評価(Meta AI、VLDB 2023) - [[@2025__arXiv__FlashRecovery - Fast and Low-Cost Recovery from Failures for Large-Scale Training of LLMs]] — FSDP でのデータ並列複製による障害復元 - [[@2025__PMBS__Pretraining LLMs at Scale - Tuning Strategies and Performance Portability]] — DeepSpeed ZeRO Stage 2 実装差の実測