GPUカーネルを完全に制御したい場合、最終的にはPTX(Parallel Thread Execution)を直接記述する必要がある。TritonやCUTLASSはコンパイラが命令スケジューリングやレジスタ割り当てを自動処理するため便利だが、ウォープスペシャライゼーションやWGMMAの命令順序を明示的に指定したいとき、コンパイラの動作を上書きする手段はほとんどない。

pyptx はその課題を解決するPython製OSS。NVIDIA HopperおよびBlackwellアーキテクチャ向けのPTX命令をPythonのDSL(ドメイン固有言語)として記述し、jax.jittorch.compile・PyTorchのeagerモードから呼び出せる。1行1PTX命令という設計により、オプティマイザもオートチューナーも介在しない。

この記事でわかること:

  • pyptxの概要と、それが解決する課題
  • HopperとBlackwellの命令セットへの対応内容
  • H100 SXM5での実測パフォーマンス(GEMMで815 TFLOPS達成)
  • PTXトランスパイラ機能の使い方
  • Triton・CuTe DSLとの使い分け

コンパイラを介さずPTXを書く理由

Tritonは「tl.dot を書けばコンパイラがWGMMAかtcgen05を選ぶ」という設計だ。これは多くのユースケースで十分だが、逆に言えば「どの命令を、どの順序で発行するか」はユーザーが決められない。

カスタムのウォープスペシャライゼーションパターン、3DのTMAマルチキャストスケジュール、Blackwellの tcgen05.mma を使った非標準パイプラインなど、命令レベルの制御が必要な場面でTritonの逃げ道は限られる。inline_asm_elementwise というエスケープハッチはあるが、element-wiseの操作に限定されており、HopperやBlackwellで重要なcollective命令(WGMMA、TMA、tcgen05)には対応していない。

pyptxはこの逆の設計を取る。Pythonの関数呼び出しがそのまま1つのPTX命令にマッピングされ、コンパイラもスケジューラも介在しない。

主な機能

1命令1呼び出しのDSL

ptx.* 名前空間の各関数が、対応するPTX命令を1つ発行する。@kernel デコレータでカーネルを定義し、レジスタ・述語・バリア・共有メモリを明示的に管理する。

from pyptx import kernel, reg, smem, ptx, Tile
from pyptx.types import bf16, f32

