JAXはLLMの学習に強いフレームワークです。ただし推論フェーズには落とし穴があります——デコードステップが進むにつれてKVキャッシュへのアテンションが膨らみ、ボトルネックは計算量ではなくメモリ帯域幅に移ります。XLAコンパイラではこの領域を最適化しきれません。

FlashInferはこの問題のために設計された手チューニング済みCUDAカーネルライブラリです。2026年4月、Katja Sirazitdinova氏がFlashInferのチュートリアルに2本の新規ノートブックをコントリビュートしました。jax-tvm-ffiブリッジを使い、FlashInferのカーネルをJAXから直接呼び出す方法を解説したものです。

この記事でわかること:

  • FlashInferとjax-tvm-ffiの仕組みと役割の違い
  • 3つのカーネル(silu_and_mul / apply_rope / single_decode)の呼び出し手順
  • Gemma 3 1B Instructへの実応用
  • 動作環境と必要パッケージ

https://github.com/flashinfer-ai/flashinfer

JAX推論にFlashInferが必要な理由

LLM推論のデコードフェーズでは、1トークンを生成するたびにKVキャッシュ全体を参照します。シーケンスが長くなるほどキャッシュサイズが増え、GPUのメモリ帯域幅が律速となります。

FlashInferはこの帯域幅ボトルネックを直接攻略するCUDAカーネル集です。アテンション演算やFFN活性化など、推論で頻出する処理を低レベルで最適化し、既成フレームワークが生成するXLAカーネルより高速に動作します。

GitHubでは5,500以上のスターを獲得しており(2026年4月時点)、PyTorch経由での利用事例は多数ありましたが、JAXから直接呼び出す公式手段はこれまでありませんでした。

jax-tvm-ffiブリッジの仕組み

FlashInferは各CUDAカーネルを .so ファイルとしてビルドし、Apache TVM FFI(Foreign Function Interface)の規約に従ったバイナリインターフェースを持たせています。

NVIDIAが開発した jax-tvm-ffi は、このTVM FFI関数をXLAカスタムコールに変換するブリッジライブラリです。変換後はJAXの @jax.jit 領域の中で普通のJAX関数と同様に呼び出せます。

FlashInfer .so(TVM FFI)→ jax-tvm-ffi → XLA custom call → @jax.jit

この構造により、JITコンパイルやvmap、gradといったJAXの機能と組み合わせながら、手チューニング済みのCUDAカーネルを利用できます。

環境構築

動作には以下が必要です。

  • NVIDIA GPU(SM 7.5以上、TuringアーキテクチャまたはRTX 20xx世代以降)
  • CUDA 12.6以上
  • Python 3.10以上

必要なパッケージは以下のコマンドでインストールします。

pip install 'jax[cuda13]'
pip install flashinfer-python -U jax-tvm-ffi \
    --no-build-isolation \
    --extra-index-url https://flashinfer.ai/whl/cu130/

CUDA 12.xを使う場合は cu130 の部分を cu126 などに変えます。

3つのカーネルを呼び出すまでの流れ

チュートリアル1本目(flashinfer_jax_tvm_ffi.py)では以下の3カーネルを順に実装します。

silu_and_mulはゲートつきFFNの活性化関数で、silu(gate) × up を計算します。FlashInferカーネルをJAXに接続する最小構成として、「ビルド → 登録 → 呼び出し」の基本パターンを学べます。

apply_ropeはパックされたバッチに対してRotary Positional Embedding(RoPE)を適用します。複数の出力を持ち、引数の並び替えが必要なケースを扱います。

single_decodeは単一リクエストのKVキャッシュアテンション(GQA対応)です。型特化JIT、スクラッチバッファ、オプション引数のセンチネル値といった高度なパターンを扱います。

チュートリアルの最後では、3つのカーネルを単一の @jax.jit 領域内で動かすコードを完成させます。実際のLLMデコードループに近い構成です。

Gemma 3への応用

チュートリアル2本目(gemma3_flashinfer_jax.py)ではGoogleのオープンウェイトモデル Gemma 3 1B Instruct を使い、同じカーネルパターンを実モデルに接続します。

Gemma 3固有の実装ポイントが3点あります。

GeGLU活性化:Gemma 3はSiLU-GLUではなくGeGLU(gelu_tanh_and_mul)を使います。FlashInferではカーネル名の1単語を変えるだけで切り替えられます。

QK-norm:各ヘッドのQとKにRMSNormを適用してからドット積を計算します。Gemma 2のlogit soft-cappingに代わる手法です。

デュアルRoPE theta:ローカルアテンション層はtheta=10,000、グローバルアテンション層はtheta=1,000,000を使います。層ごとに正しい値を選んで apply_rope に渡す実装が含まれています。

Gemma 3チュートリアルを動かすには、HuggingFaceでGemma 3のライセンス同意が必要です。また追加パッケージとして以下が必要になります。

pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install safetensors huggingface_hub transformers

チュートリアルの実行方法

Sphinx-Galleryを使った構成のため、同じ .py ファイルからHTML・Pythonスクリプト・Jupyterノートブックの3形式が生成されます。直接実行する場合は以下のコマンドです。

python docs/tutorials/jax_tvm_ffi/flashinfer_jax_tvm_ffi.py
python docs/tutorials/jax_tvm_ffi/gemma3_flashinfer_jax.py

ノートブック形式で試したい場合は、FlashInferのドキュメントサイトからダウンロードできます。

注意点

KVキャッシュが実際に大きいシーケンスでないと、セットアップのオーバーヘッドがプロファイルに紛れ込む場合があります。カーネルのウォームアップを忘れずに行いましょう。

jax-tvm-ffiはXLAカスタムコールとして動作するため、TPUでは使えません。NVIDIA GPU専用の最適化です。また、SM 7.5未満のGPU(GTX 10xxや旧世代)はサポート対象外です。

FlashInferのJITコンパイルは初回起動時に時間がかかります。CIや再現実験では flashinfer-jit-cache ディレクトリを保持するとビルド時間を省けます。

まとめ

JAXユーザーがFlashInferの手チューニングCUDAカーネルを @jax.jit 内で直接使えるようになりました。jax-tvm-ffiブリッジを介した接続方法は、FlashInferのGitHubリポジトリ(docs/tutorials/jax_tvm_ffi/)にあるチュートリアルで一通り学べます。

https://github.com/flashinfer-ai/flashinfer/tree/main/docs/tutorials/jax_tvm_ffi

Gemma 3の例まで動かせば、ローカルアテンション/グローバルアテンションの切り替えやQK-normといった実装上の細部も把握できます。PyTorch以外でFlashInferを活用したいJAX開発者にとって、実用的な出発点です。