pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.25k stars 165 forks source link

enable TP fp8 allgather with PrepareFloat8ModuleInput #393

Closed wanchaol closed 3 months ago

wanchaol commented 3 months ago

This PR is a follow up PR to enable fp8 allgather in TP after these PR landed:

One need to update their pytorch/float8_experimental to have those changes in to train with fp8 changes.

Since fp8 is not enabled as part of our integration tests yet, there should be no issues on CI or trains that does not use fp8