elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Streamline loading for params variants #309

Closed jonatanklosko closed 6 months ago

jonatanklosko commented 6 months ago

Small UX improvement:

-Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
-  params_filename: "diffusion_pytorch_model.fp16.bin"
-)
+Bumblebee.load_model({:hf, repository_id, subdir: "unet"}, params_variant: "fp16")
iex> Bumblebee.load_model({:hf, repository_id, subdir: "unet"}, params_variant: "f16")
** (ArgumentError) parameters variant "f16" not found, available variants: "fp16", "non_ema"