AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Converting LLama3.1 405B checkpoint - Requesting multipass checkpoint conversion #864

Closed shivajid closed 1 week ago

shivajid commented 3 weeks ago

In an attempt to convert the LLama 405B model checkpoints I faced some scaling issues with the https://github.com/google/maxtext/blob/main/MaxText/llama_or_mistral_ckpt.py#L1 code. This code currently seems single threaded and there is a comment to change the convert() method to multiple pass, I think that would be great to have :).

LLama 405B has two check points, one in int8 and the other in bf16. The int8 is 476 GB. I have been attempting to convert that. This failed on a A3+ which has 1.8TB of memory. Finally I got a quota approval for Memory optimized machine(M1 -ultramem) which has about 3.8TB of memory. With this I am able to do my conversion. This took a very very long time multiple hours to finish and the peak memory usage was at 2.9TB.

For the bf16 checkpoint is about 700GB, so I am looking for a bigger machine.

I think it would be nice to the https://github.com/google/maxtext/blob/main/MaxText/llama_or_mistral_ckpt.py code work with multiple passes, thats all I wanted to say in short :).

rdyro commented 3 weeks ago

Hey, thanks for the question! I'm working on a PR that reduces the memory usage of llama_or_mistral_ckpt.py to 2x the checkpoint size, so

should only need 1.6 TB of memory. I'm hoping to merge the PR soon!

rdyro commented 2 weeks ago

The new script should now be live: https://github.com/google/maxtext/blob/main/MaxText/llama_or_mistral_ckpt.py !

@shivajid

gobbleturk commented 1 week ago

Thanks Robert for the fix! Let us know if you still have any issues Shivaji