Currently, our model does one forward pass and uses the intermediate states to do one backward pass. However, a backward pass is over 3x as expensive as a forward pass, so we could change the ratio of forward to backward passes to speed up the model.\
One such approach would be MESA, which adds KL(model(x), ema_model(x)). Another method is RHO-Loss, which prioritizes some samples over others, by running (model(x) - oracle(x)).topk(). Both of these methods claim to improve sample efficiency by up to 18x.
Currently, our model does one forward pass and uses the intermediate states to do one backward pass. However, a backward pass is over 3x as expensive as a forward pass, so we could change the ratio of forward to backward passes to speed up the model.\ One such approach would be MESA, which adds
KL(model(x), ema_model(x))
. Another method is RHO-Loss, which prioritizes some samples over others, by running(model(x) - oracle(x)).topk()
. Both of these methods claim to improve sample efficiency by up to 18x.