# FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision > [!abstract] 概要 > Transformer アーキテクチャの中核層であるアテンションは、大規模言語モデルと長コンテキストアプリケーションのボトルネックになっている。FlashAttention はメモリ読み書きを最小化することで GPU 上のアテンションを高速化する手法を確立したが、近年のハードウェアの新機能を活用できていない。FlashAttention-2 は H100 GPU 上でわずか 35% の利用率しか達成していない。本論文では Hopper GPU 上のアテンションを高速化する 3 つの主要技術を開発する。(1) テンソルコアと TMA の非同期性を活用し、ワープ特殊化を通じて全体の演算とデータ転送をオーバーラップさせる手法、(2) ブロック単位の行列積とソフトマックスをインターリーブする手法、(3) FP8 低精度のハードウェアサポートを活用するブロック量子化と非コヒーレント処理。本手法 FlashAttention-3 は H100 GPU 上で FP16 において 1.5〜2.0 倍の高速化を達成し(最大 740 TFLOPs/s、利用率 75%)、FP8 では約 1.2 PFLOPs/s に達することを示す。また FP8 FlashAttention-3 がベースライン FP8 アテンションと比べて 2.6 倍低い数値誤差を達成することを検証する。 ## 論文情報 | 項目 | 内容 | |---|---| | 著者 | Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao | | 所属 | Colfax Research, Meta, NVIDIA, Georgia Tech, Princeton University, Together AI | | 発表 | arXiv:2407.08608, July 2024 | | コード | https://github.com/Dao-AILab/flash-attention (MIT ライセンス) | ## 概要 FlashAttention-3(FA3)は、NVIDIA Hopper アーキテクチャの新ハードウェア機能を最大限に活用するために FlashAttention-2 を全面的に再設計したアテンションカーネルである。中心となる着想は「非同期性」と「低精度」という Hopper の二大特徴を直接アルゴリズム設計に織り込むことにある。 FlashAttention-2 は単純化された同期モデルに従っており、TMA や WGMMA が持つ非同期実行の仕組みを積極的に利用していなかった。また H100 上で最適化された行列積(GEMM)カーネルが 80〜90% の利用率を実現する一方、FlashAttention-2 は 35% にとどまっていた。FA3 はこのギャップを埋めるために 3 つの技術革新を導入し、カーネル実装には CUTLASS 3.5 のプリミティブ(WGMMA・TMA 抽象)を用いている。 ## 問題設定 ### FlashAttention-2 の限界 FlashAttention-2 は IO 認識タイリング戦略により中間行列の HBM への書き戻しを省略し、アテンションを単一カーネルに融合させることで高速化を実現した。しかし以下の点で Hopper の能力を活かしきれていない。 - **同期実行モデルへの依存**: Tensor Core の演算終了を待ってから softmax を実行するため、WGMMA の非同期性が無駄になる。 - **TMA の未活用**: HBM〜SMEM 間のデータ転送に専用ハードウェアユニット TMA が使われていない。 - **FP8 対応なし**: Hopper は FP16/BF16 の 2 倍スループットを持つ FP8 テンソルコアを持つが、FlashAttention-2 はこれを対象としていない。 ### Hopper 固有の新機能 H100(Hopper)が提供するハードウェア機能のうち FA3 が利用するもの: - **TMA(Tensor Memory Accelerator)**: HBM〜SMEM 間の非同期 DMA 専用ユニット。データ転送をワープから切り離して発行できる。 - **WGMMA(Warpgroup-wide MMA)**: ワープグループ(128 スレッド)単位で発行する非同期行列積命令。SMEM 上のオペランドを直接読める。前世代(Ampere)の MMA と異なり非同期で完了する。 - **setmaxnreg**: ワープグループ間でレジスタを動的に再配分する命令。TMA 発行専用ワープのレジスタを削減し、演算担当ワープに割り当てる。 - **FP8(e4m3)テンソルコア**: FP16/BF16 比 2 倍のスループットを持つ。ただし FP8 WGMMA のオペランドは k-major レイアウトのみ受け付けるという制約がある。 ## 提案手法 ### 技術 1: ワープ特殊化による Producer-Consumer 非同期 CTA(Cooperate Thread Array、スレッドブロック)内のワープをプロデューサとコンシューマに役割分担し、循環 SMEM バッファを介して連携させる。 - **プロデューサワープグループ**: TMA 命令を発行して Q, K, V ブロックを HBM から SMEM に転送する。レジスタを削減(setmaxnreg)し、転送完了後にバリアコミットで通知する。 - **コンシューマワープグループ**: レジスタを増量して WGMMA を実行する。バリア待機後に SS-GEMM(SMEM ソース)で `S = QKᵀ`、RS-GEMM(レジスタソース)で `O += P̃V` を計算する。 この分離により TMA 転送と GEMM 演算の待機が切り離され、メモリレイテンシを隠蔽できる。 ### 技術 2: ピンポンスケジューリングによる GEMM-softmax インターリーブ H100 SXM5 の FP16 行列積スループットは 989 TFLOPS だが、softmax に必要な指数関数(exp)の専用ユニット(multi-function unit)はわずか 3.9 TFLOPS だ。ヘッド次元 128 の FP16 では指数関数が行列積の 50% ものサイクルを消費し得る。FP8 ではさらに悪化する。 FA3 は 2 つのコンシューマワープグループを用い、ワープグループ 1 が softmax を実行している間にワープグループ 2 が GEMM を実行する「ピンポンスケジューリング」を採用する。bar.sync 命令で GEMM の実行順を強制することで、softmax がテンソルコアの空き時間に実行されるよう誘導する。これにより、たとえばヘッド次元 128 でシーケンス長 8192 の FP16 前向きパスにおいて 570 → 620〜640 TFLOPS の改善が得られる。 **2 段階パイプライン**: アルゴリズム設計として、イテレーション j の第 2 WGMMA(`P̃V`)をイテレーション j+1 の softmax とオーバーラップさせる。これには追加バッファが必要なため、レジスタ消費量が増える。SASS コード解析により、コンパイラが意図通りにオーバーラップを生成していることを確認している。 ### 技術 3: FP8 対応 — レイアウト変換と量子化精度改善 **効率面: レイアウト変換** FP8 WGMMA は k-major レイアウトのみを受け付けるが、アテンション計算の第 2 GEMM(`P̃V`)では V のタイルがヘッド次元方向(mn-major)で連続している必要が通常ある。また第 1 GEMM の FP32 アキュムレータのレイアウトと FP8 オペランド A のレジスタレイアウトが異なる。 FA3 はカーネル内転置を採用し、LDSM/STSM 命令でプロデューサワープグループが V タイルをロードしつつ転置する。また FP32 アキュムレータをバイトパーミュート命令で FP8 オペランド A のレイアウトに変換することで、連続した FP8 WGMMA 呼び出しを可能にする。 **精度面: ブロック量子化と非コヒーレント処理** FP8 (e4m3) は仮数部 3 ビットで精度が低く、LLM に多い外れ値(outlier)によって量子化誤差が増大する。テンソル単位スケーリング(per-tensor scaling)では大きな誤差が生じる。FA3 は 2 つの技術で対処する: 1. **ブロック量子化**: Q, K, V をブロック単位(各 `Bᵣ × d` または `Bₒ × d`)で別々に量子化し、各ブロックに独立したスケールファクタを持たせる。FlashAttention アルゴリズムは元々ブロック単位で演算するため、スケールの適用に追加演算コストは生じない。ロータリーエンベディング等の直前処理への融合も可能。 2. **非コヒーレント処理(incoherent processing)**: Q と K をランダム直交行列 M で乗算してから FP8 に量子化する。M が直交行列であることから `(QM)(KM)ᵀ = QKᵀ` が成り立ちアテンション出力は変化しない。各エントリが元エントリのランダム和になることで外れ値が分散し、量子化誤差が低減される。実装では Chee ら(QuIP)や Tseng ら(QuIP#)に倣い、M を ±1 の対角行列とアダマール行列の積として `O(d log d)` で適用し、ロータリーエンベディングへ融合する。 ## 新規性 FA3 の新規性は FlashAttention-2 との比較で以下の 3 点に集約される。 | 観点 | FlashAttention-2 | FlashAttention-3 | |---|---|---| | 実行モデル | 同期 | ワープ特殊化による非同期 | | softmax 実行 | GEMM の後に直列実行 | GEMM 発行中にインターリーブ | | 低精度 | 未対応 | FP8 + ブロック量子化 + 非コヒーレント処理 | | GPU 命令 | Ampere 命令も混在 | Hopper 専用(TMA・WGMMA・setmaxnreg) | | H100 利用率 | 35% | 75%(FP16) / 理論値 FP8 ≈ 76% | ThunderKittens や cuDNN 9 も Hopper 専用命令を活用しているが、FA3 は非同期 GEMM-softmax パイプラインとブロック量子化による FP8 精度改善を組み合わせた点で独自性を持つ。 ## 実験設定 - **測定対象**: H100 80GB SXM5 GPU(クロック固定 1830MHz)上でシーケンス長 512〜16k を変化させてベンチマーク - **比較手法**: 標準アテンション(PyTorch)、FlashAttention-2、FlashAttention-2(Triton)、cuDNN アテンション(クローズドソース) - **FP16 設定**: バッチサイズ × シーケンス長 = 16k トークン固定。ヘッド次元 64、128、256。コーザルマスクあり/なし - **FP8 設定**: ヘッド次元 256 を中心に計測。シーケンス長 ≥ 4k では SM 数(132)の倍数に揃えてウェーブ量子化を回避 - **ライブラリバージョン**: CUDA 12.3、cuDNN 9.1.1.17、CUTLASS 3.5、FlashAttention 2.5.8、Triton nightly 3.0.0、PyTorch 2.3.0 - **数値誤差実験**: FP64 参照実装との RMSE を比較。LLM の outlier を模擬するため各エントリを `N(0,1) + N(0,100) × Bernoulli(0.001)` で生成 ## 実験結果 ### FP16 前向きパス速度 - FlashAttention-2 比 **1.5〜2.0 倍**の高速化 - 最大 **740 TFLOPs/s**(理論ピークの 75%、ヘッド次元 256、シーケンス長 8k) - シーケンス長 1k 以上では cuDNN(クローズドソース)を上回る ### FP16 後向きパス速度 - FlashAttention-2 比 **1.5〜1.75 倍**の高速化 - 標準アテンションと比べると **3〜16 倍**高速 ### FP8 速度 - 約 **1.2 PFLOPs/s** に達する(ヘッド次元 256、シーケンス長 8k) - ヘッド次元 64 ではシーケンス長全域で cuDNN を上回り、ヘッド次元 128 ではシーケンス長の拡大とともに cuDNN に並ぶ ### アブレーション研究 | 設定 | 時間 | TFLOPs/s | |---|---|---| | FlashAttention-3 (完全版) | 3.538 ms | 661 | | GEMM-softmax パイプラインなし、ワープ特殊化あり | 4.021 ms | 582 | | GEMM-softmax パイプラインあり、ワープ特殊化なし | 4.105 ms | 570 | 両技術が独立して大きな寄与を持ち、組み合わせで相乗効果が生じる。 ### 数値誤差(FP64 比 RMSE) | 手法 | RMSE | |---|---| | 標準アテンション FP16(ベースライン) | 3.2e-4 | | FlashAttention-2 FP16 | 1.9e-4 | | FlashAttention-3 FP16 | 1.9e-4 | | 標準アテンション FP8(テンソル単位量子化) | 2.4e-2 | | FlashAttention-3 FP8(ブロック量子化 + 非コヒーレント) | **9.1e-3** | | ブロック量子化なし | 9.3e-3 | | 非コヒーレント処理なし | 2.4e-2 | FP16 では FlashAttention-2/3 ともに中間結果を FP32 で保持するため標準実装より 1.7 倍 RMSE が小さい。FP8 では 2 技術の組み合わせで **2.6 倍**の精度改善を達成する(非コヒーレント処理の寄与が特に大きい)。 ## 考察 FlashAttention-3 の成果はハードウェア専用命令とアルゴリズムの協調設計(co-design)が持つ重要性を示している。特に以下の洞察が注目に値する。 - **指数関数の実行コストは見落とされがち**: FP8 ではテンソルコアスループットが 2 倍になるが exp のスループットは変わらない。その結果 softmax の相対コストが増大し、GEMM-softmax インターリーブの価値はさらに高まる。 - **ブロック量子化は既存の非同期ブロック単位処理と自然に統合される**: FlashAttention の基本設計(ブロック単位ループ)が FP8 のブロック量子化と相性が良く、スケールの適用に余分な演算が不要だ。 - **非コヒーレント処理の効果が支配的**: アブレーション実験から、ブロック量子化単独より非コヒーレント処理単独の方が大きな精度改善をもたらすことがわかる。 - **Hopper 以外への拡張可能性**: 本論文の手法は非同期実行と低精度のサポートを持つ任意の GPU アーキテクチャに適用可能であり、Blackwell(FP4 対応)での効果も期待される。 ## 強み / 弱点・課題 ### 強み - H100 上で既存の最良オープンソース実装(FlashAttention-2)を大幅に上回り、クローズドソース実装(cuDNN)にも競争力を持つ - FP8 における精度と速度のトレードオフを明示的に制御できる設計 - CUTLASS の高レベル抽象に基づき、コードベースの拡張性が高い - FP16 の数値精度は FlashAttention-2 と同等を維持している - Ring Attention など分散アテンション手法のプリミティブとして上位の高速化に寄与する ### 弱点・課題 - **LLM 推論最適化が未完成**: FP8 カーネルに persistent kernel 設計が組み込まれておらず、小シーケンス長やコーザルマスクありの FP8 では cuDNN に劣る場合がある - **後向きパスの FP8 対応**: 論文では前向きパスの FP8 アルゴリズムを主に記述しており、後向きパスへの FP8 適用は将来課題とされている - **レジスタ圧迫と 3 段階パイプライン**: 3 段階パイプラインは理論上さらなる高速化をもたらすが、レジスタ消費増によりタイルサイズを縮小せざるを得ず、実測では 2 段階に劣る結果となった - **大規模学習での低精度影響の未解明**: FP8 で大規模モデルを学習する際の精度影響についての体系的な調査は今後の課題 ## 関連 - [[@2022__arXiv__FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness]] — IO 認識タイリング戦略の原点 - [[@2023__arXiv__FlashAttention-2 - Faster Attention with Better Parallelism and Work Partitioning]] — シーケンス次元並列化による先行改良版