ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.05k stars 860 forks source link

Documentation request: saving a model and loading after training #314

Open sandeepimpressico opened 9 months ago

sandeepimpressico commented 9 months ago

Could you please add to the documentation what is the way to save models when using MLX - after training is complete final model for inferencing.

Perhaps even add sample code to one of the mlx_examples e.g transformer_lm. How to save for checkpoints would be useful too.

I see multiple methods to save (e.g mx.savez(), model.save_weights()) and unclear whats the best way that saves all the required state and the corresponding methods to load it back from disk.

mzbac commented 9 months ago

Maybe there isn't an official document yet, but from my understanding, the original mlx is using savez. Now it has added support for save_safeTenser, so there are some inconsistencies due to the rapid development of the framework. However, if you take a look at the lora example, it should have the most up-to-date method on how to save/load model weights. I agree that having some official documentation on how to do it would be great.

bigsnarfdude commented 9 months ago

EXAMPLE https://github.com/ml-explore/mlx-examples/tree/main/lora

SAVING if you look at this file and grok the usage https://github.com/ml-explore/mlx-examples/blame/main/lora/lora.py#L327 you see it being used in training loop. docs here: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.savez.html

LOADING loading weights from npz format usage is https://github.com/ml-explore/mlx-examples/blame/main/lora/lora.py#L335 and docs https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.Module.load_weights.html?highlight=load+weights

sandeepimpressico commented 9 months ago

@bigsnarfdude - would this work for checkpointing as well, or do i need to save additional data for checkpointing?

bigsnarfdude commented 9 months ago

@sandeepimpressico looks like framework has got new code for checkpoint. just the npz file is all that is needed for checkpoints save and weights_load. here is the code:

https://github.com/ml-explore/mlx-examples/commit/d8680a89f986492dbc27c36af3294034db26458f