HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Reuse ("donate") Buffers #38

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

No complications, code works out of the box.

ClashLuke commented 2 years ago

This PR reduces memory usage nicely, allowing us to double the parameter count or increase batch size. Unfortunately, the model buffers take comparatively little memory with our current architecture compared to the intermediate states, so this improvement isn't game-changing. We can trade off the newly gained memory for compute using MoE or improve the optimiser with more buffers, like in #35.