# TokenWeave: Efficient Compute-Communication Overlap for Distributed LLM Inference > [!info] Talk metadata > - **会議:** [[MLSys2026]] Day 2 (May 19 / Tue)、Research Track Oral "Agentic AI 2 & LLM Serving 1" セッション(14:45 PDT 開始) > - **登壇者:** Raja Gond(Microsoft Research India) > - **共著者:** Raja Gond、Nipun Kwatra、Ramachandran Ramjee(全員 Microsoft Research India) > - **URL:** https://openreview.net/forum?id=rh2Ylffkq6 > - **ソースコード:** https://github.com/microsoft/tokenweave > - **スライド:** なし(OpenReview ページおよび論文 PDF のみ) > - **Artifacts バッジ:** Artifacts Available / Artifacts Functional / Results Reproduced(3 種取得) > - **ライセンス:** CC BY 4.0 > [!abstract] 概要(OpenReview) > テンソル並列化(tensor parallelism)を用いた大規模言語モデル(LLM)の分散推論(distributed inference)では、NVLink のような高速 GPU インターコネクト経由であっても 20% に達する通信(communication)オーバーヘッドが生じうる。計算をより小さなタスクに分解し、通信とサブタスクを重畳(overlap)する手法が複数提案されてきたが、vLLM・SGLang・TensorRT-LLM 等のシステムではいずれもテンソル並列サービング時に既定で有効化されていない。これは、低遅延サービングを支えるために反復あたりの処理トークン数が小さく保たれ、そのような小規模ワークロードの分解が性能悪化を招くためである。さらに通信自体が多数のストリーミングマルチプロセッサ(SM)を占有し、本来計算に使えるリソースを奪ってオーバーヘッドを増大させる。本論文では、テンソル並列モデル推論において 1024 トークン程度の短いトークン長でも効率的な計算通信重畳を実現する初のシステム TokenWeave を提示する。TokenWeave は、従来見過ごされてきた演算である RMSNorm を重要と特定し、新規の融合 AllReduce–RMSNorm カーネルによって通信と同時に最適化する。このカーネルは、最新 GPU(Hopper、Blackwell 等)で利用可能な NVSHARP/Multimem 機能を活用し、8×H100 NVIDIA DGX システム上でわずか 2〜8 基の SM のみで通信と RMSNorm を効率的に実行する。評価では、複数のモデルとワークロードにわたり遅延で最大 1.28 倍の高速化(ベースライン÷本手法)、スループットで最大 1.19 倍の向上(本手法÷ベースライン)を達成した。いくつかの設定では、TokenWeave は全通信を除去した等価モデルよりも優れた性能を示す。 ## 問題設定 - 大規模モデルの分散推論ではテンソル並列化(TP)が標準だが、NVLink のような高速インターコネクトを用いても通信オーバーヘッドが **9〜23%** に達する(Llama-3.3-70B、Qwen2.5-72B、Mixtral-8x22B を 8×H100 DGX 上の vLLM 0.8.5 V1 エンジンで計測、図1) - 既存の計算通信重畳手法(タイルレベル分解やトークンレベル分解)は、いずれも vLLM・SGLang・TensorRT-LLM で既定無効である - 低遅延サービングでは反復あたりのトークン数が少なく(vLLM 0.8.5 の既定チャンクサイズは 2048)、小規模ワークロードの分解は波量子化効果(wave quantization)による計算効率低下を招く - 通信自体が 16〜20 基以上の SM を消費し、計算リソースを圧迫する - AllReduce を ReduceScatter + AllGather に分割する代替手法は、分割に伴うオーバーヘッドが RMSNorm 計算量の削減を上回り、逆効果となる(図4) - RMSNorm のオーバーヘッドは **4〜9%** と無視できない水準であり(図1)、AllReduce 後に全 GPU で冗長に実行される構造が非効率 ## 提案手法 TokenWeave は 3 つの鍵となる技術から構成される。 ### 1. 粗粒度(coarse-grained)トークン分割と波認識スマート分割(smart-splitting) - 入力バッチを **2 分割**(prefix-split と suffix-split)し、一方の通信(AllReduce + RMSNorm)と他方の計算をパイプラインで重畳する(図7) - 2 分割が最適:3 分割以上は分解オーバーヘッドが増すだけで重畳機会が増えない - **スマート分割(smart-splitting)**: 2 分割後の合計波(wave)数が分割前と同じになるよう分割点を選ぶ。例として 300 CTA・132 SM の GEMM では、132 CTA と 168 CTA に分割し、各分割が丁度 1 full wave(+1 partial wave)で収まる - 均等分割(equal-split)と比較して、特に小バッチで波量子化ジッタを除去(図9) - chunked-prefill とハイブリッド prefill/decode バッチの双方に対応 - 小トークン数(dense モデルで 1K 未満、Mixtral で 4K 未満)ではオーバーヘッドが利得を上回るため、分割を無効化し融合カーネルのみを適用する選択的有効化(図3) ### 2. RMSNorm 並べ替え(reordering) - AllReduce 後、各 GPU は同一のトークン埋め込みを持つため、全 GPU で RMSNorm を冗長に計算する従来実装は非効率 - TokenWeave は AllReduce を ReduceScatter → RMSNorm → AllGather の順に並べ替え、各 GPU が $\frac{1}{N}$($N$ = GPU 数)のシャードのみで RMSNorm を実行することで冗長計算を $N$ 分の 1 に削減 - ただし単純な分割は ReduceScatter + AllGather のオーバーヘッドが RMSNorm 削減を相殺するため、融合カーネルが必須 ### 3. Multimem ベースの融合 AllReduce–RMSNorm カーネル(§4.3) - NVSHARP/Multimem(NVSwitch 世代の SHARP エンジン)の Parallel Thread Execution(PTX)命令 `multimem_ld_reduce_add` および `multimem_st` を活用 - ReduceScatter・RMSNorm・残差加算・AllGather を**単一カーネル**に融合し、HBM アクセスを最小化 - 標準 RMSNorm は HBM 読み出し 2 回 + HBM 書き込み 1 回だが、融合カーネルは ReduceScatter 結果を SM レジスタ上で直接処理し、初回 HBM 読み出しを除去。さらに正規化済みの値を Multimem アドレスへ直接書き込むことで AllGather 用の追加 HBM 書き込みも除去 - **わずか 2〜8 基の SM** で通信帯域を飽和(図5・図10)。従来手法の 16〜20+ SM と比較して大幅に少なく、残りの SM を計算に解放 - 融合カーネル単体で、逐次実行(AllReduce + RMSNorm)に対し **1.34〜1.39 倍**の一貫した性能向上(トークンサイズ 64〜32K、表1)。AllReduce 単体の性能にほぼ匹敵 #### 表1: 融合カーネルのマイクロベンチマーク(hidden size 8192、bf16、8×H100 DGX、単位 µs) | トークン数 | 64 | 128 | 256 | 512 | 1K | 2K | 4K | 8K | 16K | 32K | |---|---|---|---|---|---|---|---|---|---|---| | AllReduce | 16.32 | 20.64 | 28.35 | 43.84 | 74.85 | 136.00 | 257.47 | 500.54 | 986.24 | 1955.71 | | RMSNorm | 8.32 | 9.57 | 12.06 | 18.91 | 29.82 | 52.16 | 96.29 | 185.09 | 361.54 | 716.13 | | AR+RMSNorm | 24.64 | 30.21 | 40.41 | 62.75 | 104.67 | 188.16 | 353.76 | 685.63 | 1347.78 | 2671.84 | | 融合(本手法) | 17.70 (1.39) | 22.53 (1.34) | 30.02 (1.35) | 46.46 (1.35) | 75.71 (1.38) | 137.34 (1.37) | 258.34 (1.37) | 502.24 (1.37) | 990.59 (1.36) | 1960.90 (1.36) | ※ 括弧内は逐次 AR+RMSNorm に対する高速化倍率 ## 実験・結果 ### 実験環境 - **ハードウェア:** 8×H100 NVIDIA DGX(NVSHARP 対応、128 CPU コア、800 GB ホストメモリ)。追加で 8×B200 DGX と 4×H100 構成でも評価 - **ソフトウェア:** vLLM 0.8.5 V1 エンジン上に実装。PyTorch 2.6.0、CUDA 12.4、Triton 3.2.0、FlashAttention-3 - **モデル:** Llama-3.3-70B、Qwen2.5-72B(dense)、Mixtral-8x22B(MoE)。全て instruction-tuned 版 - **ワークロード:** ShareGPT(実トレース)、arXiv summarization(実トレース)、固定入出力長の合成ワークロード - **ベースライン:** vLLM-Default(既定 AllReduce)、vLLM-Multimem(NVSHARP/Multimem 最適化 AllReduce)、vLLM-nocomm(通信除去の理論上限)、TileLink(SOTA の計算通信融合手法)、NanoFlow ### スループット(エンドツーエンド) - **ShareGPT・arXiv トレース(8×H100):** dense モデルで約 **1.19 倍**のスループット向上を一貫して達成(図11)。Mixtral では約 **1.11〜1.17 倍** - **チャンクサイズ変動(1024〜8192):** Llama-3.3-70B で **1.14〜1.26 倍**の一貫した改善(図12) - **NanoFlow との比較:** NanoFlow の改善は **1.04〜1.09 倍**に留まる(図15)。TokenWeave は同条件で約 **1.19 倍**と大幅に優位 ### 遅延(単一イテレーション) - **Llama-3.3-70B(8×H100、prefill-only):** 1K トークンから最大 **1.28 倍**の高速化。短いシーケンス長でも 1.2 倍の改善(図13) - dense モデル全体で 1K 以降 **1.16〜1.28 倍**の範囲 - **Mixtral-8x22B:** MoE 構造のため通信オーバーヘッド自体が小さく、利得は控えめ。1K・2K では分割オーバーヘッドが利得を上回るため融合カーネルのみモードで改善 - **vLLM-nocomm(通信ゼロの理論上限)超え:** シーケンス長 $\geq 4K$ の設定で TokenWeave は vLLM-nocomm を超える性能を示す。通信の回復だけでなく RMSNorm 最適化による追加利得のため ### TileLink との比較(単一レイヤー) - TileLink は 2K トークン未満で**遅延が増加**(ネットオーバーヘッド発生)。改善は 4K 以上で現れ **最大約 1.2 倍**で頭打ち(図14) - TokenWeave は 1K から **1.20 倍**の改善を開始し、最大 **1.35 倍**まで伸びる - TileLink は MLP 層の通信のみ重畳(TileLink-OnlyMLP)の方が効率的な場合もあるが、TokenWeave はアテンション層を含む全通信を重畳可能 ### アブレーション - **融合カーネルのみ(TokenWeave-fuseonly):** 全モデルで **1.04〜1.09 倍**の改善。冗長 RMSNorm 排除と HBM アクセス削減による(図16) - **スマート分割なし(TokenWeave-equalsplit):** 均等分割では波量子化によりシーケンス長ごとの性能にジッタが発生。スマート分割はこのジッタをほぼ完全に除去し安定した改善を実現(図17) - 4×H100 構成でも同様の傾向(Appendix 図23) ## 結論・オープン課題 - テンソル並列 LLM サービングにおける通信コストは、NVLink + NVSHARP のような高速ハードウェアを用いても最大 **20%** に達し、RMSNorm が追加で **4〜9%** のオーバーヘッドを加える - TokenWeave は 3 技術(スマート分割・RMSNorm 並べ替え・融合 AllReduce–RMSNorm カーネル)の組合せにより、1K トークンという小バッチでも計算通信重畳を実現する初のシステム - 8×H100、4×H100、8×B200 の 3 構成で評価し、最大 **1.28 倍**の遅延改善と **1.19 倍**のスループット向上を達成 - 通信ゼロの理論上限を超える性能を複数設定で達成(RMSNorm 最適化の追加効果) - **オープン課題(論文中に明示的な議論なし):** - Blackwell 世代以降での NVSHARP/Multimem の進化に伴うさらなる最適化余地 - エキスパート並列(expert parallelism)を用いる DeepSeek 型 MoE アーキテクチャへの適用(all-to-all 通信は TP の AllReduce とは性質が異なる) - vLLM V2 エンジンや SGLang 等の他サービングフレームワークへの統合 - disaggregated serving(prefill/decode 分離)環境でのさらなる最適化