# CAGE: Curvature-Aware Gradient Estimation For Accurate Quantization-Aware Training > [!info] Talk metadata > - **会議:** [[MLSys2026]] Day 3 (May 20 / Wed)、Grand Ballroom 2、17:00 - 17:15 PDT > - **セッション名:** Research Track Oral: Model Compression > - **登壇者:** Soroush Tabesh (ISTA) > - **URL:** https://mlsys.org/virtual/2026/oral/3841 > - **OpenReview:** https://openreview.net/forum?id=Fubm1TtWeo > - **共著者:** Soroush Tabesh$^{*1}$, Mher Safaryan$^{*1}$, Andrei Panferov$^{1}$, Alexandra Volkova$^{1}$, Dan Alistarh$^{1,2}$($*$ 等貢献。$^{1}$Institute of Science and Technology Austria (ISTA), $^{2}$Red Hat AI) > - **コード:** https://github.com/IST-DASLab/CAGE > [!abstract] 概要(論文アブストラクトの忠実な日本語訳) > 低ビット量子化認識訓練(QAT)に関する研究は数多いが、こうした手法とネイティブ訓練との間には依然として精度差が存在する。この課題に対し、本研究では CAGE(Curvature-Aware Gradient Estimation)を提案する。CAGE は STE(Straight-Through Estimator)勾配に対し、量子化による損失増加を打ち消す曲率認識補正項を付加する新たな QAT 手法である。CAGE は QAT を多目的最適化の観点から再定式化し、タスク損失の最小化と量子化制約の充足を同時にバランスさせるパレート最適解の枠組みから導出される。理論面では、滑らかな非凸設定において CAGE が強いエルゴード収束保証を達成することを証明する。実装面では、本手法はオプティマイザ非依存であり、Adam の統計量を活用した高効率な実装を提供する。CAGE は最先端手法に対し同等の計算コストで精度を大幅に改善する。QAT ファインチューニングでは、先行手法に比べ圧縮に伴う精度劣化を約半分に削減する。Llama の QAT 事前学習では、3 ビット重み・活性化(W3A3)の精度が先行最良手法の 4 ビット(W4A4)に匹敵する。 ## 問題設定: QAT と STE の限界 - 大規模言語モデル(LLM)の効率的な推論・デプロイに向け、量子化は標準的な手法となっている。Llama、Gemma、GPT-OSS といったオープンモデルは圧縮形式でリリースされており、事後量子化(PTQ)と量子化認識訓練(QAT)が二大アプローチである(論文 1 節)。 - PTQ はキャリブレーション用の小規模データセットに対し数値アルゴリズムを適用して量子化する手法であり、計算コストが低い反面、低ビット域での精度に限界がある。QAT は訓練過程そのものに量子化を組み込み、モデルが量子化制約を学習する手法であり、より高い精度を達成しうる(論文 1 節)。 - QAT の中核的課題は量子化関数の非微分可能性にある。量子化演算子 $Q$ はほぼ至る所で勾配がゼロであり、逆伝播時に有用な勾配情報が得られない。従来の標準的対処法は STE であり、これは量子化演算子のヤコビアンを恒等行列で近似する、すなわち $\nabla_x[f(Q(x))] = JQ(x)^\top \cdot \nabla f(Q(x)) \approx \nabla f(Q(x))$ とする手法である(Hinton, 2012; Bengio et al., 2013b。論文 3.1 節、スライド 2 ページ)。 - STE は汎用的であるが、収束の遅さと不安定性が知られている。既存の改良手法(LSQ、LSQ+、AdaSTE、EWGS 等)はヒューリスティックなスケーリングや学習可能な量子化パラメータを導入するが、いずれも収束保証を提供しない(論文 2 節)。 - 従来の QAT に対する収束解析は、量子化点 $Q(x^*)$ でのタスク損失の停留点 $\nabla f(Q(x^*)) = 0$ への収束を目標としてきた。しかしスライド 3 ページの図示例が示すように、量子化関数 $Q(x) = \lfloor x \rfloor$ のもとで $f(x) = \frac{1}{2}(x - \frac{1}{2})^2$ を考えると、$\nabla f(Q(x))$ は任意の点で非ゼロとなり、収束界に消去できない量子化誤差項 $\mathcal{O}(\text{quant.error})$ が残存する(論文 3.1 節)。 ## CAGE の理論的基盤: 多目的最適化とパレート最適性 ### QAT の多目的最適化としての再定式化 - CAGE は QAT を 2 つの目的の同時最適化問題として捉える(スライド 4 ページ)。 - **目的 1:** タスク損失 $f(x)$ の最小化(停留点 $\nabla f(x^*) = 0$)。 - **目的 2:** 量子化誤差 $Q(x^*) - x^* = 0$ の最小化(パラメータが量子化グリッド点に一致)。 - 一般にこの 2 目的は同時に達成できない。損失の勾配方向 $-\nabla f(x^*)$ と量子化グリッドへの引き付け方向 $Q(x^*) - x^*$ は異なるため、一方を改善すれば他方が悪化しうる(論文 3.1 節)。 ### $\lambda$-パレート最適条件 - パレート最適解 $x^*$ を、あるスカラー $\lambda > 0$ に対して以下を満たす点と定義する(論文 Equation 2、スライド 4 ページ)。 $\nabla_{\lambda\mathrm{P}} f(Q(x^*)) := \nabla f(x^*) + \lambda(x^* - Q(x^*)) = 0$ - この条件は、十分小さな更新を $x^*$ に施した場合、2 目的のうち少なくとも一方が悪化することを保証する。パラメータ $\lambda$ は 2 目的間のトレードオフを制御する(論文 3.1 節)。 ### 収束定理 - SGD をベースオプティマイザとした場合、CAGE の更新則は以下のとおりである(スライド 10 ページ)。 $x_{t+1} = x_t - \alpha(\widetilde{\nabla}f(x_t) + \lambda(x_t - Q(x_t)))$ - 損失関数 $f$ の $L_f$-滑らかさ、確率的勾配の不偏性と有界分散、および量子化演算子 $Q$ のリプシッツ連続性に相当する滑らかさ仮定のもとで、適切なステップサイズ $\alpha$ を選べば以下が成立する(論文 Theorem 1、スライド 10 ページ)。 $\mathbb{E}\left[\|\nabla_{\lambda\mathrm{P}} f(Q(\hat{x}_T))\|^2\right] = \mathcal{O}\left(\frac{1}{\sqrt{T}}\right)$ - ここで $\hat{x}_T$ は訓練履歴 $\{x_0, \ldots, x_{T-1}\}$ から一様ランダムに選んだ反復点である。収束レートは非量子化の場合と同一の最適レートであり、上界に量子化誤差起因の消去不能項が**存在しない**。これは STE ベースの先行手法に欠けていた理論的保証である(論文 Theorem 1 後の Discussion)。 ## CAGE の実装: アルゴリズム設計 ### 結合型と脱結合型の補正 - CAGE の核心は、STE 勾配に量子化残差 $e_t = x_t - Q(x_t)$ を用いた補正項を付加することである。オプティマイザ非依存の設計として、結合型(coupled)と脱結合型(decoupled)の 2 つの変種を提供する(論文 3.1 節、スライド 5-6 ページ)。 - **結合型:** 補正項 $\lambda_t e_t$ を STE 勾配に加算してからオプティマイザに渡す。Adam の場合、モーメント処理を経て $(1 - \beta_1)\lambda_t e_t / (\sqrt{v_t} + \varepsilon)$ という対角プリコンディショニングされた曲率認識補正となる。 - **脱結合型:** オプティマイザの更新 $\Delta_t$ の計算後に補正項 $\lambda_t e_t$ を単位スケールで加算する。プリコンディショニングパスを経由しないため、低ビット域でパーコーディネートの分散推定が不安定化する場合でもパレート停留性方向を保存し、実用上安定である。 - 両変種は同一のパレート由来の補正方向 $x - Q(x)$ を用いる。主な差異はスケーリングのみであり、AdamW 下で両者はモデルサイズ・量子化設定にわたりほぼ同一の最終パープレキシティを達成する(論文 Appendix D)。 ### ウォームアップスケジュール(サイレンス期間) - 訓練初期に補正を適用すると、モデルが損失ランドスケープ上を大きく移動する段階で量子化セル境界を頻繁に横切るため、量子化残差がほぼゼロ平均の有界摂動として振る舞い、累積補正が $O(\sqrt{T})$ のノイズ的変動となり訓練を不安定化させうる(論文 3.3 節、スライド 7 ページ)。 - これに対し CAGE は**サイレンス比率** $s \in [0, 1)$ を導入する。訓練進捗比 $r_t = t/T$ が $s$ 以下の間は $\lambda_t = 0$(補正なし)、$r_t > s$ 以降は $\lambda_t = \lambda \cdot \frac{r_t - s}{1 - s}$ で線形にランプアップする(論文 Algorithm 1、スライド 5, 7 ページ)。 - 収束近傍では反復点が量子化アンカー $Q(x_t)$ の小近傍に留まり、残差が多くの連続ステップにわたり**コヒーレント**になる。この段階で補正が有効に機能し、パラメータを量子化サポートへ押し込む(論文 3.3 節)。 - 実用上、$s \in [0.8, 0.95]$ および $\lambda \in [1, 10]$ がモデルに応じてロバストに機能する。ハイパーパラメータスイープにより、CAGE はこの範囲内の $s$ と $\lambda$ の具体的な値に対し非感受的であることが確認されている(論文 Appendix C)。 ### 実装の計算コスト - CAGE の追加計算は、量子化残差 $e_t = x_t - Q(x_t)$ の算出に必要な 1 回の追加量子化呼び出しと、オプティマイザステップ内の少数の要素単位演算のみである。LLM の訓練コストの大部分を占める行列乗算には手を加えないため、エンドツーエンドのオーバーヘッドは無視できる水準である(論文 4.3 節後段)。 - 壁時計反復時間の実測では、H100 上で 100M および 430M パラメータモデルにおいて QuEST と CAGE は統計的に区別できない(Hadamard 変換なし: 100M で 101.6 $\pm$ 1.7 vs 101.1 $\pm$ 1.8 ms/iter、430M で 282.7 $\pm$ 3.1 vs 283.1 $\pm$ 3.0 ms/iter。論文 4.3 節)。 - ピーク GPU メモリも変化しない。CAGE は追加の永続的状態を保持せず、残差 $e_t$ をオンザフライで計算する(論文 4.3 節)。 ## 量子化パイプラインの詳細 - 特に断りのない限り、行方向(row-wise)行列量子化器を使用する。具体的には QuEST の QAT インスタンシエーションに基づく(論文 4 節)。 - テンソル $x \in \mathbb{R}^d$ に対し、直交アダマール変換 $H \in \{\pm \frac{1}{\sqrt{d}}\}^{d \times d}$ で回転した $z = Hx$ を、対称整数量子化する。ビット幅 $b$ に対し $q_{\max} = 2^{b-1} - 1$、$q_{\min} = -2^{b-1}$ で、MSE 最適ガウシアンクリッピング係数 $k_b$ を用いてスケール $s = k_b \sigma / q_{\max}$ を設定する(論文 4 節)。 - CAGE は**量子化器非依存**である。量子化残差 $e_t = x_t - Q(x_t)$ のみを必要とし、$Q$ の内部実装やビット精度には依存しない。QuEST を選択したのは低ビット訓練における SOTA ベースラインであるためであり、LSQ 等の他の量子化器とも組み合わせ可能である(論文 4 節)。 ## 実験結果 ### 合成損失実験(論文 4.1 節) - 非等方的ヘッシアンを持つ二次目的 $f(x) = \frac{1}{2}x^\top A x - b^\top x$(条件数 $\kappa(A) \in \{1, 10, 100\}$)を 4 ビット量子化下で最適化する合成実験を実施。STE-SGD、STE-Adam、CAGE-Adam(脱結合型)を比較した(論文 4.1 節、論文 Figure 1-2)。 - CAGE は全条件数において統計的に有意に誤差を削減する。主成分上の軌跡を 2 次元に射影すると、CAGE は曲率に追従して最適解に最も近い量子化解に収束する(論文 Figure 2)。 ### QAT ファインチューニング(論文 4.2 節) - **設定:** Llama-3.2-3B モデルを MXFP4(4 ビット浮動小数点、OCP Microscaling 仕様)形式で QAT ファインチューニング。Tulu-SFT データセットを使用し、マスター重みは FP32、フォワードパスは重み・活性化ともに MXFP4 で量子化。 - **評価指標:** GSM8K(ゼロショット、exact-match)、HellaSwag(精度)、WinoGrande(精度)のスコア。 - **結果(論文 Figure 3):** CAGE は全ベンチマークで一貫して MXFP4 QAT ベースライン(QuEST)の精度を上回る。具体的には、MXFP4 への量子化に伴う精度劣化(BF16 基準からの低下幅)を約**半分に削減**する。 ### 事前学習実験(論文 4.3 節) - **設定:** Llama アーキテクチャのトランスフォーマーをパラメータ数 $N \in \{30\text{M}, 50\text{M}, 100\text{M}, 200\text{M}, 430\text{M}, 800\text{M}\}$ でスクラッチから QAT 事前学習。重み/活性化ビット幅 $b \in \{2, 3, 4\}$。C4 データセットを使用し、トークン予算 $D = 100 \times N$(TPP = 100)。マスター重みは FP32、AdamW で更新。8 基の H100 GPU で実施(論文 4.3 節、論文 Appendix B)。 - **ベースライン:** QuEST(アダマール変換 + 行方向量子化。STE や LSQ より優れることが示されている SOTA 手法)および BF16(高精度リファレンス)。 #### W4A4 事前学習パープレキシティ(論文 Table 1) | 手法 | 30M | 50M | 100M | 200M | 430M | 800M | $\mathrm{eff}(P)$ | |---|---|---|---|---|---|---|---| | CAGE + HT | **26.277** | **22.747** | **18.944** | **16.166** | **13.789** | **12.182** | **0.797** | | QuEST + HT | 26.475 | 23.062 | 19.123 | 16.311 | 14.169 | 12.482 | 0.733 | | CAGE (HT なし) | 27.287 | 23.781 | 19.630 | 16.596 | 14.024 | 12.341 | 0.705 | | QuEST (HT なし) | 27.401 | 23.991 | 19.799 | 17.093 | 14.375 | 12.797 | 0.620 | | BF16(リファレンス) | 24.715 | 21.491 | 17.923 | 15.422 | 13.176 | 11.698 | 1.0 | - CAGE は HT の有無にかかわらず、全モデルサイズで QuEST を有意に上回る。 #### スケーリング則とパレートフロント(論文 Figure 5、スライド 8 ページ) - 検証損失 vs. モデルサイズ(バイト)のパレートフロントを W2A2、W3A3、W4A4 について比較すると、CAGE は全ビット精度・全モデルサイズで QuEST に対しパレート優位である。 - **W3A3 の CAGE で訓練したモデルは、W4A4 の QuEST(先行最良手法)より低い損失を達成する。** すなわち 3 ビット精度が先行手法の 4 ビット精度を上回る(論文 4.3 節、スライド 8 ページ)。 ### 精度スケーリング則(論文 Figure 6、スライド 9 ページ) - 検証損失に対し $\mathcal{L}(N, D, P) = \frac{A}{(N \cdot \mathrm{eff}(P))^\alpha} + \frac{B}{D^\beta} + E$ というスケーリング則を当てはめる。$N$ はパラメータ数、$D$ はトークン数、$P$ はビット精度、$\mathrm{eff}(P)$ は量子化による有効容量ペナルティを捉える係数である(標準精度で $\mathrm{eff}(\text{FP}) = 1$)。 - CAGE は QuEST 比で $\mathrm{eff}(P)$ を一貫して向上させる。4 ビット精度で 10% 以上、2 ビット精度で 20% のパラメータ効率改善が得られる(論文 4.3 節)。 - CAGE 使用時の「最適」ビット幅は 4 ビット付近であるが、3 ビットに対する優位はわずかである(スライド 9 ページ)。 ### 結合型 vs. 脱結合型の比較(スライド 6 ページ) - W4A4 事前学習の 30M、50M、100M パラメータにおいて、結合型(CAGE$_C$)と脱結合型(CAGE$_D$)はほぼ同一のパープレキシティを達成する(例: 100M HT あり で CAGE$_D$ = 18.944、CAGE$_C$ = 18.948)。主要な改善は補正方向の共有に由来し、スケーリングの差異は二次的である。 ### オプティマイザ横断的汎用性(論文 4.3 節) - AdamW に加え Muon、Shampoo、SOAP の 3 オプティマイザでも検証。OLMo2 ファミリの 50M モデルを ClimbMix データセットで訓練(TPP = 100)。CAGE のハイパーパラメータは $s = 0.9$、$\lambda = 5$ で固定(論文 4.3 節)。 - 全オプティマイザにわたり CAGE は一貫して検証損失を改善する(論文 Figure 4)。追加のハイパーパラメータ調整なしに改善が得られることから、CAGE がオプティマイザ非依存な手法であることが実証される。 ### LOTION との比較(論文 4.3 節) - 併行研究 LOTION は、量子化損失を一様ランダム丸め(RR)により平滑化し、そのサロゲート損失の 2 次展開から正則化項を導入する手法である(Kwun et al., 2025)。 - 100M パラメータの Llama モデルを C4 で W4A4 QAT した場合、初期段階では両手法の損失曲線は類似するが、CAGE は訓練後半で引き続き改善し、LOTION はプラトーに達する。CAGE が明確に低い最終検証損失を達成する(論文 Figure 8)。 - CAGE との設計上の相違点: LOTION は重みのみの量子化を仮定し、フル精度のフォワードパスを要求する。CAGE は重みと活性化の両方を量子化した状態で訓練でき、高性能低精度 GEMM カーネルの恩恵を受けられる(論文 3.1 節 "Relationship to LOTION")。 ## 低ビットハードウェアサポートの展望 - **W4A4:** NVIDIA Blackwell アーキテクチャの第 5 世代テンソルコアが MXFP4 のネイティブアクセラレーションを提供し、OCP Microscaling(MX)仕様に準拠する。ファインチューニング実験はこの MXFP4 フォーマットを対象としている(論文 4.2 節)。 - **W3A3 および W2A2:** 論文執筆時点では広く利用可能なハードウェアサポートは存在しないが、QAT 研究としてアルゴリズム的フロンティアを押し広げることが重要であり、新興ハードウェアがこれらの精度域をターゲットし始めている(論文 4 節)。 - CAGE の主貢献は新たなカーネルではなく、オプティマイザステップに付加するほぼゼロコストの補正項であり、任意の量子化器上で精度を改善する汎用的な手法である(論文 4.1 節後段)。 ## 議論と位置づけ - CAGE は QAT を多目的最適化問題として再定式化し、パレート最適性の条件から原理的に補正項を導出した点が理論的新規性である。従来の STE 改良手法がヒューリスティックであったのに対し、滑らかな非凸設定での $\mathcal{O}(1/\sqrt{T})$ 収束保証を初めて提供する。 - 実用面では、オプティマイザ非依存・量子化器非依存・追加メモリ不要・計算オーバーヘッド無視可能という 4 条件を同時に満たし、既存の QAT パイプラインにプラグインできる軽量な手法である。 - W3A3 CAGE が W4A4 QuEST のパレートフロントを上回るという結果は、QAT 手法の改善がビット幅 1 ビット分のハードウェア世代前倒しに相当するインパクトを持つことを示唆する。Blackwell 以降で 3 ビットハードウェアサポートが登場した場合、CAGE の実用的価値はさらに増大する。 - 今後の方向性として、超低ビット域(1 ビット近傍)への適用、スパース化やベクトル量子化など他の圧縮手法との統合が挙げられている(論文 5 節)。