[Transformer FLOPs](https://www.adamcasson.com/posts/transformer-flops)
Transformerモデルの計算量(FLOPs)を推定し、計算資源の要件や効率を測定するための手法のまとめ。
## 1. Transformer 学習における FLOPs 計算法
### OpenAI Scaling Law 方式 (FLOPs per token)
トークンあたりの計算量を非埋め込みパラメータ数 $N$ から簡便に推定する手法。
- **全体の学習 FLOPs**: トークンあたり $C_{forward+backward} \approx 6N$
- **学習総 FLOPs**: $C = 6DN$ ($D$ は総トークン数)
- **Forward Pass FLOPs ($C_{forward}$)**:
- $C_{forward} \approx 2N$ と簡略化される。
- これは行列演算(matmuls)が積和演算(2 FLOPs)で構成されることに由来。
- コンテキスト長に依存する項($2 \cdot n_{layer} \cdot n_{ctx} \cdot d_{attn}$)は、$d_{model} > n_{ctx} / 12$ の場合に無視可能。
- **Backward Pass FLOPs**: Forward Pass の約2倍($4N$)と見積もる。
### DeepMind Chinchilla Scaling Law 方式 (FLOPs per sequence)
埋め込み、ロジット、ソフトマックス、アテンションパターンの適用など、より詳細なコンポーネントごとの計算量を算出する手法。
- **Forward Pass FLOPs per Sequence (各要素の合計)**:
- **Embeddings**: $2 \cdot n_{ctx} \cdot n_{vocab} \cdot d_{model}$
- **Attention: QKV**: $2 \cdot n_{ctx} \cdot 3 \cdot d_{model} \cdot (d_{key} \cdot n_{heads})$
- **Attention: QK logits**: $2 \cdot n_{ctx} \cdot n_{ctx} \cdot (d_{key} \cdot n_{heads})$
- **Attention: Softmax**: $3 \cdot n_{heads} \cdot n_{ctx} \cdot n_{ctx}$
- **Attention: Reduction**: $2 \cdot n_{ctx} \cdot n_{ctx} \cdot (d_{key} \cdot n_{heads})$
- **Attention: Project**: $2 \cdot n_{ctx} \cdot (d_{key} \cdot n_{heads}) \cdot d_{model}$
- **Feedforward**: $4 \cdot n_{ctx} \cdot (d_{model} \cdot d_{ff})$
- **Logits**: $2 \cdot n_{ctx} \cdot d_{model} \cdot n_{vocab}$
- **合計 Forward Pass**: `Embeddings + n_layers * (Total Attention + Feedforward) + Logits`
- **Backward Pass**: Forward Pass の2倍と仮定。
## 2. 効率の測定指標
- **FLOPS (Floating Point Operations Per Second)**: 演算速度の単位。
- **Hardware FLOPS Utilization (HFU)**: 実際に実行された全 FLOPs(アクティベーション・チェックポインティングによる再計算等を含む)と理論上のピーク FLOPS の比率。冗長な計算が含まれるため誤解を招く可能性がある。
- **Model FLOPS Utilization (MFU)**: 推奨される指標。モデルの実行に必要な最小限の FLOPs に焦点を当てる。
- **式**: $MFU = (C \cdot D) / P$
- $C$: トークンあたりの FLOPs(OpenAI方式の $6N$ など)
- $D$: 観測されたスループット(tokens/sec)
- $P$: ハードウェアの理論ピーク FLOPS
- 言語モデルにおける MFU は通常 10〜65% の範囲。
## 3. 計算量のスケーリング特性
- モデルが大規模化するにつれ、Embeddings と Logits の計算量は無視可能になる。
- アテンションの QKV 生成と Feedforward の行列演算が支配的になる。
- **線形 vs 二次スケーリング**:
- **コンテキスト長 ($n_{ctx}$) に対して二次(Quadratic)**: QK logits, Softmax, Reduction。
- **コンテキスト長 ($n_{ctx}$) に対して線形(Linear)**: Embeddings, QKV, Project, Feedforward, Logits。
- 小規模モデルでは二次項が 30% 以上を占めることもあるが、パラメータ数が増えると比率は低下する。
- 175B パラメータ級の巨大モデルでは、コンテキスト長が非常に長い場合(8192超など)にのみ二次項が顕著(10%超)になる。
## 4. Vision Transformers (ViT) への適用
DeepMind の手法を画像分類タスクの ViT に拡張可能。
- **Embeddings**: パッチ分割に伴う計算を考慮($2 \cdot n_{patches} \cdot d_{patch}^2 \cdot n_{channels} \cdot d_{model}$)。
- **$n_{ctx}$**: `[CLS]` トークンを含めて $n_{patches} + 1$ とカウント。
- **Logits**: 分類ヘッドの計算($2 \cdot d_{model} \cdot n_{classes}$)。
## 関連
- [[Transformer]]
- [[LLMのScaling Laws]]