ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 856 forks source link

Safetensor support #215

Closed dc-dc-dc closed 5 months ago

dc-dc-dc commented 5 months ago

Proposed changes

Adds support for safetensors as specified here and closes #182.

Added a bare minimum json parser for safetensor metadata parsing, still have some code to cleanup / simplify. But wanted to make a draft PR to get some initial comments on code placement / file structure and other nuances as well as any changes that maintainers might seek

Checklist

Put an x in the boxes that apply.

awni commented 5 months ago

Wow very cool! Will take a look soon!

JulianGerhard21 commented 5 months ago

Hi @dc-dc-dc - awesome, thanks a lot. Highly anticipated feature. Can it be used yet? I'd provide some feedback if so.

awni commented 5 months ago

I'm wondering about the decision to write a custom minimal json parser vs use something open source. For example https://github.com/nlohmann/json is header only and MIT License so we could include it the same way we include Doctest. (fetch it in CMake). Any thoughts on the pros/cons?

dc-dc-dc commented 5 months ago

I saw some people recommending this lib when I was originally looking into adding json support, originally felt that minimum support for strings / numbers would be all thats needed but looks like __metadata__ will require more data type support.

I will mess with utilizing this lib and see how it fits

dc-dc-dc commented 5 months ago

Update: decided to remove the custom json parser and utilize nlohmann/json. I am having some issues trying to get cmake FetchContent to link the lib so for now am using find_library after installing using homebrew. If someone has some more experience with cmake build and knows how to resolve this I would appreciate it 😄

awni commented 5 months ago

For an API what do you think about the following:

A single load function which takes a file name.

mx.load(filename, format=None)

If format is None (default) it attempts to infer the file type based on the extension (.npy, .npz, .safetensor, etc) which will nicely for 99% of cases. But if the format specifier is provided (for example `format="np") then it overrides the file type inference. CC @jagrit06 @angeloskath in case they have thoughts on that.

Edit: I wonder if we should do anything about the save? RIght now we have save and savez for .npz. The save_safetensors is kind of a mouthful. Not sure I like these more but just considering other options savest or save_st. Thoughts?

awni commented 5 months ago

I am having some issues trying to get cmake FetchContent to link the lib so for now am using find_library after installing using homebrew.

I added it with Cmake and pushed to your branch. I might try and find a way to only get the header since the whole repo is a bit cumbersome, but for now that should work.

julien-c commented 5 months ago

would be awesome to have a save_safetensors indeed!

dc-dc-dc commented 5 months ago

Updated load to take a format else infer from the file extension

Added npy, npz, and safetensors as options

import mlx.core as mx
# load safetensors format from different extension
mx.load("./temp.safe", "safetensors")
awni commented 5 months ago

@dc-dc-dc what's the state of this PR? Ready for review? I'm really stoked to get this landed 😄

dc-dc-dc commented 5 months ago

Only thing remaining is a decision on how to handle __metadata__, I added basic support by including it on save with the format field set to mlx.

It doesn't seem like there is much use for it out side of that so maybe its fine to skip for now and add in support later if requested?

Or can add it now but would have to switch the signatures a bit, making load return something like:

std::pair<std::unordered_map<std::string, std::string>, std::unordered_map<std::string, array>>

And save_safetensor would take an optional map of str:str

What do you think?

awni commented 5 months ago

To the extent that adding _metadata_ won't break the API then I'm fine adding it later if/when requested.

For save_safetensors it should be ok, since it would be an optional kwarg anywa.

For load I think it should also be fine if the idea is to put it in the dictionary return value with the key _meta_.

If that sounds right, then it makes sense to me to add it in later. The C++ API would probably break but that's much more rarely used.

dc-dc-dc commented 5 months ago

Yeah, it wouldn’t be a breaking change from python if added only c++, an easy addition down the line too.

With that in mind then this is ready for review 👀

Edit: let me just rebase and make sure its still passing

dc-dc-dc commented 5 months ago

@awni ready

dc-dc-dc commented 5 months ago

Also I think it would be good to make a subdir in mlx maybe called mlx/io/ and put the load.* and safetensor.* stuff there. The top-level directory is getting a little crowded. Wdyt?

I like this, as in the future if any other file formats like gguf / pt support are desired they can be grouped together under the io directory

awni commented 5 months ago

ValueError: [load] Input must be a file-like object, or string

I am getting this failure when running the python tests. Could you tell me which version of python and pybind11 you are using?

dc-dc-dc commented 5 months ago

Python version 3.11 Pybind 2.11.1

Which test is failing?

awni commented 5 months ago

The loading from a file pointer fails with the error message I put above

                        with open(save_file_mlx, "rb") as f:
                            load_dict = mx.load(f)
awni commented 5 months ago

Presumably those tests pass for you?

My pybind is 2.10 and python 3.9.

pybind11-global 2.10.4 py39h48ca7d4_0

I will try a higher version to see if that's the issue.

dc-dc-dc commented 5 months ago

Wait I am getting that error now after rebuilding, something I changed in the last commits must of broken it let me check

awni commented 5 months ago

Ok I will let you debug it then, lmk!

dc-dc-dc commented 5 months ago

@awni fixed on my local, was messing with adding path lib and when I was removing it out i also removed the istream_check 🤦‍♂️

awni commented 5 months ago

I changed the name from save_safetensor to save_safetensors. I found the fact that the file extension .safetensors did not match the name to be confusing... they should be consistent.

Also made a few minor docs changes.

awni commented 5 months ago

I tested it works going from:

save_file(tensors, "model.safetensors")

And loading with mx.load

dc-dc-dc commented 5 months ago

Looks good, one quick note about the json dependency it looks like they recommend fetching directly over using git due to size seen here. I don't have too much experience with cmake build so I defer to you on whether to keep the current git method or switch

dc-dc-dc commented 5 months ago

Also, have you tried loading something a pytorch tensor that was saved using the hugging face safetensor python package? That should work right?

I was able to load phi-2 safetensors file from huggingface

awni commented 5 months ago

Looks good, one quick note about the json dependency it looks like they recommend fetching directly over using git due to size seen here.

Yea worth switching that to one of the smaller options, nice fine.

awni commented 5 months ago

Included @angeloskath suggestions along with the updated CMake (much faster download) 🚀 .

Can merge whenever @angeloskath approves.

cchance27 commented 5 months ago

Wow this is gonna be so helpful as so much of the image landscape models are in safetensors hopefully it's merged soon

awni commented 5 months ago

@angeloskath this should be good to go, can you stamp? Your comments have been addressed, let's merge it unless there is anything else to add.

awni commented 5 months ago

PS @dc-dc-dc I forgot to mention would love to have you add your name to our contributors (optional). Feel free to send a PR and I will merge it.