a-r-r-o-w / cogvideox-factory

Memory optimized finetuning scripts for CogVideoX using TorchAO and DeepSpeed
Apache License 2.0
177 stars 16 forks source link

float8 matmul for inference + torchao fp8 training #28

Open a-r-r-o-w opened 5 hours ago

a-r-r-o-w commented 5 hours ago

Torch has support for float8 matmul kernels, and it seems like they are faster than bf16 on Ada and above architectures. TorchAO supports training in fp8. This has been explored in a few newer optimization examples of Flux and other larger models to achieve real-time image generation. I think we could explore this for training in CogVideoX and see how it pans out.

Relevant links:

Since this might take some time to profile properly, it is low priority but definitely worth exploring since some other training libraries/UIs are exploring into this too.

@sayakpaul @zRzRzRzRzRzRzR