[Transformer FLOPs](https://www.adamcasson.com/posts/transformer-flops) Transformerモデルの計算量(FLOPs)を推定し、計算資源の要件や効率を測定するための手法のまとめ。 ## 1. Transformer 学習における FLOPs 計算法 ### OpenAI Scaling Law 方式 (FLOPs per token) トークンあたりの計算量を非埋め込みパラメータ数 $N$ から簡便に推定する手法。 - **全体の学習 FLOPs**: トークンあたり $C_{forward+backward} \approx 6N$ - **学習総 FLOPs**: $C = 6DN$ ($D$ は総トークン数) - **Forward Pass FLOPs ($C_{forward}$)**: - $C_{forward} \approx 2N$ と簡略化される。 - これは行列演算(matmuls)が積和演算(2 FLOPs)で構成されることに由来。 - コンテキスト長に依存する項($2 \cdot n_{layer} \cdot n_{ctx} \cdot d_{attn}$)は、$d_{model} > n_{ctx} / 12$ の場合に無視可能。 - **Backward Pass FLOPs**: Forward Pass の約2倍($4N$)と見積もる。 ### DeepMind Chinchilla Scaling Law 方式 (FLOPs per sequence) 埋め込み、ロジット、ソフトマックス、アテンションパターンの適用など、より詳細なコンポーネントごとの計算量を算出する手法。 - **Forward Pass FLOPs per Sequence (各要素の合計)**: - **Embeddings**: $2 \cdot n_{ctx} \cdot n_{vocab} \cdot d_{model}$ - **Attention: QKV**: $2 \cdot n_{ctx} \cdot 3 \cdot d_{model} \cdot (d_{key} \cdot n_{heads})$ - **Attention: QK logits**: $2 \cdot n_{ctx} \cdot n_{ctx} \cdot (d_{key} \cdot n_{heads})$ - **Attention: Softmax**: $3 \cdot n_{heads} \cdot n_{ctx} \cdot n_{ctx}$ - **Attention: Reduction**: $2 \cdot n_{ctx} \cdot n_{ctx} \cdot (d_{key} \cdot n_{heads})$ - **Attention: Project**: $2 \cdot n_{ctx} \cdot (d_{key} \cdot n_{heads}) \cdot d_{model}$ - **Feedforward**: $4 \cdot n_{ctx} \cdot (d_{model} \cdot d_{ff})$ - **Logits**: $2 \cdot n_{ctx} \cdot d_{model} \cdot n_{vocab}$ - **合計 Forward Pass**: `Embeddings + n_layers * (Total Attention + Feedforward) + Logits` - **Backward Pass**: Forward Pass の2倍と仮定。 ## 2. 効率の測定指標 - **FLOPS (Floating Point Operations Per Second)**: 演算速度の単位。 - **Hardware FLOPS Utilization (HFU)**: 実際に実行された全 FLOPs(アクティベーション・チェックポインティングによる再計算等を含む)と理論上のピーク FLOPS の比率。冗長な計算が含まれるため誤解を招く可能性がある。 - **Model FLOPS Utilization (MFU)**: 推奨される指標。モデルの実行に必要な最小限の FLOPs に焦点を当てる。 - **式**: $MFU = (C \cdot D) / P$ - $C$: トークンあたりの FLOPs(OpenAI方式の $6N$ など) - $D$: 観測されたスループット(tokens/sec) - $P$: ハードウェアの理論ピーク FLOPS - 言語モデルにおける MFU は通常 10〜65% の範囲。 ## 3. 計算量のスケーリング特性 - モデルが大規模化するにつれ、Embeddings と Logits の計算量は無視可能になる。 - アテンションの QKV 生成と Feedforward の行列演算が支配的になる。 - **線形 vs 二次スケーリング**: - **コンテキスト長 ($n_{ctx}$) に対して二次(Quadratic)**: QK logits, Softmax, Reduction。 - **コンテキスト長 ($n_{ctx}$) に対して線形(Linear)**: Embeddings, QKV, Project, Feedforward, Logits。 - 小規模モデルでは二次項が 30% 以上を占めることもあるが、パラメータ数が増えると比率は低下する。 - 175B パラメータ級の巨大モデルでは、コンテキスト長が非常に長い場合(8192超など)にのみ二次項が顕著(10%超)になる。 ## 4. Vision Transformers (ViT) への適用 DeepMind の手法を画像分類タスクの ViT に拡張可能。 - **Embeddings**: パッチ分割に伴う計算を考慮($2 \cdot n_{patches} \cdot d_{patch}^2 \cdot n_{channels} \cdot d_{model}$)。 - **$n_{ctx}$**: `[CLS]` トークンを含めて $n_{patches} + 1$ とカウント。 - **Logits**: 分類ヘッドの計算($2 \cdot d_{model} \cdot n_{classes}$)。 ## 関連 - [[Transformer]] - [[LLMのScaling Laws]]