Closed jonatanklosko closed 4 months ago
Or we don't use preallocate_params there?
We don't :) Also, when Livebook adds EXLA it sets EXLA.Backend
(not :host
), so params will be loaded onto the GPU if available. I think :host
can still be a good default for production usage, so that we are explicit about what runs on the GPU, and don't accidentally block other side computations, but defaulting to the GPU in most notebooks is fine, even more so with this change :D
Currently we first load PyTorch params map (either from
.bin
or.safetensors
) with all tensors materialized, which already takes as much memory as all the params. Then when building the Axon params map, we look for tensors in the PyTorch map and oftentimes need to apply transformations (mostlyNx.transpose/1
), so for each such tensor we use further memory. Consequently the memory peak can be almost double the size of the parameters.Because of this behaviour, loading large models directly onto the GPU could result in OOM. In such cases we recommended loading the params onto the CPU first and only then transferring, but this (a) assumes we have enough RAM (which for really large models is not necessarily the case!); (b) puts high pressure on RAM; (c) is slower since those
Nx.transpose/1
calls are on the CPU.With this PR, instead of loading
%{"param" => %Nx.Tensor{}}
from.bin
, we load%{"param" => %FileTensor{}}
, whereFileTensor
is a lazy container. I also added an option tosafetensors
to do the same (https://github.com/elixir-nx/safetensors/pull/9). So now when building an Axon param tensor, we lookup the relevant PyTorch lazy containers, callNx.to_tensor/1
to materialize them, do the necessary transformations. Then we proceed to the next param and the past intermediate tensors can already be garbage collected. This way there is barely any memory footprint other than the params themselves.I tested loading Llama2 onto the GPU with different values of EXLA
:memory_fraction
to force an upper memory limit. The parameters are 13.5GB, prior to this change loading required 24.6GB, now 13.6GB was enough.