Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
CuDNN provides flexible support for performant gemm/conv with fp8 quantization. Thunder introducing fp8 casts in its traces can benefit from cudnn fusions.
Motivation
Today, thunder uses TE's fp8 linear which delegates quantization strategies to TE, making it opaque to thunder. If thunder plans to handle fp8 casts itself, performant and flexible kernels from cudnn can help.
Thank you for creating this issue! I think it should be possible to use TransformerEngine as the scaling recipe manager and cuDNN as the performance driver in Thunder.
🚀 Feature
CuDNN provides flexible support for performant gemm/conv with fp8 quantization. Thunder introducing fp8 casts in its traces can benefit from cudnn fusions.
Motivation
Today, thunder uses TE's fp8 linear which delegates quantization strategies to TE, making it opaque to thunder. If thunder plans to handle fp8 casts itself, performant and flexible kernels from cudnn can help.
Cudnn's support is described here: cuDNN's runtime fusion engine
For fp8 specifically, cudnn can provide the following graph as one fused kernel:
The graph is flexible, meaning:
The corresponding backward graphs are also supported. (Though they require offline transpose on Hopper)
Pitch
Have cudnn executor claim gemm/conv along with the fp8 casts around them.
CC @IvanYashchuk @kshitij12345 @Anerudhan