# PyLO: Towards Accessible Learned Optimizers in PyTorch > [!info] Talk metadata > - 会議: [[MLSys2026]] Day 3 (May 20 / Wed)、Research Track Oral "LLM Training 2" セッション(15:45--16:00 PDT、Grand Ballroom 2) > - 著者: Paul Janson\*, Benjamin Therien\*(\* = equal contribution), Quentin Anthony, Xiaolong Huang, Abhinav Moudgil, Eugene Belilovsky > - 所属: Concordia University, Mila, Universite de Montreal, EleutherAI > - プロジェクトページ: https://belilovsky-lab.github.io/pylo/ > - コード: https://github.com/Belilovsky-Lab/pylo(Artifacts Available / Functional / Reproducible バッジ取得) > - スライド: https://mlsys.org/media/mlsys-2026/Slides/3824_HZUGKXd.pdf > [!abstract] 概要(論文アブストラクトの忠実な日本語訳) > 学習型最適化器(Learned Optimizer)は過去 10 年にわたり活発な研究対象であり、Adam のような広く使われる手法のドロップイン代替となる実用的な汎用最適化器に向けて着実に進歩してきた。しかし、VeLO のような最近の成果は、4000 TPU-month のメタ学習で訓練されたにもかかわらず、JAX 依存とメタ学習後の独立利用を想定したユーザフレンドリなパッケージの欠如から、広範なコミュニティにはほとんど普及していない。このギャップを埋めるため、我々は PyLO を提案する。PyLO は機械学習コミュニティの約 70% が利用する PyTorch に対し、おなじみの `torch.optim.Optimizer` インタフェースを通じて学習型最適化器を提供するライブラリである。限定的な学術タスクに注力した先行研究と異なり、本研究は学習型最適化を実世界の大規模事前学習タスクへ適用することに重点を置く。システム貢献として、`small_fc_lopt` および VeLO の CUDA 高速化実装を含み、ViT-B/16(バッチサイズ 32)での学習スループットは 39.36 から 205.59 サンプル毎秒へ、49.73 から 191.18 サンプル毎秒へとそれぞれ向上した。PyLO は学習率スケジュールや重み減衰といった既存の最適化ツールと学習型最適化器を容易に組み合わせる柔軟性を備え、そうした追加が学習型最適化器の性能を大幅に改善しうることを発見した。コードは https://github.com/Belilovsky-Lab/pylo で公開されている。 ## 背景: 学習型最適化器の可能性と普及の壁 ニューラルネットワーク学習は高度に非凸な最適化問題であり、Adam や AdamW のような手設計の最適化器は証明可能な最適性を保証しない。学習型最適化器(LO)は小規模な MLP をメタ学習し、パラメータ毎の更新量を予測する手法である。代表例の VeLO は 4000 TPU-month のメタ学習を経て、ハイパーパラメータ調整済みの NAdamW を上回る損失を達成した。 しかし LO の実用普及を阻む **4 つの障壁**が存在する。 - **JAX 限定のエコシステム**: リファレンス実装(`google/learned_optimization`)は JAX で書かれているが、ML コミュニティの約 70% は PyTorch を利用する - **メタ学習との密結合**: 既存リポジトリは LO の「訓練」に主眼を置き、「適用」のための分離された実装を提供しない - **重み共有標準の不在**: HuggingFace Hub のようなモデル共有基盤が LO の重みには存在しない - **ステップ毎の計算オーバーヘッド**: LO のステップはパラメータ毎に小規模 MLP を評価するため、ナイーブ実装では Adam の 10--100 倍遅い ## PyLO のライブラリ設計 PyLO は 4 つの主要コンポーネントからなるモジュラーアーキテクチャを採用する。 - **最適化モジュール**(`pylo.optim`): `torch.optim.Optimizer` インタフェースを実装し、パラメータ毎のアキュムレータ管理・特徴量計算・更新ステップ実行を担当する。`torch.compile` やアクティベーションチェックポインティングとも合成可能である - **メタモデルアーキテクチャ**(`pylo.models`): LO のパラメータとフォワードパスをカプセル化し、HuggingFace Hub と完全統合する。`from_pretrained` によるタスク特化型重みのダウンロードとバージョン管理を提供する - **CUDA 高速化**(`pylo.csrc`): 特徴量構築・正規化・MLP 推論を融合したカスタム CUDA カーネルを実装する - **分離された評価**: PyLO-Examples リポジトリとして、言語モデリング(FineWeb-EDU 事前学習)や画像分類(ImageNet)の評価コードを別リポジトリに分離し、本体のコード量を最小化する 使い方は 5 行のコード変更で済む。 ```python from pylo.optim import VeLO optimizer = VeLO(model.parameters(), hf_key="Belilovsky-Lab/velo") scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N) ``` 既存の PyTorch エコシステム(学習率スケジューラ、分離型重み減衰、DDP / FSDP)とそのまま組み合わせられる。 ## 既存ライブラリとの比較 | リポジトリ | メタ学習からの分離 | PyTorch | HuggingFace Hub | CUDA カーネル | |---|---|---|---|---| | Open-L2O | 非対応 | 対応 | 非対応 | 非対応 | | `google/learned_optimization` | 非対応 | 非対応 | 非対応 | 非対応 | | **PyLO** | **対応** | **対応** | **対応** | **対応** | PyLO は唯一、全 4 要件を満たすフレームワークである。 ## CUDA カーネル設計: 2 パス融合 ### ナイーブ実装のボトルネック パラメータテンソル(サイズ $mn$)に対し、ナイーブ PyTorch 実装は以下の問題を抱える。 - サイズ $mn \times d_{\text{feat}}$ の特徴量テンソルをグローバルメモリに実体化する - 74--252 回のカーネル起動(要素単位演算・リダクション・MLP 層それぞれに 1 回) - メモリ帯域律速で GPU 演算ユニットが遊休する ### 融合カーネル戦略 PyLO の CUDA 実装は 2 つの融合カーネルで全処理を完結させる。 **カーネル 1(統計収集パス)**: `construct_features()` と `compute_squared_average()` を融合する。 - 特徴量をレジスタ上でオンザフライに計算し、実体化しない - スレッド毎に $d_{\text{feat}}$ 次元のアキュムレータで $\sum \text{feat}^2$ を蓄積する - ワープシャッフル → ブロックリダクション → アトミック加算の階層的集約 - 出力は 30--39 個のスカラ統計量のみ **カーネル 2(適用パス)**: `construct_features()`・`feature_normalization()`・`apply_lo()` を融合する。 - 正規化統計量を共有メモリにロードしブロードキャストする - 特徴量をレジスタ上で再計算し、正規化後にインラインで MLP を評価する - MLP の重みは `__ldg` でロードし、中間活性化はレジスタに留める - パラメータ更新のみグローバルメモリに書き戻す ### メモリ階層の活用 - **レジスタ**: 特徴量ベクトル($d_{\text{feat}} \approx 39$ for `small_fc_lopt`、$\approx 30$ for VeLO)と MLP 活性化をシングルサイクルアクセスで保持する - **共有メモリ**: $d_{\text{feat}}$ 次元の正規化統計量をブロック内全スレッドにブロードキャストする - **グローバルメモリ**: パラメータ値・勾配・アキュムレータの読み出しと更新済みパラメータの書き戻しのみ この設計によりカーネル起動回数は 74--252 → 30--114 に削減され、$mn \times d_{\text{feat}}$ の中間テンソル割り当てが完全に除去される。算術強度(FLOPs/byte)が非常に低い(4 層 MLP、$d_{\text{feat}} \approx 40$、$d_{\text{hidden}} \approx 32$)ため、MLP 計算をストリーミングロードに隠蔽し、メモリ帯域律速の Adam と同等の挙動に近づける。融合実装では $g, p, m, v$ の各ワードを正確に 2 回読む(カーネル 1 + カーネル 2)。 ## ベンチマーク結果 ### ステップ時間の削減 | モデル | 最適化器 | ステップ時間 (ms) | ナイーブ比削減 | |---|---|---|---| | ViT-B/16 (BS=32) | Adam | 4.90 | -- | | | Adafactor | 18.99 | -- | | | `small_fc_lopt` (ナイーブ) | 756.80 | -- | | | `small_fc_lopt` (CUDA) | **99.59** | **-86%** | | | VeLO (ナイーブ) | 585.11 | -- | | | VeLO (CUDA) | **113.58** | **-80%** | | GPT-2 355M (BS=4) | Adam | 20.12 | -- | | | `small_fc_lopt` (ナイーブ) | 2872.17 | -- | | | `small_fc_lopt` (CUDA) | **319.14** | **-88%** | | | VeLO (ナイーブ) | 2378.93 | -- | | | VeLO (CUDA) | **284.37** | **-88%** | CUDA 融合カーネルにより、LO のステップ時間はナイーブ実装から 86--88% 削減される。 ### JAX リファレンス実装との比較 GPT-2 スタイルのトランスフォーマで `small_fc_lopt` と VeLO のステップ時間をスケーリング評価した結果、PyLO の CUDA 実装は JAX 実装に対し **2 倍以上高速**であった。さらに JAX 実装は 1B パラメータモデルで 80 GB A100 GPU のメモリを超過し OOM となるのに対し、PyLO の CUDA 実装は $mn \times d_{\text{feat}}$ の中間テンソルを実体化しないため 1B パラメータでも動作する。 ### バッチサイズ増大時のオーバーヘッド減少 ViT-B/16 の単一 A100 GPU での学習において、バッチサイズを 32 → 512 に増加させるとフォワード/バックワードの所要時間が支配的になり、LO のステップ時間は固定コストとして相対的に無視できるレベルまで低下する。大規模バッチ学習では LO のオーバーヘッドが実質的に問題とならない。 ### 分散最適化ステップ 125M パラメータの GPT-2 を 4 基の H100 GPU で学習する設定で、all-reduce の代わりに reduce-scatter を用いて最適化ステップを分散化した結果。 | 最適化器 | All-reduce (ms) | Reduce-scatter (ms) | 差分 | |---|---|---|---| | `small_fc_lopt` CUDA | 136.53 | 104.67 | -31.86 | | VeLO CUDA | 127.67 | 108.71 | -18.96 | | Muon | 106.92 | 98.10 | -8.82 | | Adam | 99.93 | 95.93 | -4.00 | 分散ステップにより、Adam に対する LO のオーバーヘッドは `small_fc_lopt` で 9%、VeLO で 13% まで縮小する。さらに ZeRO-1 / FSDP A2A 方式でオプティマイザ状態をシャードすることでメモリ効率も改善される。 ## 実世界タスクでの精度評価 ### ViT-B/16 on ImageNet-1K(480 エポック、150k ステップ) | 最適化器 | Top-1 精度 | |---|---| | VeLO | **78.39%** | | Adam + コサインスケジュール | 77.22% | | $\mu\text{LO}_M$ | 62.14% | VeLO はハイパーパラメータ調整なしでチューニング済み Adam を上回る。$\mu\text{LO}_M$ はメタ学習ホライズンが 1000 ステップと短いため 65,000 ステップ付近で発散した。 ### GPT-2 355M on FineWeb-EDU(10B トークン) | 最適化器 | 最終損失 | |---|---| | VeLO | **2.89** | | Adam + コサインスケジュール | 2.91 | | $\mu\text{LO}_M$ | 3.18 | VeLO は言語モデルの事前学習でもチューニング済み Adam と同等以上の損失を達成する。 ## 学習率スケジュールと重み減衰の効果 PyLO は `torch.optim.Optimizer` インタフェースに準拠するため、コサインアニーリングスケジュールや分離型重み減衰を 5 行のコード追加で適用できる。 - **$\mu\text{LO}_M$**: 明示的なスケジューリングから大きな恩恵を受ける。安定学習区間が 65k → 150k ステップに拡大し、ImageNet Top-1 精度が 62.14% → 71% に、言語モデル損失も改善する - **VeLO**: スケジューリングによる改善は限定的であり、内部に適応的なスケジュール機構を備えていることを示唆する - 重み減衰は $\mu\text{LO}_M$ の GPT 学習に有効だが、VeLO と ViT タスクでは効果が見られなかった ## 制約と今後の展望 ### 現在の制約 - LO のステップ時間は絶対値では依然 Adam の 5--10 倍(ただしバッチサイズ増大で相対的に縮小する) - カーネルはパラメータ毎の特徴量次元が SRAM/SM 容量以下であることを前提とする。非常に大きな $d_{\text{feat}}$ には新たな設計が必要である - $\mu\text{LO}_M$ はメタ学習ホライズンが短い(1000 ステップ)ため長時間学習で発散する。これはライブラリの問題ではなくメタ学習側の課題である ### 今後の方向 - **ハードウェア-オプティマイザ協調設計**: 実測コストを目的関数に含めたメタ学習 - **行動クローニング蒸留**: SOAP や Shampoo のような高コスト二次最適化器の振る舞いを LO に蒸留する - **シャード化 LO ステップ**: all-to-all を $d_{\text{feat}}$ 個のスカラ交換に置き換え、通信量を大幅削減する - **コミュニティ共有**: HuggingFace Hub を通じた LO 重みのエコシステム構築 ## Key Takeaway PyLO は学習型最適化器を PyTorch エコシステムに初めて実用レベルで統合したライブラリである。融合 CUDA カーネルによるステップ時間 86--88% 削減、JAX 実装比 2 倍以上の高速化、1B パラメータモデルへのスケーリング、そして分散ステップによる Adam 比 9--13% のオーバーヘッドまでの圧縮を達成した。`torch.optim.Optimizer` 準拠により、学習率スケジュールや重み減衰との組み合わせが容易であり、VeLO はチューニング済み Adam をハイパーパラメータ調整なしで上回ることを実証した。学習型最適化器の「研究成果」から「実用ツール」への橋渡しとなるシステム基盤である。