This PR adds chunking to the cross-entropy loss in the output. The convergence is the same as without this PR. However, it allows the training of models with larger vocabularies such as 50256 (GPT-2) or 1048576 (our upcoming tokeniser) tokens without using any more memory.\
The key difficulty in getting this right was that Jax parallelises for loops like
def loop():
a = 0
for i in range(10):
a += x[i] # do something memory-expensive with x[i] here
return a
To stop Jax from instantiating all the intermediate values (which we try to avoid using chunking), we must use jax.lax.scan.
This PR adds chunking to the cross-entropy loss in the output. The convergence is the same as without this PR. However, it allows the training of models with larger vocabularies such as 50256 (GPT-2) or 1048576 (our upcoming tokeniser) tokens without using any more memory.\ The key difficulty in getting this right was that Jax parallelises for loops like
To stop Jax from instantiating all the intermediate values (which we try to avoid using chunking), we must use
jax.lax.scan
.