HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Long-Context Model #33

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

At the moment, our models can fit up to 2 million tokens. However, it seems like Jax has some internal overheads that stop us from using them in one sequence with a batch of one sample, as that'd require 200 GiB RAM instead of the 14 GiB we need for batch=512+sequence=4096.\ This issue is about tracking these overheads and finding a sensible solution.

ClashLuke commented 2 years ago

There is no padding as long as we have a batch size of 8. That means our highest possible context is 256ki until we fix this.

ClashLuke commented 2 years ago

Increasing the context from 4096 to 128ki chars (batch down from 16 to 8) reduces the loss by ~20%: grafik

Seems like a success.