elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.26k stars 90 forks source link

Parameter persistence with sharding support #338

Open jonatanklosko opened 4 months ago

jonatanklosko commented 4 months ago

Currently whenever we load a model, we need to convert their layout from whatever PyTorch uses to whatever Axon uses (mostly transposition of dense and conv layers). For smaller models this is quick, however for large models this: (a) introduces loading overhead; (b) consumes much memory (this prevents from loading params directly onto the GPU, which would make sense in a single-GPU use case) (fixed in https://github.com/elixir-nx/bumblebee/pull/344).

Ideally we would have an easy way to persist the loaded parameters into multiple files (in case of large parameters). With that, the user could do Bumblebee.load_model/2, persist the parameters into a file, then in production load the parameters directly without the conversion overhead (possibly straight onto the GPU).

This probably belongs in Axon directly, but may as well track here given the use case. I also wonder if we should be using Safetensors rather than term-to-binary for better portability. One issue with Safetensors is that it supports flat map, but Axon parameters can be any Nx.Container (e.g. LSTM uses tuples), so unless we make Axon parameters more strict we can't really do it.

This also depends on https://github.com/elixir-nx/axon/pull/553, which changes params into a struct, and we likely want to persist the whole struct.

josevalim commented 4 months ago

The flat parameters should not really be a problem, should it? You could convert a nested map of keys “foo” and “bar” into a special flattened key, such as “foo——bar”, no?

jonatanklosko commented 4 months ago

@josevalim the nested map is not a problem, it's other Nx.Containers (currently tuples), so it may make sense to restrict Axon parameters to tensors.

jonatanklosko commented 4 months ago

Sidenote: sharding is a nice-to-have, but with https://github.com/elixir-nx/safetensors/pull/8 we should be able to write all parameters into a single file efficiently.

jonatanklosko commented 4 months ago

With #344 the main motivation (excessive memory usage) is addressed, so this is less of a priority. It would still reduce some time overhead necessary for transforming the params. Either way, we should have a good way of persisting large parameters (again, rather in Axon).