HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Chunked Cross-Entropy #25

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

This PR adds chunking to the cross-entropy loss in the output. The convergence is the same as without this PR. However, it allows the training of models with larger vocabularies such as 50256 (GPT-2) or 1048576 (our upcoming tokeniser) tokens without using any more memory.\ The key difficulty in getting this right was that Jax parallelises for loops like

def loop():
  a = 0
  for i in range(10):
    a += x[i]  # do something memory-expensive with x[i] here
  return a

To stop Jax from instantiating all the intermediate values (which we try to avoid using chunking), we must use jax.lax.scan.