Open yaroslavvb opened 2 months ago
My understanding is that 8 bit all_gather is something you would only do if you're already doing compute in 8 bit. We have an experimental branch for 8 bit compute here: https://github.com/allenai/OLMo-core/tree/epwalsh/float8-investigation. This is using the new, faster trainer. So far it is faster by an impressive margin (not 50% though), but we have not vetted it at larger scales.
8 bit all_gather would be another step after that.
❓ The question
Is there plan or any partial work done towards supporting 8-bit AllGather in Olmo? https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359
Authors observe 50% improvement in throughput for training Llama 70B with on-par numerics, which seems significant (depending on what "on par numerics" means)