Closed grzuy closed 1 year ago
Thanks for the PR, a couple minor comments :)
FWIW I plan to explore in a follow up PR also supporting safetensors sharded params files.
Thanks a lot!
FWIW I plan to explore in a follow up PR also supporting safetensors sharded params files.
Mmm, correcting myself.
I think with the changes in this PR one would be able to load sharded safetensors param files.
For example, for
https://huggingface.co/stabilityai/StableBeluga-7B/tree/main
which contains the following files
model-00001-of-00002.safetensors
model-00002-of-00002.safetensors
model.safetensors.index.json
using
Bumblebee.load_model({:hf, "stabilityai/StableBeluga-7B"}, params_filename: "model.safetensors")
should use the existing sharded params loading logic, look at the index file and "just work".
What would be a good improvement might be adding some decent auto-selection of the preferred file format based on what's available in the model repo without having the user needed to explicitly provide the file name.
Good call, so far most repos had the pytorch file and optionally other formats, but as safetensors become more popular there may be cases where it's just safetensors. Currently we do fallbacks, that is, request one file, if doesn't exist request another, and so on. I checked and looks like HF API now allows listing files, so I will later reevaluate if we can improve :)
FTR as of #256 we automatically detect if there are no parameters in the pytorch format, but safetensors one is available :)
closes #96
Opening proof of concept as draft while I continue working on some improvements and test coverage, and potentially any other feedback folks have :-)