Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

FP8 Linear and conv with cudnn #486

Open vedaanta opened 4 months ago

vedaanta commented 4 months ago

🚀 Feature

CuDNN provides flexible support for performant gemm/conv with fp8 quantization. Thunder introducing fp8 casts in its traces can benefit from cudnn fusions.

Motivation

Today, thunder uses TE's fp8 linear which delegates quantization strategies to TE, making it opaque to thunder. If thunder plans to handle fp8 casts itself, performant and flexible kernels from cudnn can help.

Cudnn's support is described here: cuDNN's runtime fusion engine

For fp8 specifically, cudnn can provide the following graph as one fused kernel: fp8 drawio

The graph is flexible, meaning:

The corresponding backward graphs are also supported. (Though they require offline transpose on Hopper)

Pitch

Have cudnn executor claim gemm/conv along with the fp8 casts around them.

CC @IvanYashchuk @kshitij12345 @Anerudhan

IvanYashchuk commented 4 months ago

Thank you for creating this issue! I think it should be possible to use TransformerEngine as the scaling recipe manager and cuDNN as the performance driver in Thunder.