AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

Support for safetensor checkpoints #1027

Closed richjames0 closed 1 week ago

richjames0 commented 1 week ago

Prior to this PR, our checkpointing code only supports loading .pth files but - in support of Mixtral 8x22B - we had need to load a safetensors file (v0.3 of the Instruct checkpoint, published by Mistral). We additionally noted that - as with .pth files - safetensors checkpoints could be split across multiple files. This PR addresses both cases.

Note we also introduce a new required command line parameter: checkpoint-type, which can take the value pth or safetensors.

Finally, a couple of minor fixes: