NVIDIAがオープンソースの強化学習ライブラリ「NeMo RL」にFP8精度の完全対応を実装した。線形層だけでなくKVキャッシュとAttentionにまでFP8を拡張することで、BF16ベースラインと比べてロールアウト段階を最大48%高速化する。

この記事でわかること:

  • NeMo RLのFP8対応が解決する技術課題
  • 線形層とKVキャッシュ・Attentionそれぞれの実装アプローチ
  • 精度を落とさずに速度を上げるための工夫
  • 実際のスピードアップ数値と設定方法

NeMo RLとは

NVIDIA NeMo RLは、LLMやVLM(視覚言語モデル)向けの強化学習ポストトレーニングをスケールさせるオープンソースライブラリだ。NVIDIA NeMoフレームワークの一部として提供されており、GRPO(Group Relative Policy Optimization)やDAPO(Decoupled Clip and Dynamic Sampling Policy Optimization)などのアルゴリズムを使って、モデルを推論タスクで継続的に改善できる。

強化学習のトレーニングループは2つのフェーズに分かれる。厳しいレイテンシ要件のある「生成フェーズ」と、高スループットが求められる「学習フェーズ」だ。この2つのフェーズを効率よく回すために、FP8のような低精度データ型の活用が注目されている。

FP8適用の難所:数値のズレ問題

NeMo RLのパイプラインは、ロールアウトにvLLM、学習にNVIDIA Megatron Coreという2つのエンジンを使う。それぞれが独自のCUDAカーネルを持つため、もともと数値的な差が生じやすい。FP8を導入すると、量子化・逆量子化の処理が加わり、この差がさらに拡大する。

NVIDIAチームはこのズレを「トークン乗算確率誤差」という指標で定量化した。スコアが1.0に近いほど2つのエンジンの計算が一致している。許容範囲の目安は1.03〜1.05以下とされている。

検証した3つのレシピのうち、最も問題が大きかったのは「生成フェーズのみFP8、学習はBF16」というパターンだ。逆に生成と学習の両方でFP8を使うエンドツーエンドFP8は、数値のズレがむしろ抑えられる結果が出た。

重要度サンプリングで精度を保つ

エンドツーエンドFP8でも数値差はゼロではない。ここで効くのが重要度サンプリングだ。これはデータを生成したモデルと学習対象のモデルの分布差を補正する手法で、トークンごとの重みとして損失に掛け合わせる。

実験結果では、エンドツーエンドFP8に重要度サンプリングを組み合わせることで、BF16との精度差を完全に埋められることが示されている。Llama 3.1 8B Instructをmath datasetで4000ステップ学習した際の検証精度は、BF16が0.616に対してFP8エンドツーエンドは0.613と、実用上問題ない水準だ。

スループットは15〜25%向上する。理論的なFP8の2倍高速化より低い理由は、Attentionや要素ごとの演算がBF16のままであること、量子化カーネルの追加オーバーヘッドがあることによる。vLLMでの量子化カーネル融合が進めば、1.25倍以上の改善が見込まれるとしている。

KVキャッシュとAttentionにも拡張

さらに出力系列長が長くなるほど、KVキャッシュの増大とAttention計算がロールアウト時間の大半を占めるようになる。そこでNeMo RLはKVキャッシュとAttentionにもFP8を適用した。

実装上の難しさは、強化学習ではモデルの重みがステップごとに変わる点だ。静的な推論と違い、1回の量子化キャリブレーションでは対応できない。NeMo RLが採用したアプローチは次の3ステップだ。

  1. 再キャリブレーション: 学習ステップの終わりに、更新済みの重みを使ってQuery・Key・ValueのスケールをBF16で再計算する
  2. データ選択: キャリブレーションにはトレーニングデータ(プロンプトと生成レスポンス)を使い、現在の分布を正確に反映させる
  3. 同期: 計算したスケールを次のロールアウトフェーズのためにvLLMへ同期する

キャリブレーションのオーバーヘッドはステップ全体の2〜3%に収まっている。

Qwen3-8B-Baseで検証した結果、線形層W8A8に加えてKVキャッシュとAttentionにもFP8を適用することで、ロールアウト段階がさらに約30%高速化し、BF16ベースラインとの比較で合計約48%のスピードアップを達成した。この効果は出力系列長が長くなるほど顕著になる。

設定方法

FP8を有効にするのは設定ファイルの数行を変えるだけだ。

Megatron Core側はprecisionの設定から自動的にFP8が使われる。詳細な設定はやのレシピが参考になる。

上位のトランスフォーマー層だけをBF16のままにするオプション、スケールファクターを2のべき乗に制限するオプションなど、上級者向けの調整パラメータも用意されている。

まとめ

NVIDIA NeMo RLのFP8対応は、強化学習ポストトレーニングを実用的に高速化するための体系的な解答だ。数値誤差の定量的な評価、重要度サンプリングによる精度担保、ステップごとのキャリブレーションによる動的対応という3つの工夫が組み合わさっている。

Qwen3-30BのMoEモデルでも同様の精度を維持できており、モデルアーキテクチャを問わず適用できる汎用性がある。長い推論チェーンを扱うモデルのポストトレーニングコストを下げたい開発者にとって、試す価値のある選択肢だ。