elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.33k stars 96 forks source link

feat: supports loading .safetensors params file #231

Closed grzuy closed 1 year ago

grzuy commented 1 year ago

closes #96

Opening proof of concept as draft while I continue working on some improvements and test coverage, and potentially any other feedback folks have :-)

jonatanklosko commented 1 year ago

Thanks for the PR, a couple minor comments :)

grzuy commented 1 year ago

FWIW I plan to explore in a follow up PR also supporting safetensors sharded params files.

jonatanklosko commented 1 year ago

Thanks a lot!

grzuy commented 1 year ago

FWIW I plan to explore in a follow up PR also supporting safetensors sharded params files.

Mmm, correcting myself.

I think with the changes in this PR one would be able to load sharded safetensors param files.

For example, for

https://huggingface.co/stabilityai/StableBeluga-7B/tree/main

which contains the following files

model-00001-of-00002.safetensors
model-00002-of-00002.safetensors
model.safetensors.index.json

using

Bumblebee.load_model({:hf, "stabilityai/StableBeluga-7B"}, params_filename: "model.safetensors")

should use the existing sharded params loading logic, look at the index file and "just work".

grzuy commented 1 year ago

What would be a good improvement might be adding some decent auto-selection of the preferred file format based on what's available in the model repo without having the user needed to explicitly provide the file name.

jonatanklosko commented 1 year ago

Good call, so far most repos had the pytorch file and optionally other formats, but as safetensors become more popular there may be cases where it's just safetensors. Currently we do fallbacks, that is, request one file, if doesn't exist request another, and so on. I checked and looks like HF API now allows listing files, so I will later reevaluate if we can improve :)

jonatanklosko commented 1 year ago

FTR as of #256 we automatically detect if there are no parameters in the pytorch format, but safetensors one is available :)