# ProTrain: Efficient LLM Training via Automatic Memory Management
> [!info] Talk metadata
> - **会議:** [[MLSys2026]]、Research Track Oral "LLM Training 1" セッション(08:30 PDT 開始の LLM Serving 2 & LLM Training 1 枠)
> - **論文タイトル:** ProTrain: Efficient LLM Training via Automatic Memory Management
> - **著者:** Hanmei Yang, Jin Zhou, Hui Guan, Tongping Liu(University of Massachusetts Amherst)/ Yao Fu, Xiaoqun Wang, Ramine Roane(Advanced Micro Devices, Inc.)。連絡先は Hanmei Yang(
[email protected])
> - **所属:** University of Massachusetts Amherst / Advanced Micro Devices, Inc.
> - **掲載:** Proceedings of the 9th MLSys Conference, Bellevue, WA, USA, 2026
> [!abstract] 概要(論文 PDF アブストラクト・忠実日本語訳)
> メモリ圧迫は、特に資源制約のある環境において、大規模言語モデル (LLM) の学習をスケールさせるうえで支配的な制約となっている。現代のフレームワークは様々なメモリ節約技法を備えるが、それらはしばしば手動チューニングと専門的なシステム知識を要する低レベルの設定ノブを露出させる。これはエンジニアリングのオーバーヘッドを増やすだけでなく、設定を誤るとハードウェア利用が最適でなくなるリスクをもたらす。本論文は **ProTrain** を提案する。これはメモリ管理ポリシーをモデルアーキテクチャと基盤ハードウェア資源に対して自動的に適合させ、手動介入の必要を排除する新しい学習システムである。ProTrain の中核は、複雑なメモリ管理戦略を少数のチューナブルな設定パラメータに抽象化し、コストモデルを用いて最適なパラメータ設定の探索を可能にする自動メモリ管理である。ProTrain には、レイテンシ・メモリ使用量・I/O 帯域の精密な推定を提供して高忠実度のコストモデルを構築するランタイムプロファイラが備わっている。ProTrain は学習アルゴリズムを変更しないため精度を損なわない。実験により、ProTrain は最先端の学習システムと比較して学習スループットを **1.43×〜2.71×** 向上させることを示す。
> [!note] 出典に関する注記
> 本ノートは**論文 PDF とスライド PDF のみ**に基づき、音声文字起こしは存在しない。したがって Q&A は記載しない。公式タイトルは論文/スライド表紙の "ProTrain: Efficient LLM Training via Automatic Memory Management"(ファイル名では `:` を ` -` に置換)。**発表者個人(誰が登壇したか)は提供資料からは特定できない**。著者は UMass Amherst(Hanmei Yang, Jin Zhou, Hui Guan, Tongping Liu)と AMD(Yao Fu, Xiaoqun Wang, Ramine Roane)の共同で、連絡先は Hanmei Yang。本文の数値・固有名・式・図はスライド/論文を権威とした。
## 問題設定: メモリの壁と手動チューニングの限界
- **メモリがボトルネック**: LLM 学習では model states が支配的なメモリ消費源。混合精度学習では1パラメータあたり **16 bytes**(fp16 weights + fp16 grads + fp32 master weights + fp32 momentum + fp32 variance、activation は別)を要する。スライドの例では **LLaMA-2 70B × 16 bytes = 1.12 TB** で、単一の 80GB A100 GPU の 14× のメモリが必要(activation を除いてなお)。論文によれば既存研究 (Ren et al., 2021) に従えば各パラメータ増加は一般に 16× のメモリ増を伴う。
- **3つの主要なメモリ最適化技法**(スライド "Memory Optimization Techniques"):
- **ZeRO / FSDP**: model states を GPU 間で分割し冗長性を削減(sharded data parallelism)。
- **Offloading / Swapping**: CPU メモリを活用して GPU 容量を拡張。テンソル(model states または activations)を CPU DRAM や NVMe へ退避。
- **Gradient Checkpointing**: forward で選択テンソルを破棄し backward で再計算することで、計算をメモリと引き換える。
- **これらの技法は競合・対立する**(スライド "These Techniques Conflict and Compete")。論文も「技法はしばしば相互排他的」「補完的技法でさえ共有ハードウェア資源を奪い合う」と述べる。
- 技法選択: model states では **ZeRO Sharding vs. Param. Offloading**(GPU にフル保持か、分割して CPU へ退避か)。activation memory では **Grad. Checkpointing vs. Act. Swapping**(再計算か CPU へ退避か)。
- 資源競合: PCIe 帯域で **Act. Swapping vs. Param. Offloading**(両者とも PCIe を必要とし一方を飽和させると他方を飢餓状態にする)。GPU メモリで **Resident Param. vs. Prefetch Buffers**(パラメータを多く常駐させると prefetch の余地が減る)。
- **手動チューニングが破綻する**: 例えば DeepSpeed は ZeRO partitioning・tensor swapping・gradient checkpointing にまたがる **18 以上のチューナブルパラメータ**を露出し、相互に密結合する。`stage3_max_reuse_distance` と `stage3_max_live_parameters` のように互いに対立するノブもある。論文の評価では、RTX 3090 上で 10B GPT-2 を**デフォルト設定**で学習すると GPU メモリの **35.6%** しか使えず、最適化設定より **1.18× 遅い**。さらに RTX 3090 向け最適設定は A100 を活用しきれず、A100 向け設定は RTX 3090 で OOM を起こすため、ハードウェアが変わるたびに手動で再チューニングが必要になる。
- **本質的な問題は「協調 (Coordination)」**(スライド "The Real Problem: Coordination")。LLM 学習のメモリ最適化は、独立した技法の寄せ集めではなく、**compute(再計算コスト)・GPU memory(容量割当)・communication(I/O 帯域)** にまたがる**結合的な協調問題**である。これを3つの技術課題に分解:
1. **クリーンな抽象化の欠如** → Structured Memory Strategies で解決(既存は複雑な相互依存を持つ非構造的なパラメータ空間を露出)。
2. **不正確なメモリ推定** → Memory-Aware Profiler で解決(従来のプロファイラは実メモリ使用量を過小評価)。
3. **手動チューニングのスケール不能** → Automatic Memory Management で解決(探索空間が手動探索には大きすぎ、最適設定はハードウェアで変わる)。
## ProTrain の設計: 3 コンポーネントと最適化目的
- ProTrain は3つのコアコンポーネントからなる(論文 Figure 1, スライド "ProTrain: System Overview"):
1. **Structured Memory Strategies**: model states と activations を、それぞれの最適化技法に適した戦略で管理。
2. **Memory-Aware Profiler**: ランタイム情報とメモリデータを収集してメモリ管理判断を導く。
3. **Automatic Memory Management**: 対象 LLM をその対象ハードウェア上で学習する最適なメモリ最適化戦略を動的に特定。
- **最適化目的**(スライド "Optimization Objective"、論文 §3.3): per-iteration の実行時間を最小化しつつ peak memory を GPU 容量内に収める制約付き最適化。
- **minimize** `T(n_persist, n_buffer, n_swap, n_ckpt)`(per-iteration time)
- **subject to** `M(n_persist, n_buffer, n_swap, n_ckpt) ≤ C`(GPU memory capacity)
- **4 つのチューナブル設定パラメータ**(探索対象)。スライドの表記で:
- **n_persist**: persistent chunks の数。
- **n_buffer**: chunk buffers の数。
- **n_swap**: swapping blocks の数。
- **n_ckpt**: checkpointing blocks の数。
- 論文ではこれらに加えて、設定探索前に独立に決まる補助パラメータ `S_chunk`(chunk size)・`N_chunk`(number of chunks)・`N_interval`(swapping interval)・`N_block`(number of blocks) を定義する。探索では `{n_persist, n_buffer, n_swap, n_ckpt}` をチューニングする。
- ProTrain は PyTorch 上に**約 7,600 行**で実装され、ユーザは model と optimizer を提供インタフェースでラップするだけで学習ループは変更不要(論文 §4、Figure 1 左上のコード例)。学習アルゴリズム自体は変えないため精度を損なわない。評価対象アーキテクチャは GPT-2・OPT・Mistral・LLaMA。
### Structured Memory Strategies
- **Model States → Hierarchical Chunk Management**(階層的チャンク管理): ZeRO Sharding と CPU Offloading を統合。model states(parameters, gradients, optimizer states)を効率的な I/O 転送のための**均一なチャンク**に組織。chunk を **persistent chunk**(GPU 常駐、直接 GPU でパラメータ更新)と **non-persistent chunk**(CPU/他ランクへ退避、計算時に GPU へ gather、更新後に CPU へ offload)に分ける。persistent chunk 数を Automatic Memory Management で調整することで、複雑な協調なしにメモリ使用量と通信コストの単純かつ効果的なトレードオフを達成。
- **Intra-Chunk Level**(スライド "Hierarchical Chunk Management"): model-defined order で chunk 内のレイヤーを並べると ping-pong access(Chunk1→Chunk2→Chunk1→Chunk2…)が起きキャッシュスラッシングを誘発。ProTrain は**実行順 (execution order)** でチャンク内を並べ替え、sequential access と early release を実現(gradient checkpointing 使用時は逆順アクセスにも適応)。これは forward/backward の operator 実行順をランタイムフックで捕捉して実現。
- **Inter-Chunk Level**: persistent chunk を**モデル先頭から順に**割り当て、残りを non-persistent にする。FWD では早期レイヤーを persistent にすることで prefetch window を長く取り、non-persistent chunk の prefetch を後続レイヤーの計算と重ねて隠す。BWD では non-persistent chunk の **CPU optimizer update を GPU backward 計算と重ねる**(effective use of idle CPU cycles)。事前確保した chunk buffer と決定的 prefetch で eviction policy を不要にし、安定したランタイム挙動と正確なモデリングを可能にする。
- **Activations → Interleaved Block Management**(インターリーブブロック管理): Gradient Checkpointing と Activation Swapping を統合。transformer block 単位で各ブロックに独立な activation 戦略を割り当てる **block-wise activation management**。各ブロックは3戦略から選択(スライド "Interleaved Block Management"):
- **S (Swap)**: CPU へ offload。
- **C (Checkpoint)**: 再計算。
- **N (None)**: GPU に保持(最適化なし)。
- block 単位は tensor-level 管理(探索空間が巨大)より粗く、coarse-grained な一律適用より柔軟。計算グラフの再構築を要さず既存 transformer 実装に適応。
- **Interleaved Layout**(論文 Figure 2、スライド例は 8 ブロック): FWD で `S1 C2 C3 S4 C5 C6 N7 N8`(S を間隔を空けて配置、C が隙間を埋め、N を末尾に置く)。Offload(O1 は C2,C3 計算中、O4 は C5,C6 計算中)が C の計算と重なる。BWD は `N8 N7 C6 C5 S4 C3 C2 S1`(N を先に走らせメモリを解放して prefetch 余地を作る)。Prefetch は just-in-time(P4/P1 ready)で OOM を防ぎ BWD に間に合わせる。
- レイアウトの利点(論文): (1) swapping block を早期に置くほど swapping を計算と重ねる機会が増える。(2) swapping と checkpointing を交互配置することで activation 蓄積による OOM リスクを減らし peak memory を抑える。(3) 最適化なしのブロックを後段に置くことで、その activation を早く消費し swapping block の prefetch を間に合わせる。
### Memory-Aware Profiler
- **従来プロファイリングが失敗する理由**(スライド "Memory-Aware Profiler"、論文 §3.2): 静的プロファイリングやレイヤー単位ランタイムプロファイリングは2つのソースのメモリを捕捉できない。
1. **Intra-op Delta**: operator 実行中の transient tensors(中間結果)が peak memory を押し上げるが捕捉されない。
2. **Inter-op Delta**: `nn.functional.softmax` や `nn.functional.layer_norm` のような functional API 経由の **unhookable operators**(`nn.Module` でなくフックを回避するもの)のメモリ。
- これらの見落としが 10B GPT-2(batch size 16)で peak memory の **17.2%(3.06 GB)** を占め、不正確なメモリプランニングと OOM リスクの増大を招く。
- **Memory-Aware Profiling のアプローチ**: operator を孤立して個別にプロファイルするのではなく、**完全なモデル実行トレース内**でプロファイルする。
- **Full Model Execution Trace**: 単一の完全な forward pass で全 operator をプロファイル。
- **Drop Model States & Activations**: on-demand tensor management(テンソルを使用直前に確保し直後に解放)で peak memory を「単一モデル全体」ではなく「最大の単一 operator」のレベルまで下げ、単一 GPU に収める。model states や activations の固定的で予測可能なサイズは静的解析で peak memory への寄与を再構築。
- **Measure Both Memory Deltas**: intra-op(transient)と inter-op(unhookable)の双方のメモリデルタを計測。PyTorch の CUDA caching allocator 統計と pre/post operator hooks を用いる。
- **Reconstruct Actual Peak Memory**: 静的解析と operator-level トラッキングを組み合わせ、多様な設定で正確な peak memory 推定を実現。
- プロファイラはユーザ指定 batch size で**単一 GPU**上のサンプリングを行い、memory・latency・I/O(memory transfer bandwidth, collective communication operation durations)等のハードウェアメトリクスを収集。プロファイリングは**単一の再利用可能なパス**で済む(論文 Figure 1: "Single Profiling Pass (Reusable)")。
### Automatic Memory Management(コストモデルと探索)
- **2 つの解析的コストモデル**(スライド "Automatic Memory Management"、論文 §3.3, Appendix A.1/A.2): プロファイラ出力(memory usage per op, runtime per op, memory transfer bandwidth, collective comm. durations 等)を入力に、設定パラメータの関数として:
- **Runtime Model**: `T = f(n_persist, n_buffer, n_swap, n_ckpt)` で iteration time を予測。
- **Memory Model**: `M = g(n_persist, n_buffer, n_swap, n_ckpt)` で peak GPU memory を予測。
- 実際に各候補で学習イテレーションを走らせる手法と異なり、コストモデルで全設定を解析的に評価するため探索空間が大きくても扱える。非同期 tensor swapping のようにランタイム変動でモデル化が難しい挙動は、決定的なスケジューリング構造を露出するよう memory strategy を再設計することで、相互作用を予測可能にし bandwidth contention や overlap といった複合効果も推論可能にした。
- **Configuration Search**(スライド同上): `min T(config) s.t. M(config) ≤ GPU cap` を**枝刈り付き網羅探索**で解く。
- **Prune by bandwidth**: `n_swap` は swapping interval `N_interval` と利用可能帯域で制約され、実行可能値が小集合に限定される。
- **Discard OOM early**: 設定をメモリ使用量の昇順で評価し、GPU 容量を超えるものを早期に破棄。
- メモリ制約を満たしつつ最短 iteration runtime を達成する `{n_persist, n_buffer, n_swap, n_ckpt}` を最終設定に選ぶ。
- **プロファイリング・探索オーバーヘッド**(論文 §5.3.4): プロファイリングは単一学習イテレーションに線形にスケールし軽量。RTX 3090 上で 7B Mistral のプロファイリングは **3.09 秒**、20B GPT-2 は **5.38 秒**。最適設定探索はさらに高速で平均 **0.06 秒**。
## 評価
### 実験セットアップ(論文 §5.1, スライド "Evaluation Setup")
- **モデル**: Mistral 7B、OPT 13B/30B、LLaMA-2 13B/34B、GPT-2 10B/15B/20B/30B/40B(HuggingFace 実装、sequence length は 1024 固定)。
- **ベースライン**: DeepSpeed(ZeRO-3 + offloading、v0.12.1)、Colossal-AI(Gemini Plugin chunk-based ZeRO-3、v0.3.3)、PyTorch FSDP(ZeRO-3 + CPU offloading + selective gradient checkpointing、PyTorch 2.0.1)。比較は ZeRO-3・CPU offloading・gradient checkpointing の共通機能に絞る。
- **ハードウェア**(2 環境):
- **Server 1**: 4× RTX 3090(24GB)、Xeon Silver 4214R(CPU DRAM 384GB)、CPU-GPU/GPU-GPU とも PCIe 3.0(15.8 GB/s)。NVLink なし。
- **Server 2**: 4× A100(80GB)、Xeon Platinum 8480+(CPU DRAM 1TB)、CPU-GPU は PCIe 4.0(31.5 GB/s)、GPU-GPU は NVLink 3.0(300 GB/s)。
### 最大学習可能モデルサイズ(論文 Table 2, スライド "Maximum Trainable Model Size")
| ハードウェア | ProTrain | DeepSpeed | Colossal-AI | FSDP |
|---|---|---|---|---|
| 1× RTX 3090 | **34B** | 15B | 25B | 3B |
| 4× RTX 3090 | **37B** | 15B | 25B | 15B |
| 1× A100 | **75B** | 34B | 53B | 10B |
| 4× A100 | **87B** | 37B | 53B | 55B |
- ProTrain は 4×A100 で **87B** までスケールし、DeepSpeed 比 **2.35×**、Colossal-AI 比 **1.64×**、FSDP 比 **1.58×** 大きいモデルを学習できる(スライド)。RTX 3090 では DeepSpeed 比 2.47×、Colossal-AI 比 1.48×(論文)。
### 学習スループット(論文 §5.2.2, Figure 3)
- **4× RTX 3090**(スライド "Training Throughput (4× RTX 3090)"): ProTrain は全モデルで最高スループット。平均 **2090 tokens/s** で、**DeepSpeed 比 1.97×・Colossal-AI 比 1.77×・FSDP 比 2.71×**。モデル別(対 FSDP)では opt-13b で 2.62×、gpt2-15b で **5.05×**(=対 DeepSpeed)等。gpt2-20b では DeepSpeed/FSDP が OOM で失敗(×)。
- **4× A100**(論文 §5.2.2, Figure 3 下): DeepSpeed 比 **1.85×**、Colossal-AI 比 **1.43×**、FSDP 比 **2.22×**。34B LLaMA で対 FSDP **2.78×**。
- 全体として論文要旨どおり **1.43×〜2.71×** の範囲のスループット向上(下限は A100 の対 Colossal-AI、上限は RTX 3090 の対 FSDP)。
### スケーラビリティ・性能分解(論文 §5.2.3, §5.2.4)
- **GPU 数スケーラビリティ**(Figure 4(a)): 10B GPT-2 で RTX 3090 を 4 枚使い 2493 tokens/s、単一 GPU ベースライン比 **3.5×** 向上。34B LLaMA を 4×A100 で単一 GPU 比 **2.49×〜3.58×**(Figure 7(a))。
- **性能分解**(Figure 4(b)): hierarchical chunk management と interleaved block management が data movement と I/O-compute overlap を改善。CPU parameter update を GPU backward と重ねることで寄与をほぼ無視できる水準に隠す。batch size 増で ProTrain はメモリ節約技法を強める。
### Ablation・推定器精度・オフロード比較
- **最適化戦略の Ablation**(Figure 5、10B GPT-2 / RTX 3090): 各最適化を無効化した相対ランタイム。
- hierarchical chunk management 無効化(persistent chunk を 3 chunk buffer に置換): 1.02×〜1.19× の slowdown。
- overlapped parameter update 無効化: 平均 1.22× slowdown(CPU 更新が GPU に隠れなくなる)。
- interleaved block management 無効化(全ブロックに checkpointing 適用): 平均 1.04× slowdown(RTX 3090 は PCIe 3.0 15.8 GB/s で通信律速のため checkpointing 寄りが既に最適に近い)。GH200(NVLink-C2C 450 GB/s、A100 比 ~14.3×)のような高帯域では swapping が有利になり interleaved block management の利得が大きくなると論文は推定。
- **推定器の精度**(Figure 6): 10B GPT-2 の多様な設定で予測 vs 実測の runtime・peak memory がよく一致。**推定誤差は 4% 未満**で信頼できる自動戦略選択を可能にする。
- **オフロード有無の比較**(Table 3、4×A100 tokens/s): ProTrain は automatic 設定で Mistral-7B 11060.92・GPT2-10B 8266.40・LLaMA-13B 6471.32・GPT2-20B 5043.75。13B LLaMA で Colossal-AI は offloading 有で batch 220 まで(無効だと 32 までで 13% slowdown)、ProTrain は batch 228 を支え best Colossal-AI 比 1.3× 高スループット。FSDP は offloading 無効で全モデル OOM。offloading を適切に統合すれば並列化の代替として cost-effective に大 batch・高スループットを実現できることを示す。
- **探索された最適設定**(Table 4): モデル/batch/ハードで設定が変わることを例示。GPT2-1B/batch64/RTX 3090 は `N_block=32, n_ckpt=24, n_swap=2, N_chunk=12, n_persist=2, n_buffer=3`。GPT2-10B/batch8/RTX 3090 は `N_block=48, n_ckpt=48, n_swap=0, N_chunk=49, n_persist=3, n_buffer=46`、同 A100 では `n_ckpt=0, n_persist=15, n_buffer=3`。batch を 8→64 に上げると n_swap・n_ckpt・n_buffer が増え n_persist が減る。大モデルほど小さい n_persist/n_buffer と大きい n_swap/n_ckpt を要する。
- **マルチノードスケーラビリティ**(論文 §5.5): 175B GPT(batch 256)を 16×A100(4 ノード、100Gb InfiniBand)で学習し **2218.70 tokens/s**、DeepSpeed 比 2.58×・Colossal-AI 比 2.09×。FSDP は OOM で失敗。
## 結論
- **問題**(スライド "Conclusion"): 既存フレームワークは密結合した多数のメモリ関連パラメータの広範な手動チューニングを要し、システム実装が統一的最適化を欠く。
- **解法**: ProTrain はメモリ戦略を構造化された **4 パラメータ空間**に抽象化し、memory-aware profiler とコストモデルで自動設定探索を可能にする。学習アルゴリズムは変えないため精度を損なわない。
- **成果**: 4×A100 で最大 **87B** パラメータを学習(ベースライン比 1.58×〜2.35× 大)。全モデル・全ハードウェア構成で **1.43×〜2.71×** の高スループット。論文の結論は単一 A100 で SOTA 比最大 **5×** の性能・最大 75B パラメータの学習にも言及し、追加投資なしに既存ハードウェアをより効率的に使う道を提供すると主張する。
- **位置付け**(論文 §6): ProTrain は ZeRO を用いた data-parallel 学習に注力し、並列化最適化(Alpa, Galvatron, nnScaler 等)と相補的。並列化とメモリ最適化を同時探索する Mist(Zhu et al., 2025)とは対照的に、並列化戦略を固定する代わりに operator レベルの精密なメモリモデリング(intra-op transient tensors・unhookable operators の捕捉)で深く最適化する narrow but deep なアプローチを取る。