# FlashAttention ## 定義 FlashAttention は、GPU のメモリ階層を明示的に考慮した IO 認識型(IO-aware)の厳密アテンションアルゴリズム群である。標準的なアテンション実装が $N \times N$ のスコア行列を HBM(高帯域メモリ)に実体化するのに対し、FlashAttention はタイリングとオンライン softmax を組み合わせて中間行列の HBM 読み書きを排除し、すべてのアテンション演算を単一の GPU カーネルに融合する。近似ではなく厳密なアテンション出力を維持しながら、メモリ使用量を $O(N^2)$ から $O(N)$ へ削減し、2-4 倍の実時間高速化を実現する。(Source: [[@2022__arXiv__FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness]]) 核となる原理は 2 つある。(1) **タイリング**: 入力をブロックに分割し、各ブロックを高速な SRAM に読み込んでアテンションを計算する。オンライン softmax により、行全体へのアクセスなしに softmax の逐次的更新が可能になる。(2) **再計算**: 順伝播で中間アテンション行列を保存せず、逆伝播時にオンチップで再計算する。FLOP は増加するが、HBM アクセスの大幅削減により正味の高速化を達成する。 ## 世代別進化 FlashAttention は 4 世代にわたって、GPU アーキテクチャの進化に合わせたアルゴリズム・カーネル協調設計を展開している。 ### FA1(2022、A100) IO 認識型アテンションの原型。タイリング+オンライン softmax+再計算の 3 技法で HBM アクセスを $O(N^2 d^2 M^{-1})$ に削減(標準アテンションは $\Omega(Nd + N^2)$)。A100 で BERT-large 15% 高速化、GPT-2 で 3 倍高速化。Path-X(16K)と Path-256(64K)で初のランダム超え精度を達成した。(Source: [[@2022__arXiv__FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness]]) ### FA2(2023、A100) FA1 の GPU 利用率(25-40%)を改善。(1) 非行列積 FLOP の削減、(2) シーケンス長次元への並列化によるオキュパンシー向上、(3) ワープ間の仕事分配最適化による共有メモリ通信の削減。A100 で理論最大 FLOP/秒の 50-73% を達成し、FA1 比 2 倍高速化、GPT 訓練で 225 TFLOP/秒(72% MFU)に到達した。(Source: [[@2023__arXiv__FlashAttention-2 - Faster Attention with Better Parallelism and Work Partitioning]]) ### FA3(2024、H100 Hopper) FA2 は H100 で 35% 利用率にとどまった。Hopper 固有の機能を活用する 3 技法: (1) TMA と非同期テンソルコアを用いたワープ特化による計算・データ移動のオーバーラップ、(2) ブロック単位の GEMM と softmax のインターリーブ、(3) FP8 ブロック量子化とインコヒーレント処理。FP16 で 740 TFLOP/秒(75% 利用率)、FP8 で約 1.2 PFLOP/秒を達成。FP8 の数値誤差はベースライン FP8 比 2.6 倍低い。(Source: [[@2024__arXiv__FlashAttention-3 - Fast and Accurate Attention with Asynchrony and Low-precision]]) ### FA4(2026、B200 Blackwell) Blackwell の非対称ハードウェアスケーリング(テンソルコアスループット 2 倍、共有メモリ帯域・指数関数ユニットは据え置き)に対応。(1) 完全非同期 MMA と拡大タイルサイズによるパイプライン再設計、(2) 多項式近似によるソフトウェアエミュレート指数関数と条件付き softmax リスケーリング、(3) テンソルメモリ(TMEM)と 2-CTA MMA モードによる共有メモリトラフィック削減と逆伝播のアトミック加算半減。BF16 で cuDNN 9.13 比 1.3 倍、Triton 比 2.7 倍高速化、1613 TFLOP/秒(71% 利用率)。CuTe-DSL(Python)による実装でコンパイル時間を従来の C++ テンプレート比 20-30 倍短縮。(Source: [[@2026__arXiv__FlashAttention-4 - Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling]]) ### 性能推移 | 世代 | GPU | FP16/BF16 TFLOP/秒 | 利用率 | 対前世代比 | |------|-----|---------------------|--------|-----------| | FA1 | A100 | — | 25-40% | — | | FA2 | A100 | 225 (訓練) | 50-73% | 2× | | FA3 | H100 | 740 | 75% | 1.5-2× | | FA4 | B200 | 1613 | 71% | 1.3× (vs cuDNN) | ## 横断的知見 - **ボトルネックは世代ごとに移動する**: FA1 は HBM 帯域がボトルネック、FA2 はスレッドブロック間の仕事分配、FA3 は GEMM と非 GEMM 演算の非同期実行、FA4 は共有メモリ帯域と指数関数ユニットである。テンソルコアスループットが世代ごとに倍増する一方で他のハードウェアユニットの改善は緩やかなため、最適化対象が「演算の隠蔽」から「非演算リソースの回避」へ体系的に遷移している。(Source: [[@2022__arXiv__FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness]], [[@2026__arXiv__FlashAttention-4 - Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling]]) - **利用率は一定の天井に収束する**: FA2 50-73%、FA3 75%、FA4 71% と、各世代が当該 GPU の理論最大 FLOP/秒の 70-75% 付近に落ち着く。GEMM カーネルの 80-90% には到達しない。これはアテンションが softmax というメモリバウンドな非線形演算を含むためであり、GEMM に帰着できない構造的制約が存在する。(Source: [[@2023__arXiv__FlashAttention-2 - Faster Attention with Better Parallelism and Work Partitioning]], [[@2024__arXiv__FlashAttention-3 - Fast and Accurate Attention with Asynchrony and Low-precision]], [[@2026__arXiv__FlashAttention-4 - Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling]]) - **[[カーネルフュージョン]]の成功事例として最も影響力が大きい**: FlashAttention は GPU 最適化文献が論じるカーネルフュージョンの原理(中間結果の HBM 書き出し排除、データのオンチップ保持)を、Transformer アテンションという具体的かつ広く使われる演算に適用した実例である。MLPerf 訓練ベンチマークの大半が FlashAttention を採用しており(FA1 時点で MLPerf 1.1 の FMHA 実装が土台)、理論的最適性と産業的採用の両方を達成した稀有なケースである。(Source: [[@2022__arXiv__FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness]], [[@2023__CSUR__Optimization Techniques for GPU Programming]]) - **FP8 精度のアテンションにおける数値安定性は解決可能**: FA3 はブロック量子化とインコヒーレント処理により、ベースライン FP8 アテンション比で 2.6 倍低い数値誤差を達成した。これは [[テンソルコア]] の HPC 応用(蓄積精度が安定性を決める)と同じ原理——FP8 入力 + FP32 蓄積——の検証であり、低精度テンソルコアのアテンション応用が実用段階にあることを示す。(Source: [[@2024__arXiv__FlashAttention-3 - Fast and Accurate Attention with Asynchrony and Low-precision]], [[@2018__SC__Harnessing GPU Tensor Cores for Fast FP16 Arithmetic to Speed up Mixed-Precision Iterative Refinement Solvers]]) ## 未解決の問い - FA4 の CuTe-DSL 実装は C++ テンプレートと同等の性能を達成しているが、次世代 GPU(B300 以降)で指数関数ユニットのスループットが倍増した場合、ボトルネックはどこへ移るか - FlashAttention は厳密アテンションだが、Linear Attention・Mamba 等のサブ二次アーキテクチャとの性能交差点は文脈長何トークンか - GQA/MQA(グループ化クエリアテンション)における FlashAttention の最適化は FA2 以降に進んでいるが、MoE モデルのエキスパート単位アテンションとの組み合わせはどう設計すべきか - 決定論的逆伝播モード(FA4 §3.2.4)の性能オーバーヘッドは強化学習の訓練ワークロードでどの程度か ## 関連 - 隣接 concept: [[カーネルフュージョン]] / [[テンソルコア]] / [[GPU最適化]] / [[LLM推論]] / [[アテンションヘッド]] - 関連 entity: [[Tri Dao]] / [[Jay Shah]] / [[Together AI]] - MOC: [[分散深層学習 - MOC]] - ライブラリ: https://github.com/Dao-AILab/flash-attention ## 出典 - [[@2022__arXiv__FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness]] — FA1。Dao ほか。タイリング+再計算による IO 認識型厳密アテンション - [[@2023__arXiv__FlashAttention-2 - Faster Attention with Better Parallelism and Work Partitioning]] — FA2。Dao。シーケンス並列化+ワープ仕事分配最適化 - [[@2024__arXiv__FlashAttention-3 - Fast and Accurate Attention with Asynchrony and Low-precision]] — FA3。Shah ほか。Hopper ワープ特化+FP8 ブロック量子化 - [[@2026__arXiv__FlashAttention-4 - Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling]] — FA4。Zadouri ほか。Blackwell 非対称スケーリング協調設計