JAXはLLMの学習に強いフレームワークです。ただし推論フェーズには落とし穴があります——デコードステップが進むにつれてKVキャッシュへのアテンションが膨らみ、ボトルネックは計算量ではなくメモリ帯域幅に移ります。XLAコンパイラではこの領域を最適化しきれません。
FlashInferはこの問題のために設計された手チューニング済みCUDAカーネルライブラリです。2026年4月、Katja Sirazitdinova氏がFlashInferのチュートリアルに2本の新規ノートブックをコントリビュートしました。jax-tvm-ffiブリッジを使い、FlashInferのカーネルをJAXから直接呼び出す方法を解説したものです。
この記事でわかること:
- FlashInferとjax-tvm-ffiの仕組みと役割の違い
- 3つのカーネル(silu_and_mul / apply_rope / single_decode)の呼び出し手順
- Gemma 3 1B Instructへの実応用
- 動作環境と必要パッケージ
https://github.com/flashinfer-ai/flashinfer
JAX推論にFlashInferが必要な理由
LLM推論のデコードフェーズでは、1トークンを生成するたびにKVキャッシュ全体を参照します。シーケンスが長くなるほどキャッシュサイズが増え、GPUのメモリ帯域幅が律速となります。
FlashInferはこの帯域幅ボトルネックを直接攻略するCUDAカーネル集です。アテンション演算やFFN活性化など、推論で頻出する処理を低レベルで最適化し、既成フレームワークが生成するXLAカーネルより高速に動作します。
GitHubでは5,500以上のスターを獲得しており(2026年4月時点)、PyTorch経由での利用事例は多数ありましたが、JAXから直接呼び出す公式手段はこれまでありませんでした。
jax-tvm-ffiブリッジの仕組み
FlashInferは各CUDAカーネルを .so ファイルとしてビルドし、Apache TVM FFI(Foreign Function Interface)の規約に従ったバイナリインターフェースを持たせています。
NVIDIAが開発した jax-tvm-ffi は、このTVM FFI関数をXLAカスタムコールに変換するブリッジライブラリです。変換後はJAXの @jax.jit 領域の中で普通のJAX関数と同様に呼び出せます。
FlashInfer .so(TVM FFI)→ jax-tvm-ffi → XLA custom call → @jax.jit
この構造により、JITコンパイルやvmap、gradといったJAXの機能と組み合わせながら、手チューニング済みのCUDAカーネルを利用できます。
環境構築
動作には以下が必要です。
- NVIDIA GPU(SM 7.5以上、TuringアーキテクチャまたはRTX 20xx世代以降)
- CUDA 12.6以上
- Python 3.10以上
必要なパッケージは以下のコマンドでインストールします。
pip install 'jax[cuda13]'
pip install flashinfer-python -U jax-tvm-ffi \
--no-build-isolation \
--extra-index-url https://flashinfer.ai/whl/cu130/
CUDA 12.xを使う場合は cu130 の部分を cu126 などに変えます。
3つのカーネルを呼び出すまでの流れ
チュートリアル1本目(flashinfer_jax_tvm_ffi.py)では以下の3カーネルを順に実装します。
silu_and_mulはゲートつきFFNの活性化関数で、silu(gate) × up を計算します。FlashInferカーネルをJAXに接続する最小構成として、「ビルド → 登録 → 呼び出し」の基本パターンを学べます。
apply_ropeはパックされたバッチに対してRotary Positional Embedding(RoPE)を適用します。複数の出力を持ち、引数の並び替えが必要なケースを扱います。
single_decodeは単一リクエストのKVキャッシュアテンション(GQA対応)です。型特化JIT、スクラッチバッファ、オプション引数のセンチネル値といった高度なパターンを扱います。
チュートリアルの最後では、3つのカーネルを単一の @jax.jit 領域内で動かすコードを完成させます。実際のLLMデコードループに近い構成です。
Gemma 3への応用
チュートリアル2本目(gemma3_flashinfer_jax.py)ではGoogleのオープンウェイトモデル Gemma 3 1B Instruct を使い、同じカーネルパターンを実モデルに接続します。
Gemma 3固有の実装ポイントが3点あります。
GeGLU活性化:Gemma 3はSiLU-GLUではなくGeGLU(gelu_tanh_and_mul)を使います。FlashInferではカーネル名の1単語を変えるだけで切り替えられます。
QK-norm:各ヘッドのQとKにRMSNormを適用してからドット積を計算します。Gemma 2のlogit soft-cappingに代わる手法です。
デュアルRoPE theta:ローカルアテンション層はtheta=10,000、グローバルアテンション層はtheta=1,000,000を使います。層ごとに正しい値を選んで apply_rope に渡す実装が含まれています。
Gemma 3チュートリアルを動かすには、HuggingFaceでGemma 3のライセンス同意が必要です。また追加パッケージとして以下が必要になります。
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install safetensors huggingface_hub transformers
チュートリアルの実行方法
Sphinx-Galleryを使った構成のため、同じ .py ファイルからHTML・Pythonスクリプト・Jupyterノートブックの3形式が生成されます。直接実行する場合は以下のコマンドです。
python docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
python docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py
ノートブック形式で試したい場合は、FlashInferのドキュメントサイトからダウンロードできます。
注意点
KVキャッシュが実際に大きいシーケンスでないと、セットアップのオーバーヘッドがプロファイルに紛れ込む場合があります。カーネルのウォームアップを忘れずに行いましょう。
jax-tvm-ffiはXLAカスタムコールとして動作するため、TPUでは使えません。NVIDIA GPU専用の最適化です。また、SM 7.5未満のGPU(GTX 10xxや旧世代)はサポート対象外です。
FlashInferのJITコンパイルは初回起動時に時間がかかります。CIや再現実験では flashinfer-jit-cache ディレクトリを保持するとビルド時間を省けます。
まとめ
JAXユーザーがFlashInferの手チューニングCUDAカーネルを @jax.jit 内で直接使えるようになりました。jax-tvm-ffiブリッジを介した接続方法は、FlashInferのGitHubリポジトリ(docs/tutorials/jax_tvm_ffi/)にあるチュートリアルで一通り学べます。
https://github.com/flashinfer-ai/flashinfer/tree/main/docs/tutorials/jax_tvm_ffi
Gemma 3の例まで動かせば、ローカルアテンション/グローバルアテンションの切り替えやQK-normといった実装上の細部も把握できます。PyTorch以外でFlashInferを活用したいJAX開発者にとって、実用的な出発点です。