@kernel(
    in_specs=(Tile("M", "K", bf16), Tile("K", "N", bf16)),
    out_specs=(Tile("M", "N", f32),),
    grid=lambda M, N, K: (N // 64, M // 64),
    block=(128, 1, 1),
    arch="sm_90a",
)
def gemm(A, B, C):
    sA = smem.wgmma_tile(bf16, (64, 16), major="K")
    sB = smem.wgmma_tile(bf16, (16, 64), major="MN")
    acc = reg.array(f32, 32)
    # TMAロード + ptx.wgmma.mma_async(...) — 各呼び出しが1PTX命令を発行

print(gemm.ptx()) を実行すると、書いた通りのPTXがそのまま表示される。最適化による命令の並び替えは一切起きない。

HopperとBlackwell両世代に対応

Hopper(sm_90a)では WGMMA、TMA 2D/3Dマルチキャスト、mbarrier、クラスターラウンチをサポート。Blackwell(sm_100a)では tcgen05.mma / .ld、TMEM、SMEMデスクリプタ、ウォープスペシャライゼーションに対応している。現時点でバージョン 0.1.0(プレリリース)。

1つのカーネルオブジェクトで3つの実行パス

同じカーネルオブジェクトをJAX、PyTorch eager、torch.compile の3つから呼び出せる。

# PyTorch eager
out = gemm(a, b)

# torch.compile
out = torch.compile(gemm)(a, b)

# JAX jit(型付きFFI経由)
out = jax.jit(gemm)(a, b)

内部ではPTXを cuModuleLoadData でJITコンパイルし、約150行のC++ランチシムを経由してディスパッチする。PyTorchのディスパッチオーバーヘッドは、CUDAグラフリプレイで約4μs、キャッシュ済みC++拡張のTurbo eagerモードで約14μsになる。

PTX→Pythonトランスパイラ

既存のPTXファイルをpyptxのPythonコードに変換できる機能も持つ。

python -m pyptx.codegen kernel.ptx --sugar --name my_kernel > my_kernel.py

--sugar オプションは名前のデマングル、スピンループのループ構文への変換、mbarrier待機ブロックの折りたたみなどを行う。CUTLASS・Triton・fast.cu・DeepGEMM・ThunderKittensなど218以上の実際のPTXファイルでバイト同一のラウンドトリップを確認済み(参考)。

nvccやTritonが出力したPTXを一度Pythonに変換して読み、修正してから使う、というワークフローが実現する。H100での815 TFLOPSのGEMMカーネルも、fast.cuのkernel12をこのトランスパイラ経由でPythonに落とし、そこから改変することで作られた。

H100・B200での実測パフォーマンス

Hopper H100 SXM5(bf16/f32)での測定値:

カーネル 結果 比較
GEMM(wgmma、ウォープスペシャライズ、8192³) 815 TFLOPS cuBLAS 6K以上で上回る
RMS norm(f32、B=2048 N=8192) 2.6 TB/s(HBM 88%) torchの3.9倍
Layer norm(f32、B=2048 N=8192) 2.5 TB/s(HBM 83%) F.layer_normの1.5倍
SwiGLU(f32、M=2048 F=8192) 2.8 TB/s(HBM 94%) F.silu(g)*uの1.6倍
Flash attention(bf16、M=N=4096、HD=64) 88μs naive torchの3.0倍

Blackwell B200(bf16)での測定値:

カーネル 結果 cuBLAS比
GEMM(tcgen05.mma、4ステージ、1SM、8192³) 1240 TFLOPS 77%
GEMM(1SM、4096³) 1194 TFLOPS 78%
Grouped GEMM(MoE G=4 M=N=K=2048) 401 TFLOPS PyTorch参照の約10倍

詳細なベンチマーク表と再現コマンドは pyptx.dev/performance に掲載されている。

インストール

pip install pyptx               # DSL・パーサー・エミッター・トランスパイラ(GPUランタイムなし)
pip install 'pyptx[torch]'      # + PyTorch eager / torch.compile
pip install 'pyptx[jax]'        # + jax.jit(型付きFFI)
pip install 'pyptx[all]'        # PyTorchとJAX両方

pip install ninja を追加しておくと、PyTorchのC++拡張をJITビルドする際のディスパッチオーバーヘッドが約34μsから約14μsに下がる。ライセンスはApache-2.0。

TritonやCuTe DSLとの使い分け

pyptxと他ツールの違いは「誰がPTXを決めるか」にある。Tritonと cuTile(NVIDIA製、CUDA 13.1で追加)はコンパイラが命令スケジュールを決定する設計で、オートチューニングや標準的なGEMM/convolutionには向いている。CuTe DSL(CUTLASS 4)はNVIDIAのatom抽象化を通じてHardwareアクセスを提供し、CUTLASS連携のカーネルを作るのに適している。

対してpyptxは「ユーザーが1命令ずつ決める」設計で、次のような場面に向いている。

  • WGMMA・TMA 3Dマルチキャスト・mbarrierフェーズ管理など、Hopperの命令列を明示的に制御したいとき
  • Blackwellで tcgen05.mma・TMEM・cta_group::2 の2SM協調MMAを手書きしたいとき
  • nvcc・Triton・CUTLASSが出力したPTXをPythonに変換して読み、修正したいとき
  • PTXの学習や既存カーネルの解析に使いたいとき

逆に向いていない場面もある。pyptxはHopperとBlackwell(sm_90a / sm_100a)に特化しており、旧世代GPU向けの汎用コードは書けない。また、命令スケジュールをコンパイラに任せたい標準的なパターンなら、Tritonの方が開発効率が高い。

使い始めるには

ドキュメントサイト(pyptx.dev)にはHopper向けRMS normカーネルを題材にした入門ガイドがあり、TMAロードとウォープreduceの最小構成から始められる。Blackwell向けには examples/blackwell/tcgen05_suite.py に13個の独立したtcgen05プリミティブのサンプルが含まれており、B200での動作確認の起点として使いやすい。

PTX命令を直接制御したい研究者や、パフォーマンスチューニングの限界まで追いたいMLエンジニアにとって、試す価値のあるツールだ。