Currently, our model allocates one set of buffers for the input parameters and another set of buffers for the output parameters. So, for a 16GB GPU, we could fill up to half its memory with buffers as input and output buffers are separate. This separation means that we have 8GB of effective memory, which means we can allocate up to 8GB/(4 Bytes/Buffers)=2 billion Buffers. However, Jax supports buffer donation, which would allow its compiler to deallocate the inputs.
Currently, our model allocates one set of buffers for the input parameters and another set of buffers for the output parameters. So, for a 16GB GPU, we could fill up to half its memory with buffers as input and output buffers are separate. This separation means that we have 8GB of effective memory, which means we can allocate up to
8GB/(4 Bytes/Buffers)=2 billion Buffers
. However, Jax supports buffer donation, which would allow its compiler to deallocate the inputs.