Closed shivajid closed 1 week ago
Hey, thanks for the question! I'm working on a PR that reduces the memory usage of llama_or_mistral_ckpt.py
to 2x the checkpoint size, so
should only need 1.6 TB of memory. I'm hoping to merge the PR soon!
The new script should now be live: https://github.com/google/maxtext/blob/main/MaxText/llama_or_mistral_ckpt.py !
@shivajid
Thanks Robert for the fix! Let us know if you still have any issues Shivaji
In an attempt to convert the LLama 405B model checkpoints I faced some scaling issues with the https://github.com/google/maxtext/blob/main/MaxText/llama_or_mistral_ckpt.py#L1 code. This code currently seems single threaded and there is a comment to change the convert() method to multiple pass, I think that would be great to have :).
LLama 405B has two check points, one in int8 and the other in bf16. The int8 is 476 GB. I have been attempting to convert that. This failed on a A3+ which has 1.8TB of memory. Finally I got a quota approval for Memory optimized machine(M1 -ultramem) which has about 3.8TB of memory. With this I am able to do my conversion. This took a very very long time multiple hours to finish and the peak memory usage was at 2.9TB.
For the bf16 checkpoint is about 700GB, so I am looking for a bigger machine.
I think it would be nice to the https://github.com/google/maxtext/blob/main/MaxText/llama_or_mistral_ckpt.py code work with multiple passes, thats all I wanted to say in short :).