Torch has support for float8 matmul kernels, and it seems like they are faster than bf16 on Ada and above architectures. TorchAO supports training in fp8. This has been explored in a few newer optimization examples of Flux and other larger models to achieve real-time image generation. I think we could explore this for training in CogVideoX and see how it pans out.
Since this might take some time to profile properly, it is low priority but definitely worth exploring since some other training libraries/UIs are exploring into this too.
Torch has support for float8 matmul kernels, and it seems like they are faster than bf16 on Ada and above architectures. TorchAO supports training in fp8. This has been explored in a few newer optimization examples of Flux and other larger models to achieve real-time image generation. I think we could explore this for training in CogVideoX and see how it pans out.
Relevant links:
Since this might take some time to profile properly, it is low priority but definitely worth exploring since some other training libraries/UIs are exploring into this too.
@sayakpaul @zRzRzRzRzRzRzR