pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.49k stars 136 forks source link

[FP8 options] Float8Linear vs TransformerEngine #462

Open yundai424 opened 1 month ago

yundai424 commented 1 month ago

Hi team, first of all thanks for this great repo for showcasing how to leverage the latest techniques in torch ecosystem, it's been super useful and insightful :) I have a naive question about FP8 options and would like to know more about how you view it.

There's the https://github.com/NVIDIA/TransformerEngine by nvidia for fp8 training on hopper and it's started to be integrated into downstream frameworks like HF, lightning etc. However I'm also seeing https://github.com/pytorch-labs/float8_experimental evolving quickly and the fact that it's more lightweight & potentially more composable w/ remaining torch techniques is also important to us. I'm wondering if you have some insight about the pros and cons of each of them, how would Float8Linear's performance compare to TE, and if you would recommend going with TE or Float8Linear for LLM pretraining/finetuning use cases. Thanks a lot!

tianyu-l commented 1 month ago

@weifengpy @awgu Any insights or thoughts to share?

weifengpy commented 1 month ago

good question. torchtitan + float8_experimental (or TorchAO for new dtypes in general) is a place where we showcase everything compose well together (fp8, parallalsim, torch.compile, activation checkpointing) using pytorch APIs. We have plans to benchmark perf improvement of fp8 over bf16

for perf comparison with TE, we do not have specific numbers yet. TE is more like our parterners/customers. We welcome adoption of pytorch APIs to fit their needs better

vkuzo commented 1 month ago

@yundai424 , I can speak to PyTorch's float8 modeling plans and can't comment on other things you asked about. From the POV of float8_experimental, we care about performance, composability with key PyTorch systems (autograd, distributed, compile), debuggability and readability. Please feel free to file issues in https://github.com/pytorch-labs/float8_experimental/tree/main if you have more specific questions and we will be happy to help.

jeromeku commented 1 week ago

@vkuzo @awgu

Any updates on benchmarking well-tuned e2e torchtitan / float8_experimental against a comparable TransformerEngine implementation?

Also, are there any examples of running torchtitan with all the composability benefits (e.g., seamless integration with torch.compile, FSDP2, etc.) but with TransformerEngine instead of float8_experimental?

vkuzo commented 1 week ago

Any updates on benchmarking well-tuned e2e torchtitan / float8_experimental against a comparable TransformerEngine implementation?

This isn't something the PyTorch team is likely to publish in the near term, but we definitely welcome benchmarks from the community on this topic.

Also, are there any examples of running torchtitan with all the composability benefits (e.g., seamless integration with torch.compile, FSDP2, etc.) but with TransformerEngine instead of float8_experimental?

I think that would be really nice! It also isn't something the PyTorch team is likely to focus on, but would be great if someone from the community drove this and shared their findings. From what I know, getting a meaningful performance boost from torch.compile + TE might be difficult because TE extensively uses custom CUDA kernels.

jeromeku commented 1 week ago

@vkuzo @awgu

If possible, can you speak to whether:

Thanks!

vkuzo commented 1 week ago

registering TE kernels as custom operators + torch.compile with torchtitan would result in performance degradation vs a purely native torch workflow (float8_experimental, etc.)?

TE has handwritten kernels for the important float8 fusions, which is why running torch.compile on TE would have a limited benefit. torchao.float8 does not ship with any handwritten kernels, so a compiler is required to recover performance.

are you aware of any benchmarks from the community on how various fp8 implementations compose with torch (torch.compile, torchtitan, FSDP2, etc.)?

I'm not aware of any, but it would be great if someone helped out with this.