HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

Multiple forward per backward #81

Open ClashLuke opened 1 year ago

ClashLuke commented 1 year ago

Currently, our model does one forward pass and uses the intermediate states to do one backward pass. However, a backward pass is over 3x as expensive as a forward pass, so we could change the ratio of forward to backward passes to speed up the model.\ One such approach would be MESA, which adds KL(model(x), ema_model(x)). Another method is RHO-Loss, which prioritizes some samples over others, by running (model(x) - oracle(x)).topk(). Both of these methods claim to improve sample efficiency by up to 18x.