GPUカーネルを完全に制御したい場合、最終的にはPTX(Parallel Thread Execution)を直接記述する必要がある。TritonやCUTLASSはコンパイラが命令スケジューリングやレジスタ割り当てを自動処理するため便利だが、ウォープスペシャライゼーションやWGMMAの命令順序を明示的に指定したいとき、コンパイラの動作を上書きする手段はほとんどない。
pyptx はその課題を解決するPython製OSS。NVIDIA HopperおよびBlackwellアーキテクチャ向けのPTX命令をPythonのDSL(ドメイン固有言語)として記述し、jax.jit・torch.compile・PyTorchのeagerモードから呼び出せる。1行1PTX命令という設計により、オプティマイザもオートチューナーも介在しない。
この記事でわかること:
- pyptxの概要と、それが解決する課題
- HopperとBlackwellの命令セットへの対応内容
- H100 SXM5での実測パフォーマンス(GEMMで815 TFLOPS達成)
- PTXトランスパイラ機能の使い方
- Triton・CuTe DSLとの使い分け
コンパイラを介さずPTXを書く理由
Tritonは「tl.dot を書けばコンパイラがWGMMAかtcgen05を選ぶ」という設計だ。これは多くのユースケースで十分だが、逆に言えば「どの命令を、どの順序で発行するか」はユーザーが決められない。
カスタムのウォープスペシャライゼーションパターン、3DのTMAマルチキャストスケジュール、Blackwellの tcgen05.mma を使った非標準パイプラインなど、命令レベルの制御が必要な場面でTritonの逃げ道は限られる。inline_asm_elementwise というエスケープハッチはあるが、element-wiseの操作に限定されており、HopperやBlackwellで重要なcollective命令(WGMMA、TMA、tcgen05)には対応していない。
pyptxはこの逆の設計を取る。Pythonの関数呼び出しがそのまま1つのPTX命令にマッピングされ、コンパイラもスケジューラも介在しない。
主な機能
1命令1呼び出しのDSL
ptx.* 名前空間の各関数が、対応するPTX命令を1つ発行する。@kernel デコレータでカーネルを定義し、レジスタ・述語・バリア・共有メモリを明示的に管理する。
from pyptx import kernel, reg, smem, ptx, Tile
from pyptx.types import bf16, f32
@kernel(
in_specs=(Tile("M", "K", bf16), Tile("K", "N", bf16)),
out_specs=(Tile("M", "N", f32),),
grid=lambda M, N, K: (N // 64, M // 64),
block=(128, 1, 1),
arch="sm_90a",
)
def gemm(A, B, C):
sA = smem.wgmma_tile(bf16, (64, 16), major="K")
sB = smem.wgmma_tile(bf16, (16, 64), major="MN")
acc = reg.array(f32, 32)
# TMAロード + ptx.wgmma.mma_async(...) — 各呼び出しが1PTX命令を発行
print(gemm.ptx()) を実行すると、書いた通りのPTXがそのまま表示される。最適化による命令の並び替えは一切起きない。
HopperとBlackwell両世代に対応
Hopper(sm_90a)では WGMMA、TMA 2D/3Dマルチキャスト、mbarrier、クラスターラウンチをサポート。Blackwell(sm_100a)では tcgen05.mma / .ld、TMEM、SMEMデスクリプタ、ウォープスペシャライゼーションに対応している。現時点でバージョン 0.1.0(プレリリース)。
1つのカーネルオブジェクトで3つの実行パス
同じカーネルオブジェクトをJAX、PyTorch eager、torch.compile の3つから呼び出せる。
# PyTorch eager
out = gemm(a, b)
# torch.compile
out = torch.compile(gemm)(a, b)
# JAX jit(型付きFFI経由)
out = jax.jit(gemm)(a, b)
内部ではPTXを cuModuleLoadData でJITコンパイルし、約150行のC++ランチシムを経由してディスパッチする。PyTorchのディスパッチオーバーヘッドは、CUDAグラフリプレイで約4μs、キャッシュ済みC++拡張のTurbo eagerモードで約14μsになる。
PTX→Pythonトランスパイラ
既存のPTXファイルをpyptxのPythonコードに変換できる機能も持つ。
python -m pyptx.codegen kernel.ptx --sugar --name my_kernel > my_kernel.py
--sugar オプションは名前のデマングル、スピンループのループ構文への変換、mbarrier待機ブロックの折りたたみなどを行う。CUTLASS・Triton・fast.cu・DeepGEMM・ThunderKittensなど218以上の実際のPTXファイルでバイト同一のラウンドトリップを確認済み(参考)。
nvccやTritonが出力したPTXを一度Pythonに変換して読み、修正してから使う、というワークフローが実現する。H100での815 TFLOPSのGEMMカーネルも、fast.cuのkernel12をこのトランスパイラ経由でPythonに落とし、そこから改変することで作られた。
H100・B200での実測パフォーマンス
Hopper H100 SXM5(bf16/f32)での測定値:
| カーネル | 結果 | 比較 |
|---|---|---|
| GEMM(wgmma、ウォープスペシャライズ、8192³) | 815 TFLOPS | cuBLAS 6K以上で上回る |
| RMS norm(f32、B=2048 N=8192) | 2.6 TB/s(HBM 88%) | torchの3.9倍 |
| Layer norm(f32、B=2048 N=8192) | 2.5 TB/s(HBM 83%) | F.layer_normの1.5倍 |
| SwiGLU(f32、M=2048 F=8192) | 2.8 TB/s(HBM 94%) | F.silu(g)*uの1.6倍 |
| Flash attention(bf16、M=N=4096、HD=64) | 88μs | naive torchの3.0倍 |
Blackwell B200(bf16)での測定値:
| カーネル | 結果 | cuBLAS比 |
|---|---|---|
| GEMM(tcgen05.mma、4ステージ、1SM、8192³) | 1240 TFLOPS | 77% |
| GEMM(1SM、4096³) | 1194 TFLOPS | 78% |
| Grouped GEMM(MoE G=4 M=N=K=2048) | 401 TFLOPS | PyTorch参照の約10倍 |
詳細なベンチマーク表と再現コマンドは pyptx.dev/performance に掲載されている。
インストール
pip install pyptx # DSL・パーサー・エミッター・トランスパイラ(GPUランタイムなし)
pip install 'pyptx[torch]' # + PyTorch eager / torch.compile
pip install 'pyptx[jax]' # + jax.jit(型付きFFI)
pip install 'pyptx[all]' # PyTorchとJAX両方
pip install ninja を追加しておくと、PyTorchのC++拡張をJITビルドする際のディスパッチオーバーヘッドが約34μsから約14μsに下がる。ライセンスはApache-2.0。
TritonやCuTe DSLとの使い分け
pyptxと他ツールの違いは「誰がPTXを決めるか」にある。Tritonと cuTile(NVIDIA製、CUDA 13.1で追加)はコンパイラが命令スケジュールを決定する設計で、オートチューニングや標準的なGEMM/convolutionには向いている。CuTe DSL(CUTLASS 4)はNVIDIAのatom抽象化を通じてHardwareアクセスを提供し、CUTLASS連携のカーネルを作るのに適している。
対してpyptxは「ユーザーが1命令ずつ決める」設計で、次のような場面に向いている。
- WGMMA・TMA 3Dマルチキャスト・mbarrierフェーズ管理など、Hopperの命令列を明示的に制御したいとき
- Blackwellで
tcgen05.mma・TMEM・cta_group::2の2SM協調MMAを手書きしたいとき - nvcc・Triton・CUTLASSが出力したPTXをPythonに変換して読み、修正したいとき
- PTXの学習や既存カーネルの解析に使いたいとき
逆に向いていない場面もある。pyptxはHopperとBlackwell(sm_90a / sm_100a)に特化しており、旧世代GPU向けの汎用コードは書けない。また、命令スケジュールをコンパイラに任せたい標準的なパターンなら、Tritonの方が開発効率が高い。
使い始めるには
ドキュメントサイト(pyptx.dev)にはHopper向けRMS normカーネルを題材にした入門ガイドがあり、TMAロードとウォープreduceの最小構成から始められる。Blackwell向けには examples/blackwell/tcgen05_suite.py に13個の独立したtcgen05プリミティブのサンプルが含まれており、B200での動作確認の起点として使いやすい。
PTX命令を直接制御したい研究者や、パフォーマンスチューニングの限界まで追いたいMLエンジニアにとって、試す価値のあるツールだ。