HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Reuse Parameter-Buffers #37

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Currently, our model allocates one set of buffers for the input parameters and another set of buffers for the output parameters. So, for a 16GB GPU, we could fill up to half its memory with buffers as input and output buffers are separate. This separation means that we have 8GB of effective memory, which means we can allocate up to 8GB/(4 Bytes/Buffers)=2 billion Buffers. However, Jax supports buffer donation, which would allow its compiler to deallocate the inputs.

ClashLuke commented 2 years ago

Closed by #38