> [!abstract] 概要
> 大規模トランスフォーマーモデルの訓練は現代 AI の最重要計算課題の一つである。本論文では、アクティベーション再計算を削減することで大規模トランスフォーマーモデルの訓練を大幅に高速化する方法を示す。アクティベーション再計算はメモリ容量制約への対処として一般的に用いられるが、アクティベーションをメモリに保持する代わりに逆伝播時に再計算する従来手法はメモリは節約できるものの冗長な計算を引き起こす。本研究では、この冗長計算の大部分は不要であること——十分にメモリ消費を削減できるため——を示す。シーケンス並列化(sequence parallelism)と選択的アクティベーション再計算(selective activation recomputation)という 2 つの新規かつ非常に単純な手法を提案する。テンソル並列化との組み合わせにより、これらの手法はアクティベーション再計算の必要性をほぼ排除する。最大 1 兆パラメータの言語モデルで評価を行い、アクティベーションメモリを 5 倍削減しながらアクティベーション再計算による実行時間オーバーヘッドを 90% 以上削減できることを示す。例として、530B パラメータ GPT-3 スタイルモデルを 2,240 NVIDIA A100 GPU で訓練した際、MFU 54.2% を達成し、完全再計算の MFU 42.1% に比べて 29% 高速化した。
## 論文情報
| 項目 | 内容 |
|---|---|
| タイトル | Reducing Activation Recomputation in Large Transformer Models |
| 著者 | Vijay Korthikanti, Jared Casper, Sangkug Lym, Lawrence McAfee, Michael Andersch, Mohammad Shoeybi, Bryan Catanzaro |
| 所属 | NVIDIA |
| 会議 | MLSys 2023(第 6 回 MLSys Conference、マイアミビーチ) |
| URL | https://proceedings.mlsys.org/paper_files/paper/2023/file/80083951326cf5b35e5100260d64ed81-Paper-mlsys2023.pdf |
| 関連フレームワーク | [[Megatron-LM]]([[NVIDIA]]) |
## 概要
大規模トランスフォーマーの訓練では、GPU メモリに収まりきらないアクティベーションを逆伝播時に再計算する「完全アクティベーション再計算(full activation recomputation)」が広く用いられてきた。しかしこの手法は 30〜40% の計算時間オーバーヘッドを発生させる。
本論文は 2 つの手法を組み合わせてこの問題に対処する。
1. **シーケンス並列化(sequence parallelism)**: テンソル並列化の適用外領域(LayerNorm・Dropout)をシーケンス次元に沿って分割し、冗長複製を排除する。
2. **選択的アクティベーション再計算(selective activation recomputation)**: 計算コストは低いがメモリ消費が大きい演算(アテンション内部の QKᵀ 行列積・Softmax・Dropout 等)の出力を保存せず再計算し、それ以外の演算の出力は保持する。
両手法の組み合わせにより、アクティベーションメモリを完全保持比で 5 倍削減しながら、完全再計算で達成できるメモリ削減効果の 90% 以上を、わずか約 2% の追加メモリ使用のみで実現する。
## 問題設定
### テンソル並列化の限界
[[Megatron-LM]] 式のテンソル並列化はアテンション/MLP ブロック内部のアクティベーションを分割するが、各ブロックへの入力(Q/K/V 計算への入力や h→4h 線形層への入力)は分割されず、テンソル並列グループ全体で複製される。LayerNorm と Dropout についても同様に複製が発生する。
数値的に整理すると、テンソル並列度 t のとき 1 トランスフォーマー層あたりのアクティベーションメモリは:
$\text{テンソル並列(Eq.2)} = sbh\left(10 + \frac{24}{t} + 5\frac{as}{ht}\right) \text{ [バイト]}$
ここで s = シーケンス長、b = マイクロバッチサイズ、h = 隠れ次元、a = アテンションヘッド数。
非テンソル並列化領域に相当する 10sbh の項は t で割れていない——これが冗長複製の根本問題である。
### パイプライン並列化の追加制約
パイプライン並列化(並列度 p)は 1F1B スケジュールにより第 1 ステージが p マイクロバッチ分のアクティベーションを保持しなければならない。結果として実効的なアクティベーション保持層数は常に L(全層数)に相当し、パイプライン並列化はアクティベーションメモリ削減に直接貢献しない。
## 提案手法
### シーケンス並列化
テンソル並列化が適用されない領域(LayerNorm・Dropout)は、シーケンス次元に沿って独立に処理できる。これを利用し、テンソル並列化の前後に新しい通信演算 g と ḡ を挿入する。
- **g**(順伝播: All-Gather、逆伝播: Reduce-Scatter)
- **ḡ**(順伝播: Reduce-Scatter、逆伝播: All-Gather)
テンソル並列化の既存の all-reduce は reduce-scatter + all-gather の 2 段階であるため、追加の通信帯域使用量はゼロである(総通信量が等価)。
結果として、テンソル並列化 + シーケンス並列化を組み合わせると 1 層あたりのアクティベーションメモリは:
$\text{テンソル+シーケンス並列(Eq.4)} = \frac{sbh}{t}\left(34 + 5\frac{as}{h}\right)$
これは並列化なしの Eq.1 をテンソル並列度 t で割ったものと等しい。アクティベーションがテンソル並列グループ全体で均等分散される。
逆伝播時に All-Gather が 1 回追加で必要になるが、後続の勾配計算と通信をオーバーラップさせることでレイテンシを隠蔽する。
### 選択的アクティベーション再計算
Eq.4 のうち、`5as/h` の項はアテンションスコアに関する演算(アテンション幅が線形層 Q/K/V で拡大した後の QKᵀ 行列積・Softmax・Softmax Dropout・V 上のアテンション)に対応する。これらは以下の特性を持つ:
- **大きなメモリ占有**: 大規模モデルでは `5as/h` は 34(残りの項)を上回る(GPT-3: 80、MT-NLG: 64)。
- **低い再計算コスト**: FLOPs/入力要素比が小さい。
したがって、これらの演算結果を保存せずに再計算するのが最適である。GPT-3 の場合:
- アクティベーション削減率: 70%
- 再計算 FLOPs オーバーヘッド: 2.7%
MT-NLG の場合:
- アクティベーション削減率: 65%
- 再計算 FLOPs オーバーヘッド: 1.6%
選択的再計算使用時のメモリ式:
$\text{選択的再計算(Eq.6)} = \frac{34 \cdot sbhL}{t}$
これはシーケンス長に対して線形にスケールし、アテンションヘッド数に非依存となる。
### 実装上の工夫
- 逆伝播時の追加 All-Gather は後続の計算とオーバーラップさせてレイテンシを隠蔽する。
- パイプライン並列化との組み合わせ時には、利用可能なデバイスメモリに応じてできるだけ多くのマイクロバッチのアクティベーションを完全保持し、残りのみを選択的再計算する追加の最適化も適用可能(Appendix C 詳述)。
## 新規性
1. **シーケンス並列化の再定式化**: Li+ 2021a の先行シーケンス並列化は全デバイスにパラメータとオプティマイザ状態を複製する必要があり大規模モデルには不適であった。本研究は [[Megatron-LM]] のテンソル並列化と相補的に組み合わせることで、追加の計算・通信・メモリオーバーヘッドなしにアクティベーション分散を実現した。
2. **選択的再計算の分析的根拠**: 「どのアクティベーションを選択的に再計算すべきか」をアドホックに決めるのではなく、FLOPs/要素比と占有メモリの分析式から体系的に導出した。
3. **アクティベーションの均等分散保証**: 提案手法はアクティベーション(とパラメータ)がデバイス間で均等に分散されるという意味でメモリ最適な並列化戦略であることを解析的に示した。
## 実験設定
- **クラスタ**: Selene スーパーコンピュータ(A100 80GB NVLink/NVSwitch × 8/ノード、200Gbps InfiniBand × 8 HCA/ノード)
- **精度**: 混合精度(FP16/BF16)
- **モデル構成**:
| モデル | アテンションヘッド | 隠れ次元 | 層数 | TP | PP | GPU 数 |
|---|---|---|---|---|---|---|
| 22B | 64 | 6144 | 48 | 8 | 1 | 8 |
| 175B (GPT-3) | 96 | 12288 | 96 | 8 | 8 | 64 |
| 530B (MT-NLG) | 128 | 20480 | 105 | 8 | 35 | 280 |
| 1T | 160 | 25600 | 128 | 8 | 64 | 512 |
- s = 2048、v = 51200 に固定。データ並列化なし(提案手法はデータ並列化と独立)。
- インターリーブドスケジュール(m=3 インターリーブステージ)を 175B・530B に適用。
## 実験結果
### メモリ使用量
提案手法(テンソル+シーケンス並列化 + 選択的再計算)は、テンソル並列化ベースライン比で 22B モデルで約 20%、大規模モデルでも同程度のメモリに削減する(完全再計算の約 2 倍)。
シーケンス並列化単独・選択的再計算単独それぞれがほぼ同等の約 50% 削減を達成し、組み合わせで 5 倍削減を実現する。
### 1 層あたり実行時間(22B モデル、1 層のみ)
| 手法 | 順伝播(ms) | 逆伝播(ms) | 合計(ms) | オーバーヘッド |
|---|---|---|---|---|
| ベースライン(再計算なし) | 7.7 | 11.9 | 19.6 | — |
| シーケンス並列化 | 7.2 | 11.8 | 19.0 | −3% |
| 完全再計算 | 7.7 | 19.5 | 27.2 | +39% |
| 選択的再計算 | 7.7 | 13.2 | 20.9 | +7% |
| シーケンス + 選択的再計算 | 7.2 | 13.1 | 20.3 | +4% |
完全再計算の 39% オーバーヘッドに対し、提案手法は 4% のみ。シーケンス並列化によりレイヤーノームと Dropout が 1/t のデータで実行されるため順伝播が 6% 高速化する。
### エンドツーエンドイテレーション時間
| モデル | 完全再計算(秒) | 提案手法(秒) | スループット向上 | MFU | HFU |
|---|---|---|---|---|---|
| 22B | 1.42 | 1.10 | +29.0% | 41.5% | 42.2% |
| 175B | 18.13 | 13.75 | +31.8% | 51.4% | 51.8% |
| 530B | 49.05 | 37.83 | +29.7% | 56.0% | 56.4% |
| 1T | 94.42 | 71.49 | +32.1% | 56.3% | 56.5% |
すべての構成で 29〜32% のスループット向上。モデル規模が増すにつれてオーバーヘッド削減効果が拡大し、530B・1T では再計算オーバーヘッドがわずか 2% にとどまる。
530B を 8-way データ並列化(2,240 GPU)に拡張した場合: MFU は 56.0% → 54.2%(データ並列化の勾配 AllReduce オーバーヘッドはわずか)。
## 考察
### シーケンス並列化の意義
テンソル並列化のみでは分割できない領域(LayerNorm・Dropout)が実は非常に小さな通信オーバーヘッドで分散できる。reduce-scatter と all-gather に既存の all-reduce を分解するだけであり、総通信量は等価。この単純な洞察がアクティベーション分散の「穴」を埋める。
### 選択的再計算の一般性
選択する演算の判断基準は「FLOPs/要素比が低く、メモリ占有が大きい」という解析的条件であり、モデルサイズやシーケンス長の変化に対して定量的に適用可能。シーケンス長が増大するにつれて as/h の項が支配的になるため、長コンテキスト訓練では選択的再計算の効果がさらに高まる。
### MFU の定義と hardware FLOPs utilization との関係
HFU/MFU の比 ≈ 1 + s/(18h) が選択的再計算の FLOPs オーバーヘッド上界を与える。530B では s=2048、h=20480 なので比 ≈ 1.0056、つまり実測 MFU 56.0% に対し HFU 56.4% はほぼ等価であり、選択的再計算の再計算コストがいかに軽微かを示す。
## 強み
- **実装容量が極めて小さい**: テンソル並列化の既存演算を reduce-scatter/all-gather の 2 ステップに分解するだけで、追加の通信帯域が発生しない。
- **既存フレームワークとの直交性**: データ並列化(ZeRO 系)・パイプライン並列化と独立に組み合わせ可能。
- **解析的保証**: アクティベーション均等分散がメモリ最適であることを解析的に示している。
- **産業規模での実証**: 530B の大規模モデルで 30% 近い速度向上を実測。
## 弱点・課題
- **シーケンス長増大時の選択的再計算コスト**: as/h の比率が増すと選択的再計算の FLOPs 割合も増加する。論文は「実用的な s/h 比の範囲では問題ない」と主張するが、長コンテキスト(s/h が大きい場合)での実証は限定的。
- **テンソル並列化前提**: シーケンス並列化は [[Megatron-LM]] 式のテンソル並列化と密結合しており、テンソル並列化を使わない ZeRO/FSDP 系の構成への直接適用は別設計が必要。
- **データ並列化非評価**: エンドツーエンド評価でデータ並列化を使わない構成のみ評価しており、大規模本番訓練での総合的な MFU への影響は別途測定が必要。
- **自動化への拡張未解決**: 「どのアクティベーションを選択的に再計算すべきか」を任意アーキテクチャに対して自動探索する手法(Feng & Huang 2021 等)との組み合わせは今後の課題として残されている。