# Flashlight: PyTorch Compiler Extensions to Accelerate Attention Variants > [!info] Talk metadata > - **会議:** [[MLSys2026]] Day 5 (May 22 / Fri)、Grand Ballroom 2、Research Track Oral: Efficient Computation(08:15 PDT 開始、第3発表) > - **登壇者:** Bozhi You(UT Austin, 筆頭著者) > - **全著者:** Bozhi You¹, Irene Wang², Zelal Su Mustafaoglu¹, Abhinav Jangda³, Angélica Moreira³, Roshan Dathathri³, Divya Mahajan², Keshav Pingali¹(¹UT Austin、²Georgia Tech、³Microsoft) > - **URL:** https://mlsys.org/virtual/2026/oral/3763 > - **OpenReview:** https://openreview.net/forum?id=lboOMA8XWr > - **関連研究:** https://github.com/bozhiyou/flashlight/tree/mlsys26-ae > - **ACM バッジ:** Artifacts Available v1.1、Artifacts Evaluated — Functional、Results Reproduced v1.1 > [!abstract] 概要(OpenReview) > アテンションは大規模言語モデル(LLM)の基本構成要素であり、その効率的な実装に多くの取り組みがなされてきた。たとえば FlashAttention はタイリングとカーネル融合を活用してアテンションを最適化する。近年、モデルの品質や効率を向上させるためにアテンションの多数の変種が提案されているが、それらの効率的なサポートは依然として困難である。通常、専用カーネルや手動チューニングされた実装が必要となるためである。FlexAttention は最近、静的プログラミングテンプレートを用いてアテンション変種の一部に対し FlashAttention 風のカーネルをサポートすることで、このギャップの一部に対処した。本論文では、PyTorch エコシステム内のコンパイラネイティブなフレームワークである Flashlight を提案する。Flashlight は静的テンプレートや事前定義されたカーネル特殊化に依存せず、任意のアテンションベースのプログラムに対して融合された FlashAttention スタイルのカーネルを自動生成する。Flashlight は PyTorch のコンパイルワークフローを活用し、アテンション計算の融合とタイリングを透過的に行うことで、多様なアテンションパターンの効率的な実行を可能にする。FlexAttention モデルで表現可能なすべての変種をサポートするだけでなく、FlexAttention の能力を超える、より一般的なデータ依存アテンション定式化も扱う。実験の結果、Flashlight は FlexAttention に対して同等以上の性能を持つカーネルを生成し、同時にネイティブ PyTorch コードの柔軟性を提供することで、開発者が性能を犠牲にすることなく新しいアテンションモデルを迅速に探索できることを示す。 ## 背景と動機 - アテンションは LLM の訓練・推論において計算のボトルネックとなる基本演算である。FlashAttention はタイリングとカーネル融合により、メモリ読み書きの削減・データ局所性の向上・GPU 利用率の改善を実現した。 - しかし近年、差分アテンション(Differential Attention)、AlphaFold の Evoformer における行/列方向ゲート付きセルフアテンション、IPA(Invariant Point Attention)、RSA(Rectified Sparse Attention)など多数の変種が提案されており、FlashAttention のような手書きカーネルではこれらをカバーできない。 - FlexAttention と FlashInfer はテンプレートベースの高階 API で一部の変種に対応するが、テンプレートに適合しない変種(差分アテンション、Evoformer など)は対象外である。 - 既存手法の整理: - **FlashAttention:** Vanilla アテンション専用の手書き融合カーネル。 - **FlexAttention:** `score_mod`(要素ごとのスコア変換)と `mask_mod`(`block_mask` によるスパース実行)をテンプレートとして受け取り、TorchInductor 経由で融合 Triton カーネルを生成。ALiBi、Softcap、Causal、Sliding Window 等をサポート。 - **FlashInfer:** CUDA ベースの JIT コード生成エンジン。テンプレート特殊化 CUDA カーネルを生成。 - **`torch.compile`(デフォルト):** PyTorch コードを複数の Triton カーネルにコンパイルするが、matmul と softmax の融合が不可能なため FlashAttention 級の I/O 効率は達成できない。 ## Flashlight の位置づけ - Flashlight は独立コンパイラではなく、**TorchInductor への汎用拡張**である。`torch.compile` のパイプライン(PyTorch Code → TorchDynamo → AOTAutograd → TorchInductor + Flashlight → Triton カーネル → Triton Compiler → NVIDIA PTX カーネル)に組み込まれる。 - ユーザーは通常の PyTorch コードを `torch.compile(fn, dynamic=False, enable_flashlight=True)` で呼ぶだけでよい。テンプレート API も `block_mask` キャッシュも不要である。 - Flashlight は FlexAttention がサポートする全変種に加え、差分アテンション、Evoformer のゲート付きセルフアテンション、IPA、RSA などデータ依存の複雑な変種も自動融合する。 ## コンパイラ設計: 4 つの中核変換 Flashlight は TorchInductor の中間表現(IR)を拡張し、4 つの組み合わせ可能なコンパイラ変換を導入する。 ### 1. 統一リダクション IR(Unified Reduction IR) - 既存の TorchInductor は matmul(`torch.bmm`)を CUBLAS/ATen 呼び出しに特殊化し、後続の softmax 等のリダクションとは別カーネルにする。この分岐が**人為的な融合境界**を生む。 - Flashlight は matmul を汎化リダクションとしてモデル化する。テンソル次元を p-dimensions(並列/ポイントワイズ)と r-dimensions(リダクション)に分類し、GEMM $C_{mn} = \sum_k A_{mk} B_{kn}$ を p-dimensions $(m, n)$、r-dimension $(k)$ の統一 IR で表現する。 - これにより matmul が融合エンジンに完全参加し、matmul の連鎖や matmul + 要素演算の融合が可能になる。CUBLAS の性能を犠牲にする代わりに融合機会を獲得するトレードオフである。 ### 2. 代数的変換(Algebraic Transformation) - 安定 softmax は 2 パスで計算される: (1) max の算出、(2) `exp(x - max)` の総和。第 2 パスは第 1 パスの最終結果 $m_{\text{final}}$ に依存するため、ナイーブな融合は不可能である。 - Flashlight は指数関数の**準同型性**($\exp(a - b) = \exp(a)/\exp(b)$)を利用して、安定 softmax をオンライン softmax に自動変換する。2 パスの依存関係を「ランニング max とランニング sum の逐次更新」に変換し、単一ループで max と総和を同時計算する。 - この変換は環(ring)上の準同型の一般的定義に基づいて定式化されており、softmax 以外のリダクション融合にも適用可能である。 ### 3. 次元降格(Dimension Demotion) - matmul と softmax は計算スケッチ(ループ階層)が異なるため、p-dimensions が一致しない場合に直接融合できない。 - Flashlight は「並列ループは常にリダクションループとして扱える」という洞察に基づき、融合可能なプロデューサの p-dimension をコンシューマの r-dimension に**降格**する。並列性を犠牲にする代わりに、中間テンソルのグローバルメモリへの実体化を完全に排除する。 - 結果として $A = \text{softmax}(QK^\top / \sqrt{d_k})$ に対する単一融合カーネルが得られる。 ### 4. タイリング対応次元除去(Tiling-Aware Dimension Elimination) - タイリングにより、小さい次元(例: ヘッド次元 $d_k$)のタイルサイズを次元全体に設定すると、そのタイルループは反復 1 回で消滅する。これにより 2 つの連続した matmul($QK^\top$ と $\text{softmax}(\cdot)V$)を含む完全アテンションの単一融合カーネルが生成される。 - TorchInductor はタイリング次元と物理 GPU グリッド次元を剛結合する。Flashlight は**論理グリッド次元**を導入し、タイルの多次元グリッドを線形シーケンスにアンロールして単一の物理グリッド次元にマッピングする。カーネル内部で逆アフィン写像により多次元タイル座標を復元する。 ## 実装の詳細 - **精度保持:** GEMM を統一リダクション IR に下降させる際、FP16/BF16 入力の計算型を FP32 に無条件昇格し、ハードウェアテンソルコアの蓄積動作と一致させる。 - **L2 キャッシュ最適化:** 複数の並列タイリング次元を持つカーネルでは、ブロック反復順序をスウィズリング(GROUP_M 幅のストリップ内で次元を交互に走査)し、空間局所性を向上させる。 - **インデックス順序追跡:** SymPy による式簡約で変数順序が失われる問題に対し、Flashlight はインデックス変数の順序を記録し、コード生成フェーズで正しい N 次元テンソル・マスク・ロードを出力する。 - **ブロックリダクションヒューリスティック:** 融合ブロックリダクションカーネル向けに `blockreduction` テンプレートを導入し、`(XBLOCK, RBLOCK, num_warps, num_stages)` の構成空間をテンプレートベースでオートチューニングする。 - **実体化閾値の緩和:** デフォルトの TorchInductor は融合操作数の上限で中間テンソルをグローバルメモリに実体化するが、Flashlight はこの閾値を引き上げ、ALiBi のような複雑な融合サブグラフの早期実体化を回避する。 - **実験環境:** Python 3.12、PyTorch 2.5.0、Triton 3.1.0、CUDA 12.9、FlashInfer 0.2.5。NVIDIA H100 80GB および A100 80GB で評価。SM 周波数は H100 = 1290 MHz、A100 = 1080 MHz に固定(定常周波数)。20 回のウォームアップ後に 20 回実行の平均を報告(標準偏差 1% 以内)。 ## 評価: FlexAttention 対応変種 - 対象変種: Vanilla、ALiBi、Softcap、Causal、Sliding Window、PrefixLM、Document Mask の 7 種。MHA(ヘッド数 16)と GQA(Q ヘッド 16、KV ヘッド 2)の両構成で評価。シーケンス長 512–16k、バッチサイズはトークン総数 16k を維持するよう設定。ヘッド次元 64。 - **H100 上の MHA(図 2):** - `score_mod` 変種(ALiBi、Softcap)で Flashlight は FlexAttention に対し最大 **1.48 倍**高速。FlexAttention のテンプレートカーネルが full/partial/empty ブロック処理の計算・メモリ命令を含むのに対し、Flashlight の融合 Triton カーネルはより単純なため。 - `block_mask` 変種(Causal、Sliding Window、PrefixLM、Document Mask)では、FlexAttention のカーネル実行自体は Flashlight より常に高速(スパースブロックマスクにより冗長計算をスキップするため)。ただし `block_mask` の構築時間を含めると Flashlight が同等以上となるケースが多い。 - Vanilla では FlexAttention がわずかに高速(バッチサイズ 1 の ALiBi MHA で FlexAttention がわずかに優位)。 - **スライド上の要約数値(H100):** - Vanilla: Flashlight は FlexAttention 比 **0.97x**–**1.00x**(ほぼ同等)。 - Softcap: Flashlight は FlexAttention 比 **1.41x**–**1.44x** 高速。 - Causal: Flashlight は FlexAttention(Kernel のみ)比 **1.22x**–**4.51x** 高速(`block_mask` 構築コストを含む場合)。 - **A100 上(図 3):** H100 と同様の傾向。FlashInfer はほぼすべてのバッチサイズ・シーケンス長で FlexAttention と Flashlight より高速(FlashInfer は CUDA ベースの特殊化カーネンを生成し、block sparsity を API パラメータ経由で直接活用するため)。例外は ALiBi(FlashInfer がバイアスを要素ごとに計算するオーバーヘッド)と、バッチサイズ 16/32 の `block_mask` 変種(FlexAttention の `block_mask` 構築コストが支配的)。 ## 評価: 複雑なアテンション変種(FlexAttention 非対応) - **差分アテンション(DiffAttn):** nhead=32、headdim=64 の構成で、Flashlight は `torch.compile` 比 **4.51x**–**5.78x** 高速(H100)。nhead=16、headdim=128 でも同様に大幅な高速化。 - **Evoformer:** nhead=4、headdim=64 の行/列方向ゲート付きセルフアテンションで、Flashlight は `torch.compile` 比 **5x 以上**高速(スライドでは **7.22x**–**7.90x**、H100)。 - これらの変種は FlexAttention のテンプレートに適合しないため、FlexAttention との比較は不可能。Flashlight のみが融合カーネルを自動生成できる。 ## 評価: エンドツーエンド推論レイテンシ ### AlphaFold2 - OpenFold リポジトリの AlphaFold2 モデル(48 Evoformer レイヤー)。シーケンス長 256、バッチサイズ 1–32。Evoformer は 8 ヘッド・ヘッド次元 32、IPA は 12 ヘッド・ヘッド次元 16。 - PyTorch(eager)と `torch.compile` の推論レイテンシに有意差はない。Flashlight は H100・A100 の両方で推論レイテンシを **6%–9%** 改善。行/列方向ゲート付きセルフアテンションの実行時間を **5 倍以上**短縮した結果がエンドツーエンドに反映される。 ### LLaMA-3.2-1B(vLLM / Mooncake) - LLaMA-3.2-1B の Vanilla アテンションを Causal および Softcap に変更した 2 つのモデル変種を作成。Mooncake 会話トレース(先頭 200 リクエスト)を vLLM で H100 上で実行。 - デフォルトの PyTorch および `torch.compile` はこれらの変種でメモリ不足(中間アテンションテンソルの実体化のため)。 - **Softcap 変種:** Flashlight は FlexAttention 比で TTFT(P99 初回トークンレイテンシ)**1.17x** 高速、ITL(P99 トークン間レイテンシ)**0.88x**(FlexAttention がやや優位)、トークンスループット **1.04x**(図 5)。Flashlight のカーネル実行自体が高速なため(`score_mod` 変種)。 - **Causal 変種:** FlexAttention が TTFT **1.05x**、ITL **0.88x**、スループット **1.04x** で優位。`block_mask` 変種では FlexAttention のスパースブロックスキップが複数回呼び出しで償却されるため。 ## FlexAttention との比較: 汎用性とトレードオフ - **汎用性:** Flashlight は FlexAttention のテンプレートに適合しない変種(差分アテンション、Evoformer、IPA、RSA)も自動融合する。FlexAttention は `score_mod`(スコアへの要素演算)と `mask_mod`(`block_mask`)の 2 パターンに限定される。 - **ブロックスパース性:** FlexAttention は `block_mask` でスパースブロックをスキップし冗長計算を排除する。Flashlight はブロックスパース性を最適化しない。Causal のような `block_mask` 変種では FlexAttention のカーネル実行が高速。 - **プログラミングモデル:** Flashlight はネイティブ PyTorch コードをそのままコンパイルする。FlexAttention はユーザーが `score_mod`/`mask_mod` 関数を定義し、`block_mask` を構築・キャッシュする必要がある。 - **性能の精度:** 構造的融合(次元降格)はメモリアクセスを削減するが並列性を犠牲にする。代数的変換(オンライン softmax 化)は実数上は等価だが、浮動小数点の非結合性により一部のカーネルで不正確となる可能性がある。 ## 関連研究 - **効率的アテンションカーネル:** FlashAttention(タイリング+カーネル融合)、FlexAttention(テンプレート+TorchInductor)、FlashInfer(CUDA JIT 特殊化)。Flashlight はこれらと異なり `torch.compile` 内で自動的に最適化カーネルを生成する。 - **柔軟・特殊アテンションモデル:** Longformer、ALiBi、DiffTransformer など。これらの効率的サポートが汎用コンパイラフレームワークの必要性を動機づける。 - **深層学習コンパイラ基盤:** Triton(GPU DSL、TorchInductor が使用)、PolyBlocks(アフィンアクセス解析による MLIR ベースのアテンション融合、並行研究)、TVM(オートチューニング)、Mirage(プログラム合成)、ThunderKittens(カーネル融合とスケジューリングの DSL)。