Closed ClashLuke closed 2 years ago
There is no padding as long as we have a batch size of 8. That means our highest possible context is 256ki until we fix this.
Increasing the context from 4096 to 128ki chars (batch down from 16 to 8) reduces the loss by ~20%:
Seems like a success.
At the moment, our models can fit up to 2 million tokens. However, it seems like Jax has some internal overheads that stop us from using them in one sequence with a batch of one sample, as that'd require 200 GiB RAM instead of the 14 GiB we need for batch=512+sequence=4096.\ This issue is about tracking these overheads and finding a sensible solution.