# FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Navigation: [[index]] | [[hot]]
> [!abstract] 概要
> Transformer はシーケンス長に対して時間・メモリ計算量が 2 乗のため、長いシーケンスでは低速かつメモリ消費が大きい。近似アテンション手法はモデル品質を犠牲にして計算量を削減しようとしてきたが、実測速度の改善を達成できないことが多い。本論文は、不足していた原理がアテンションアルゴリズムを IO 非対応(IO-aware)にすること——すなわち GPU メモリの各階層間における読み書きを厳密に計上すること——であると主張する。提案手法 FlashAttention は、タイリングを用いることで GPU HBM(高帯域幅メモリ)とオンチップ SRAM 間のメモリ読み書き回数を削減する、IO 対応の厳密なアテンションアルゴリズムである。FlashAttention の IO 複雑度を解析し、標準アテンションより HBM アクセス数が少なく、あらゆる SRAM サイズに対して最適であることを示す。さらに FlashAttention をブロックスパースアテンションへ拡張し、既存のあらゆる近似アテンション手法より高速な近似アテンションアルゴリズムを実現する。FlashAttention は既存ベースラインより Transformer の訓練を高速化する:MLPerf 1.1 の訓練速度記録と比べて BERT-large(系列長 512)でエンドツーエンドの実測速度 15% 向上、GPT-2(系列長 1K)で 3× の高速化、ロングレンジアリーナ(系列長 1K–4K)で 2.4× の高速化。FlashAttention とブロックスパース FlashAttention は Transformer に長いコンテキストを可能にし、より高品質なモデル(GPT-2 でパープレキシティ 0.7 改善、長文書分類で 6.4 ポイント向上)と、全く新しい能力——Path-X チャレンジ(系列長 16K、61.4% 精度)と Path-256(系列長 64K、63.1% 精度)でランダム以上の性能を達成した史上初の Transformer——を実現する。
## 論文情報
| 項目 | 内容 |
|---|---|
| タイトル | FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness |
| 著者 | Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré |
| 所属 | Stanford University(Dao・Fu・Ermon・Ré)、University at Buffalo SUNY(Rudra) |
| 会場 | arXiv プレプリント(後に NeurIPS 2022 採録) |
| 投稿日 | 2022-05-27(v2: 2022-06-23) |
| arXiv ID | 2205.14135 |
| コード | https://github.com/HazyResearch/flash-attention |
## 概要
現代 GPU では計算速度がメモリ帯域幅を大幅に上回っており、標準アテンションの実行時間は FLOP 数ではなく HBM アクセス数に律速される。FlashAttention は Q・K・V をブロック分割して SRAM 上でタイル処理し、N×N のアテンション行列を HBM に書き出すことなく厳密なアテンション出力を計算する。再計算によって逆伝播時の中間行列保存を省略し、系列長に線形なメモリ使用量と実測速度の大幅な向上を同時に達成した。
## 問題設定
**入力**: Q, K, V ∈ R^{N×d}(N: 系列長、d: ヘッド次元)が HBM に配置されている。
**出力**: O = softmax(QK^T)V ∈ R^{N×d} を HBM に書き込む。
**前提**: GPU の SRAM(A100 では 108 ストリーミングマルチプロセッサあたり 192 KB、帯域幅約 19 TB/s)は HBM(40 GB、帯域幅 1.5 TB/s)より 1 桁以上高速だが容量が極小。標準アテンションは S = QK^T と P = softmax(S) を HBM に素材化(materialize)するため、N が大きいと HBM アクセスが 2 乗のオーダーで増加する。
**課題**: 既存の近似アテンション手法(Reformer, Linformer, Performer 等)は FLOP 削減に注力するが、HBM アクセスのオーバーヘッドを無視しているため、実測速度が改善されないことが多い。
## 提案手法
### タイリング(tiling)
ソフトマックスの safe-softmax 分解を利用し、Q・K・V をブロックに分割して SRAM 上でブロック単位に処理する。数値安定性のため、ローカルな最大値 m(x) とスケール因子 ℓ(x) を保持し、ブロックをまたいで漸進的に統計量を更新することで、N×N のアテンション行列を HBM に書き込まずに最終出力を得る。ブロックサイズは Bc = ⌈4M/d⌉、Br = min(⌈4M/d⌉, d) と設定する(M は SRAM サイズ)。
### 再計算(recomputation)
逆伝播時に通常必要な S, P ∈ R^{N×N} の中間行列を保存しない代わりに、順伝播で出力 O とソフトマックス統計量(m, ℓ)だけを保持し、逆伝播の際に S と P をブロック単位で再計算する。この選択的な勾配チェックポイントは再計算による追加 FLOP を生じさせるが、HBM アクセス削減の効果がそれを上回り、逆伝播も高速化される。
### カーネル融合(kernel fusion)
タイリングにより、行列乗算・ソフトマックス・マスキング・ドロップアウト・行列乗算をすべて単一 CUDA カーネルに融合できる。HBM への中間書き込みを排除し、SRAM 上ですべての演算を完結させる。
### ブロックスパース FlashAttention
事前定義のブロックスパースマスク M ∈ {0,1}^{N/Br × N/Bc} を用いて非ゼロブロックのみを処理する拡張。IO 複雑度は Θ(Nd + N²d²M⁻¹ s)(s: 非ゼロブロック比率)となり、スパース比に比例した直接的な削減が得られる。バタフライスパースパターンを採用し、任意のスパース構造を近似できることが示されている。
### IO 複雑度の解析
| 手法 | HBM アクセス数 |
|---|---|
| 標準アテンション | Θ(Nd + N²) |
| FlashAttention | Θ(N²d²M⁻¹) |
典型値 d = 64、M ≈ 100 KB では d² ≪ M となり、FlashAttention は標準アテンションの最大 9× 少ない HBM アクセスで済む。また、Proposition 3 により、いかなる厳密アテンションアルゴリズムも漸近的に HBM アクセスを改善できないことが証明されており、FlashAttention の最適性が示されている。
## 新規性
- **IO 非対応問題の定式化**: 近似アテンション手法群が FLOP は削減できても実測速度を改善できない原因を「HBM アクセスを無視しているため」と明確に定式化した最初の研究。
- **厳密アテンションのタイリング実現**: Rabe & Staats [66] は N² メモリを削減する手法を示したが、HBM アクセス削減(速度改善)は達成していなかった。FlashAttention は HBM アクセス削減と厳密アテンションの維持を同時に達成した。
- **逆伝播の効率化**: 再計算による選択的チェックポイントで逆伝播の HBM アクセスも同様に削減し、標準実装と同等の速度の逆伝播を実現した。
- **漸近最適性の証明**: HBM アクセス数の下限を証明し、自手法がすべての SRAM サイズに対して最適であることを示した。
## 実験設定
**ハードウェア**: NVIDIA A100 GPU(HBM 40 GB、帯域幅 1.5 TB/s、SRAM 192 KB/SM × 108 SM)
**モデル・タスク**:
- BERT-large(系列長 512): Wikipedia データで訓練、8×A100
- GPT-2 small/medium(系列長 1K): OpenWebText データセットで訓練、8×A100
- Long-Range Arena(LRA)(系列長 1K–4K): ListOps / Text / Retrieval / Image / Pathfinder
- 長文書分類: MIMIC-III、ECtHR(RoBERTa ファインチューニング)
- Path-X(16K)・Path-256(64K): 画像の 2 点間経路有無の分類
**ベースライン**:
- 訓練速度: Nvidia MLPerf 1.1(BERT)、HuggingFace・Megatron-LM(GPT-2)
- アテンション実装: PyTorch 標準、Linformer、Linear Attention、Performer、Reformer、Smyrf、OpenAI Sparse Attention
## 実験結果
### 訓練速度
**BERT-large**: FlashAttention は 17.4 ± 1.4 分、Nvidia MLPerf 1.1 記録は 20.0 ± 1.5 分。**15% 高速**。
**GPT-2(8×A100)**:
| 実装 | パープレキシティ | 訓練時間 | 速度向上 |
|---|---|---|---|
| GPT-2 small - HuggingFace | 18.2 | 9.5 日 | 1.0× |
| GPT-2 small - Megatron-LM | 18.2 | 4.7 日 | 2.0× |
| GPT-2 small - FlashAttention | 18.2 | 2.7 日 | **3.5×** |
| GPT-2 medium - HuggingFace | 14.2 | 21.0 日 | 1.0× |
| GPT-2 medium - Megatron-LM | 14.3 | 11.5 日 | 1.8× |
| GPT-2 medium - FlashAttention | 14.3 | 6.9 日 | **3.0×** |
**Long-Range Arena**: FlashAttention は標準アテンション比 **2.4× 高速**、ブロックスパース FlashAttention は **2.8×** 高速。精度は標準アテンション(平均 59.3)と同等(FlashAttention: 59.8、ブロックスパース: 59.6)。
### アテンション単体のランタイム・メモリ
**ランタイム(順伝播 + 逆伝播、GPT-2 medium, 系列長 1024)**:
| 指標 | 標準アテンション | FlashAttention |
|---|---|---|
| GFLOP | 66.6 | 75.2 |
| HBM 読み書き | 40.3 GB | 4.4 GB |
| 実行時間 | 41.7 ms | 7.3 ms |
GPT-2 アテンション計算単体では **7.6× の高速化**。系列長 128–2K では PyTorch 比 **最大 3× 高速**。
**メモリ使用量**: 系列長に対して**線形**にスケール。系列長 64K でも A100 上で動作。標準アテンション実装と比べ最大 **20× のメモリ効率改善**。Linformer との比較でも **2× 効率的**。
### 長文脈による品質向上
**GPT-2 + 長コンテキスト**:
| 設定 | 系列長 | パープレキシティ | 訓練時間 |
|---|---|---|---|
| Megatron-LM | 1K | 18.2 | 4.7 日(1.0×) |
| FlashAttention | 1K | 18.2 | 2.7 日(1.7×) |
| FlashAttention | 2K | 17.6 | 3.0 日(1.6×) |
| FlashAttention | 4K | 17.5 | 3.6 日(1.3×) |
系列長 4K の FlashAttention は Megatron-LM の系列長 1K より **30% 高速**で、パープレキシティは **0.7 改善**。
**長文書分類(micro F1)**:
| データセット | 系列長 512 | 最良 | 向上 |
|---|---|---|---|
| MIMIC-III | 52.8 | 57.1(16384) | +4.3 点 |
| ECtHR | 72.2 | 80.7(8192) | **+8.5 点** |
**Path-X・Path-256**:
| モデル | Path-X(16K) | Path-256(64K) |
|---|---|---|
| 標準 Transformer | 7%(ランダム) | 7%(ランダム) |
| Linformer・Performer ほか | 7%(ランダム) | 7%(ランダム) |
| FlashAttention | **61.4%** | 7%(OOM) |
| ブロックスパース FlashAttention | 56.0% | **63.1%** |
FlashAttention はランダム以上の性能を Path-X で達成した史上初の Transformer モデルとなった。
## 考察
HBM アクセス数が実行時間の主要な決定因子であることが実証された(図 2 左)。ブロックサイズを変えた実験(図 2 中)では、HBM アクセスが減るほど実行時間が短縮し、ブロックサイズが大きすぎると SRAM に収まらなくなる上限があることも確認された。FLOP が増加(再計算による逆伝播の追加コスト)しても速度が向上するという反直感的な結果は、現代 GPU におけるメモリバウンドの本質を端的に示している。
精度面では FlashAttention は標準アテンションと数値的に等価(exact attention)であり、LRA での精度・GPT-2 のパープレキシティは完全に同等。長文脈化による品質向上は IO 効率化によって初めて可能になった副産物であり、近似アテンション手法では得られなかった成果である。
## 強み / 弱点・課題
### 強み
- 厳密アテンション(近似なし)を保ちつつ速度とメモリ効率を同時改善する。
- IO 複雑度の漸近的下限を証明し、理論的な最適性を担保している。
- 単一 CUDA カーネルへの融合により、ドロップアウト・マスキングを追加コストなく内包できる。
- ブロックスパース拡張により近似アテンションへの応用も可能。
- 再現可能なコードを公開(HazyResearch/flash-attention)。
### 弱点・課題
- **CUDA への依存**: 新しいアテンション亜種ごとに CUDA カーネルを新規実装する必要があり、エンジニアリングコストが高い。PyTorch レベルの高水準言語で記述して IO 対応実装にコンパイルする仕組み(Halide 相当)が今後の課題として挙げられている。
- **GPU アーキテクチャ依存性**: CUDA カーネル実装はアーキテクチャ間で移植できない可能性がある。
- **単一 GPU 最適化**: 複数 GPU 間での最適な IO 解析は未達成であり、GPU 間データ転送を含む解析が今後の研究課題。
- **本論文時点の IO 対応はアテンションのみ**: Transformer の他のレイヤー(全結合層など)への同様の IO 最適化は今後の方向性として示されるのみ。