elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.26k stars 90 forks source link

Load parameter tensors lazily #344

Closed jonatanklosko closed 4 months ago

jonatanklosko commented 4 months ago

Currently we first load PyTorch params map (either from .bin or .safetensors) with all tensors materialized, which already takes as much memory as all the params. Then when building the Axon params map, we look for tensors in the PyTorch map and oftentimes need to apply transformations (mostly Nx.transpose/1), so for each such tensor we use further memory. Consequently the memory peak can be almost double the size of the parameters.

Because of this behaviour, loading large models directly onto the GPU could result in OOM. In such cases we recommended loading the params onto the CPU first and only then transferring, but this (a) assumes we have enough RAM (which for really large models is not necessarily the case!); (b) puts high pressure on RAM; (c) is slower since those Nx.transpose/1 calls are on the CPU.

With this PR, instead of loading %{"param" => %Nx.Tensor{}} from .bin, we load %{"param" => %FileTensor{}}, where FileTensor is a lazy container. I also added an option to safetensors to do the same (https://github.com/elixir-nx/safetensors/pull/9). So now when building an Axon param tensor, we lookup the relevant PyTorch lazy containers, call Nx.to_tensor/1 to materialize them, do the necessary transformations. Then we proceed to the next param and the past intermediate tensors can already be garbage collected. This way there is barely any memory footprint other than the params themselves.

I tested loading Llama2 onto the GPU with different values of EXLA :memory_fraction to force an upper memory limit. The parameters are 13.5GB, prior to this change loading required 24.6GB, now 13.6GB was enough.

jonatanklosko commented 4 months ago

Or we don't use preallocate_params there?

We don't :) Also, when Livebook adds EXLA it sets EXLA.Backend (not :host), so params will be loaded onto the GPU if available. I think :host can still be a good default for production usage, so that we are explicit about what runs on the GPU, and don't accidentally block other side computations, but defaulting to the GPU in most notebooks is fine, even more so with this change :D